In [56]:
from torchvision.datasets import ImageFolder
import torchvision.transforms as tt
from torchvision.utils import save_image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [57]:
ROOT = "./data/cats/train"
BATCH_SIZE = 200
LATENT_SIZE = 100
TFMS = tt.Compose([
    tt.ToTensor(),
    tt.Normalize((0.4819, 0.4325, 0.3845),(0.2602, 0.2519, 0.2537))
])

In [58]:
train_ds = ImageFolder(ROOT,transform=TFMS)
train_dl = DataLoader(train_ds,BATCH_SIZE,num_workers= 4,pin_memory=True)

In [59]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [60]:
D = nn.Sequential(
    nn.Conv2d(3,15,3,1,1),#15 64 64
    nn.BatchNorm2d(15),
    nn.LeakyReLU(0.2),

    nn.Conv2d(15,45,3,2,1),#45 32 32
    nn.BatchNorm2d(45),
    nn.LeakyReLU(0.2),
    # nn.Conv2d(45,45,3,1,1),
    # nn.LeakyReLU(0.2),

    nn.Conv2d(45,135,3,2,1),#135 16 16
    nn.BatchNorm2d(135),
    nn.LeakyReLU(0.2),
    # nn.Conv2d(135,135,3,1,1),
    # nn.LeakyReLU(0.2),

    nn.Conv2d(135,270,3,2,1),#270 8 8
    nn.BatchNorm2d(270),
    nn.LeakyReLU(0.2),
    # nn.Conv2d(270,270,3,1,1),
    # nn.LeakyRelU(0.2),
    
    nn.Conv2d(270,450,3,2,1),#450 4 4
    nn.BatchNorm2d(450),
    nn.LeakyReLU(0.2),

    nn.AdaptiveAvgPool2d(1),#450 1 1
    nn.Flatten(),

    nn.Linear(450,25),
    nn.LeakyReLU(0.2),
    nn.Linear(25,1),

    nn.Sigmoid()
).to(device)

G = nn.Sequential(
    nn.BatchNorm2d(LATENT_SIZE),
    nn.ConvTranspose2d(LATENT_SIZE,400,4,2,1),#400 2 2
    nn.BatchNorm2d(400),
    nn.ReLU(),
    nn.ConvTranspose2d(400,200,4,2,1),#200 4 4
    nn.BatchNorm2d(200),
    nn.ReLU(),
    nn.ConvTranspose2d(200,100,4,2,1),#100 8 8
    nn.BatchNorm2d(100),
    nn.ReLU(),
    nn.ConvTranspose2d(100,50,4,2,1),#50 16 16
    nn.BatchNorm2d(50),
    nn.ReLU(),
    nn.ConvTranspose2d(50,15,4,2,1),#15 32 32
    nn.BatchNorm2d(15),
    nn.ReLU(),
    nn.ConvTranspose2d(15,3,4,2,1),#3 64 64
    nn.Tanh()
).to(device)

In [61]:
loss_fn = nn.BCELoss()

d_opt = torch.optim.Adam(D.parameters(), lr=0.0001)
g_opt = torch.optim.Adam(G.parameters(), lr=0.0001)

In [62]:
real_labels = torch.ones(BATCH_SIZE,1).to(device)
fake_labels = torch.zeros(BATCH_SIZE,1).to(device)

def d_fit(real_images):
    real_images = real_images.to(device)

    real_preds = D(real_images)
    real_loss = loss_fn(real_preds,real_labels)
    
    fake_images = G(torch.randn(BATCH_SIZE,LATENT_SIZE,1,1).to(device))

    fake_preds = D(fake_images)
    fake_loss = loss_fn(fake_preds,fake_labels)

    d_loss = real_loss + fake_loss

    d_loss.backward()

    d_opt.step()
    d_opt.zero_grad()

    return d_loss,real_preds,fake_preds

In [63]:
def g_fit():

    fake_images = G(torch.randn(BATCH_SIZE,LATENT_SIZE,1,1).to(device))
    g_loss = loss_fn(D(fake_images), real_labels)

    g_loss.backward()

    g_opt.step()
    g_opt.zero_grad()

    return g_loss

In [64]:
gen = torch.randn(36,LATENT_SIZE,1,1).to(device)

def save_fake_image(num):
    images = G(gen)

    for i in range(len(images)):
        images[i][0]=images[i][0]*0.2602+0.4819
        images[i][1]=images[i][1]*0.2519+0.4325
        images[i][2]=images[i][2]*0.2537+0.3845

    save_image(images,"./data/cats/gans_data/epoch_"+str(num)+".png",nrow=6)

save_fake_image(0)

In [65]:
epochs = 5


for epoch in range(epochs):
    print("Epoch: ", epoch+1)
    for images,_ in train_dl:
        d_loss,real_preds,fake_preds = d_fit(images)
        g_loss = g_fit()

    print(real_preds.mean().item(),fake_preds.mean().item())
    
    with torch.no_grad():
        save_fake_image(epoch+1)

Epoch:  1
