# Training of MoCoGAN
---
In this Notebook is covered the topic of training the MoCoGAN and the fine tuning of this network.

## Defining global variables and imports
---
In the below cell all of the imports that are needed to the training and some global variables like `device` to use the GPU if available.

In [1]:
import os
import sys
import time
import math
import torch
import skvideo.io
import numpy as np

from glob import glob
from torch.autograd import Variable
from torch import nn, optim, device, manual_seed

from torch.utils.data import DataLoader
from torchvision.transforms import Lambda, Compose

## Try to avoid problems with dataloader.
"""
import torch.multiprocessing as mp

try:
    mp.set_start_method('forkserver')
    #mp.set_start_method('spawn')
except RuntimeError as errs:
    print(errs)
"""
##########################################

'''Import variables from train.py'''
img_size = 96
nc = 3
ndf = 64 # from dcgan
ngf = 64
d_E = 10
hidden_size = 100 # guess
d_C = 50
d_M = d_E
nz  = d_C + d_M
criterion = nn.BCEWithLogitsLoss()
categoriesCriterion = nn.CrossEntropyLoss()

T = 16 # Hyperparameter for taking #Frames into discriminator.

ngpu       = 1
batch_size = 16
n_iter     = 120000
pre_train  = False

## Addition for training on UCF-101
n_epochs_saveV      = 1
n_epochs_display    = 1
n_epochs_check      = 1
max_frame           = 25
cuda                = True
#### End of additions


seed = 0
manual_seed(seed)
np.random.seed(seed)

## Import of Models
From the `models` module, let's import all of the models and let's load the previous state for fine tuning.

Then models are moved into the device chosen in the cell above.

In [2]:
sys.path.append("./mocogan/")
from models import Discriminator_I, Discriminator_V, Generator_I ,GRU, UCF_101

In [3]:
'''Create the objects for Discriminator_(I|V), GRU and Generator_I'''
gen_i = Generator_I(nc, ngf, nz, ngpu = 1, batch_size= batch_size)
gru = GRU(d_E, hidden_size, gpu = cuda)
dis_i = Discriminator_I(nc, ndf, ngpu = 1)
dis_v = Discriminator_V(nc, ndf, T = T, ngpu = 1)
gru.initWeight()

'''Move objects into the device chosen'''
''' adjust to cuda '''
if cuda == True:
    dis_i.cuda()
    dis_v.cuda()
    gen_i.cuda()
    gru.cuda()
    criterion.cuda()
    categoriesCriterion.cuda()

'''Optimizer Settings and Optimizer'''
lr           = 0.0002
betas        = (0.5, 0.999)
weight_decay = 0.00001
optim_Di  = optim.Adam(dis_i.parameters(), lr=lr, betas=betas, weight_decay = weight_decay)
optim_Dv  = optim.Adam(dis_v.parameters(), lr=lr, betas=betas, weight_decay = weight_decay)
optim_Gi  = optim.Adam(gen_i.parameters(), lr=lr, betas=betas, weight_decay = weight_decay)
optim_GRU = optim.Adam(gru.parameters(),   lr=lr, betas=betas, weight_decay = weight_decay)

## Dataloader
---
In the cells below, the dataloader will be defined and also all of the transformation that will be applied to the videos before taking them into a batch.
To apply transformation to videos, since torchvision does not have a set of methods for this task, a repository called `torch_videovision` has been forked and modified to support transformation on numpy arrays.

In [4]:
from torch_videovision.videotransforms.video_transforms import ColorJitter, RandomTemporalCrop, RandomHorizontalFlip, RandomCrop, Resize
from torch_videovision.videotransforms.volume_transforms import ToTensor, ClipToTensor, TransposeChannels
from torch_videovision.videotransforms.tensor_transforms import Normalize, SpatialRandomCrop

  data = yaml.load(f.read()) or {}


In [5]:
current_path = !pwd
current_path = str(current_path[0])
resized_path = os.path.join(current_path, "mocogan", 'resized_data')
files = glob(os.path.join(resized_path, "*", "*"))

dict_dir = os.path.join(current_path, "mocogan", "ucfTrainTestlist", "classInd.txt")
raw_path = os.path.join(current_path, "mocogan", "raw_data")
raw_files = glob(os.path.join(raw_path, "*", "*"))

