In [1]:
import torch

from torch import nn
from torch.utils import data
from torch.nn import functional as F

from torchvision import transforms
from random import randint

import sys
import os
from os import path

import pandas as pd

from functools import partial, reduce
from tqdm import tqdm

import datetime

## Local Imports ##
if '../' not in sys.path:
    sys.path.insert(0, '../')
from models import helpers as model_helpers, model_definitions as custom_models
from datasets import helpers as dataset_helpers, datasets as custom_datasets

from train_single_script import create_arg_str

from VOC import DT_DEST_RGB_RANDOM, DT_DEST_RGB_SINGLE_CLASS 

In [2]:
# Channel, Width, Height
C, W, H = (3, 128, 128)

TRAIN_SINGLE_PATH = './train_single_script.py'

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
HOME = path.expanduser('~')
DT_ROOT = path.abspath(
    path.join('..', 'data', f'{W}x{H}')
)

POLYGON_COUNT_DIR = path.join(DT_ROOT, 'polygon_data_counts')
POLYGON_PERCENTAGE_DIR = path.join(DT_ROOT, 'polygon_data_percentage')

POLYGON_RGB_COUNT_DIR = path.join(DT_ROOT, 'polygon_RGB_counts')
POLYGON_RGB_NOISED_COUNT_DIR = path.join(DT_ROOT, 'polygon_rgb_noised_counts')

ELLIPSE_COUNT_DIR = path.join(DT_ROOT, 'ellipse_data_counts')
ELLIPSE_PERCENTAGE_DIR = path.join(DT_ROOT, 'ellipse_data_percentage')

VOC_SEGS_COUNTS_DIR = path.join(HOME, 'datasets', 'VOC_FORMS')

In [3]:
## Grid Search Params ##
RANDOM_SEARCH = False
SEARCH_LEN = 4

MODELS = custom_models.get_models((C, W, H))
MODELS = [
    "UNET",
    "UNET_HALF",
    "RESNET_18",
    "STRIDE_4",
    "STRIDE_8",
    "MAX_POOL_4",
    "MAX_POOL_8",
    "SUM_POOL_4",
    "SUM_POOL_8"
    
]

DATASETS = [
    #DT_DEST_RGB_RANDOM, 
    #DT_DEST_RGB_SINGLE_CLASS("AEROPLANE"),
    #VOC_SEGS_COUNTS_DIR,
    #POLYGON_COUNT_DIR
    POLYGON_RGB_NOISED_COUNT_DIR
]

OPTIMS = ["ADAM"]
LOSS_FNS = ["L1LOSS"]
LRS = [1e-2, 1e-3, 5e-4]

In [4]:
grid = model_helpers.new_grid_search(MODELS, OPTIMS, LOSS_FNS, LRS)
grid = list(grid)

print(f"Will train {len(LOSS_FNS) * len(MODELS) * len(DATASETS) * len(OPTIMS) * len(LRS)} models")
print(f"{len(MODELS)} MODELS")
print(f"{len(DATASETS)} DATASETS")
print(f"{len(LOSS_FNS)} LOSS_FNS")
print(f"{len(OPTIMS)} OPTIMS")
print(f"Device: {DEVICE}")

models_str = '\t' + "\n\t".join(MODELS)
lrs_str = '\t' + "\n\t".join(map(str, LRS))

dts_str = '\t' + "\n\t".join([dt.split('/')[-1] for dt in DATASETS])
print(f"MODELS:")
print(models_str)
print(f"LRS:")
print(lrs_str)
print(f"DATASETS:")
print(dts_str)

Will train 27 models
9 MODELS
1 DATASETS
1 LOSS_FNS
1 OPTIMS
Device: cuda
MODELS:
	UNET
	UNET_HALF
	RESNET_18
	STRIDE_4
	STRIDE_8
	MAX_POOL_4
	MAX_POOL_8
	SUM_POOL_4
	SUM_POOL_8
