In [2]:
# Colab training

# from google.colab import drive
# drive.mount("/content/drive")

# import zipfile
# with zipfile.ZipFile("./drive/MyDrive/cats.zip") as f:
#   f.extractall("./data")

In [3]:
# torchvision imports
from torchvision.datasets import ImageFolder
import torchvision.transforms as tt
from torchvision.utils import save_image
# torch imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [4]:
# CONSTANTS

ROOT = "./data/cats/train"
BATCH_SIZE = 100
LATENT_SIZE = 100
#calculated using ./normalizer.py over the whole cats dataset
MEAN = (0.4819, 0.4325, 0.3845)
DEVIATION = (0.2602, 0.2519, 0.2537)

In [5]:
# simple normalized transforms 
tfms = tt.Compose([
    tt.ToTensor(),
    tt.Normalize(MEAN,DEVIATION)
])

#creating dataloader
train_ds = ImageFolder(ROOT,transform=tfms)
train_dl = DataLoader(train_ds,BATCH_SIZE,num_workers= 2,pin_memory=True)

In [6]:
#setting device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [7]:
#Discriminator model simple 4 layer doubled
D = nn.Sequential(
    nn.Conv2d(3,60,4,2,1),#60 32 32
    nn.BatchNorm2d(60),
    nn.LeakyReLU(0.2),

    nn.Conv2d(60,120,4,2,1),#120 16 16
    nn.BatchNorm2d(120),
    nn.LeakyReLU(0.2),

    nn.Conv2d(120,250,4,2,1),#250 8 8
    nn.BatchNorm2d(250),
    nn.LeakyReLU(0.2),
    
    nn.Conv2d(250,500,4,2,1),#500 4 4
    nn.BatchNorm2d(500),
    nn.LeakyReLU(0.2),

    nn.AdaptiveAvgPool2d(1),#500 1 1
    nn.Flatten(),#500*1*1

    nn.Linear(500,25),#25
    nn.LeakyReLU(0.2),
    nn.Linear(25,1),#1

    nn.Sigmoid()#converted into probablity
).to(device)

#Generator model - inverse of Discriminator
G = nn.Sequential(
    nn.ConvTranspose2d(LATENT_SIZE,600,4,1,0),#600 4 
    nn.BatchNorm2d(600),
    nn.ReLU(),
    nn.ConvTranspose2d(600,300,4,2,1),#300 8
    nn.BatchNorm2d(300),
    nn.ReLU(),
    nn.ConvTranspose2d(300,150,4,2,1),#150 16
    nn.BatchNorm2d(150),
    nn.ReLU(),
    nn.ConvTranspose2d(150,50,4,2,1),#60 32
    nn.BatchNorm2d(50),
    nn.ReLU(),
    nn.ConvTranspose2d(50,3,4,2,1),#3 64
    nn.Tanh()
).to(device)

In [8]:
#using Binary cross entropy loss function
loss_fn = nn.BCELoss()

#setting discriminator learning rate less to give generator a head start
d_opt = torch.optim.Adam(D.parameters(), lr=0.00001)
g_opt = torch.optim.Adam(G.parameters(), lr=0.0001)

In [9]:
#declaring labels used to calculate loss by comparing against preds
real_labels = torch.ones(BATCH_SIZE,1).to(device)
fake_labels = torch.zeros(BATCH_SIZE,1).to(device)

#Disciminator Function training
def d_fit(real_images):
    d_opt.zero_grad()

    real_images = real_images.to(device)

    #calculating loss on real images
    real_preds = D(real_images)
    real_loss = loss_fn(real_preds,real_labels)
    
    #calculating loss on real images 
    fake_images = G(torch.randn(BATCH_SIZE,LATENT_SIZE,1,1).to(device))
    fake_preds = D(fake_images)
    fake_loss = loss_fn(fake_preds,fake_labels)

    #doing gradient descent on both losses
    d_loss = real_loss + fake_loss
    d_loss.backward()

    d_opt.step()

    #returning average probablies of the batch on both real and fake images 
    return d_loss,real_preds.mean().item(),fake_preds.mean().item()

In [10]:
#Generator Model training function
def g_fit():
    g_opt.zero_grad()

    #generating fake images and comparing the probality of them from discriminator to real labels
    fake_images = G(torch.randn(BATCH_SIZE,LATENT_SIZE,1,1).to(device))
    g_loss = loss_fn(D(fake_images), real_labels)

    #optimizing for better Generator
    g_loss.backward()

    g_opt.step()

    return g_loss

In [11]:
# random tensor to be used for refrence of images being generated per epoch
gen = torch.randn(36,LATENT_SIZE,1,1).to(device)

def save_fake_image(num):
    images = G(gen)

    #denormalizing generated images
    for i in range(len(images)):
        images[i][0]=images[i][0]*0.2602+0.4819
        images[i][1]=images[i][1]*0.2519+0.4325
        images[i][2]=images[i][2]*0.2537+0.3845

    #saving images to make progess VIDEO
    save_image(images,"./data/MyDrive/cat_gan/epoch_"+str(num)+".png",nrow=6)

# save_fake_image(0)

In [12]:
# main training fit funcion to train both D,G alternatively
def fit(epochs):

    for epoch in range(epochs):
        print("Epoch: ", epoch+1)
        
        for images,_ in train_dl:
            #condition to avoid error in case of last batch not being 100
            if images.shape[0]==100:
                d_loss,real_prob,fake_prob = d_fit(images)
                g_loss = g_fit()
            
            #condition to give Generator a boost if its lagging behind
            if real_prob >0.75:
                g_loss = g_fit()

        print("Real Pred: ",round(real_prob,4),"Fake Pred: ",round(fake_prob,4))
        print("D Loss: ",round(d_loss.item(),4),"G Loss: ",round(g_loss.item(),4))

        #saving images per epoch
        with torch.no_grad():
            save_fake_image(epoch+1)

# trained 300 epoch on Colab
# fit(300)

In [13]:
# torch.save(G.state_dict(),"./drive/MyDrive/cat_model/CatGan.pth")

# Loading prior trained generator model
G.load_state_dict(torch.load("./saved_models/CatGan.pth", map_location=device))

<All keys matched successfully>

In [20]:
#final random cat images generating function
def generate():
    #random images every call
    images = G(torch.randn(36,LATENT_SIZE,1,1).to(device))

    #denorming
    for i in range(len(images)):
        images[i][0]=images[i][0]*0.2602+0.4819
        images[i][1]=images[i][1]*0.2519+0.4325
        images[i][2]=images[i][2]*0.2537+0.3845

    #saving image to view
    save_image(images,"./data/cats/sample_gen.png",nrow=6)

generate()