In [1]:
import os
import shutil
import subprocess
import json
from functools import reduce

In [2]:
job_name = "wresnet50_2"

# input file path
current_path = os.getcwd()
job_path = os.path.join(current_path, job_name)
input_path = os.path.join(job_path, "input")
models_path = os.path.join(input_path, "models")
structure_json_path = os.path.join(input_path, 'structure.json')

# template file path
template_path = os.path.join(current_path, "templates")
sam_app_path = os.path.join(template_path, "sam-app")

# output file path
output_path = os.path.join(job_path, "output")

### Preprocess

In [3]:
def check_and_new_folder(path):
    if os.path.exists(path):
        shutil.rmtree(path)
    os.mkdir(path)

In [4]:
def copy_files(src, dest, name_lists):
    for name in name_lists:
        shutil.copyfile(os.path.join(src, name), os.path.join(dest, name))

def copy_tree(src, dst, symlinks=False, ignore=None):
    for item in os.listdir(src):
        s = os.path.join(src, item)
        d = os.path.join(dst, item)
        if os.path.isdir(s):
            shutil.copytree(s, d, symlinks, ignore)
        else:
            shutil.copy2(s, d)

In [5]:
def get_stage_info(path):
    with open(structure_json_path, 'r') as json_file:
        stage_dict = json.load(json_file)
    return stage_dict

In [6]:
stage_dict = get_stage_info(input_path)
check_and_new_folder(output_path)
copy_tree(sam_app_path, output_path + "/")

### Prepare master and workers

In [7]:
output_functions_path = os.path.join(output_path, 'lambda_functions')
worker_template_path = os.path.join(template_path, 'worker.py')

for name in os.listdir(models_path):
    if 'json' in name:
        path = os.path.join(models_path, name)
        name_body = name[:-5]
        func_id = name_body.split('_')[-1]
        if int(func_id) == 0:
            shutil.copy2(path, os.path.join(output_functions_path, 'master'))
            shutil.copy2(structure_json_path, os.path.join(output_functions_path, 'master'))
        else:
            split_name = name_body.split('_')
            worker_path = os.path.join(output_functions_path,
                                      'from{}To{}Worker{}'.format(
                                      split_name[0], split_name[1], split_name[2]))
            os.mkdir(worker_path)
            shutil.copy(path, worker_path)
            shutil.copy(worker_template_path, worker_path)

for name in os.listdir(output_functions_path):
    path = os.path.join(output_functions_path, name)
    if os.path.isdir(path):
        copy_files(template_path, path, ['requirements.txt', '__init__.py', 'utils.py'])

### Generate template.yaml

In [8]:
def dump_line_list(line_list, file_obj):
    for line in line_list:
        file_obj.writelines(line)
    file_obj.writelines("\n")

In [9]:
prefix_path = os.path.join(template_path, 'template_yaml_prefix')
postfix_path = os.path.join(template_path, 'template_yaml_postfix')
output_template_path = os.path.join(output_path, 'template.yaml')

line_list_list = []
line_list_list.append(open(prefix_path, 'r').readlines())

body_list = []
for name in os.listdir(models_path):
    if 'json' in name:
        name_body = name[:-5]
        func_id = name_body.split('_')[-1]
        if not int(func_id) == 0:
            split_name = name_body.split('_')
            identity = 'from{}To{}Worker{}'.format(split_name[0], split_name[1], split_name[2])
            
            body_list.append('  {}:\n'.format(identity))
            body_list.append('    Type: AWS::Serverless::Function\n')
            body_list.append('    Properties:\n')
            body_list.append('      FunctionName: {}\n'.format(identity))
            body_list.append('      CodeUri: lambda_functions/{}/\n'.format(identity))
            body_list.append('      Handler: worker.lambda_handler\n')
            body_list.append('\n')

body_list.append('\n')
line_list_list.append(body_list)
line_list_list.append(open(postfix_path, 'r').readlines())


output_obj = open(output_template_path, 'w+')
for line_list in line_list_list:
    dump_line_list(line_list, output_obj)
output_obj.close()