Skip to content

Commit

Permalink
Merge branch 'support_sparse_table_resize' into 'master'
Browse files Browse the repository at this point in the history
add sparse resize

See merge request deep-learning/tensornet!13
  • Loading branch information
gzm55 committed Jul 15, 2024
2 parents 0f8555c + 7a8bac3 commit c036438
Show file tree
Hide file tree
Showing 8 changed files with 440 additions and 98 deletions.
17 changes: 0 additions & 17 deletions tools/merge_sparse/README.md

This file was deleted.

63 changes: 0 additions & 63 deletions tools/merge_sparse/utils.py

This file was deleted.

54 changes: 54 additions & 0 deletions tools/table_tools/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
## 合并sparse file

由于生成的sparse文件分布在各个目录, 可以通过spark将sparse文件结果抽取结果并合并到一个目录下

```bash
spark-submit3 --executor-memory 8g --driver-memory 10g --py-files utils.py merge_sparse.py -i "/user/test/model_path/sparse_table/*/*/*bin.gz" -o "/user/test/model_merge_path" -f 'bin' -n 500 -b
```

### 参数配置

配置名称 | 默认值 | 含义
----------- | ----------- | -----------
-i/--input | None | 输入路径
-o/--output | None | 输出路径
-f/--format | bin | 输入文件格式
-n/--number | 20 | 输出并行度
-b/--bracker | False | 输出的Weights是否需要用[]包括, []当作一列, 用\t分割


## sparse切换并行度

现阶段生成的sparse_table目录并行度无法切换,如果前后不一致会导致数据缺失问题,无法扩缩容。通过spark读入原始数据,按照指定的并行度输出文件parttern

由于使用了hdfs3来写入文件,需要打包上传环境,使用[env文件](config/tn_tool_env.yaml)

```bash
spark-submit3 --conf spark.executor.memory=10g --conf spark.archives=hdfs://nn/user/test/cache/python.tar.gz#envs --conf spark.pyspark.driver.python=/home/test/micromamba/envs/tn_tool_env/bin/python --conf spark.pyspark.python=./envs/bin/python --py-files utils.py resize_sparse.py --input /user/test/model/* --output /user/test/resize --number 50
```

### 参数配置

配置名称 | 默认值 | 含义
----------- | ----------- | -----------
-i/--input | None | 输入路径, 会抓取hdfs头用作hdfs文件写入,如没有hdfs头会默认用hdfs://ss-hadoop2
-o/--output | None | 输出路径,会在输出路径下生成 handle_name/rank_number/block_num.gz 文件
-f/--format | bin | 输入文件格式
-n/--number | 20 | 输出并行度


## dense切换并行度

和 sparse 类似

```bash
spark-submit3 --conf spark.executor.memory=10g --conf spark.archives=hdfs://nn/user/test/cache/python.tar.gz#envs --conf spark.pyspark.driver.python=/home/test/micromamba/envs/tn_tool_env/bin/python --conf spark.pyspark.python=./envs/bin/python --py-files utils.py resize_dense.py --input /user/test/model/* --output /user/test/resize --number 50
```

### 参数配置

配置名称 | 默认值 | 含义
----------- | ----------- | -----------
-i/--input | None | 输入路径, 会抓取hdfs头用作hdfs文件写入,如没有hdfs头会默认用hdfs://ss-hadoop2
-o/--output | None | 输出路径,会在输出路径下生成 handle_name/rank_number 文件
-n/--number | 20 | 输出并行度
8 changes: 8 additions & 0 deletions tools/table_tools/config/tn_tool_env.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: tn_build
channels:
- conda-forge
dependencies:
- python=3.8
- nomkl
- openssl>=3
- hdfs3
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from pyspark import SparkContext, SparkConf
from pyspark.sql import *
from pyspark.sql.functions import *
from pyspark.sql import functions as F
from pyspark.sql.types import *
from utils import *


Expand All @@ -30,22 +28,8 @@ def main(args):
.getOrCreate()

sc = spark.sparkContext

