In [2]:
import os

import numpy as np

import torch
import torch.nn as nn

from torch.autograd import Variable

import torchvision.utils

from new_data_loader import get_loader
from make_gif import make_gif

In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.layer_video = nn.Conv3d(in_channels=3, out_channels=32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
        self.layer_y = nn.Conv3d(in_channels=6, out_channels=32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
        
        self.discriminator = nn.Sequential(
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Conv3d(in_channels=64, out_channels=128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=128, eps=1e-03),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Conv3d(in_channels=128, out_channels=256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=256, eps=1e-03),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Conv3d(in_channels=256, out_channels=512, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=512, eps=1e-03),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Conv3d(in_channels=512, out_channels=2, kernel_size=(2, 4, 4), stride=(1, 1, 1), padding=(0, 0, 0)),
        )

    
    def forward(self, video, y):
        out_video = self.layer_video(video)
        out_y = self.layer_y(y)
                             
        out_cat = torch.cat([out_video, out_y], 1)
                             
        out = self.discriminator(out_cat)
                             
        return out

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.layer_3d_video = nn.ConvTranspose3d(in_channels=100, out_channels=256, kernel_size=(2,4,4))
        self.layer_2d_video = nn.ConvTranspose2d(in_channels=100, out_channels=256, kernel_size=4, stride=1, padding=0)

        self.layer_3d_y = nn.ConvTranspose3d(in_channels=6, out_channels=256, kernel_size=(2,4,4))        
        self.layer_2d_y = nn.ConvTranspose2d(in_channels=6, out_channels=256, kernel_size=4, stride=1, padding=0)
        
        self.net_video = nn.Sequential(
            nn.BatchNorm3d(num_features=512),
            nn.ReLU(inplace=True),

            nn.ConvTranspose3d(in_channels=512, out_channels=256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose3d(in_channels=256, out_channels=128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose3d(in_channels=128, out_channels=64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=64),
            nn.ReLU(inplace=True)
        )

        self.gen_net = nn.Sequential(
            nn.ConvTranspose3d(in_channels=64, out_channels=3, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.Tanh()
        )

        self.mask_net = nn.Sequential(
            nn.ConvTranspose3d(in_channels=64, out_channels=1, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.Sigmoid()
        )

        self.static_net = nn.Sequential(
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z, y):
        
        local_batch_size = z.size()[0]
        
        z_forgeround =  z.view(-1, 100, 1, 1, 1)
        z_background = z.view(-1, 100, 1, 1)
        
        y_foreground =  y.view(-1, 6, 1, 1, 1)
        y_background = y.view(-1, 6, 1, 1)
        
        out_3d_video = self.layer_3d_video(z_forgeround)
        out_2d_video = self.layer_2d_video(z_background)
        
        out_3d_y = self.layer_3d_y(y_foreground)
        out_2d_y = self.layer_2d_y(y_background)

        out_cat_3d = torch.cat([out_3d_video, out_3d_y],1)
        out_cat_2d = torch.cat([out_2d_video, out_2d_y],1)
        
        m_net_video = self.net_video(out_cat_3d)
        
        m_gen_net = self.gen_net(m_net_video)
        m_mask_net = self.mask_net(m_net_video)
        
        m_static_net = self.static_net(out_cat_2d)
        
        foreground = m_gen_net

        mask = m_mask_net.expand(local_batch_size, 3, 32, 64, 64)

        background = m_static_net.view(local_batch_size, 3, 1, 64, 64).expand(local_batch_size, 3, 32, 64, 64)
        
        video = foreground * mask + background * (1 - mask)

        return video


In [5]:
def init_weights(m) :
    name = type(m)

    if name == nn.Conv3d or name == nn.ConvTranspose2d or name == nn.ConvTranspose3d :
        m.weight.data.normal_(0.0, 0.01)
        m.bias.data.fill_(0)
    elif name == nn.BatchNorm2d or name == nn.BatchNorm3d :
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [6]:
pre_train = True

batch_size = 64
video_size = 64
epoch_size = 10000
        
#check GPU
is_gpu = torch.cuda.is_available()
print(is_gpu)

if is_gpu :
    dtype = torch.cuda.FloatTensor
else :
    dtype = torch.FloatTensor

if pre_train :
    D = torch.load('D.ckpt').type(dtype)
    G = torch.load('G.ckpt').type(dtype)
else :
    D = Discriminator()
    D = D.type(dtype)

    G = Generator()
    G = G.type(dtype)

    D.apply(init_weights)
    G.apply(init_weights)

criterion = nn.BCEWithLogitsLoss().type(dtype)

d_optimizer = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))

True


In [None]:
data_loader = get_loader(data_path='./dataset', image_size=video_size, batch_size=batch_size, num_workers=2)

for epoch in range(1, epoch_size + 1) :

    for iter, (video, y) in enumerate(data_loader) :
        local_batch_size = video.size()[0]
        
        real_labels = Variable(torch.ones(local_batch_size, 2).type(dtype))
        fake_labels = Variable(torch.zeros(local_batch_size, 2).type(dtype))
        
        # 1. Train Discriminator
        video_data = Variable(video).type(dtype)
        y = Variable(y).type(dtype)
        
        y_data = torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(y, -1), -1), -1)
        y_data = y_data.expand(local_batch_size, 6, 32, 64, 64)
       
        
        
        # 1-1. Real Video
        outputs = D(video_data, y_data).view(local_batch_size, 2)
        d_loss_real = criterion(outputs, real_labels)

        
        
        # 1-2. Fake Video
        z = Variable(torch.randn(local_batch_size, 100) * 0.01).type(dtype)
        fake_videos = G(z, y)
        outputs = D(fake_videos, y_data).view(local_batch_size, 2)
        d_loss_fake = criterion(outputs, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        
        D.zero_grad()
        G.zero_grad()
        d_loss.backward()
        d_optimizer.step()





        # 2. Train Generator
        z = Variable(torch.randn(local_batch_size, 100) * 0.01).type(dtype)
        fake_videos = G(z, y)
        outputs = D(fake_videos, y_data).view(local_batch_size, 2)

        g_loss = criterion(outputs, real_labels)
        D.zero_grad()
        G.zero_grad()

        g_loss.backward()
        g_optimizer.step()
          
        print('Epoch [%d/%d], Iter [%d/%d], d_loss: %.4f, g_loss: %.4f' % (epoch, epoch_size, iter, len(data_loader), d_loss.data[0], g_loss.data[0]))
    
    print('Model saving...')
    
    torch.save(D, 'D.ckpt')
    torch.save(G, 'G.ckpt')

Epoch [1/10000], Iter [0/49], d_loss: 1.2017, g_loss: 1.9545
Epoch [1/10000], Iter [1/49], d_loss: 2.5498, g_loss: 1.0514
Epoch [1/10000], Iter [2/49], d_loss: 1.4314, g_loss: 1.2753
Epoch [1/10000], Iter [3/49], d_loss: 1.3733, g_loss: 1.2508
Epoch [1/10000], Iter [4/49], d_loss: 1.3737, g_loss: 1.2907
Epoch [1/10000], Iter [5/49], d_loss: 1.4533, g_loss: 1.1530
Epoch [1/10000], Iter [6/49], d_loss: 1.5563, g_loss: 1.1968
Epoch [1/10000], Iter [7/49], d_loss: 1.4809, g_loss: 1.0170
Epoch [1/10000], Iter [8/49], d_loss: 1.4877, g_loss: 0.7220
Epoch [1/10000], Iter [9/49], d_loss: 1.3996, g_loss: 1.3135
Epoch [1/10000], Iter [10/49], d_loss: 1.1819, g_loss: 1.5108
Epoch [1/10000], Iter [11/49], d_loss: 1.0471, g_loss: 1.2364
Epoch [1/10000], Iter [12/49], d_loss: 1.3763, g_loss: 1.6865
Epoch [1/10000], Iter [13/49], d_loss: 1.1672, g_loss: 0.9161
Epoch [1/10000], Iter [14/49], d_loss: 1.0765, g_loss: 1.1071
Epoch [1/10000], Iter [15/49], d_loss: 1.1187, g_loss: 0.6412
Epoch [1/10000], I

Epoch [3/10000], Iter [35/49], d_loss: 1.1947, g_loss: 1.0959
Epoch [3/10000], Iter [36/49], d_loss: 1.1163, g_loss: 1.3388
Epoch [3/10000], Iter [37/49], d_loss: 1.4250, g_loss: 1.5614
Epoch [3/10000], Iter [38/49], d_loss: 1.4271, g_loss: 0.6682
Epoch [3/10000], Iter [39/49], d_loss: 1.1763, g_loss: 1.1814
Epoch [3/10000], Iter [40/49], d_loss: 1.1637, g_loss: 1.2033
Epoch [3/10000], Iter [41/49], d_loss: 1.2435, g_loss: 1.0797
Epoch [3/10000], Iter [42/49], d_loss: 1.1177, g_loss: 1.2219
Epoch [3/10000], Iter [43/49], d_loss: 1.4228, g_loss: 1.1813
Epoch [3/10000], Iter [44/49], d_loss: 1.3083, g_loss: 1.3583
Epoch [3/10000], Iter [45/49], d_loss: 1.2335, g_loss: 0.7920
Epoch [3/10000], Iter [46/49], d_loss: 1.2509, g_loss: 1.2170
Epoch [3/10000], Iter [47/49], d_loss: 1.3776, g_loss: 1.1370
Epoch [3/10000], Iter [48/49], d_loss: 1.7494, g_loss: 1.3499
Model saving...
Epoch [4/10000], Iter [0/49], d_loss: 1.5862, g_loss: 1.3309
Epoch [4/10000], Iter [1/49], d_loss: 1.7743, g_loss: 0

Epoch [6/10000], Iter [20/49], d_loss: 0.9140, g_loss: 1.1123
Epoch [6/10000], Iter [21/49], d_loss: 1.0793, g_loss: 1.9039
Epoch [6/10000], Iter [22/49], d_loss: 1.1869, g_loss: 0.8947
Epoch [6/10000], Iter [23/49], d_loss: 1.4859, g_loss: 1.6486
Epoch [6/10000], Iter [24/49], d_loss: 1.3638, g_loss: 0.8205
Epoch [6/10000], Iter [25/49], d_loss: 1.4671, g_loss: 1.4419
Epoch [6/10000], Iter [26/49], d_loss: 1.4344, g_loss: 0.8112
Epoch [6/10000], Iter [27/49], d_loss: 1.2135, g_loss: 2.0748
Epoch [6/10000], Iter [28/49], d_loss: 1.1594, g_loss: 0.8115
Epoch [6/10000], Iter [29/49], d_loss: 1.2792, g_loss: 1.0872
Epoch [6/10000], Iter [30/49], d_loss: 1.1315, g_loss: 2.2075
Epoch [6/10000], Iter [31/49], d_loss: 1.5968, g_loss: 0.8096
Epoch [6/10000], Iter [32/49], d_loss: 1.0872, g_loss: 0.9424
Epoch [6/10000], Iter [33/49], d_loss: 1.0028, g_loss: 1.9012
Epoch [6/10000], Iter [34/49], d_loss: 1.5248, g_loss: 0.6817
Epoch [6/10000], Iter [35/49], d_loss: 1.6329, g_loss: 1.5915
Epoch [6

Epoch [9/10000], Iter [5/49], d_loss: 1.2425, g_loss: 1.4492
Epoch [9/10000], Iter [6/49], d_loss: 1.1202, g_loss: 2.3892
Epoch [9/10000], Iter [7/49], d_loss: 1.1916, g_loss: 1.1528
Epoch [9/10000], Iter [8/49], d_loss: 1.2584, g_loss: 2.3835
Epoch [9/10000], Iter [9/49], d_loss: 1.2864, g_loss: 1.4995
Epoch [9/10000], Iter [10/49], d_loss: 1.2032, g_loss: 1.3545
Epoch [9/10000], Iter [11/49], d_loss: 1.0507, g_loss: 1.8968
Epoch [9/10000], Iter [12/49], d_loss: 1.0679, g_loss: 1.6259
Epoch [9/10000], Iter [13/49], d_loss: 1.1286, g_loss: 2.1873
Epoch [9/10000], Iter [14/49], d_loss: 1.0196, g_loss: 1.4619
Epoch [9/10000], Iter [15/49], d_loss: 0.7766, g_loss: 2.9123
Epoch [9/10000], Iter [16/49], d_loss: 1.5330, g_loss: 0.4492
Epoch [9/10000], Iter [17/49], d_loss: 1.8128, g_loss: 2.1561
Epoch [9/10000], Iter [18/49], d_loss: 1.2209, g_loss: 1.5078
Epoch [9/10000], Iter [19/49], d_loss: 0.8215, g_loss: 1.2074
Epoch [9/10000], Iter [20/49], d_loss: 1.1606, g_loss: 1.7072
Epoch [9/1000

Epoch [11/10000], Iter [38/49], d_loss: 1.0921, g_loss: 1.7704
Epoch [11/10000], Iter [39/49], d_loss: 1.0796, g_loss: 1.0131
Epoch [11/10000], Iter [40/49], d_loss: 1.0634, g_loss: 2.1107
Epoch [11/10000], Iter [41/49], d_loss: 1.4253, g_loss: 0.6679
Epoch [11/10000], Iter [42/49], d_loss: 1.4897, g_loss: 2.6200
Epoch [11/10000], Iter [43/49], d_loss: 1.5401, g_loss: 0.7530
Epoch [11/10000], Iter [44/49], d_loss: 1.0619, g_loss: 1.5044
Epoch [11/10000], Iter [45/49], d_loss: 0.9745, g_loss: 1.4018
Epoch [11/10000], Iter [46/49], d_loss: 1.3025, g_loss: 1.2713
Epoch [11/10000], Iter [47/49], d_loss: 1.5474, g_loss: 1.9405
Epoch [11/10000], Iter [48/49], d_loss: 1.2697, g_loss: 1.4735
Model saving...
Epoch [12/10000], Iter [0/49], d_loss: 1.0181, g_loss: 0.6894
Epoch [12/10000], Iter [1/49], d_loss: 1.3486, g_loss: 0.9546
Epoch [12/10000], Iter [2/49], d_loss: 1.2641, g_loss: 1.1136
Epoch [12/10000], Iter [3/49], d_loss: 1.1527, g_loss: 1.2760
Epoch [12/10000], Iter [4/49], d_loss: 1.20

Epoch [14/10000], Iter [21/49], d_loss: 1.0604, g_loss: 1.8979
Epoch [14/10000], Iter [22/49], d_loss: 1.1333, g_loss: 1.7987
Epoch [14/10000], Iter [23/49], d_loss: 1.1745, g_loss: 2.0530
Epoch [14/10000], Iter [24/49], d_loss: 1.1447, g_loss: 2.3172
Epoch [14/10000], Iter [25/49], d_loss: 0.9817, g_loss: 1.5379
Epoch [14/10000], Iter [26/49], d_loss: 1.3231, g_loss: 3.2766
Epoch [14/10000], Iter [27/49], d_loss: 1.6522, g_loss: 0.7110
Epoch [14/10000], Iter [28/49], d_loss: 1.5508, g_loss: 2.4520
Epoch [14/10000], Iter [29/49], d_loss: 1.2583, g_loss: 1.3969
Epoch [14/10000], Iter [30/49], d_loss: 1.2158, g_loss: 1.4186
Epoch [14/10000], Iter [31/49], d_loss: 1.1482, g_loss: 1.6413
Epoch [14/10000], Iter [32/49], d_loss: 1.0547, g_loss: 1.4811
Epoch [14/10000], Iter [33/49], d_loss: 0.9697, g_loss: 1.5391
Epoch [14/10000], Iter [34/49], d_loss: 0.9535, g_loss: 1.9462
Epoch [14/10000], Iter [35/49], d_loss: 1.0434, g_loss: 1.2975
Epoch [14/10000], Iter [36/49], d_loss: 1.2749, g_loss:

Epoch [17/10000], Iter [4/49], d_loss: 0.9969, g_loss: 1.7317
Epoch [17/10000], Iter [5/49], d_loss: 0.9162, g_loss: 1.7734
Epoch [17/10000], Iter [6/49], d_loss: 0.8702, g_loss: 2.6593
Epoch [17/10000], Iter [7/49], d_loss: 1.1971, g_loss: 0.8283
Epoch [17/10000], Iter [8/49], d_loss: 1.1688, g_loss: 3.4786
Epoch [17/10000], Iter [9/49], d_loss: 0.9750, g_loss: 1.8067
Epoch [17/10000], Iter [10/49], d_loss: 0.9382, g_loss: 2.5852
Epoch [17/10000], Iter [11/49], d_loss: 0.7361, g_loss: 1.9282
Epoch [17/10000], Iter [12/49], d_loss: 0.6603, g_loss: 1.5424
Epoch [17/10000], Iter [13/49], d_loss: 0.8253, g_loss: 2.2674
Epoch [17/10000], Iter [14/49], d_loss: 1.0175, g_loss: 2.9737
Epoch [17/10000], Iter [15/49], d_loss: 1.4044, g_loss: 1.0936
Epoch [17/10000], Iter [16/49], d_loss: 1.2063, g_loss: 4.0890
Epoch [17/10000], Iter [17/49], d_loss: 1.2537, g_loss: 1.7438
Epoch [17/10000], Iter [18/49], d_loss: 0.8421, g_loss: 2.0212
Epoch [17/10000], Iter [19/49], d_loss: 0.6983, g_loss: 3.207

Epoch [19/10000], Iter [36/49], d_loss: 1.4621, g_loss: 1.2536
Epoch [19/10000], Iter [37/49], d_loss: 1.7650, g_loss: 4.0449
Epoch [19/10000], Iter [38/49], d_loss: 1.4299, g_loss: 1.0183
Epoch [19/10000], Iter [39/49], d_loss: 0.9831, g_loss: 2.0585
Epoch [19/10000], Iter [40/49], d_loss: 0.9375, g_loss: 2.4183
Epoch [19/10000], Iter [41/49], d_loss: 0.8523, g_loss: 1.7635
Epoch [19/10000], Iter [42/49], d_loss: 0.7988, g_loss: 1.7663
Epoch [19/10000], Iter [43/49], d_loss: 0.8673, g_loss: 2.0926
Epoch [19/10000], Iter [44/49], d_loss: 0.7801, g_loss: 2.1867
Epoch [19/10000], Iter [45/49], d_loss: 0.9295, g_loss: 2.7408
Epoch [19/10000], Iter [46/49], d_loss: 0.7501, g_loss: 2.1433
Epoch [19/10000], Iter [47/49], d_loss: 0.9449, g_loss: 2.5237
Epoch [19/10000], Iter [48/49], d_loss: 0.8836, g_loss: 1.2965
Model saving...
Epoch [20/10000], Iter [0/49], d_loss: 1.1367, g_loss: 2.5771
Epoch [20/10000], Iter [1/49], d_loss: 0.8778, g_loss: 1.5024
Epoch [20/10000], Iter [2/49], d_loss: 0.

Epoch [22/10000], Iter [19/49], d_loss: 0.8723, g_loss: 2.4062
Epoch [22/10000], Iter [20/49], d_loss: 0.6298, g_loss: 2.7374
Epoch [22/10000], Iter [21/49], d_loss: 0.8381, g_loss: 4.0151
Epoch [22/10000], Iter [22/49], d_loss: 1.1041, g_loss: 1.8091
Epoch [22/10000], Iter [23/49], d_loss: 0.9732, g_loss: 4.6450
Epoch [22/10000], Iter [24/49], d_loss: 1.2931, g_loss: 0.9877
Epoch [22/10000], Iter [25/49], d_loss: 1.1817, g_loss: 4.6087
Epoch [22/10000], Iter [26/49], d_loss: 0.4104, g_loss: 3.4659
Epoch [22/10000], Iter [27/49], d_loss: 0.5576, g_loss: 2.1555
Epoch [22/10000], Iter [28/49], d_loss: 0.8764, g_loss: 3.7297
Epoch [22/10000], Iter [29/49], d_loss: 0.8105, g_loss: 1.9215
Epoch [22/10000], Iter [30/49], d_loss: 0.7980, g_loss: 3.7947
Epoch [22/10000], Iter [31/49], d_loss: 0.7148, g_loss: 1.8758
Epoch [22/10000], Iter [32/49], d_loss: 1.1384, g_loss: 4.4537
Epoch [22/10000], Iter [33/49], d_loss: 1.4098, g_loss: 1.3738
Epoch [22/10000], Iter [34/49], d_loss: 1.6381, g_loss:

Epoch [25/10000], Iter [2/49], d_loss: 0.4056, g_loss: 3.3177
Epoch [25/10000], Iter [3/49], d_loss: 0.4546, g_loss: 2.4241
Epoch [25/10000], Iter [4/49], d_loss: 0.4145, g_loss: 3.3774
Epoch [25/10000], Iter [5/49], d_loss: 0.3252, g_loss: 3.0876
Epoch [25/10000], Iter [6/49], d_loss: 0.5077, g_loss: 3.0872
Epoch [25/10000], Iter [7/49], d_loss: 0.7547, g_loss: 4.5921
Epoch [25/10000], Iter [8/49], d_loss: 0.9210, g_loss: 2.1741
Epoch [25/10000], Iter [9/49], d_loss: 0.9241, g_loss: 5.8683
Epoch [25/10000], Iter [10/49], d_loss: 0.7707, g_loss: 3.4803
Epoch [25/10000], Iter [11/49], d_loss: 0.4057, g_loss: 4.4918
Epoch [25/10000], Iter [12/49], d_loss: 0.7997, g_loss: 3.4071
Epoch [25/10000], Iter [13/49], d_loss: 0.4238, g_loss: 2.6354
Epoch [25/10000], Iter [14/49], d_loss: 0.4864, g_loss: 3.8283
Epoch [25/10000], Iter [15/49], d_loss: 0.7289, g_loss: 1.4393
Epoch [25/10000], Iter [16/49], d_loss: 1.1731, g_loss: 6.6494
Epoch [25/10000], Iter [17/49], d_loss: 0.9457, g_loss: 4.4582


In [7]:
path_dir = './testvideo'

if not os.path.exists(path_dir) :
    os.mkdir(path_dir)

for classes in range(6) :        
    z = Variable(torch.randn(1, 100) * 0.01).type(dtype)
    
    label =  torch.zeros(1, 6)
    label.scatter_(1, torch.LongTensor([[classes]]), 1)
    label.transpose_(1, 0)    
    label = Variable(label).type(dtype)
    
    for i in range(32) :
        fake_video = torch.squeeze(G(z, label))[:,i,:,:]
        torchvision.utils.save_image(tensor=fake_video.data, filename=path_dir + '/test' + str(classes) + '_' + str(i+1) + '.png')
        
    make_gif(root=path_dir, output='output' + str(classes) + '.gif', fps=16)
    
    for i in range(32) :
        os.remove(path_dir + '/test' + str(classes) + "_" + str(i+1) + ".png")


[MoviePy] Building file output0.gif with imageio


 97%|█████████▋| 32/33 [00:00<00:00, 96.41it/s]



[MoviePy] Building file output1.gif with imageio


 97%|█████████▋| 32/33 [00:00<00:00, 118.24it/s]



[MoviePy] Building file output2.gif with imageio


 97%|█████████▋| 32/33 [00:00<00:00, 116.23it/s]



[MoviePy] Building file output3.gif with imageio


 97%|█████████▋| 32/33 [00:00<00:00, 115.26it/s]



[MoviePy] Building file output4.gif with imageio


 97%|█████████▋| 32/33 [00:00<00:00, 117.72it/s]



[MoviePy] Building file output5.gif with imageio


 97%|█████████▋| 32/33 [00:00<00:00, 117.66it/s]
