In [1]:
import argparse
import os
import sys
import numpy as np
import math
import time
import pickle
import cv2 as cv
import matplotlib
import matplotlib.pyplot as plt
import random
from cv2 import VideoWriter, VideoWriter_fourcc, imread

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torch.autograd import Variable
from torch.cuda.amp import autocast, GradScaler

import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision

import warnings

from ModelZoo import TemporalDiscriminator, StyleGanGenerator, StyleGanDiscriminator

BATCH_SIZE = 24

In [2]:
class PretrainModel(nn.Module):
    def __init__(self):
        super(PretrainModel, self).__init__()

        self.maxpool1 = nn.MaxPool2d(kernel_size=[3, 3], stride=2)
        self.maxpool2 = nn.MaxPool2d(kernel_size=[8, 8], stride=1, padding=0)
        self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
        self.softmax = nn.Softmax(dim=2)
        self.relu = nn.ReLU()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=[7, 7], stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=[7, 7], stride=2, padding=3, bias=False)
        self.bn2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=[3, 3], stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.conv4 = nn.Conv2d(in_channels=128, out_channels=20000, kernel_size=[3, 3], stride=1, padding=1)

        self.inter1 = nn.Conv2d(in_channels=20000, out_channels=64, kernel_size=[3, 3], stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.inter2 = nn.Linear(64, 64)
        self.bn5 = nn.BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.inter3 = nn.ConvTranspose2d(in_channels=64, out_channels=128, kernel_size=[4, 4], padding=0)
        self.bn6 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.inter4 = nn.ConvTranspose2d(in_channels=128, out_channels=256, kernel_size=[4, 4], stride=2, padding=1)
        self.bn7 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

        self.deconv1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=[4, 4], stride=2, padding=1)
        self.bn8 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.deconv2 = nn.ConvTranspose2d(in_channels=512, out_channels=128, kernel_size=[4, 4], stride=2, padding=1)
        self.bn9 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.deconv3 = nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=[4, 4], stride=2, padding=1)

    
    #@autocast()
    def forward(self, x):
        b = x.shape[0]

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.maxpool1(x)
        x = self.conv4(x)
        
        activation = x

        ## Inter Layer. will be discarded in real use, for bottlenecking reason only.
        x = self.inter1(x)
        x = self.bn4(x)
        x = self.relu(x)
        x = self.maxpool2(x)
        x = self.inter2(x.view(b, 64))
        x = self.bn5(x)
        x = self.relu(x)
        x = self.inter3(x.view(b, 64, 1, 1))
        x = self.bn6(x)
        x = self.relu(x)
        x = self.upsample(x)
        x = self.inter4(x)
        x = self.bn7(x)
        x = self.upsample(x)

        ## Deconv Layers, will be finetuned with the reformer.
        ## Has input shape of (b, 256, 1024) ==> all the embeddings for one image
        ## first, reconstruct an image of size 16*16 from the embeddings, it will have 512 channels
        x = x.view(b, 16, 16, 1024)
        x = x.transpose(1, 3).transpose(2, 3)
        x = self.deconv1(x)
        x = self.bn8(x)
        x = self.deconv2(x)
        x = self.bn9(x)
        x = self.deconv3(x)

        return x, activation
    
    
def one_diff_loss(inp): # output shape per pixel: inp = (b, 16, 16, 20000)

    loss_function = nn.L1Loss()
    inp = inp.reshape(inp.shape[0], inp.shape[1]*inp.shape[2], inp.shape[3]) # to shape (b, 256, 20000)
    inp = F.softmax(inp, dim=2)
    
    print(torch.argmax(inp[0][0]), torch.max(inp[0][0]))
    x = torch.argmax(inp, dim=2).unsqueeze(2) #index of shape (b, 256, 1)
    
    onehot = torch.zeros_like(inp) # (b, 256, 20000)
    onehot.scatter_(dim=2, index=x, value=1, reduce='add')

    return loss_function(inp, onehot)


class Dataset(Dataset):

    def __init__(self, file_dir, transform=None):

        self.dir = file_dir
        self.transform = transform
        self.diction = {}
        
        idx = 0
        for filename in os.listdir(self.dir):
            if filename.endswith('jpg'):
                self.diction[idx] = filename
                idx += 1
                        
    def __len__(self):
        return len(self.diction)

    
    def __getitem__(self, idx):
        x = self.diction[idx]
        directory_x = self.dir + "\\" + str(x)
        x = cv.imread(directory_x) / 255
        if self.transform:
            x = self.transform(x)
        x = torch.Tensor(x)
        x = F.interpolate(x.unsqueeze(0), size=(128, 128)).squeeze(0)
        return x

    
def HWC2CHW(x):
    return np.array(x).transpose(2, 0, 1)

In [3]:
dataset = Dataset(file_dir=r"C:\Users\Leo's PC\Documents\SSTP Tests\SSTP\Vformer\test_frames", transform=HWC2CHW)
loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=0)

In [4]:
from vformer_model import Conv, CombinedEmbedding, DeConv
test_model = nn.Sequential(Conv(dic_size=16384),
                           CombinedEmbedding(positional_encoding=False),
                           DeConv(input_dim=512))

