In [None]:
%matplotlib inline
import os
import numpy as np
import torch
from torch import nn
import my_nntools_new as nt
from torch.nn import functional as F
import torch.utils.data as td
import torchvision as tv
import pandas as pd
from PIL import Image
from matplotlib import pyplot as plt

import itertools

import models as models
from utils import *

In [None]:
class NNClassifier(nt.NeuralNetwork):
    def __init__(self, lr):
        super(NNClassifier, self).__init__()
        self.lr=lr
        
    def criterion_GAN(y, d):
        return nn.MSELoss(y, d)
    def criterion_cycle(y, d):
        return nn.L1Loss(y, d)*5
    def criterion_identity(y, d):
        return nn.L1Loss(y, d)*10
    
class C_GAN(NNClassifier):
    def __init__(self, fine_tuning=True):
        super(C_GAN, self).__init__()
        
        self.G_A2B = Generator(3, 3)
        self.G_B2A = Generator(3, 3)
        self.D_A = Discriminator(3)
        self.D_B = Discriminator(3)
        
        self.G_A2B.apply(init_parameters)
        self.G_A2B.apply(init_parameters)
        self.D_A.apply(init_parameters)
        self.D_B.apply(init_parameters)
        
        self.optimizer_D_A = torch.optim.Adam(net.D_A.parameters(), lr=self.lr)
        self.optimizer_D_B = torch.optim.Adam(net.D_B.parameters(), lr=self.lr)
        self.optimizer_G = torch.optim.Adam(itertools.chain(net.G_A2B.parameters(), 
                                                       net.G_B2A.parameters()), lr=self.lr)
        
        self.fake_A_buffer = ReplayBuffer()
        self.fake_B_buffer = ReplayBuffer()
    
    def forward(self, real_a, real_b):        
        fake_b = 0.5*(self.G_A2B(real_a) + 1.0)
        fake_a = 0.5*(selfG_B2A(real_b) + 1.0)
        
        return fake_a,fake_b
        
#     def forward(self, real_a, real_b):
        
#         real_target = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False)
#         fake_target = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False)

#         ###### Generators A2B and B2A ######
#         optimizer_G.zero_grad()

#         # Identity loss
#         # G_A2B(b) should equal b if real b is fed
#         same_b = self.G_A2B(real_b)
#         loss_Idt_B = self.criterion_identity(same_b, real_b)
#         # G_B2A(a) should equal a if real a is fed
#         same_a = netG_B2A(a)
#         loss_Idt_A = self.criterion_identity(same_a, real_real_a)

#         # GAN loss
#         fake_b = self.G_A2B(real_a)
#         fake_pred = self.D_B(fake_b)
#         loss_GAN_A2B = self.criterion_GAN(fake_pred, real_target)

#         fake_a = self.G_B2A(real_B)
#         fake_pred = self.D_A(fake_a)
#         loss_GAN_B2A = self.criterion_GAN(fake_pred, real_target)

#         # Cycle loss
#         recovered_a = self.G_B2A(fake_b)
#         loss_cycle_ABA = self.criterion_cycle(recovered_a, real_a)

#         recovered_b = self.G_A2B(fake_a)
#         loss_cycle_BAB = self.criterion_cycle(recovered_b, real_b)

#         # Total loss
#         loss_G = loss_Idt_A + loss_Idt_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
#         loss_G.backward()
        
#         optimizer_G.step()
#         ###################################

#         ###### Discriminator A ######
#         optimizer_D_A.zero_grad()

#         # Real loss
#         real_pred = self.D_A(real_a)
#         loss_D_real = self.criterion_GAN(real_pred, real_target)

#         # Fake loss
#         fake_a = fake_a_buffer.push_and_pop(fake_a)
#         fake_pred = self.D_A(fake_a.detach())
#         loss_D_fake = self.criterion_GAN(fake_pred, fake_target)

#         # Total loss
#         loss_D_A = (loss_D_real + loss_D_fake)*0.5
#         loss_D_A.backward()

#         optimizer_D_A.step()
#         ###################################

#         ###### Discriminator B ######
#         optimizer_D_B.zero_grad()

#         # Real loss
#         real_pred = self.D_B(real_b)
#         loss_D_real = self.criterion_GAN(real_pred, real_target)
        
#         # Fake loss
#         fake_b = fake_b_buffer.push_and_pop(fake_b)
#         fake_pred = self.D_B(fake_b.detach())
#         loss_D_fake = self.criterion_GAN(fake_pred, fake_target)

#         # Total loss
#         loss_D_B = (loss_D_real + loss_D_fake)*0.5
#         loss_D_B.backward()

#         optimizer_D_B.step()
#         ###################################

In [None]:
class ClassificationStatsManager(nt.StatsManager):
    def __init__(self):
        super(ClassificationStatsManager, self).__init__()
        
    def init(self):
        super(ClassificationStatsManager, self).init()
        self.running_loss_G = 0
        self.running_loss_D_A = 0
        self.running_loss_D_B = 0
    
    def accumulate(self, loss_G, loss_D_A, loss_D_B):
        super(ClassificationStatsManager, self).accumulate(loss_G, loss_D_A, loss_D_B)
        self.running_loss_G += loss_G
        self.running_loss_D_A += loss_D_A
        self.running_loss_D_B += loss_D_B

    def summarize(self):
        """Compute statistics based on accumulated ones"""
        loss_G = self.running_loss_G / self.number_update 
        loss_D_A = self.running_loss_D_A / self.number_update 
        loss_D_B = self.running_loss_D_B / self.number_update
        return { 'G loss' : loss_G, 'D_A loss' : loss_D_A, 'D_B loss' : loss_D_B}

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
dataset_root_dir="/datasets/ee285f-public/caltech_ucsd_birds/"

In [None]:
# Dataset loader
train_set = ImageDataset(dataset_root_dir, image_size=(512, 512), unaligned=True, mode='train')
val_set = ImageDataset(dataset_root_dir, image_size=(512, 512), unaligned=True, mode='val')

In [None]:
lr = 1e-3
net = C_GAN(lr)
net = net.to(device)

stats_manager = ClassificationStatsManager()
exp1 = nt.Experiment(net, train_set, val_set, stats_manager, output_dir="CGAN")

In [None]:
exp1.run(num_epochs=20)