In [2]:
import torch.backends.cudnn as cudnn
import torch
import torchvision.transforms as transforms
import PIL
import argparse
import os
import random
import sys
import pprint
import datetime
import dateutil.tz
import numpy as np
import functools

In [3]:
from storygen.config import cfg, cfg_from_file

In [4]:
!pip3 install -r requirements.txt

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [5]:
random.seed(0)
torch.manual_seed(0)
if cfg.CUDA:
    print('CUDA Flag enabled: ', cfg.CUDA)
    torch.cuda.manual_seed_all(0)
now = datetime.datetime.now(dateutil.tz.tzlocal())
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
output_dir = './output/%s_%s' % (cfg.DATASET_NAME, cfg.CONFIG_NAME)

# number of gpus
num_gpu = len(cfg.GPU_ID.split(','))
print("number of GPUs: ", num_gpu)

CUDA Flag enabled:  True
number of GPUs:  1


In [6]:
if cfg.TRAIN.FLAG:
    print('TRAIN FLAG ENABLED:', cfg.TRAIN.FLAG)
    image_transforms = transforms.Compose([PIL.Image.fromarray, 
                                           transforms.Resize((cfg.IMSIZE, cfg.IMSIZE)),
                                           #transforms.RandomHorizontalFlip(),
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        # dataset = TextDataset(cfg.DATA_DIR, 'train',
        #                       imsize=cfg.IMSIZE,
        #                       transform=image_transform)
        #assert dataset
    def video_transform(video, image_transform):
        vid = []
        for im in video:
            vid.append(image_transform(im))
        vid = torch.stack(vid).permute(1, 0, 2, 3)
        print("vid value: ", vid)
        return vid

    video_len = 5
    n_channels = 3
    # functools.partial takes methods/functions as an input
    video_transforms = functools.partial(video_transform, image_transform=image_transforms)

TRAIN FLAG ENABLED: True


In [7]:
import storygen.pororo_data as data
import storygen.train as gan_train

In [9]:
dir_path = "./pororo_data/"
counter = np.load(os.path.join(dir_path, 'frames_counter.npy'), allow_pickle=True).item()
print("The number of frames: ", len(counter))
base = data.VideoFolderDataset(dir_path, counter = counter, cache = dir_path, min_len = 4, mode="train")
storydataset = data.StoryDataset(base, dir_path, video_transforms)
imagedataset = data.ImageDataset(base, dir_path, image_transforms)

The number of frames:  183
Total number of clips 10191


In [19]:
## Train
imageloader = torch.utils.data.DataLoader(imagedataset, batch_size=cfg.TRAIN.IM_BATCH_SIZE * num_gpu,
                                          drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))
print("imageloader length: ", len(imageloader))
storyloader = torch.utils.data.DataLoader(storydataset, batch_size=cfg.TRAIN.ST_BATCH_SIZE * num_gpu,
                                          drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))
print("storyloader length: ", len(storyloader))

## Validation
val_dir_path = dir_path
base_val = data.VideoFolderDataset(val_dir_path, counter, val_dir_path, 4, mode="val")
valdataset = data.StoryDataset(base_val, val_dir_path, video_transforms)
valloader = torch.utils.data.DataLoader(valdataset, batch_size=20, 
                                         drop_last=True, shuffle=False, num_workers=int(cfg.WORKERS))
print("Validation loader length: ", len(valloader))

imageloader length:  159
storyloader length:  159
Total number of clips 2320
Validation loader length:  116


In [13]:
algo = gan_train(cfg, output_dir, ratio = 1.0)
algo.train(imageloader, storyloader, valloader, cfg.STAGE)

159