stdDev = 0.5
medium = 0.5

transformation = Compose([  TransposeChannels(),
                            RandomTemporalCrop(max_frame),
                            SpatialRandomCrop( (img_size * 2, img_size * 3) ),
                            RandomHorizontalFlip(),
                            TransposeChannels(reverse = True),
                            Resize((img_size, img_size), interpolation='bilinear'),
                            ColorJitter(brightness = 0.2, contrast = 0.7, saturation = 0.5, hue = 0.1), #Max Values for Color Jitter
                            ClipToTensor(div_255= True),
                            Normalize(medium, stdDev)])

dataset = UCF_101(raw_path, dict_dir, supportedExtensions= ["avi"], transform= transformation)
dataloader = DataLoader(dataset, batch_size= batch_size, num_workers= 8, shuffle= True, pin_memory= True, drop_last= True)

## Training
---

In the cells below, first some methods to generate the Noise needed for the GAN is defined, then 

In [6]:
'''Utilities'''
def trim(video):
    start = np.random.randint(0, video.shape[1] - (T+1))
    end = start + T
    return video[:, start:end, :, :]

# for input noises to generate fake video
# note that noises are trimmed randomly from n_frames to T for efficiency
def trim_noise(noise):
    
    start = np.random.randint(0, noise.size(1) - (T+1))
    end = start + T
    
    return noise[:, start:end, :, :, :]


''' calc grad of models '''

def bp_i(inputs, y, retain=False):
    
    if cuda:
        label = (torch.FloatTensor()).cuda()
        
    else:
        label = (torch.FloatTensor())
        
    label.resize_(inputs.size(0)).fill_(y)
    labelv = Variable(label)
    outputs = dis_i(inputs)
    
    err = criterion(outputs, labelv)
    err.backward(retain_graph=retain)
    toReturnErr = err.data[0] if err.size() == torch.Tensor().size() else err.item()
    
    return toReturnErr, outputs.data.mean()

def bp_v(inputs, labels, y, retain=False):
    
    if cuda:
        label = (torch.FloatTensor()).cuda()
    
    else:
        label = (torch.FloatTensor())
        
    try:
        label.resize_(inputs.size(0)).fill_(y)

    except RuntimeError as _:
        # Dimension of y does not allow to use fill_
        assert(inputs.size(0) == y.size(0))
        label = (torch.FloatTensor(y)).cuda()

    labelv = Variable(label)
    outputs, categories = dis_v(inputs)
    
    err = criterion(outputs, labelv)
    
    err += categoriesCriterion(categories, labels) #Add error for categories
    
    err.backward(retain_graph=retain)
    toReturnErr = err.data[0] if err.size() == torch.Tensor().size() else err.item()
    #print("----End of BackPropagate_V-----")
    return toReturnErr, outputs.data.mean()


''' gen input noise for fake video '''

def gen_z(n_frames, batch_size = batch_size):
    
    z_C = Variable(torch.randn(batch_size, d_C))
    #  repeat z_C to (batch_size, n_frames, d_C)
    z_C = z_C.unsqueeze(1).repeat(1, n_frames, 1)
    eps = Variable(torch.randn(batch_size, d_E))
    if cuda == True:
        z_C, eps = z_C.cuda(), eps.cuda()

    gru.initHidden(batch_size)
    # notice that 1st dim of gru outputs is seq_len, 2nd is batch_size
    z_M = gru(eps, n_frames).transpose(1, 0)
    z = torch.cat((z_M, z_C), 2)  # z.size() => (batch_size, n_frames, nz)
    
    return z.view(batch_size, n_frames, nz, 1, 1)

''' prepare for train '''
def timeSince(since):
    now = time.time()
    s = now - since
    d = math.floor(s / ((60**2)*24))
    h = math.floor(s / (60**2)) - d*24
    m = math.floor(s / 60) - h*60 - d*24*60
    s = s - m*60 - h*(60**2) - d*24*(60**2)
    return '%dd %dh %dm %ds' % (d, h, m, s)

trained_path = os.path.join(current_path, "mocogan", 'trained_models')
def checkpoint(model, optimizer, epoch):
    filename = os.path.join(trained_path, '%s_epoch-%d' % (model.__class__.__name__, epoch))
    torch.save(model.state_dict(), filename + '.model')
    torch.save(optimizer.state_dict(), filename + '.state')