LRS:
	0.01
	0.001
	0.0005
DATASETS:
	polygon_rgb_noised_counts


In [5]:
curr_time = datetime.datetime.now()
CURR_TIME_STR = (
    f"{curr_time.year}-{curr_time.month}-{curr_time.day}_"
    f"{curr_time.hour}-{curr_time.minute}-{curr_time.second}"
)
OUT_FILE = path.join("logs", f"out_{CURR_TIME_STR}.log")
MAX_EPOCHS = 25
BS = 32

BASE_ARGS = {
    "C": C,
    "H": H,
    "W": W,
    "bs": BS,
    "epochs": MAX_EPOCHS,
    "device": DEVICE,
    "id": CURR_TIME_STR,
    "epochs": MAX_EPOCHS
}

print(f"Epochs: {MAX_EPOCHS}")
print(f"BS: {BS}")
print(f"Timestamp: {CURR_TIME_STR}")

Epochs: 25
BS: 32
Timestamp: 2020-1-5_11-44-44


In [6]:
def grid_search(dts, rows, sanity):
    if sanity: print("Performing sanity check...")
    else     : print("Training...")
    for dt in tqdm(dts):
        for row in tqdm(rows):
            command = (
                f'python3 {TRAIN_SINGLE_PATH} ' + 
                create_arg_str({
                    **BASE_ARGS,
                    "dataset": dt,
                    "model"  : row.model,
                    "optim"  : row.opt,
                    "loss_fn": row.loss,
                    "lr"     : row.lr,
                    "sanity" : sanity,
                }) + f' >> {OUT_FILE}')
            status = os.system(command)
            if status != 0: raise RuntimeError(f'FAILED: {command}')
    if sanity: print("Sanity Check: All Passed!")
    else     : print("Done Training!")

In [None]:
if __name__ == "__main__":
    grid_search(DATASETS, grid, sanity=True)
    grid_search(DATASETS, grid, sanity=False)

  0%|          | 0/1 [00:00<?, ?it/s]
  0%|          | 0/27 [00:00<?, ?it/s][A

Performing sanity check...



  4%|▎         | 1/27 [00:03<01:37,  3.76s/it][A
  7%|▋         | 2/27 [00:07<01:33,  3.75s/it][A
 11%|█         | 3/27 [00:11<01:30,  3.76s/it][A
 15%|█▍        | 4/27 [00:15<01:27,  3.79s/it][A
 19%|█▊        | 5/27 [00:19<01:23,  3.82s/it][A
 22%|██▏       | 6/27 [00:22<01:20,  3.83s/it][A
 26%|██▌       | 7/27 [00:26<01:15,  3.79s/it][A
 30%|██▉       | 8/27 [00:30<01:11,  3.77s/it][A
 33%|███▎      | 9/27 [00:34<01:07,  3.75s/it][A
 37%|███▋      | 10/27 [00:37<01:02,  3.70s/it][A
 41%|████      | 11/27 [00:41<00:58,  3.66s/it][A
 44%|████▍     | 12/27 [00:44<00:54,  3.64s/it][A
 48%|████▊     | 13/27 [00:48<00:50,  3.62s/it][A
 52%|█████▏    | 14/27 [00:51<00:46,  3.60s/it][A
 56%|█████▌    | 15/27 [00:55<00:43,  3.60s/it][A
 59%|█████▉    | 16/27 [00:58<00:38,  3.54s/it][A
 63%|██████▎   | 17/27 [01:02<00:34,  3.49s/it][A
 67%|██████▋   | 18/27 [01:05<00:31,  3.46s/it][A
 70%|███████   | 19/27 [01:09<00:27,  3.44s/it][A
 74%|███████▍  | 20/27 [01:12<00:23,  3

Sanity Check: All Passed!
Training...



  4%|▎         | 1/27 [19:50<8:35:44, 1190.18s/it][A