if args.format == 'txt':
get_handle_name_udf = udf(get_handle_name, StringType())
dims_df = sc.textFile(args.input)\
.map(lambda x: process_txt_line(x))\
.toDF(["key", "dims"])\
.withColumn("input_file_name",F.input_file_name())\
.withColumn("handle", get_handle_name_udf(col("input_file_name")))\
.drop("input_file_name")\
.filter(col("key") != "").dropDuplicates(['key','handle'])
elif args.format == 'bin':
dims_df = sc.binaryFiles(args.input)\
.mapPartitions(process_binary_partition)\
.toDF(['handle', 'key', 'dims'])

dims_df.dropDuplicates(['key','handle']).drop('handle').rdd.map(lambda x: output_line(x, args.bracket)).repartition(args.number).saveAsTextFile(args.output)
dims_df = load_sparse_table_to_df(sc, args.input, args.format)
dims_df.select("sign", "weights", "handle").dropDuplicates(['sign','handle']).drop('handle').rdd.map(lambda x: output_line(x, args.bracket)).repartition(args.number).saveAsTextFile(args.output)


if __name__ == '__main__':
Expand Down
42 changes: 42 additions & 0 deletions tools/table_tools/resize_dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#coding=utf-8
import sys
import argparse
import os
from pyspark import SparkContext, SparkConf
from pyspark.sql import *
from pyspark.sql.functions import *
from pyspark.sql import functions as F
from pyspark.sql.types import *
from utils import *
import math

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", type=str, help="sparse table input path")
parser.add_argument("-o", "--output", type=str, help="merged file output path")
parser.add_argument("-n", "--number", type=int, help="output file parallelism", default=30)
args = parser.parse_args()
return args


def main(args):
spark = SparkSession.builder \
.appName("[spark][resize dense table]") \
.master('yarn') \
.enableHiveSupport() \
.getOrCreate()

sc = spark.sparkContext
output_bc_value = sc.broadcast(args.output)
dense_file_rdd = sc.wholeTextFiles(args.input).map(lambda x: (x[0].split("/")[-1], x[0].split("/")[-2], x[1])).flatMap(mapIndexToDenseRecord)

whole_data = dense_file_rdd.collect()
res = process_whole_text(whole_data, args.number)

res_rdd = sc.parallelize(res, args.number)
res_rdd.foreachPartition(lambda p:write_dense_partition(p, output_bc_value))


if __name__ == '__main__':
args = parse_args()
main(args)
46 changes: 46 additions & 0 deletions tools/table_tools/resize_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#coding=utf-8
import sys
import argparse
import os
from pyspark import SparkContext, SparkConf
from pyspark.sql import *
from pyspark.sql.functions import *
from pyspark.sql import functions as F
from pyspark.sql.types import *
from utils import *


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", type=str, help="sparse table input path")
parser.add_argument("-o", "--output", type=str, help="merged file output path")
parser.add_argument("-f", "--format", type=str, help="input file format, 'txt' or 'bin'")
parser.add_argument("-n", "--number", type=int, help="output file parallelism", default=30)
args = parser.parse_args()
return args


def main(args):
spark = SparkSession.builder \
.appName("[spark][resize sparse table]") \
.master('yarn') \
.enableHiveSupport() \
.getOrCreate()

sc = spark.sparkContext
output_bc_value = sc.broadcast(args.output)
format_bc_value = sc.broadcast(args.format)
number_bc_value = sc.broadcast(args.number)

handle_names = fetch_hanlds(args.input)
handle_names_bc_value = sc.broadcast(handle_names)

dims_df = load_sparse_table_to_df(sc, args.input, args.format)

dims_df.rdd.map(lambda row: (get_sign_partition_key(row[0], args.number), row)).partitionBy(args.number * BLOCK_NUM)\
.foreachPartition(lambda p: resize_partition(p, output_bc_value, format_bc_value, number_bc_value, handle_names_bc_value))


if __name__ == '__main__':
args = parse_args()
main(args)
Loading

0 comments on commit c036438

Please sign in to comment.