# **OracleGAN v2** - 
--------------
This study explores how to make **lighter, faster [OracleGAN](https://www.kaggle.com/lapl04/oraclegan-pix2pix-for-time-series-image)**.

MobileNet v2 is adopted to Discriminator.
Also, part of Generator's convolution layers are replaced with Depthwise Separable Convolution.
Convolution is added to Generator's skips.

# **Key Featues of OracleGAN**
-------------------------------------
- [Time Step Image Dataset](#time_step_image_dataset)
- [Cost Function of Generator](#cost_of_generator)
- [Cost Function of Discriminator](#cost_of_discriminator)

# **Key Featues of FastOracleGAN**
-------------------------------------
- [Depthwise Separable Convolution](#depthwise_separable_convolution)
- [MobileNet v2](#mobilenetv2)
- [Skip Convolution](#skip_convolution)

# **Key Featues of OracleGAN v2**
-------------------------------------
- 

# **Goal**
- Predict future weather images using current weather images.

# Install additional libraries

IQA_pytorch is a library which is used to calculate SSIM Score

In [None]:
!pip install IQA_pytorch #For SSIM Score

In [None]:
!pip install torchsummaryX

In [None]:
!pip install pytorch_msssim

# Import Libraries

In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision
from torch.optim import *
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn.functional as F
from IQA_pytorch import DISTS, utils
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
from torchsummaryX import summary

import math
import time
import numpy as np
from PIL import Image
import cv2
import numpy as np
import albumentations
import albumentations.pytorch
from matplotlib import pyplot as plt
import matplotlib.animation as animation
from matplotlib import font_manager, rc
from IPython import display
import random
import glob
import os
from os import listdir
from os.path import isfile, join
import warnings
import sys
from tqdm import tqdm
import pickle
import gc
import random
import urllib.request

warnings.filterwarnings("ignore")

print("Version of Torch : {0}".format(torch.__version__))
print("Version of TorchVision : {0}".format(torchvision.__version__))

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
%matplotlib inline

plt.rcParams['axes.unicode_minus'] = False
fontpath = "../input/koreanfont/NanumBrush.ttf"
fontprop = font_manager.FontProperties(fname=fontpath)

plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  
plt.ioff()

# Define Hyperparameters
|Name of Hyperparameter|Explanation|
|-----|-----|
|USE_CUDA|whether to use GPU|
|DEBUG|whether to print specific logs|
|RANDOM_SEED|random seed of pytorch, random, numpy|
|start_epoch|this is used to continuing train from checkpoint|
|all_epochs|Epochs|
|batch_size|Batch Size|
|lrG|the learning rate of Generator|
|lrD|the learning rate of Discriminator|
|beta1, beta2|the beta1 and beta2 of Generator and Discriminator|
|**L1Lambda**|lambda of pix2pix objective function|
|**GAMMA**|factor similar to discount factor of DQN. (0<$\gamma$<1) (check cost function of OracleGAN Generator)|
|**INPUT_NUM**|the number of input images. (If INPUT_NUM is 1, this code works like FastOracleGAN of single input version)|
|**TIME_STEP**|the number of future images which is used to calculate loss (check cost function of OracleGAN Generator)|



In [None]:
# Device
USE_CUDA = torch.cuda.is_available()

print("Device : {0}".format("GPU" if USE_CUDA else "CPU"))
device = torch.device("cuda" if USE_CUDA else "cpu")
cpu_device = torch.device("cpu")

DEBUG = False

RANDOM_SEED = 2004

# Train
only_d_train_step = 5

start_epoch_only_d = 0
all_epochs_only_d = 0
start_epoch = 0
all_epochs = 1
batch_size = 13

lrG = 0.0002
lrD = 0.0002
beta1 = 0.5
beta2 = 0.999

L1lambda = 100
START_GAMMA = 0
IDLE_GAMMA = 500
ITER_GAMMA = 0.0008
END_GAMMA = 0.8

INPUT_NUM = 4
TIME_STEP = 4
TEST_TIME_STEP = 6
IMAGE_SIZE = 128

patch = (1,256//2**4,256//2**4)

# Path
DATASET1_PATH = '../input/the-cloudcast-dataset'

# Checkpoint
USE_CHECKPOINT = False

OLD_PATH = '../input/fastoraclegan-multiple-input'
OLD_GENERATOR_MODEL = os.path.join(OLD_PATH, 'Generator.pth')
OLD_DISCRIMINATOR_MODEL = os.path.join(OLD_PATH, 'Discriminator.pth')
OLD_G_LOSS = os.path.join(OLD_PATH, 'gloss.txt')
OLD_D_LOSS = os.path.join(OLD_PATH, 'dloss.txt')

In [None]:
replay_memory = []
GAMMA = START_GAMMA

In [None]:
def gamma_updater(now_gamma, now_iter, fixed_gamma=None):
    if fixed_gamma != None:
        return fixed_gamma
    
    gamma = now_gamma    
    if now_iter >= IDLE_GAMMA:
        gamma += ITER_GAMMA
        if gamma >= END_GAMMA:
            gamma = END_GAMMA
        
    return gamma

In [None]:
def random_seed():
    torch.manual_seed(RANDOM_SEED)
    torch.cuda.manual_seed(RANDOM_SEED)
    torch.cuda.manual_seed_all(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(RANDOM_SEED)
    random.seed(RANDOM_SEED)

    print('Random Seed : {0}'.format(RANDOM_SEED))
    
random_seed()

In [None]:
def log(text):
    global DEBUG
    if DEBUG:
        print(text)

# Visualize Data
|Name of Function|Explanation|
|-----|-----|
|torch_tensor_to_plt|Convert torch image to matplotlib image|
|plt_image_animation|show a video by update_function|

In [None]:
def torch_tensor_to_plt(img):
    img = img.detach().numpy()[0]
    img = np.transpose(img, (1, 2, 0))
    return img 

In [None]:
def show_video_in_jupyter_nb(width, height, video_url):
    from IPython.display import HTML
    return HTML("""<video width="{}" height="{}" controls>
    <source src={} type="video/mp4">
    </video>""".format(width, height, video_url))

In [None]:
def plt_image_animation(frames, update_func):
    fig, ax = plt.subplots(figsize=(4,4))
    plt.axis('off')
    anim = animation.FuncAnimation(fig, update_func, frames=frames)
    video = anim.to_html5_video()
    html = display.HTML(video)
    display.display(html)
    plt.close()

In [None]:
plt_image_animation(15, lambda t : plt.imshow(np.load(join(DATASET1_PATH, '2017M01', '{0}.npy'.format(t))), cmap='gray'))

# Preprocess Dataset

In [None]:
transformer = transforms.Compose([transforms.ToTensor(),
                                  torchvision.transforms.Resize(IMAGE_SIZE),
                                  transforms.Normalize((0.5), (0.5)), #GrayScale
                                 ])


<a id="time_step_image_dataset"></a>
## **Time Step Image Dataset**

------------------------------------------------

OracleGAN calculates loss between predicted image and real image not only after 15 minutes but also **after 15×TimeStep minutes**.
 
So, dataset need to have **multiple output** images.

In [None]:
nowpath = ""

class TimeStepImageDataset(Dataset):
    def __init__(self, date, input_num, time_step, transform=None):
        self.date = date
        self.input_num = input_num
        self.time_step = time_step
        self.transformer = transform
        self.file = []
        
        file_list = glob.glob(join(self.date, '*'))
        self.file = [file for file in file_list if (file.endswith(".npy") and not file.endswith('TIMESTAMPS.npy'))]
        
    def __len__(self):
        return len(self.file)-self.time_step-self.input_num
    
    def transform(self, image):
        if self.transformer:
            return self.transformer(image)
        else :
            return image

    def __getitem__(self, idx):
        global nowpath
        
        log(join(self.date, str(idx)+'.npy'))
        X_list = []
        for i in range(0, self.input_num):
            X_list.append(self.transform(np.load(join(self.date, str(idx+i)+'.npy'))).unsqueeze(0))
        X = torch.cat(X_list)
        nowpath = join(self.date, str(idx)+'.npy')

        Y_list = []
        for i in range(self.input_num, self.input_num+self.time_step):
            Y_list.append(self.transform(np.load(join(self.date, str(idx+i)+'.npy'))).unsqueeze(0))
        Y = torch.cat(Y_list)    

        return X, Y

In [None]:
DATASET1_DIRS = glob.glob(join(DATASET1_PATH, '*'))

random.shuffle(DATASET1_DIRS)

traindatasetlist = []
for ind, name in enumerate(DATASET1_DIRS[:20]):
    traindatasetlist.append(TimeStepImageDataset(name, INPUT_NUM, TIME_STEP, transform=transformer))
train_dataset = torch.utils.data.ConcatDataset(traindatasetlist)

testdatasetlist = []
for ind, name in enumerate(DATASET1_DIRS[20:]):
    testdatasetlist.append(TimeStepImageDataset(name, INPUT_NUM, TEST_TIME_STEP, transform=transformer))
test_dataset = torch.utils.data.ConcatDataset(testdatasetlist)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

test_dataloader_bs1_shuffle = DataLoader(test_dataset, batch_size=1, shuffle=True) 
test_dataloader_bs1_noshuffle = DataLoader(test_dataset, batch_size=1, shuffle=False) 

In [None]:
def ShowDatasetImage(x, y):
    grid = torchvision.utils.make_grid(y)
    
    fig = plt.figure(figsize=(8, 2.5))
    #plt.imshow(torch_tensor_to_plt(x.unsqueeze(0)), cmap='gray')
    plt.axis('off')
    plt.title('Input', fontproperties=fontprop)
    for i in range(1, INPUT_NUM+1):
        ax = fig.add_subplot(1, INPUT_NUM, i)
        ax.axis('off')
        ax.imshow(torch_tensor_to_plt(x[i-1].unsqueeze(0)), cmap='gray')
        #ax.set_title('after {0} minutes'.format(15*i), fontproperties=fontprop)
    plt.show()   

    
    fig = plt.figure(figsize=(8, 2.5))
    plt.title('Real Weather Image', fontproperties=fontprop)
    plt.axis('off')
    for i in range(1, TIME_STEP+1):
        ax = fig.add_subplot(1, TIME_STEP, i)
        ax.axis('off')
        ax.imshow(torch_tensor_to_plt(y[i-1].unsqueeze(0)), cmap='gray')
        ax.set_title('after {0} minutes'.format(15*i), fontproperties=fontprop)
    plt.show()

    del x, y

In [None]:
for ind, (x, y) in enumerate(train_dataset):
    if ind != 0:
        continue
    ShowDatasetImage(x, y)
    break

# Define Neural Networks and Optimizers
|Name|Sort|
|----|----|
|Generator|UNet|
|Discriminator|ResNet|
|Optimizer of Generator|Adam|
|Optimizer of Disciminator|Adam|

<a id="depthwise_separable_convolution"></a>
## Depthwise Separable Convolution
------------------------
Generator's second~sixth convolution layers are replaced with Depthwise Separable Convolution.

In [None]:
class depthwise_separable_conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=in_channels, bias=False)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

In [None]:
class UNetDown(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0, dsconv=True):
        super().__init__()
    
        if dsconv:
            layers = [depthwise_separable_conv(in_channels, out_channels, 4, stride=2, padding=1, bias=False)]
        else :
            layers = [nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1, bias=False)]
            
        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels)),

        layers.append(nn.LeakyReLU(0.2))

        if dropout:
            layers.append(nn.Dropout(dropout))

        self.down = nn.Sequential(*layers)

    def forward(self, x):
        x = self.down(x)
        return x

<a id="skip_convolution"></a>
## Skip Convolution
------------------------
**1x1 Convolutions are added to Generator's Skip.**

At first, I tried adding **Attention Blocks**.
But, they are too **heavy** for the purpose of FastOracleGAN. Also, I thought the purpose to Generator is different to the original purpose of UNet.
The original purpose of UNet is Semantic Segmentation. Preventing to disturb input is important.
However, the purpose of Generator is converting image. I estimated skip which adds original image to decoder layers can **interrupt** rather achieving the original purpose.
**So I thought it was necessary to set Generator to be more interested in Up-sampling than Skip, instead of leaving it to artificial intelligence learning about which things to be interested in, such as Attention Blocks.** Therefore, I judged most of Generator's Attention Blocks can be substituted with just manipulating Skip with skip convolution.

In [None]:
class UNetUp(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0, use_skip_conv=True):
        super().__init__()

        layers = [
            nn.ConvTranspose2d(in_channels, out_channels,4,2,1,bias=False),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU()
        ]

        if dropout:
            layers.append(nn.Dropout(dropout))

        self.up = nn.Sequential(*layers)
        self.use_skip_conv = use_skip_conv
        
        if use_skip_conv:
            self.skip_conv = nn.Conv2d(out_channels, out_channels, 1, bias=False) #Skip Convolution

    def forward(self,x,skip):
        x = self.up(x)
        if self.use_skip_conv:
            skip = self.skip_conv(skip)
        x = torch.cat((x, skip),1)
        return x

In [None]:
class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()

        self.down1 = UNetDown(in_channels, 64, normalize=False, dsconv=False)
        self.down2 = UNetDown(64,128)                 
        self.down3 = UNetDown(128,256)               
        self.down4 = UNetDown(256,512,dropout=0.5) 
        self.down5 = UNetDown(512,512,dropout=0.5)      
        self.down6 = UNetDown(512,512,dropout=0.5)             
        #self.down7 = UNetDown(512,512,dropout=0.5)              
        self.down8 = UNetDown(512,512,normalize=False,dropout=0.5, dsconv=False)

        #self.up1 = UNetUp(512,512,dropout=0.5)
        self.up2 = UNetUp(1024//2,512,dropout=0.5)
        self.up3 = UNetUp(1024,512,dropout=0.5)
        self.up4 = UNetUp(1024,512,dropout=0.5)
        self.up5 = UNetUp(1024,256)
        self.up6 = UNetUp(512,128)
        self.up7 = UNetUp(256,64)
        self.up8 = nn.Sequential(
            nn.ConvTranspose2d(128,out_channels,4,stride=2,padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down8(d6)
        u1 = d7
        u2 = self.up2(u1,d6)
        u3 = self.up3(u2,d5)
        u4 = self.up4(u3,d4)
        u5 = self.up5(u4,d3)
        u6 = self.up6(u5,d2)
        u7 = self.up7(u6,d1)
        u8 = self.up8(u7)
        
        return u8

<a id="mobilenetv2"></a>
## MobileNet v2
------------------------
MobileNet v2 is adopted to Discriminator for better, lighter, and faster.

In [None]:
def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


def conv_3x3_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        assert stride in [1, 2]

        hidden_dim = round(inp * expand_ratio)
        self.identity = stride == 1 and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.identity:
            return x + self.conv(x)
        else:
            return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, num_classes, patch=True, width_mult=1.):
        super(Discriminator, self).__init__()
        # setting of inverted residual blocks        
        self.patch = patch
        if self.patch:
            self.cfgs = [
                # t, c, n, s
                [1,  16, 1, 1],
                [6,  24, 2, 2],
                [6,  32, 3, 2],
            ]
        elif self.patch:
            self.cfgs = [
                # t, c, n, s
                [1,  16, 1, 1],
                [6,  24, 2, 2],
                [6,  32, 3, 2],
                [6,  64, 4, 2],
                [6,  96, 3, 1],
                [6, 160, 3, 2],
                [6, 320, 1, 1],
            ]

        # building first layer
        input_channel = _make_divisible(32 * width_mult, 4 if width_mult == 0.1 else 8)
        layers = [conv_3x3_bn(INPUT_NUM+1, input_channel, 2)]
        # building inverted residual blocks
        block = InvertedResidual
        for t, c, n, s in self.cfgs:
            output_channel = _make_divisible(c * width_mult, 4 if width_mult == 0.1 else 8)
            for i in range(n):
                layers.append(block(input_channel, output_channel, s if i == 0 else 1, t))
                input_channel = output_channel
        self.features = nn.Sequential(*layers)
        # building last several layers
        output_channel = _make_divisible(1280 * width_mult, 4 if width_mult == 0.1 else 8) if width_mult > 1.0 else 1280
        self.conv = conv_1x1_bn(input_channel, output_channel)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.d_classifier = nn.Linear(output_channel, 1)
        self.frame_classifier = nn.Linear(output_channel, num_classes)
        self.sigmoid = nn.Sigmoid()
        self._initialize_weights()

    def forward(self, x):
        conv_feature = self.features(x)
        conv_feature = self.conv(conv_feature)
        
        x = self.avgpool(conv_feature)
        feature = x.view(x.size(0), -1)
        dresult = self.d_classifier(feature) #real or fate
        frame = self.frame_classifier(feature) #What time step?
        #x = self.sigmoid(x)      
             
        if self.patch:
            return (conv_feature, frame)
        else:
            return (dresult, frame)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

# Initiate Weights and Biases

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if type(m) == nn.Conv2d:
        m.weight.data.normal_(0.0, 0.02)
    elif type(m) == nn.BatchNorm2d:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [None]:
Generator = GeneratorUNet(in_channels=INPUT_NUM).to(device)
Discriminator = Discriminator(num_classes=TIME_STEP).to(device) 

summary_g = Generator.apply(weights_init)
summary_d = Discriminator.apply(weights_init)

In [None]:
summary(Generator, torch.rand((batch_size, INPUT_NUM, IMAGE_SIZE, IMAGE_SIZE)).to(device))

In [None]:
summary(Discriminator, torch.rand((batch_size, INPUT_NUM+1, IMAGE_SIZE, IMAGE_SIZE)).to(device))

In [None]:
optimizerG = Adam(Generator.parameters(), lr=lrG, betas=(beta1, beta2))
optimizerD = Adam(Discriminator.parameters(), lr=lrD, betas=(beta1, beta2))

In [None]:
img_list = []
G_loss = []
D_loss = []

FAKE_LABEL = 0.0
REAL_LABEL = 1.0

# Define Loss Functions

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, inputs, targets):    
        ce_loss = nn.CrossEntropyLoss(reduction='none')(inputs, targets)

        pt = torch.exp(-ce_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * ce_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

In [None]:
l1loss = nn.L1Loss()
l2loss = nn.MSELoss()
smoothl1loss = nn.SmoothL1Loss()

bceloss = nn.BCEWithLogitsLoss() #nn.BCELoss()
celoss = nn.CrossEntropyLoss()
focalloss = FocalLoss()

<a id="cost_of_generator"></a>
## **Cost Function of Generator**

------------------------------------------------

**$$ \mathbf{Loss_G(x, y) = \sum_{i=1}^{t}\gamma ^ {i-1}\times \left \{ \lambda _1 \times  E_{x,y_i}\left [ \left \| y_i - G^i(x) \right \|_1 \right ] + E_{x}\left [ log(1-D(G^i(x))) \right ] \right \} } $$**

$t$ is Time Step. $\gamma$ is discount factor(GAMMA). $\lambda _1$ is L1Lambda.

In [None]:
def generator_error(netG, netD, sketch, real, real_label, fake_label, gamma=0.0, b_size=batch_size):
    def G_error(iter_input, G_output, real, D_output, D_target):
        log(iter_input.shape)
        log(G_output.shape)
        log(real.shape)
        return l1loss(G_output, real)*L1lambda + bceloss(D_output[0].view(-1), torch.tensor(REAL_LABEL).expand_as(D_output[0].view(-1)).to(device))
    
    next_input = sketch
    error = None
    
    real_list = []
    for i in range(TIME_STEP):
        real_list.append(real[:,i,:,:,:])
    
    for ind, y in enumerate(real_list):
        iter_input = next_input[:, -1, :, :].unsqueeze(1).clone().detach()
        G_output = netG(next_input)
        next_input = G_output.clone().detach()
        if INPUT_NUM > 1:
            next_input = torch.cat((sketch[:, 1:, :, :].clone().detach(), next_input), dim=1)
        else :
            next_input = G_output.clone().detach()
        D_output = netD(torch.cat([sketch, G_output], dim=1))
        
        class_label = torch.full((b_size,), ind, dtype=torch.long, device=device)
        
        if ind==0:
            error = G_error(iter_input, G_output, y, D_output, class_label)
        else :
            error += (gamma ** ind) * G_error(iter_input, G_output, y, D_output, class_label)
            
        del G_output, D_output
        gc.collect()
        torch.cuda.empty_cache()
            
    return error

<a id="cost_of_discriminator"></a>
## **Cost Function of Discriminator**

------------------------------------------------

**$$ \mathbf{Loss_D(x, y) = E_x\left [ log D(G(x)) \right ] + \frac{1}{t} \sum_{i=1}^{t} E_{y_i}\left [ log(1-D(G(y_i)))) \right ]} $$**

$t$ is Time Step.

In [None]:
def discriminator_error_only_d(netD, sketch, real, real_label, fake_label, b_size=batch_size, avg=True): 
    errD = 0.0    
    for i in range(0, TIME_STEP):
        outputs_real = netD(torch.cat([sketch, real[:,i,:,:,:]], dim=1))
        
        class_label = torch.full((b_size,), i, dtype=torch.long, device=device)
        if avg:
            errD += (l2loss(outputs_real[0].view(-1), torch.tensor(REAL_LABEL).expand_as(outputs_real[0].view(-1))).to(device))/TIME_STEP
        else:
            errD += focalloss(outputs_real[1], class_label) + bceloss(outputs_real[0].view(-1), real_label)
        del outputs_real
        gc.collect()
        torch.cuda.empty_cache()
        
    return errD

In [None]:
def discriminator_error_in_gan(netG, netD, sketch, real, real_label, fake_label, b_size=batch_size, avg=True):
    output_g = netG(sketch)
    outputs_fake = netD(torch.cat([sketch, output_g.detach()], dim=1))    
    log(outputs_fake[0].shape)
    errD = bceloss(outputs_fake[0], torch.tensor(FAKE_LABEL).expand_as(outputs_fake[0]).to(device))
    
    del output_g, outputs_fake
    gc.collect()
    torch.cuda.empty_cache()
    
    for i in range(0, TIME_STEP):
        outputs_real = netD(torch.cat([sketch, real[:,i,:,:,:]], dim=1))
        
        
        class_label = torch.full((b_size,), i, dtype=torch.long, device=device)
        if avg:
            errD += l2loss(outputs_real[0].view(-1), torch.tensor(REAL_LABEL).expand_as(outputs_real[0].view(-1)).to(device))/TIME_STEP
        else:
            errD += focalloss(outputs_real[1], class_label) + bceloss(outputs_real[0].view(-1), real_label)
        del outputs_real
        gc.collect()
        torch.cuda.empty_cache()
        
    return errD

# Apply Checkpoint

In [None]:
def apply_checkpoint(use_checkpoint=True):
    global Generator, Discriminator, optimizerG, optimizerD, G_Loss, D_Loss, start_epoch, all_epochs_only_d
    
    if os.path.isdir(OLD_PATH) and use_checkpoint:        
        checkpoint = torch.load(OLD_GENERATOR_MODEL)
        start_epoch = checkpoint['epoch']
        Generator.load_state_dict(checkpoint['model_state_dict'])
        optimizerG.load_state_dict(checkpoint['optimizer_state_dict'])
        
        checkpoint = torch.load(OLD_DISCRIMINATOR_MODEL)
        start_epoch = checkpoint['epoch']
        Discriminator.load_state_dict(checkpoint['model_state_dict'])
        optimizerD.load_state_dict(checkpoint['optimizer_state_dict'])
        
        with open(OLD_G_LOSS, 'rb') as f:
            G_loss = pickle.load(f)
            
        with open(OLD_D_LOSS, 'rb') as f:
            D_loss = pickle.load(f)
            
        all_epochs_only_d = 0
        
        print('Continue training. (Epoch : {0})'.format(start_epoch))
    else :
        print('Begin training newly.')

# Define Train Function

In [None]:
nowepoch_only_d = 0
nowepoch = 0

In [None]:
# Only Discriminator Train

def fit_only_d(device, num_epochs_only_d=1):
    global nowepoch_only_d
    iters = 0
    for epoch in range(start_epoch_only_d+1, num_epochs_only_d+start_epoch_only_d+1):
        nowepoch_only_d = epoch
        print("< EPOCH{0} >".format(epoch))
        result = train_one_epoch_only_d(device, train_dataloader, Discriminator, optimizerD, epoch, num_epochs_only_d)
        if not result:
            return
        
def train_one_epoch_only_d(device, dataloader, netD, optimizerD, epoch, num_epochs, iters=0):
    global nowpath, strange_error_num, strange_error_limit
    with torch.autograd.set_detect_anomaly(True):
        for i, data in enumerate(dataloader):   
            if i%only_d_train_step != 0 :
                continue
            start = time.time()
            sketch, real = data
            sketch, real = sketch.to(device), real.to(device)
            
            sketch_list = []
            for ind in range(0, INPUT_NUM):
                sketch_list.append(sketch[:, ind, :, :, :])
            sketch = torch.cat(sketch_list, dim=1)
            
            b_size = sketch.size(0)
            real_label = torch.full((b_size,), REAL_LABEL, dtype=torch.float, device=device)
            fake_label = torch.full((b_size,), FAKE_LABEL, dtype=torch.float, device=device)

            netD.train()
            netD.zero_grad()
            
            errD = discriminator_error_only_d(netD, sketch, real, real_label, fake_label, b_size=b_size)
            
            log('Complete calcuating of Discriminator')
            errD.backward()
            log('Complete backprogration of Discriminator')
            optimizerD.step()
            log('Complete stepping OptimizerD')

            
            del b_size, real_label, fake_label, sketch, real
            gc.collect()
            torch.cuda.empty_cache()

            #Log
            if i % 1 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tTime: %.6f'
                      % (epoch, num_epochs, i, len(dataloader),
                         errD.item(), time.time() - start))
                
            D_loss.append(errD.item())
            
            del errD
            gc.collect()
            torch.cuda.empty_cache()

            iters += 1
    return True

In [None]:
# GAN Train
def fit_gan(device, num_epochs=1):
    global nowepoch
    gan_train_iters = 0
    for epoch in range(start_epoch+1, num_epochs+start_epoch+1):
        nowepoch = epoch
        print("< EPOCH{0} >".format(epoch))
        gan_train_iters = train_one_epoch_gan(device, train_dataloader, Generator, Discriminator, optimizerG, optimizerD, epoch, num_epochs, iters=gan_train_iters)      
    

def train_one_epoch_gan(device, dataloader, netG, netD, optimizerG, optimizerD, epoch, num_epochs, iters=0):
    global nowpath, strange_error_num, strange_error_limit, replay_memory, GAMMA
    with torch.autograd.set_detect_anomaly(True):
        for i, data in enumerate(dataloader):   
            start = time.time()
            sketch, real = data
            
            sketch, real = sketch.to(device), real.to(device)
            sketch_list = []
            for ind in range(0, INPUT_NUM):
                sketch_list.append(sketch[:, ind, :, :, :])
            sketch = torch.cat(sketch_list, dim=1)           
            
            
            b_size = sketch.size(0)
            real_label = torch.full((b_size,), REAL_LABEL, dtype=torch.float, device=device)
            fake_label = torch.full((b_size,), FAKE_LABEL, dtype=torch.float, device=device)
            
            #Train Discriminator
            netG.eval()
            netD.train()
            netD.zero_grad()
            
            errD = discriminator_error_in_gan(netG, netD, sketch, real, real_label, fake_label, b_size=b_size)
            
            log('Complete calcuating of Discriminator')
            errD.backward()
            log('Complete backprogration of Discriminator')
            optimizerD.step()
            log('Complete stepping OptimizerD')
        
            #Train Generator
            netG.train()
            netD.eval()
            netG.zero_grad()
            
            errG = generator_error(netG, netD, sketch, real, real_label, fake_label, gamma=GAMMA, b_size=b_size)
            GAMMA = gamma_updater(GAMMA, iters)
 
            log('Complete calcuating of Generator')
            errG.backward()
            log('Complete backprogration of Genereator')
            optimizerG.step()
            log('Complete stepping OptimizerG')
            
            del b_size, real_label, fake_label, sketch, real
            gc.collect()
            torch.cuda.empty_cache()

            #Log
            if i % 1 == 0:
                print('[%d/%d][%d/%d]    Loss_G: %.4f  Loss_D: %.4f  Gamma: %f  Time: %.6f'
                      % (epoch, num_epochs, i, len(dataloader),
                         errG.item(), errD.item(), GAMMA, time.time() - start))
                
            

            G_loss.append(errG.item())
            D_loss.append(errD.item())
            
            del errG, errD
            gc.collect()
            torch.cuda.empty_cache()

            iters += 1
    return iters

# Train

In [None]:
def train_main():
    global replay_memory, GAMMA
    apply_checkpoint(use_checkpoint=USE_CHECKPOINT)
    replay_memory = []
    GAMMA = START_GAMMA
    
    # Only D
    summary = Discriminator.train()
    
    if all_epochs_only_d>0:
        print('-'*20)
        print('Train Only D')
        fit_only_d(device, num_epochs_only_d=all_epochs_only_d)
        print('-'*20)

    summary = Discriminator.eval()
    
    # G and D
    summary = Generator.train()
    summary = Discriminator.train()

    if all_epochs>0:
        print('-'*20)
        print('Train G and D')
        fit_gan(device, num_epochs=all_epochs)
        print('-'*20)

    summary = Generator.eval()
    summary = Discriminator.eval()

In [None]:
train_main()

# Test

1. Calculate SSIM Score each Time Steps
2. Generate test predicted images.
3. Generate video which consist of series predicted images.

In [None]:
plt.figure(figsize=(10,5))
plt.title('Loss of Generator')
plt.plot(G_loss,label="")
plt.xlabel("Iter")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(10,5))
plt.title('Loss of Discriminator')
plt.plot(D_loss,label="train")
plt.xlabel("Iter")
plt.legend()
plt.show()

In [None]:
def model_predict(model, time, input):
    if time%15==0 and time!=0:
        model.eval()
        num = time//15
        
        final_answer = None
        next_input = input
        for i in range(num):
            final_answer = model(next_input).clone().detach()
            if INPUT_NUM > 1:
                next_input = torch.cat((next_input[:, 1:, :, :], final_answer), dim=1)
            else:
                next_input = model(next_input).clone().detach()
        return final_answer
    else:
        raise ValueError('Please set the time to a multiple of 15.')

In [None]:
from IQA_pytorch import SSIM, utils

toPILImage = transforms.ToPILImage()
ssim_model = SSIM(channels=1)

def one_time_step_ssim_score(dataloader, model, time_step, num=-1):
    model.eval()
    score = 0
    total = 0
    for ind, (x, y) in enumerate(test_dataloader_bs1_shuffle):
        x, y = x.squeeze(0).to(device), y.squeeze(0).to(device)
        x_list = []
        for i in range(0, INPUT_NUM):
            x_list.append(x[i, :, :, :])
        x = torch.cat(x_list, dim=0).to(device)
        outputG = model_predict(model, time_step*15, x.unsqueeze(0))

        sketch = utils.prepare_image(toPILImage(outputG.squeeze(0))).to(device)
        real = utils.prepare_image(toPILImage(y[time_step-1])).to(device)

        score += ssim_model(sketch, real, as_loss=False).item()
        total += 1

        del x, y, outputG, sketch, real
        gc.collect()
        torch.cuda.empty_cache()
        
        if num != -1:
            if ind+1 >= num:
                break
            
    print("SSIM Score of the prediction {0} minutes later : {1}".format(time_step*15, score/total))
    return score/total

for ind in range(1, TEST_TIME_STEP+1):
    one_time_step_ssim_score(test_dataloader_bs1_shuffle, Generator, ind, num=2000)

> **< SSIM Score of OracleGAN >**  *(check "[OracleGAN - Pix2Pix for Time Series Image](https://www.kaggle.com/lapl04/oraclegan-pix2pix-for-time-series-image)")*
>
> |Prediction|SSIM Score|
> |-------------------|----------------|
> |prediction 15 minutes later|0.8025623009409756|
> |prediction 30 minutes later|0.7757316670715809|
> |prediction 45 minutes later|0.7566662636697292|
> |prediction 60 minutes later|0.7422156925499439|
> |prediction 75 minutes later|0.7312239380329847|
> |prediction 90 minutes later|0.720590250596404|

> **< SSIM Score of normal Pix2Pix >**  *(check "[Pix2Pix (Compared to OracleNet)](https://www.kaggle.com/lapl04/pix2pix-compared-to-oraclenet)")*
>
> |Prediction|SSIM Score|
> |-------------------|----------------|
> |prediction 15 minutes later|0.8199273004531861|
> |prediction 30 minutes later|0.7006081487536431|
> |prediction 45 minutes later|0.6030721757411956|
> |prediction 60 minutes later|0.5150686911344529|

In [None]:
import zipfile

y_nums = 40
iter = 0

ai_noseries_ls = []
real_noseries_ls = []

start_ind = 200

for ind, (x, y) in enumerate(test_dataloader_bs1_shuffle):
    if ind < start_ind:
        continue
        
    iter += 1
    
    x, y = x.to(device), y[0].to(device)
    
    x_list = []
    for i in range(0, INPUT_NUM):
        x_list.append(x[:, i, :, :, :])
    x = torch.cat(x_list, dim=1).to(device)
        
    outputg = Generator(x).to(cpu_device)
    
    outputg = outputg*127.5+127.5
    realimage = y*127.5+127.5

    cv2.imwrite('./AI_NOSERIES_Answer{0}.png'.format(ind+1), torch_tensor_to_plt(outputg)*30)
    cv2.imwrite('./Real_NOSERIES{0}.png'.format(ind+1), torch_tensor_to_plt(realimage.to(cpu_device))*30)
    
    ai_noseries_ls.append('./AI_NOSERIES_Answer{0}.png'.format(ind+1))
    real_noseries_ls.append('./Real_NOSERIES{0}.png'.format(ind+1))
    
    if iter > y_nums:
        break

with zipfile.ZipFile("ai_noseries.zip", 'w') as my_zip:
    for i in ai_noseries_ls:
        my_zip.write(i)
    my_zip.close()


with zipfile.ZipFile("real_noseries.zip", 'w') as my_zip:
    for i in real_noseries_ls:
        my_zip.write(i)
    my_zip.close()
    
for file in (ai_noseries_ls + real_noseries_ls):
    os.remove(file)
    
print('NOSERIES Images are generated.')

In [None]:
import zipfile

y_nums = 40
iter = 0

ai_series_ls = []
real_series_ls = []

next_input = None
start_ind = 200

for ind, (x, y) in enumerate(test_dataloader_bs1_noshuffle):
    if ind < start_ind:
        continue
        
    iter += 1
    
    x, y = x.to(device), y[0].to(device)
    
    x_list = []
    for i in range(0, INPUT_NUM):
        x_list.append(x[:, i, :, :, :])
    x = torch.cat(x_list, dim=1).to(device)
    
    if ind == start_ind:
        next_input = x.clone().detach().to(device)
        #cv2.imwrite('./Input_SERIES.png', torch_tensor_to_plt(next_input[].to(cpu_device)*127.5+127.5)*30)
    
    outputg_series = Generator(next_input).to(cpu_device)
    if INPUT_NUM > 1:
        next_input = torch.cat((next_input[:, 1:, :, :].clone().detach(), outputg_series.clone().detach().to(device)), dim=1).to(device)
    else:
        next_input = outputg_series.clone().detach().to(device)
        
    outputg_series = outputg_series * 127.5 + 127.5
    realimage = y*127.5+127.5

    cv2.imwrite('./AI_SERIES_Answer{0}.png'.format(ind+1), torch_tensor_to_plt(outputg_series)*30)
    cv2.imwrite('./Real_SERIES{0}.png'.format(ind+1), torch_tensor_to_plt(realimage.to(cpu_device))*30)
    
    ai_series_ls.append('./AI_SERIES_Answer{0}.png'.format(ind+1))
    real_series_ls.append('./Real_SERIES{0}.png'.format(ind+1))
    
    if iter > y_nums:
        break

with zipfile.ZipFile("ai_series.zip", 'w') as my_zip:
    for i in ai_series_ls:
        my_zip.write(i)
    my_zip.close()

with zipfile.ZipFile("real_series.zip", 'w') as my_zip:
    for i in real_series_ls:
        my_zip.write(i)
    my_zip.close()
    

    
print('SERIES Images are generated')

In [None]:
v1 = cv2.VideoWriter('oraclegan_series.mp4',cv2.VideoWriter_fourcc(*'DIVX'), 3, (128, 128))
for name in ai_series_ls:
    img = cv2.imread(name)
    v1.write(img)
v1.release()

v2 = cv2.VideoWriter('real_series.mp4',cv2.VideoWriter_fourcc(*'DIVX'), 3, (128, 128))
for name in real_series_ls:
    img = cv2.imread(name)
    v2.write(img)
v2.release()

print('Videos are generated')
print('video path : "./oraclegan_series.mp4" and "./real_series.mp4"')

In [None]:
for file in (ai_series_ls + real_series_ls):
    os.remove(file)

# Save Checkpoint

In [None]:
torch.save({
            'epoch': nowepoch,
            'epoch_only_d': nowepoch_only_d,
            'model_state_dict': Generator.state_dict(),
            'optimizer_state_dict': optimizerG.state_dict(),
            }, 'Generator.pth')

torch.save({
            'epoch': nowepoch,
            'epoch_only_d': nowepoch_only_d,
            'model_state_dict': Discriminator.state_dict(),
            'optimizer_state_dict': optimizerD.state_dict(),
            }, 'Discriminator.pth')

In [None]:
with open('./gloss.txt', 'wb') as f:
    pickle.dump(G_loss, f)
with open('./dloss.txt', 'wb') as f:
    pickle.dump(D_loss, f)

# **Conclusion**
------------------
**the time required of 1 iter training with FastOracleGAN(about 7.5s) is reduced more twice times than its OracleGAN(about 3.36s).**