In [1]:
import torch
from torchsummary import summary
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import residual_GAN
from image_dataset import image_dataset
from torch.utils.data import RandomSampler
from torch_snippets import *
device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [2]:
path_gen = r"resnet_300M_generator/"
path_disc = r"resnet_300M_discriminator/"
discriminator = residual_GAN.Discriminator().to(device)
discriminator.load_state_dict(torch.load(path_disc+"discriminator.pth"))
generator = residual_GAN.Generator(noise_dim=2048).to(device)
generator.load_state_dict(torch.load(path_gen+"generator.pth"))

<All keys matched successfully>

# Data Pipeline and Utility Functions

In [3]:
def noise_generator(device,batch_size=64, dimension=2048):
    return torch.randn((batch_size, dimension)).to(device)

In [4]:
img_transforms = transforms.Compose([
    transforms.Resize([128,128]),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor(),
    #transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])
folder = "../../Datasets/annotated_img/images/train2017"
train_dataset = image_dataset(folder,img_transforms)
batch_size = 64
num_samples = 64000
train_sampler = RandomSampler(train_dataset,replacement=False,num_samples=num_samples)
train_dataloader = DataLoader(train_dataset,batch_size=batch_size,sampler=train_sampler)
print(len(train_dataloader))

# training routine

In [5]:
model_name = "resnet_300M_discriminator"
if not os.path.exists(model_name):
    os.makedirs(model_name)

In [6]:
def train_discriminator(discriminator, real_data, fake_data, loss_fn, optimizer,real_factor=3):
    optimizer.zero_grad()
    pred_real = discriminator(real_data)
    loss_real = loss_fn(pred_real, torch.ones((len(real_data),1)).to(device))
    pred_fake = discriminator(fake_data)
    loss_fake = loss_fn(pred_fake, torch.zeros((len(fake_data),1)).to(device))
    loss = loss_real * real_factor + loss_fake
    loss.backward()
    optimizer.step()
    return loss_real, loss_fake

In [7]:
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(discriminator.parameters(),lr=1e-5)

# training discriminator

In [8]:
N = len(train_dataloader)
noise_dim = 2048
num_epochs = 10
threshold = 0.02
log = Report(num_epochs)
discriminator.to(device)
loss_total = 0
num_batchs = 0
count = 0
num_counts=2

for epoch in range(num_epochs):
    for idx, imgs in enumerate(train_dataloader):
        real_data = imgs.to(device)
        fake_data = generator(noise_generator(device,batch_size,noise_dim)).to(device)
        fake_data = fake_data.detach()
        loss_real, loss_fake = train_discriminator(discriminator,real_data,fake_data,loss_fn,optimizer,4)
        total_loss = loss_real.item() + loss_fake.item()
        num_batchs += 1
        loss_total += total_loss
        log.record(epoch+(1+idx)/N, total_loss = total_loss, loss_real = loss_real.item(),
                   loss_fake = loss_fake.item(), end = '\r')
    log.report_avgs(epoch+1)
    if (loss_total/num_batchs) < threshold:
        count += 1
        if count == num_counts:
            torch.save(discriminator.state_dict(),model_name+"/discriminator.pth")
            break
    else:
        count = 0
    loss_total = 0
    num_batchs = 0
log.plot_epochs(['loss_real','loss_fake'])
    


EPOCH: 1.000  loss_real: 0.039  total_loss: 0.243  loss_fake: 0.204  (562.76s - 5064.88s remaining)
EPOCH: 2.000  loss_real: 0.005  total_loss: 0.022  loss_fake: 0.017  (1113.83s - 4455.30s remaining)
EPOCH: 3.000  loss_real: 0.003  total_loss: 0.013  loss_fake: 0.010  (1659.89s - 3873.08s remaining)
EPOCH: 3.214  total_loss: 0.014  loss_real: 0.014  loss_fake: 0.000  (1777.03s - 3752.01s remaining)

KeyboardInterrupt: 

In [9]:
torch.save(discriminator.state_dict(),model_name+"/discriminator.pth")