In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.io as io
from torch.optim import Adam, RMSprop
from torch.utils.data import DataLoader
from UCF101Loader import UCF101
from itertools import repeat

In [2]:
class LatentTemporalGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.dc0 = nn.ConvTranspose1d(100, 512, 1, 1, 0)
        self.dc1 = nn.ConvTranspose1d(512, 256, 4, 2, 1) 
        self.dc2 = nn.ConvTranspose1d(256, 128, 4, 2, 1)
        self.dc3 = nn.ConvTranspose1d(128, 128, 4, 2, 1)
        self.dc4 = nn.ConvTranspose1d(128, 100, 4, 2, 1)
        self.bn0 = nn.BatchNorm1d(512)
        self.bn1 = nn.BatchNorm1d(256)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(128)
    
    def forward(self, z0):
        h = z0.view(z0.size(0),-1, 1)
        h = F.relu(self.bn0(self.dc0(h)))
        h = F.relu(self.bn1(self.dc1(h)))
        h = F.relu(self.bn2(self.dc2(h)))
        h = F.relu(self.bn3(self.dc3(h)))
        zT = torch.tanh(self.dc4(h))
        return zT

In [3]:
class VideoGenerator(nn.Module):
    def __init__(self, conv_ch=512):
        super().__init__()
        self.ch = conv_ch
        self.latent_dim = 256 * 4*4
        
        self.zTGen = LatentTemporalGenerator()
        
        self.lz0 = nn.Linear(100, self.latent_dim)
        self.lzT = nn.Linear(100, self.latent_dim)
        self.dc1 = nn.ConvTranspose2d(conv_ch, conv_ch // 2, 4, 2, 1)
        self.dc2 = nn.ConvTranspose2d(conv_ch // 2, conv_ch // 4, 4, 2, 1)
        self.dc3 = nn.ConvTranspose2d(conv_ch // 4, conv_ch // 8, 4, 2, 1)
        self.dc4 = nn.ConvTranspose2d(conv_ch // 8, conv_ch // 16, 4, 2, 1)
        self.dc5 = nn.ConvTranspose2d(conv_ch // 16, 3, 3, 1, 1)
        
        self.bn0s = nn.BatchNorm1d(self.latent_dim)
        self.bn0f = nn.BatchNorm1d(self.latent_dim)
        self.bn1 = nn.BatchNorm2d(conv_ch // 2)
        self.bn2 = nn.BatchNorm2d(conv_ch // 4)
        self.bn3 = nn.BatchNorm2d(conv_ch // 8)
        self.bn4 = nn.BatchNorm2d(conv_ch // 16)
        
    def forward(self, z0):
        zT = self.zTGen(z0)
        B = z0.shape[0]
        #z0 [B, 100] -> [B, 1, 100] -> [B, 16, 100], [B*16, 100]
        z0 = z0.unsqueeze(1).repeat(1, 16, 1).contiguous().view(B*16, 100)
        #zT [B, 100, 16] -> [B, 16, 100] -> [B*16, 100]
        zT = zT.permute(0, 2, 1).contiguous().view(B*16, 100)
        
        n = z0.shape[0]
        h_z0 = (F.relu(self.bn0s(self.lz0(z0))).view(n, self.ch // 2, 4, 4))
        h_zT = (F.relu(self.bn0f(self.lzT(zT)))).view(n, self.ch // 2, 4, 4)
        h = torch.cat((h_z0, h_zT), 1)
        h = F.relu(self.bn1(self.dc1(h)))
        h = F.relu(self.bn2(self.dc2(h)))
        h = F.relu(self.bn3(self.dc3(h)))
        h = F.relu(self.bn4(self.dc4(h)))
        x = torch.tanh(self.dc5(h)).view(B, 16, 3, 64, 64)
        return x

In [4]:
class Discriminator(nn.Module):
    def __init__(self, sequence_first=True, mid_ch=64):
        super().__init__()
        self.sequence_first=True

        self.c0 = nn.Conv3d(3, mid_ch, 4, 2, 1)
        self.c1 = nn.Conv3d(mid_ch, mid_ch * 2, 4, 2, 1)
        self.c2 = nn.Conv3d(mid_ch * 2, mid_ch * 4, 4, 2, 1)
        self.c3 = nn.Conv3d(mid_ch * 4, mid_ch * 8, 4, 2, 1)
        self.bn0 = nn.BatchNorm3d(mid_ch)
        self.bn1 = nn.BatchNorm3d(mid_ch * 2)
        self.bn2 = nn.BatchNorm3d(mid_ch * 4)
        self.bn3 = nn.BatchNorm3d(mid_ch * 8)

    def forward(self, x):
        if self.sequence_first:
            x = x.permute(0, 2, 1, 3, 4)

        h = F.leaky_relu(self.c0(x))
        h = F.leaky_relu(self.bn1(self.c1(h)))
        h = F.leaky_relu(self.bn2(self.c2(h)))
        h = F.leaky_relu(self.bn3(self.c3(h)))
        h = h.view(h.size(0), -1)
        return torch.mean(h, 1)

In [5]:
def tensorToVideo(t, epoch, const=True):
    os.makedirs('Video {}/{}'.format('Constant' if const else 'Random', epoch), exist_ok=True)
    t = (t.permute(0, 1, 3, 4, 2)+1) * 255 / 2
    for i in range(t.shape[0]):
        io.write_video('Video {}/{}/{}.mp4'.format('Constant' if const else 'Random', epoch, i), t[i], 25)

In [11]:
epochs = 1000
batch_size = 64
n_disc = 5
n_steps = 1000
constant_noise = torch.randn([64, 100]).cuda()
c = 1


g = VideoGenerator().cuda()
d = Discriminator().cuda()
g_optim = RMSprop(g.parameters(), lr=5e-5)
d_optim = RMSprop(d.parameters(), lr=5e-5)
loader = DataLoader(UCF101(), batch_size=batch_size, shuffle=True, drop_last=True)

In [12]:
def data_gen(data_loader):
    for loader in repeat(data_loader):
        for data in loader:
            yield data

In [13]:
loader = data_gen(loader)

In [14]:
for i, (X, y) in enumerate(loader):
    i += 1
    print(i)
    z = torch.randn([64, 100]).cuda()
    pred_g = g(z)
    pred_dx = d(X.cuda())
    pred_dg = d(pred_g)
    loss_d = -torch.mean(pred_dx - pred_dg)
    loss_d.backward()

    d_optim.step()
    d_optim.zero_grad()
    g_optim.zero_grad()
    
    for p in d.parameters():
        p.data.clamp_(-c, c)
    
    print('Discriminator Loss: ', loss_d)
    
    if i % n_disc == 0:
        z = torch.randn([64, 100]).cuda()
        pred_g = g(z)
        with torch.no_grad():
            pred_dg = d(pred_g)
        loss_g = -torch.mean(pred_dg)
        loss_g.backward()
        
        g_optim.step()
        print('Generator Loss: ', loss_g)
        
        if i % 25 == 0:
            tensorToVideo(g(z), i, const=False)
            tensorToVideo(g(constant_noise), i)

1
Discriminator Loss:  tensor(0.0072, device='cuda:0', grad_fn=<NegBackward>)
2
Discriminator Loss:  tensor(-0.0003, device='cuda:0', grad_fn=<NegBackward>)
3
Discriminator Loss:  tensor(-0.0010, device='cuda:0', grad_fn=<NegBackward>)
4
Discriminator Loss:  tensor(-0.0016, device='cuda:0', grad_fn=<NegBackward>)
5
Discriminator Loss:  tensor(-0.0020, device='cuda:0', grad_fn=<NegBackward>)
Generator Loss:  tensor(-0.0185, device='cuda:0', grad_fn=<NegBackward>)
6
Discriminator Loss:  tensor(-0.0021, device='cuda:0', grad_fn=<NegBackward>)
7
Discriminator Loss:  tensor(-0.0024, device='cuda:0', grad_fn=<NegBackward>)
8
Discriminator Loss:  tensor(-0.0028, device='cuda:0', grad_fn=<NegBackward>)
9
Discriminator Loss:  tensor(-0.0030, device='cuda:0', grad_fn=<NegBackward>)
10
Discriminator Loss:  tensor(-0.0032, device='cuda:0', grad_fn=<NegBackward>)
Generator Loss:  tensor(-0.0184, device='cuda:0', grad_fn=<NegBackward>)
11
Discriminator Loss:  tensor(-0.0019, device='cuda:0', grad_fn

RuntimeError: The size of tensor a (8) must match the size of tensor b (64) at non-singleton dimension 0