In [2]:
!python train.py --model 4 --checkpoints ./checkpoints

^C


In [3]:
import argparse
import os
import random
from shutil import copyfile
import cv2
import torch
from torch.utils.data import DataLoader

from src.config import Config
from src.dataset import Dataset
from src.metrics import PSNR, EdgeAccuracy
from src.models import EdgeModel, InpaintingModel, RefineModel
from src.utils import getVGGModel, Progbar, stitch_images, create_dir

In [12]:
def load_config(mode=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', type=str, default='./checkpoints')
    parser.add_argument('--checkpoints', type=str, default='./checkpoints')
    parser.add_argument('--model', type=int, choices=[1, 2, 3, 4], default=4)

    # test mode
    if mode == 2:
        parser.add_argument('--input', type=str, help='path to the input images directory or an input image')
        parser.add_argument('--mask', type=str, help='path to the masks directory or a mask file')
        parser.add_argument('--edge', type=str, help='path to the edges directory or an edge file')
        parser.add_argument('--output', type=str, help='path to the output directory')

    args = parser.parse_args([])
    config_path = os.path.join(args.path, 'config.yml')
    print(config_path)

    # create checkpoints path if does't exist
    if not os.path.exists(args.path):
        os.makedirs(args.path)

    # copy config template if does't exist
    if not os.path.exists(config_path):
        copyfile('./config.yml.example', config_path)

    # load config file
    configmain = Config(config_path)

    # train mode
    if mode == 1:
        configmain.MODE = 1
        if args.model:
            configmain.MODEL = args.model

    # test mode
    elif mode == 2:
        configmain.MODE = 2
        configmain.MODEL = args.model if args.model is not None else 3
        configmain.INPUT_SIZE = 0

        if args.input is not None:
            configmain.TEST_FLIST = args.input

        if args.mask is not None:
            configmain.TEST_MASK_FLIST = args.mask

        if args.edge is not None:
            configmain.TEST_EDGE_FLIST = args.edge

        if args.output is not None:
            configmain.RESULTS = args.output

    # eval mode
    elif mode == 3:
        configmain.MODE = 3
        configmain.MODEL = args.model if args.model is not None else 3

    return configmain


# Load Pre values
config=load_config(1)

./checkpoints\config.yml


In [13]:
TRAIN_FLIST= './checkpoints/datasets/train.flist'
VAL_FLIST= './checkpoints/datasets/val.flist'
TEST_FLIST= './checkpoints/datasets/test.flist'

TRAIN_EDGE_FLIST: './checkpoints/datasets/places2_edges_train.flist'
VAL_EDGE_FLIST: './checkpoints/datasets/places2_edges_val.flist'
TEST_EDGE_FLIST: './checkpoints/datasets/places2_edges_test.flist'

TRAIN_MASK_FLIST: './checkpoints/datasets/masks_train.flist'
VAL_MASK_FLIST: './checkpoints/datasets/masks_train.flist'
TEST_MASK_FLIST: './checkpoints/datasets/masks_train.flist'

In [14]:
train_dataset = Dataset(config, config.TRAIN_FLIST, config.TRAIN_EDGE_FLIST, config.TRAIN_MASK_FLIST, augment=True,
                        training=True)
val_dataset = Dataset(config, config.VAL_FLIST, config.VAL_EDGE_FLIST, config.VAL_MASK_FLIST, augment=False, training=True)
sample_iterator = val_dataset.create_iterator(config.SAMPLE_SIZE)
train_loader = DataLoader(dataset=train_dataset, batch_size=config.BATCH_SIZE, num_workers=4, drop_last=True, shuffle=True)

In [None]:
for items in train_loader:
    print("1")
    print(items)