In [1]:
%matplotlib inline
import os
import numpy as np
import torch
from torch import nn
import my_nntools as nt
from torch.nn import functional as F
import torch.utils.data as td
import torchvision as tv
import pandas as pd
from PIL import Image
from matplotlib import pyplot as plt

import itertools

import models as models
from utils import *

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
dataset_root_dir="./datasets/apple2orange/"

In [4]:
# Dataset loader
train_set = ImageDataset(dataset_root_dir, image_size=256, unaligned=True, mode='train')
val_set = ImageDataset(dataset_root_dir, image_size=256, unaligned=True, mode='val')
test_set = ImageDataset(dataset_root_dir, image_size=256, unaligned=True, mode='test')

In [5]:
class NNClassifier(nt.NeuralNetwork):
    def __init__(self):
        super(NNClassifier, self).__init__()
        self.cross_entropy = nn.CrossEntropyLoss()
        self.MSE_Loss = nn.MSELoss()
        self.L1_Loss = nn.L1Loss()
    
    def criterion(self, y, d):
        return self.cross_entropy(y, d)    
    def criterion_GAN(self, y, d):
        return self.MSE_Loss(y, d)
    def criterion_cycle(self, y, d):
        return self.L1_Loss(y, d)*5
    def criterion_identity(self, y, d):
        return self.L1_Loss(y, d)*10
    
class C_GAN(NNClassifier):
    def __init__(self, fine_tuning=True):
        super(C_GAN, self).__init__()
        
        self.G_A2B = models.Generator(3, 3)
        self.G_B2A = models.Generator(3, 3)
        self.D_A = models.Discriminator(3)
        self.D_B = models.Discriminator(3)
        
        self.G_A2B.apply(init_parameters)
        self.G_A2B.apply(init_parameters)
        self.D_A.apply(init_parameters)
        self.D_B.apply(init_parameters)
    
        self.fake_a_buffer = ReplayBuffer()
        self.fake_b_buffer = ReplayBuffer()
    
    def forward(self, real_a, real_b):        
        fake_b = 0.5*(self.G_A2B(real_a) + 1.0)
        fake_a = 0.5*(self.G_B2A(real_b) + 1.0)
        
        return fake_a,fake_b

In [6]:
class ClassificationStatsManager(nt.StatsManager):
    def __init__(self):
        super(ClassificationStatsManager, self).__init__()
        
    def init(self):
        super(ClassificationStatsManager, self).init()
        self.running_loss_G = 0
        self.running_loss_D_A = 0
        self.running_loss_D_B = 0
    
    def accumulate(self, loss_G, loss_D_A, loss_D_B):
        super(ClassificationStatsManager, self).accumulate(loss_G, loss_D_A, loss_D_B)
        self.running_loss_G += loss_G
        self.running_loss_D_A += loss_D_A
        self.running_loss_D_B += loss_D_B

    def summarize(self):
        """Compute statistics based on accumulated ones"""
        loss_G = self.running_loss_G / self.number_update 
        loss_D_A = self.running_loss_D_A / self.number_update 
        loss_D_B = self.running_loss_D_B / self.number_update
        return { 'G loss' : loss_G, 'D_A loss' : loss_D_A, 'D_B loss' : loss_D_B}

In [8]:
lr = 1e-3
net = C_GAN()
net = net.to(device)

optimizer_G = torch.optim.Adam(itertools.chain(net.G_A2B.parameters(), 
                                               net.G_B2A.parameters()), lr=lr)
optimizer_D_A = torch.optim.Adam(net.D_A.parameters(), lr=lr)
optimizer_D_B = torch.optim.Adam(net.D_B.parameters(), lr=lr)

stats_manager = ClassificationStatsManager()
exp1 = nt.Experiment(net, train_set, val_set, 
                     optimizer_G, optimizer_D_A,optimizer_D_B,
                     stats_manager, output_dir="CGAN",batch_size=1)

In [9]:
exp1.run(num_epochs=50)

Start/Continue training from epoch 0


  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 1 [G loss: 9.5715, D_A loss: 0.2891, D_B loss: 0.3334] (Time: 368.74s)
Epoch 2 [G loss: 7.7709, D_A loss: 0.3028, D_B loss: 0.2818] (Time: 367.19s)
Epoch 3 [G loss: 7.1713, D_A loss: 0.3072, D_B loss: 0.2670] (Time: 370.86s)
Epoch 4 [G loss: 6.9971, D_A loss: 0.3026, D_B loss: 0.2375] (Time: 369.25s)
Epoch 5 [G loss: 6.4614, D_A loss: 0.2826, D_B loss: 0.2356] (Time: 364.44s)
Epoch 6 [G loss: 6.0706, D_A loss: 0.3054, D_B loss: 0.2256] (Time: 366.72s)
Epoch 7 [G loss: 5.9425, D_A loss: 0.3008, D_B loss: 0.1941] (Time: 365.77s)
Epoch 8 [G loss: 6.0823, D_A loss: 0.2816, D_B loss: 0.1770] (Time: 372.72s)
Epoch 9 [G loss: 5.8023, D_A loss: 0.2792, D_B loss: 0.1753] (Time: 366.25s)
Epoch 10 [G loss: 5.4601, D_A loss: 0.2620, D_B loss: 0.2011] (Time: 361.82s)
Epoch 11 [G loss: 5.4797, D_A loss: 0.2617, D_B loss: 0.2041] (Time: 368.84s)
Epoch 12 [G loss: 5.4490, D_A loss: 0.2338, D_B loss: 0.1852] (Time: 364.47s)
Epoch 13 [G loss: 5.3923, D_A loss: 0.2081, D_B loss: 0.1804] (Time: 365.

In [10]:
def GetResults(net,test_set):
    test_loader = td.DataLoader(test_set, batch_size=1, shuffle=False, drop_last=True, pin_memory=True)
    net.eval()
    
    # Create output dirs if they don't exist
    if not os.path.exists('output/A'):
        os.makedirs('output/A')
    if not os.path.exists('output/B'):
        os.makedirs('output/B')
        
    with torch.no_grad():
        i=0
        for real_a,real_b in test_loader:
            real_a, real_b = real_a.to(net.device), real_b.to(net.device)
            fake_a,fake_b = net(real_a, real_b)
            # Save image files
            tv.utils.save_image(fake_a, 'output/A/%04d.png' % (i+1))
            tv.utils.save_image(fake_b, 'output/B/%04d.png' % (i+1))
            i+=1
            
            print('Generated images %04d of %04d' % (i, len(test_loader)))

In [11]:
GetResults(net,test_set)

Generated images 0001 of 0006
Generated images 0002 of 0006
Generated images 0003 of 0006
Generated images 0004 of 0006
Generated images 0005 of 0006
Generated images 0006 of 0006