def save_video(fake_video, epoch):
    outputdata = (fake_video * stdDev + medium) * 255 # Remove Normalization
    outputdata = outputdata.astype(np.uint8)
    dir_path = os.path.join(current_path, 'mocogan', 'generated_videos')
    file_path = os.path.join(dir_path, 'fakeVideo_epoch-%d.mp4' % (epoch))
    skvideo.io.vwrite(file_path, outputdata)


In [7]:
''' train models '''
def train():
    
    start_time = time.time()

    print(f"Starting training: CUDA is { 'On' if cuda == True else 'Off'}")

    for epoch in range(1, n_iter+1):
        ''' prepare real images '''
        # real_videos.size() => (batch_size, nc, T, img_size, img_size)

        # Get data iterator
        updateEvery_Generator_I_GRU = 2
        fake_label = 0
        true_label = [val for val in range(1, 102)]
        data_iter = iter(dataloader) #Iterator
        data_len = len(dataloader) #Num Batches
        data_i = 0

        while data_i < data_len:

            try:
          
                (real_videos, labels) = next(data_iter)
                  
                print(f"\r--------Batch {data_i}/{data_len}---------", end = "")

                if cuda == True:
                    real_videos = real_videos.cuda()
                    labels = labels.cuda()
                
                
                real_videos = Variable(real_videos)
                real_img = real_videos[:, :, np.random.randint(0, T), :, :]

                ''' prepare fake images '''
                # note that n_frames is sampled from video length distribution
                if (len(dataset.videoLengths) > 0):
                    randomVideo = list(dataset.videoLengths)[np.random.randint(0, len(dataset.videoLengths))]
                    n_frames = dataset.videoLengths[randomVideo]
          
                else: # Use this for first iterations, when dataset.videoLengths is not yet updated.
                    n_frames = T + 2 + np.random.randint(0, real_videos.size()[2])
          
                Z = gen_z(n_frames, batch_size)  # Z.size() => (batch_size, n_frames, nz, 1, 1)
                # trim => (batch_size, T, nz, 1, 1)
                Z = trim_noise(Z)
                # generate videos
                Z = Z.contiguous().view(batch_size*T, nz, 1, 1)
                
                fake_videos = gen_i(Z, labels)
                fake_videos = fake_videos.view(batch_size, T, nc, img_size, img_size)
                # transpose => (batch_size, nc, T, img_size, img_size)
                fake_videos = fake_videos.transpose(2, 1)
                # img sampling
                fake_img = fake_videos[:, :, np.random.randint(0, T), :, :]

                ''' train discriminators '''
                # video
                dis_v.zero_grad()
                randomStartFrameIdx = np.random.randint(0, real_videos.size()[2] - T - 1)

                croppedRealVideos = real_videos[:,:,randomStartFrameIdx: randomStartFrameIdx + T, :, :]
                err_Dv_real, Dv_real_mean = bp_v(croppedRealVideos, labels, 0.9)
                err_Dv_fake, Dv_fake_mean = bp_v(fake_videos.detach(), labels, fake_label)
                err_Dv = err_Dv_real + err_Dv_fake
                optim_Dv.step()
                # image
                dis_i.zero_grad()
                err_Di_real, Di_real_mean = bp_i(real_img, 0.9)
                err_Di_fake, Di_fake_mean = bp_i(fake_img.detach(), fake_label)
                err_Di = err_Di_real + err_Di_fake
                optim_Di.step()


                ''' train generators '''
                gen_i.zero_grad()
                gru.zero_grad()
                # video. notice retain=True for back prop twice
                err_Gv, _ = bp_v(fake_videos, labels, 0.9, retain=True)
                # images
                err_Gi, _ = bp_i(fake_img, 0.9)
          
                if epoch % updateEvery_Generator_I_GRU == 0:
                    optim_Gi.step()
                    optim_GRU.step()

                '''Increment index for Batch'''
                data_i = data_i + 1
          
                '''Cool down the Hardware'''
                time.sleep(1/2)
          
            except StopIteration:
                break
            
            except KeyboardInterrupt:
                save_video(fake_videos[0].data.cpu().numpy().transpose(1, 2, 3, 0), epoch)
                checkpoint(dis_i, optim_Di, epoch)
                checkpoint(dis_v, optim_Dv, epoch)
                checkpoint(gen_i, optim_Gi, epoch)
                checkpoint(gru,   optim_GRU, epoch)

        if epoch % n_epochs_display == 0:
            print('[%d/%d] (%s) Loss_Di: %.4f Loss_Dv: %.4f Loss_Gi: %.4f Loss_Gv: %.4f Di_real_mean %.4f Di_fake_mean %.4f Dv_real_mean %.4f Dv_fake_mean %.4f'
                  % (epoch, n_iter, timeSince(start_time), err_Di, err_Dv, err_Gi, err_Gv, Di_real_mean, Di_fake_mean, Dv_real_mean, Dv_fake_mean))

        if epoch % n_epochs_saveV == 0:
            save_video(fake_videos[0].data.cpu().numpy().transpose(1, 2, 3, 0), epoch)

        if epoch % n_epochs_check == 0:
            checkpoint(dis_i, optim_Di, epoch)
            checkpoint(dis_v, optim_Dv, epoch)
            checkpoint(gen_i, optim_Gi, epoch)
            checkpoint(gru,   optim_GRU, epoch)
          
        '''Cool down the Hardware'''
        time.sleep(5)
          