In [5]:
# Loss function
MSE_loss = nn.L1Loss()
L1_loss = nn.L1Loss().cuda()


# Initialize generator and discriminator
model = test_model.cuda()
# D = StyleGanDiscriminator()


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = torch.nn.DataParallel(model)
# D = torch.nn.DataParallel(D)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.99))

scaler = GradScaler()

training_log = {'iteration':[], 'loss':[]}

In [6]:
warnings.filterwarnings("ignore", category=UserWarning)

for epoch in range(300):
    
    start_time = time.time()
    total_loss = 0
    
    for i, imgs in enumerate(loader):

        imgs = Variable(imgs).cuda()

        optimizer.zero_grad()

        out, activation = model(imgs)
        # print(torch.argmax(activation[0][0]), torch.max(activation[0][0]))
 
        gt_loss = MSE_loss(out, imgs)
        #ad_loss = MSE_loss(D(F.interpolate(out, size=(256, 256))), D(F.interpolate(imgs, size=(256, 256))).detach()*2)
        ad_loss = L1_loss(D(F.interpolate(out, size=(256, 256))), torch.ones(BATCH_SIZE, 1).cuda()*0.3)
        # print(gt_loss * 0.1, ad_loss * 0.5, one_diff_loss(activation) * 0.8)
        loss = gt_loss * 0.2 + ad_loss * 0.3 + one_diff_loss(activation.transpose(1, 3)) * 0.5
        loss.backward()
        optimizer.step()
        
        total_loss += loss
        
    print('Epoch {:d} | {:.2f} minutes | loss: {:.6f}'.format(epoch, (time.time() - start_time) / 60, total_loss/len(loader)))
    
    torchvision.utils.save_image(out[0], os.path.join(r"C:/Users/Leo's PC/Documents/SSTP Tests/SSTP/Vformer/out_samples", str(epoch) + '.jpg'))
    
    if epoch%10 == 0:
        with open(r"C:/Users/Leo's PC/Documents/SSTP Tests/SSTP/Vformer/model_checkpoints/" + str(epoch), 'wb') as checkpoint_file:
            torch.save({'model': model.state_dict()}, checkpoint_file)
        
    training_log['iteration'].append(epoch)
    training_log['loss'].append(total_loss/len(loader))

tensor(19832, device='cuda:0') tensor(0.0009, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(19832, device='cuda:0') tensor(0.0018, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(19832, device='cuda:0') tensor(0.0265, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(19832, device='cuda:0') tensor(0.0035, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(19832, device='cuda:0') tensor(0.0074, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(19832, device='cuda:0') tensor(0.0028, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(19832, device='cuda:0') tensor(0.0342, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(19832, device='cuda:0') tensor(0.0147, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(19832, device='cuda:0') tensor(0.7121, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(19832, device='cuda:0') tensor(0.0018, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(19832, device='cuda:0') tensor(0.0023, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(19832, device='cuda:0') tensor(0.861

KeyboardInterrupt: 

In [6]:
warnings.filterwarnings("ignore", category=UserWarning)

for epoch in range(300):
    
    start_time = time.time()
    total_loss = 0
    
    for i, imgs in enumerate(loader):

        imgs = Variable(imgs.half()).cuda()

        optimizer.zero_grad()

        with autocast():
            out = model(imgs)
        print(out.dtype)
        print(out.shape)
        loss = MSE_loss(out, imgs)
        scaler.scale(loss).backward()
        scaler.step(optimizer) 
        scaler.update()
        
        total_loss += loss
        
    print('Epoch {:d} | {:.2f} minutes | loss: {:.6f}'.format(epoch, (time.time() - start_time) / 60, total_loss/len(loader)*BATCH_SIZE))
    
    torchvision.utils.save_image(out[0], os.path.join(r"C:/Users/Leo's PC/Documents/SSTP Tests/SSTP/Vformer/out_samples", str(epoch) + '.jpg'))
    
    if epoch%10 == 0:
        with open(r"C:/Users/Leo's PC/Documents/SSTP Tests/SSTP/Vformer/model_checkpoints/" + str(epoch), 'wb') as checkpoint_file:
            torch.save({'model': model.state_dict()}, checkpoint_file)
        
    training_log['iteration'].append(epoch)
    training_log['loss'].append(total_loss/len(loader))

AssertionError: Caught AssertionError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "C:\ProgramData\Anaconda3\envs\pd\lib\site-packages\torch\nn\parallel\parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\pd\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\pd\lib\site-packages\torch\nn\modules\container.py", line 117, in forward
    input = module(input)
  File "C:\ProgramData\Anaconda3\envs\pd\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\Leo's PC\Documents\SSTP Tests\SSTP\Vformer\vformer_model.py", line 54, in forward
    assert torch.max(x) == 1 and torch.min(x) == 0, "input has to be one-hot encoded. Got max {}, min {}".format(torch.max(x), torch.min(x))
AssertionError: input has to be one-hot encoded. Got max nan, min nan
