In [1]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from tqdm.notebook import tqdm
from CycleGAN_utils import *
from CycleGAN_models import *
import os
import numpy as np
import math
import itertools
import datetime
import time
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn.functional as F


epoch = 0
n_epochs = 101
batch_size = 1
lr=0.0002
b1 = 0.5
b2 = 0.999
decay_epoch = 100
n_cpu = 8
img_height = 224
img_width = 224
channels = 1
sample_interval = 100
checkpoint_interval = 25
n_residual_blocks = 9
lambda_cyc = 10
lambda_id = 5
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
dataset = torch.load("../datasets/224skulldataset.pt")

dataset = [sample for sample in dataset if sample != None]

NC = []
AD = []
for data in dataset:
    if data[1] == 0:
        NC.append(data)
    else:
        AD.append(data)
        
        
def process_gan(dataset, s):
    output = []
    dataset = [sample[0] for sample in dataset]
    for sample in dataset:
        sample = sample[s][0]
        sample /= torch.max(sample)
        output.append(torch.unsqueeze(sample, 0))
    return output

        
NCgan1 = process_gan(NC, 0)
NCgan2 = process_gan(NC, 1)
NCgan3 = process_gan(NC, 2)

ADgan1 = process_gan(AD, 0)
ADgan2 = process_gan(AD, 1)
ADgan3 = process_gan(AD, 2)

gan1 = []
for i in range(len(ADgan1)):
    gan1.append({"A": NCgan1[i], "B": ADgan1[i]})

gan2 = []
for i in range(len(ADgan2)):
    gan2.append({"A": NCgan2[i], "B": ADgan2[i]})
    
gan3 = []
for i in range(len(ADgan3)):
    gan3.append({"A": NCgan3[i], "B": ADgan3[i]})

batch_size = 1
dataloader1 = DataLoader(gan1, batch_size=batch_size, shuffle=True, num_workers=4)
dataloader2 = DataLoader(gan2, batch_size=batch_size, shuffle=True, num_workers=4)
dataloader3 = DataLoader(gan3, batch_size=batch_size, shuffle=True, num_workers=4)

In [3]:
G_losses = []
D_losses = []
        
def sample_images(batches_done, dataloader):
    """Saves a generated sample from the test set"""
    imgs = next(iter(dataloader))
    G_AB.eval()
    G_BA.eval()
    real_A = Variable(imgs["A"].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs["B"].type(Tensor))
    fake_A = G_BA(real_B)

    real_A = make_grid(real_A, nrow=4, normalize=True)
    real_B = make_grid(real_B, nrow=4, normalize=True)
    fake_A = make_grid(fake_A, nrow=4, normalize=True)
    fake_B = make_grid(fake_B, nrow=4, normalize=True)

    image_grid = torch.stack((real_A, fake_B, real_B, fake_A), 0)
    save_image(image_grid, "ganimages/%s/%s.png" % (dataset_name, batches_done), normalize=False)
    