## Loading Previous State
---
If wanted, the following cell can be used to load the previous state of a trained model.

In [8]:
''' use pre-trained models '''

def load():
    
    loadEpoch = 1
    addString = f"_epoch-{loadEpoch}" if loadEpoch is not None else ""
    
    dis_i.load_state_dict(torch.load(trained_path + f'/Discriminator_I{addString}.model'))
    dis_v.load_state_dict(torch.load(trained_path + f'/Discriminator_V{addString}.model'))
    gen_i.load_state_dict(torch.load(trained_path + f'/Generator_I{addString}.model'))
    gru.load_state_dict(torch.load(trained_path + f'/GRU{addString}.model'))
    optim_Di.load_state_dict(torch.load(trained_path + f'/Discriminator_I{addString}.state'))
    optim_Dv.load_state_dict(torch.load(trained_path + f'/Discriminator_V{addString}.state'))
    optim_Gi.load_state_dict(torch.load(trained_path + f'/Generator_I{addString}.state'))
    optim_GRU.load_state_dict(torch.load(trained_path + f'/GRU{addString}.state'))

## Finally, start training


In [9]:
load()
train() 

Starting training: CUDA is On
--------Batch 831/832---------[1/120000] (0d 0h 17m 36s) Loss_Di: 1.1064 Loss_Dv: 5.3351 Loss_Gi: 0.6931 Loss_Gv: 10.9461 Di_real_mean 1.0000 Di_fake_mean 0.0000 Dv_real_mean 2.2075 Dv_fake_mean -7.1006
--------Batch 831/832---------[2/120000] (0d 0h 35m 21s) Loss_Di: 1.1064 Loss_Dv: 5.9322 Loss_Gi: 0.6931 Loss_Gv: 8.1312 Di_real_mean 1.0000 Di_fake_mean 0.0000 Dv_real_mean 1.4950 Dv_fake_mean -3.8291
--------Batch 831/832---------[3/120000] (0d 0h 53m 3s) Loss_Di: 1.1064 Loss_Dv: 7.1383 Loss_Gi: 0.6931 Loss_Gv: 10.7796 Di_real_mean 1.0000 Di_fake_mean 0.0000 Dv_real_mean 2.2963 Dv_fake_mean -6.8874
--------Batch 831/832---------[4/120000] (0d 1h 10m 49s) Loss_Di: 1.1070 Loss_Dv: 6.8129 Loss_Gi: 0.6930 Loss_Gv: 6.6681 Di_real_mean 1.0000 Di_fake_mean 0.0012 Dv_real_mean 2.5522 Dv_fake_mean -0.8968
--------Batch 831/832---------[5/120000] (0d 1h 28m 33s) Loss_Di: 1.1064 Loss_Dv: 5.9951 Loss_Gi: 0.6931 Loss_Gv: 10.6014 Di_real_mean 1.0000 Di_fake_mean 0.0000

RuntimeError: DataLoader worker (pid(s) 20854, 20855, 20856, 20857, 20858, 20859, 20860, 20861) exited unexpectedly