# 1.SM_Multigpu Distributed Training-ScriptMode-DALLE
---

본 모듈에서는 Amzaon SageMaker API을 효과적으로 이용하기 위해 multigpu-distributed 학습을 위한 PyTorch 프레임워크 자체 구현만으로 모델 훈련을 수행해 봅니다.

In [87]:
install_needed = True  # should only be True once
# install_needed = False

In [88]:
import sys
import IPython

if install_needed:
    print("installing deps and restarting kernel")
#     !{sys.executable} -m pip install -U split-folders tqdm albumentations crc32c wget
    !{sys.executable} -m pip install 'sagemaker[local]' --upgrade
    !{sys.executable} -m pip install -U bokeh smdebug sagemaker-experiments gdown
    !{sys.executable} -m pip install -U sagemaker torch torchvision
    !/bin/bash ./local/local_mode_setup.sh
    IPython.Application.instance().kernel.do_shutdown(True)

installing deps and restarting kernel
Collecting gdown
  Downloading gdown-3.13.0.tar.gz (9.3 kB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h    Preparing wheel metadata ... [?25ldone
Building wheels for collected packages: gdown
  Building wheel for gdown (PEP 517) ... [?25ldone
[?25h  Created wheel for gdown: filename=gdown-3.13.0-py3-none-any.whl size=9034 sha256=1c98df5e7792e54f808448fc3b612f45424ed7af261f51e8cb44866c45ba2107
  Stored in directory: /home/ec2-user/.cache/pip/wheels/6a/87/bd/09b16161b149fd6711ac76b5420d78ed58bd6a320e892117c3
Successfully built gdown
Installing collected packages: gdown
Successfully installed gdown-3.13.0
nvidia-docker2 already installed. We are good to go!
Stopping docker: [60G[[0;32m  OK  [0;39m]
Starting docker:	.[60G[[0;32m  OK  [0;39m]
SageMaker instance route table setup is ok. We are good to go.
SageMaker instance routing for Docker is ok. We are good to go!


## 2. 환경 설정

<p>Sagemaker 학습에 필요한 기본적인 package를 import 합니다. </p>
<p>boto3는 HTTP API 호출을 숨기는 편한 추상화 모델을 가지고 있고, Amazon EC2 인스턴스 및 S3 버켓과 같은 AWS 리소스와 동작하는 파이선 클래스를 제공합니다. </p>
<p>sagemaker python sdk는 Amazon SageMaker에서 기계 학습 모델을 교육 및 배포하기 위한 오픈 소스 라이브러리입니다.</p>

In [1]:
import joblib
import matplotlib.pyplot as plt
import sagemaker
# import splitfolders

import datetime
import glob
import os
import time
import warnings

from smexperiments.experiment import Experiment
from smexperiments.trial import Trial

# import wget
# import tarfile
import shutil

import boto3
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchvision

# from tqdm import tqdm
from time import strftime
from PIL import Image
from torch.utils.data import Dataset
from torchvision import datasets, transforms

from sagemaker import get_execution_role
from sagemaker.pytorch import PyTorch

from sagemaker.debugger import (Rule,
                                rule_configs,
                                ProfilerConfig, 
                                FrameworkProfile, 
                                DetailedProfilingConfig, 
                                DataloaderProfilingConfig, 
                                PythonProfilingConfig)

warnings.filterwarnings('ignore')
%config InlineBackend.figure_format = 'retina'

In [2]:
role = get_execution_role()

In [3]:
sagemaker.__version__

'2.42.1'

In [4]:
def create_experiment(experiment_name):
    try:
        sm_experiment = Experiment.load(experiment_name)
    except:
        sm_experiment = Experiment.create(experiment_name=experiment_name,
                                          tags=[
                                              {
                                                  'Key': 'multigpu',
                                                  'Value': 'yes'
                                              },
                                              {
                                                  'Key': 'multinode',
                                                  'Value': 'yes'
                                              },
                                          ])

In [5]:
def create_trial(experiment_name, set_param, i_type, i_cnt, spot):
    create_date = strftime("%m%d-%H%M%s")
    
    spot = 's' if spot else 'd'
    i_tag = 'test'
    if i_type == 'ml.p3.16xlarge':
        i_tag = 'p3'
    elif i_type == 'ml.p3dn.24xlarge':
        i_tag = 'p3dn'
    elif i_type == 'ml.p4d.24xlarge':
        i_tag = 'p4d'    
        
    trial = "-".join([i_tag,str(i_cnt),spot])
       
    sm_trial = Trial.create(trial_name=f'{experiment_name}-{trial}-{create_date}',
                            experiment_name=experiment_name)

    job_name = f'{sm_trial.trial_name}'
    return job_name

In [6]:
bucket = 'bucket-exp-dalle-210410'
code_location = f's3://{bucket}/sm_codes'
output_path = f's3://{bucket}/vqgan_poc/output/' 

In [7]:
metric_definitions=[
     {'Name': 'train:lr', 'Regex': 'lr - (.*?),'},
     {'Name': 'train:Loss', 'Regex': 'loss -(.*?),'},
]

In [8]:
from sagemaker.debugger import Rule, ProfilerRule, rule_configs

rules=[ 
    Rule.sagemaker(rule_configs.loss_not_decreasing()),
    Rule.sagemaker(rule_configs.overfit()),
    ProfilerRule.sagemaker(rule_configs.ProfilerReport()),
]

In [9]:
hyperparameters = {
        't' : True,
        'base' : '/opt/ml/code/configs/faceshq_vqgan_test.yaml',
        'output_s3' : output_path,
#         'gpus' : 8
    }

experiment_name = 'vqgan-poc-exp1'
instance_type = 'ml.p4d.24xlarge'  # 'ml.p3.16xlarge', 'ml.p3dn.24xlarge', 'ml.p4d.24xlarge', 'local_gpu'
# instance_type = 'local_gpu'

instance_count = 2
do_spot_training = False
max_wait = None
max_run = 1*60*60

In [10]:
# !gdown https://drive.google.com/uc?id=1vF8Ht0VThpobtmShD52_INhpIgy6eEXq
# !gdown https://drive.google.com/uc?id=1kaIqFwTLD7Ml3ib9NQpjoUSD4FUD21-I

In [11]:
# !rm -rf dataset
# !mkdir dataset
# !unzip birds.zip -d dataset/
# !tar zxvf CUB_200_2011.tgz -C dataset/

In [12]:
if instance_type =='local_gpu':
    from sagemaker.local import LocalSession
    from pathlib import Path

    sagemaker_session = LocalSession()
    sagemaker_session.config = {'local': {'local_code': True}}
    s3_data_path = 'file:///home/ec2-user/SageMaker/dataset'
    source_dir = f'{Path.cwd()}/taming-transformers'
else:
    sess = boto3.Session()
    sagemaker_session = sagemaker.Session()
    sm = sess.client('sagemaker')
#     bucket_name = 'dataset-cyj-coco-210410'
#     s3_data_path = f's3://{bucket_name}/dataset1'
#     s3_data_path = 's3://dataset-cyj-us-east-1/CUB-BIRD'
    s3_data_path = 's3://dataset-cyj-us-east-1/conceptual_captions/validation'
    source_dir = 'taming-transformers'


In [78]:
image_uri = None
distribution = None
train_job_name = 'sagemaker'


train_job_name = 'smp-dist'
distribution = {}

# if hyperparameters.get('sagemakermp'):
#     distribution['smdistributed'] = { "modelparallel": {
#                                               "enabled":True,
#                                               "parameters": {
#                                                   "partitions": hyperparameters['num_partitions'],
#                                                   "microbatches": hyperparameters['num_microbatches'],
#                                                   "placement_strategy": hyperparameters['placement_strategy'],
#                                                   "pipeline": hyperparameters['pipeline'],
#                                                   "optimize": hyperparameters['optimize'],
#                                                   "ddp": hyperparameters['ddp'],
#                                               }
#                                           }
#                                       }


distribution["smdistributed"]={ 
                    "dataparallel": {
                        "enabled": False
                    }
            }
distribution["mpi"]={
                    "enabled": True,
                    "processes_per_host": 0, # Pick your processes_per_host
#                     "custom_mpi_options": "-verbose -x orte_base_help_aggregate=0 "
              }

if do_spot_training:
    max_wait = max_run

print("train_job_name : {} \ntrain_instance_type : {} \ntrain_instance_count : {} \nimage_uri : {} \ndistribution : {}".format(train_job_name, instance_type, instance_count, image_uri, distribution))    

train_job_name : smp-dist 
train_instance_type : ml.p4d.24xlarge 
train_instance_count : 2 
image_uri : None 
distribution : {'smdistributed': {'dataparallel': {'enabled': False}}, 'mpi': {'enabled': True, 'processes_per_host': 0}}


In [79]:
# image_uri = '763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.8.1-gpu-py36-cu111-ubuntu18.04'

In [80]:
%%time

# all input configurations, parameters, and metrics specified in estimator 
# definition are automatically tracked
estimator = PyTorch(
    entry_point='main.py',
    source_dir=source_dir,
    role=role,
    sagemaker_session=sagemaker_session,
    framework_version='1.8.0',
    py_version='py36',
#     image_uri=image_uri,
    instance_count=instance_count,
    instance_type=instance_type,
    volume_size=1024,
    code_location = code_location,
    output_path=output_path,
    hyperparameters=hyperparameters,
    distribution=distribution,
#     disable_profiler=True,
    metric_definitions=metric_definitions,
#     rules=rules,
    max_run=max_run,
    use_spot_instances=do_spot_training,  # spot instance 활용
    max_wait=max_wait,
    subnets=['subnet-0b731e2124d43368d'],  # subnet-0c775b056a6e540ee  , 	subnet-05b7d4713e03d2bfe, subnet-0b731e2124d43368d
    security_group_ids=['sg-04e9a37dbd74e3ade'],
)

CPU times: user 200 µs, sys: 9 µs, total: 209 µs
Wall time: 217 µs


In [72]:
!sudo rm -rf ./taming-transformers/wandb
!sudo rm -rf ./taming-transformers/logs/*

In [73]:
# Configure FSx Input for your SageMaker Training job

from sagemaker.inputs import FileSystemInput

file_system_directory_path= '/ksmjfbmv'

file_system_id='fs-0ffed11a31906f7ee'

file_system_access_mode='rw'
file_system_type='FSxLustre'
train_fs = FileSystemInput(file_system_id=file_system_id,
                                    file_system_type=file_system_type,
                                    directory_path=file_system_directory_path,
                                    file_system_access_mode=file_system_access_mode)

In [74]:
# input_data = sagemaker.inputs.TrainingInput(
#         s3_data=s3_data_path,
#         distribution='ShardedByS3Key',
#         s3_data_type='S3Prefix',
#         input_mode='File',
#         shuffle_config=sagemaker.inputs.ShuffleConfig(123)
#         )

In [75]:
create_experiment(experiment_name)
job_name = create_trial(experiment_name, hyperparameters, instance_type, instance_count, do_spot_training)

# Now associate the estimator with the Experiment and Trial
estimator.fit(
    inputs={'training': train_fs}, 
#     inputs={'training': s3_data_path}, 
    job_name=job_name,
    experiment_config={
      'TrialName': job_name,
      'TrialComponentDisplayName': job_name,
    },
    wait=False,
)

INFO:sagemaker.image_uris:Defaulting to the only supported framework/algorithm version: latest.
INFO:sagemaker.image_uris:Ignoring unnecessary instance type: None.
INFO:sagemaker:Creating training-job with name: vqgan-poc-exp1-p4d-2-d-0602-03521622605930


In [170]:
# job_name=estimator.latest_training_job.name
job_name='vqgan-poc-exp1-p4d-2-d-0604-14311622817099'
# dalle-poc-exp4-p4d-2-d-0525-03071621912021 --> public
# dalle-poc-exp4-p4d-2-d-0525-03091621912148 --> another private

In [171]:
sagemaker_session = sagemaker.Session()

In [172]:
sagemaker_session.logs_for_job(job_name=job_name, wait=True)

2021-06-04 14:45:53 Starting - Preparing the instances for training
2021-06-04 14:45:53 Downloading - Downloading input data
2021-06-04 14:45:53 Training - Training image download completed. Training in progress.
2021-06-04 14:45:53 Uploading - Uploading generated training model
2021-06-04 14:45:53 Completed - Training job completed[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2021-06-04 14:40:41,134 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2021-06-04 14:40:41,215 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[35mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[35mbash: no job control in this shell[0m
[35m2021-06-04 14:40:42,931 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[35m2021-06-04 14:40:43,013 sagema

In [70]:
import glob

In [71]:
model_dir='test'

In [76]:
print(f"************** file : {glob.glob(model_dir+'/*')}")

************** file : []


In [84]:
!pip install g_mlp_pytorch

Collecting g_mlp_pytorch
  Downloading g_mlp_pytorch-0.0.16-py3-none-any.whl (5.2 kB)
Collecting einops>=0.3
  Downloading einops-0.3.0-py2.py3-none-any.whl (25 kB)
Installing collected packages: einops, g-mlp-pytorch
Successfully installed einops-0.3.0 g-mlp-pytorch-0.0.16


In [161]:
from omegaconf import OmegaConf

In [None]:
https://download.pytorch.org/models/vgg16-397923af.pth

In [176]:
URL_MAP = {
    "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
}

CKPT_MAP = {
    "vgg_lpips": "vgg.pth"
}

MD5_MAP = {
    "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
}

In [177]:
import os, hashlib
import requests
from tqdm import tqdm

In [178]:
def download(url, local_path, chunk_size=1024):
    os.makedirs(os.path.split(local_path)[0], exist_ok=True)
    with requests.get(url, stream=True) as r:
        total_size = int(r.headers.get("content-length", 0))
        with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
            with open(local_path, "wb") as f:
                for data in r.iter_content(chunk_size=chunk_size):
                    if data:
                        f.write(data)
                        pbar.update(chunk_size)


def md5_hash(path):
    with open(path, "rb") as f:
        content = f.read()
    return hashlib.md5(content).hexdigest()

In [179]:
path = 'taming/modules/autoencoder/lpips/vgg.pth'

In [180]:
download('https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1', path)

8.19kB [00:00, 385kB/s]                    


In [181]:
md5 = md5_hash(path)

In [182]:
assert md5 == MD5_MAP["vgg_lpips"], md5

AssertionError: d507d7349b931f0638a25a48a722f98a

In [160]:
# def get_ckpt_path(name, root, check=False):
#     assert name in URL_MAP
    path = os.path.join(root, CKPT_MAP[name])
    if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
        print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
        download(URL_MAP[name], path)
        md5 = md5_hash(path)
        assert md5 == MD5_MAP[name], md5


In [None]:
def resource_check():
    import subprocess
    result = subprocess.run(['df', '-h'], stdout=subprocess.PIPE)
    print(result.stdout.decode('utf-8'))

In [None]:
# !pip install wandb

In [356]:
import wandb
run = wandb.init(
    project="test",  # 'dalle_train_transformer' by default
#     resume=RESUME,
#     config=model_config,
#             dir=wandb_dir
)

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


[34m[1mwandb[0m: Paste an API key from your profile and hit enter:  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/ec2-user/.netrc


In [47]:
def get_dir_size(path='.'):
    import os
    total = 0
    with os.scandir(path) as it:
        for entry in it:
            if entry.is_file():
                total += entry.stat().st_size
            elif entry.is_dir():
                total += get_dir_size(entry.path)
    return total

In [48]:
get_dir_size("./source_code")

11244848

In [82]:
import glob

In [None]:
/opt/ml/input/data/training/CUB_200_2011/images/162.Canada_Warbler/Canada_Warbler_0005_162389.jpg

In [101]:
test = glob.glob("../test-fsx/CUB_200_2011/images/*/*")

In [102]:
f = open("/home/ec2-user/SageMaker/images.txt", 'w+')
for line in test:
    print(str(line.split('images/')[1])+'\n')
    f.write(str(line.split('images/')[1])+'\n')
f.close()

114.Black_throated_Sparrow/Black_Throated_Sparrow_0054_107026.jpg

114.Black_throated_Sparrow/Black_Throated_Sparrow_0066_106974.jpg

114.Black_throated_Sparrow/Black_Throated_Sparrow_0070_107196.jpg

114.Black_throated_Sparrow/Black_Throated_Sparrow_0074_107113.jpg

114.Black_throated_Sparrow/Black_Throated_Sparrow_0039_107259.jpg

114.Black_throated_Sparrow/Black_Throated_Sparrow_0088_107220.jpg

114.Black_throated_Sparrow/Black_Throated_Sparrow_0099_106944.jpg

114.Black_throated_Sparrow/Black_Throated_Sparrow_0010_107375.jpg

114.Black_throated_Sparrow/Black_Throated_Sparrow_0086_106970.jpg

114.Black_throated_Sparrow/Black_Throated_Sparrow_0003_107035.jpg

114.Black_throated_Sparrow/Black_Throated_Sparrow_0008_107000.jpg

114.Black_throated_Sparrow/Black_Throated_Sparrow_0027_107278.jpg

114.Black_throated_Sparrow/Black_Throated_Sparrow_0084_107066.jpg

114.Black_throated_Sparrow/Black_Throated_Sparrow_0020_106971.jpg

114.Black_throated_Sparrow/Black_Throated_Sparrow_0091_107346.

In [137]:
!rm -rf logs

In [138]:
!aws s3 cp s3://bucket-exp-dalle-210410/vqgan_poc/output/vqgan-poc-exp1-p4d-1-d-0604-03111622776272/output/model.tar.gz ./

download: s3://bucket-exp-dalle-210410/vqgan_poc/output/vqgan-poc-exp1-p4d-1-d-0604-03111622776272/output/model.tar.gz to ./model.tar.gz


In [140]:
!tar -xvzf model.tar.gz

logs/
logs/2021-06-04T03-22-03_faceshq_vqgan_test/
logs/2021-06-04T03-22-03_faceshq_vqgan_test/testtube/
logs/2021-06-04T03-22-03_faceshq_vqgan_test/testtube/version_0/
logs/2021-06-04T03-22-03_faceshq_vqgan_test/testtube/version_0/tf/
logs/2021-06-04T03-22-03_faceshq_vqgan_test/testtube/version_0/tf/events.out.tfevents.1622776960.algo-1.297.0
logs/2021-06-04T03-22-03_faceshq_vqgan_test/testtube/version_0/meta_tags.csv
logs/2021-06-04T03-22-03_faceshq_vqgan_test/testtube/version_0/media/
logs/2021-06-04T03-22-03_faceshq_vqgan_test/testtube/version_0/metrics.csv
logs/2021-06-04T03-22-03_faceshq_vqgan_test/testtube/version_0/meta.experiment
logs/2021-06-04T03-22-03_faceshq_vqgan_test/images/
logs/2021-06-04T03-22-03_faceshq_vqgan_test/images/val/
logs/2021-06-04T03-22-03_faceshq_vqgan_test/images/val/inputs_gs-000000_e-000000_b-000000.png
logs/2021-06-04T03-22-03_faceshq_vqgan_test/images/val/reconstructions_gs-000349_e-000000_b-000000.png
logs/2021-06-04T03-22-03_faceshq_vqgan_test/imag

In [81]:
os.environ

environ{'LESS_TERMCAP_mb': '\x1b[01;31m',
        'JAVA_LD_LIBRARY_PATH': '/home/ec2-user/anaconda3/envs/JupyterSystemEnv/lib/server',
        'HOSTNAME': 'ip-172-16-7-91',
        'LESS_TERMCAP_md': '\x1b[01;38;5;208m',
        'LESS_TERMCAP_me': '\x1b[0m',
        'SHELL': '/bin/sh',
        'TERM': 'xterm-color',
        'HISTSIZE': '1000',
        'EC2_AMITOOL_HOME': '/opt/aws/amitools/ec2',
        'CONDA_SHLVL': '3',
        'CONDA_PROMPT_MODIFIER': '(pytorch_p36) ',
        'GSETTINGS_SCHEMA_DIR_CONDA_BACKUP': '',
        'PYTHON_INSTALL_LAYOUT': 'amzn',
        'LESS_TERMCAP_ue': '\x1b[0m',
        'USER': 'ec2-user',
        'LD_LIBRARY_PATH': '/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/lib/:/usr/local/cuda-10.1/lib64:/usr/local/cuda-10.1/extras/CUPTI/lib64:/usr/local/cuda-10.1/lib:/usr/local/cuda-10.1/efa/lib:/opt/amazon/efa/lib:/opt/amazon/efa/lib64:/usr/local/cuda/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/opt/amazon/efa/l

In [114]:
ckptdir = 'test'
result = os.path.isfile(os.path.join(ckptdir, "last.ckpt"))

In [116]:
print(f"******* os.path.isfile : {result}")

******* os.path.isfile : False


In [110]:
os.path.isfile(os.path.join(ckptdir, "last.ckpt"))

False

In [132]:
ckptdir='./logs/2021-06-03T00-00-50_faceshq_vqgan_test/test/'

In [133]:
os.makedirs(ckptdir, exist_ok=True)

In [136]:
os.environ.get('LOCAL_RANK') is None

True

In [167]:
!mkdir sir_suc

In [168]:
!aws s3 cp s3://bucket-exp-dalle-210410/sm_codes/vqgan-poc-exp1-p4d-2-d-0604-14311622817099/source/sourcedir.tar.gz ./sir_suc/

download: s3://bucket-exp-dalle-210410/sm_codes/vqgan-poc-exp1-p4d-2-d-0604-14311622817099/source/sourcedir.tar.gz to sir_suc/sourcedir.tar.gz


In [169]:
!tar -xvzf sir_suc/sourcedir.tar.gz

assets/
assets/drin.jpg
assets/mountain.jpeg
assets/stormy.jpeg
assets/first_stage_squirrels.png
assets/lake_in_the_mountains.png
assets/faceshq.jpg
assets/sunset_and_ocean.jpg
assets/teaser.png
configs/
configs/coco_cond_stage.yaml
configs/imagenet_vqgan.yaml
configs/faceshq_transformer.yaml
configs/.ipynb_checkpoints/
configs/.ipynb_checkpoints/faceshq_transformer-checkpoint.yaml
configs/.ipynb_checkpoints/faceshq_vqgan-checkpoint.yaml
configs/.ipynb_checkpoints/imagenet_vqgan-checkpoint.yaml
configs/.ipynb_checkpoints/faceshq_vqgan_test-checkpoint.yaml
configs/faceshq_vqgan_test.yaml
configs/sflckr_cond_stage.yaml
configs/imagenetdepth_vqgan.yaml
configs/drin_transformer.yaml
configs/faceshq_vqgan.yaml
__pycache__/
__pycache__/main_ori.cpython-36.pyc
__pycache__/main.cpython-36.pyc
logs/
main-Copy1.py
.netrc
.git/
.git/logs/
.git/logs/HEAD
.git/logs/refs/
.git/logs/refs/heads/
.git/logs/refs/heads/master
.git/logs/refs/remotes/
.git/logs/refs/remotes/origin/
.git/logs/refs/remotes/o

In [156]:
os.path.join("aaa","bbb","ccc")

'aaa/bbb/ccc'

In [None]:
os.path.join(os.environ['OUTPUT_S3'], os.environ['JOB_NAME'] + "/temp")) 

In [162]:
resume = '../test-fsx/vqgan_poc/logs/2021-06-04T15-10-35_faceshq_vqgan_test/checkpoints'
logdir = './logs'

In [163]:
if os.path.isfile(resume):
    paths = resume.split("/")
    print(f"paths : {paths}")
    idx = len(paths)-paths[::-1].index("logs")+1
    print(f"idx : {idx}")
    logdir = "/".join(paths[:idx])
    print(f"logdir : {logdir}")
    ckpt = resume
    print(f"ckpt : {ckpt}")
else:
    assert os.path.isdir(resume), resume
    logdir = resume.rstrip("/")
    print(f"logdir : {logdir}")
    ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
    print(f"ckpt : {ckpt}")

logdir : ../test-fsx/vqgan_poc/logs/2021-06-04T15-10-35_faceshq_vqgan_test/checkpoints
ckpt : ../test-fsx/vqgan_poc/logs/2021-06-04T15-10-35_faceshq_vqgan_test/checkpoints/checkpoints/last.ckpt


In [None]:
opt.resume_from_checkpoint = ckpt
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
opt.base = base_configs+opt.base
_tmp = logdir.split("/")
nowname = _tmp[_tmp.index("logs")+1]