def train_gan(dataloader, epoch):
    prev_time = time.time()
    for epoch in range(epoch, n_epochs):
        for i, batch in enumerate(dataloader):

            # Set model input
            real_A = Variable(batch["A"].type(Tensor))
            real_B = Variable(batch["B"].type(Tensor))

            # Adversarial ground truths
            valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
            fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)

            # ------------------
            #  Train Generators
            # ------------------

            G_AB.train()
            G_BA.train()

            optimizer_G.zero_grad()


            # Identity loss
            loss_id_A = criterion_identity(G_BA(real_A), real_A)
            loss_id_B = criterion_identity(G_AB(real_B), real_B)

            loss_identity = (loss_id_A + loss_id_B) / 2

            # GAN loss
            fake_B = G_AB(real_A)
            loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
            fake_A = G_BA(real_B)
            loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

            # Cycle loss
            recov_A = G_BA(fake_B)
            loss_cycle_A = criterion_cycle(recov_A, real_A)
            recov_B = G_AB(fake_A)
            loss_cycle_B = criterion_cycle(recov_B, real_B)

            loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

            # Total loss
            loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identity
                
            loss_G.backward()
            optimizer_G.step()

            # -----------------------
            #  Train Discriminator A
            # -----------------------
           
            optimizer_D_A.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_A(real_A), valid)
            # Fake loss (on batch of previously generated samples)
            fake_A_ = fake_A_buffer.push_and_pop(fake_A)
            loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
            # Total loss
            loss_D_A = (loss_real + loss_fake) / 2

            loss_D_A.backward()
            optimizer_D_A.step()

            # -----------------------
            #  Train Discriminator B
            # -----------------------

            optimizer_D_B.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_B(real_B), valid)
            # Fake loss (on batch of previously generated samples)
            fake_B_ = fake_B_buffer.push_and_pop(fake_B)
            loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
            # Total loss
            loss_D_B = (loss_real + loss_fake) / 2

            loss_D_B.backward()
            optimizer_D_B.step()

            loss_D = (loss_D_A + loss_D_B) / 2

            # --------------
            #  Log Progress
            # --------------

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + i
            batches_left = n_epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
            prev_time = time.time()
            
            if i % 100 == 0:
            
                print(
                    "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
                    % (
                        epoch,
                        n_epochs,
                        i,
                        len(dataloader),
                        loss_D.item(),
                        loss_G.item(),
                        loss_GAN.item(),
                        loss_cycle.item(),
                        loss_identity.item(),
                        time_left,
                    )
                )

            # If at sample interval save image
            if batches_done % sample_interval == 0:
                sample_images(batches_done, dataloader)
                
            
            G_losses.append(loss_G.item())
            D_losses.append(loss_D.item())
            
            
        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % (dataset_name, epoch))
            torch.save(G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth" % (dataset_name, epoch))
            torch.save(D_A.state_dict(), "saved_models/%s/D_A_%d.pth" % (dataset_name, epoch))
            torch.save(D_B.state_dict(), "saved_models/%s/D_B_%d.pth" % (dataset_name, epoch))

    plt.figure(figsize=(10,5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses,label="G")
    plt.plot(D_losses,label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

In [4]:
input_shape = (channels, img_height, img_width)

# Initialize generator and discriminator
G_AB = GeneratorResNet(input_shape, n_residual_blocks).cuda()
G_BA = GeneratorResNet(input_shape, n_residual_blocks).cuda()
D_A = Discriminator(input_shape).cuda()
D_B = Discriminator(input_shape).cuda()

# Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(b1, b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(b1, b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(b1, b2))

# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)

Tensor = torch.cuda.FloatTensor

# Buffers of previously generated samples
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Losses
criterion_GAN = torch.nn.MSELoss().cuda()
criterion_cycle = torch.nn.L1Loss().cuda()
criterion_identity = torch.nn.L1Loss().cuda()


if epoch != 0:

    G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (dataset_name, epoch)))
    G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (dataset_name, epoch)))
    D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (dataset_name, epoch)))
    D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (dataset_name, epoch)))
else:

    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)

In [6]:
epoch = 0
dataset_name = '224CycleGAN1'
os.makedirs("ganimages/%s" % dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % dataset_name, exist_ok=True)

train_gan(dataloader1, epoch)

[Epoch 0/101] [Batch 0/476] [D loss: 0.272010] [G loss: 1.141037, adv: 0.247313, cycle: 0.061556, identity: 0.055633] ETA: 6:25:27.476296
[Epoch 0/101] [Batch 100/476] [D loss: 0.345722] [G loss: 1.333307, adv: 0.357637, cycle: 0.067353, identity: 0.060429] ETA: 4:20:58.052982
[Epoch 0/101] [Batch 200/476] [D loss: 0.283500] [G loss: 1.282477, adv: 0.298982, cycle: 0.065568, identity: 0.065562] ETA: 4:21:49.529376


KeyboardInterrupt: 

In [None]:
epoch = 0
dataset_name = '224CycleGAN2'
os.makedirs("ganimages/%s" % dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % dataset_name, exist_ok=True)

train_gan(dataloader2, epoch)

In [None]:
epoch = 0
dataset_name = '224CycleGAN3'
os.makedirs("ganimages/%s" % dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % dataset_name, exist_ok=True)

train_gan(dataloader3, epoch)