In [2]:
import torch
from torch import nn, optim
from jcopdl.callback import Callback

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

## Dataset and data loader

In [3]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [4]:
bs = 64

train_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize([.5], [.5])
])

train_set = datasets.ImageFolder("data/train/", transform=train_transform)
trainloader = DataLoader(train_set, batch_size=bs, shuffle=True, num_workers=4)

## Arsitektur and configuration

In [5]:
import torch
from torch import nn
from jcopdl.layers import linear_block

In [6]:
%%writefile model_gan.py

import torch
from torch import nn
from jcopdl.layers import linear_block

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),
            linear_block(784, 512, activation='lrelu'),
            linear_block(512, 256, activation='lrelu'),
            linear_block(256, 128, activation='lrelu'),
            linear_block(128, 1, activation='sigmoid')
        )
    
    def forward(self, x):
        return self.fc(x)
    
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            linear_block(100, 128, activation='lrelu'),
            linear_block(128, 256, activation='lrelu', batch_norm=True),
            linear_block(256, 512, activation='lrelu', batch_norm=True),
            linear_block(512, 1024, activation='lrelu', batch_norm=True),
            linear_block(1024, 784, activation='tanh')
        )
    
    def forward(self, x):
        return self.fc(x)
    
    def generate(self, n, device):
        z = torch.randn((n, 100), device=device)
        return self.fc(z)

Overwriting model_gan.py


## Training preparation

In [7]:
from model_gan import Discriminator, Generator

In [8]:
D = Discriminator().to(device)
G = Generator().to(device)

criterion = nn.BCELoss()

d_optimizer = optim.AdamW(D.parameters(), lr=0.0002)
g_optimizer = optim.AdamW(G.parameters(), lr=0.0002)

## Training

In [9]:
import os
from torchvision.utils import save_image

In [10]:
os.makedirs('output/GAN/', exist_ok=True)
os.makedirs('model/GAN/', exist_ok=True)

In [11]:
max_epochs = 500

for epoch in range(max_epochs):
    D.train()
    G.train()
    for real_img, _ in trainloader:
        n_data = real_img.shape[0]
        
        real_img = real_img.to(device)
        fake_img = G.generate(n_data, device)
        
        real = torch.ones((n_data,1), device=device)
        fake = torch.zeros((n_data,1), device=device)
        
        d_optimizer.zero_grad()
        
        output = D(real_img)
        d_real_loss = criterion(output, real)
        
        output = D(fake_img.detach())
        d_fake_loss = criterion(output, fake)
        
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()
        
        g_optimizer.zero_grad()
        
        output = D(fake_img)
        g_loss = criterion(output, real)
        g_loss.backward()
        g_optimizer.step()
        
    if epoch % 5 == 0:
        print(f'Epoch: {epoch:5} | D_loss: {d_loss/2:.5f} | G_loss: {g_loss:.5f}')
        
    if epoch % 15 == 0:
        G.eval()
        epoch = str(epoch).zfill(3)
        fake_img = G.generate(64, device)
        save_image(fake_img.view(-1, 1, 28, 28), f'output/GAN/ {epoch}.jpg', nrow=8, normalize=True)
        
        torch.save(G, 'model/GAN/generator.pth')
        torch.save(D, 'model/GAN/discriminator.pth')

Epoch:     0 | D_loss: 0.03149 | G_loss: 9.79497
Epoch:     5 | D_loss: 0.21948 | G_loss: 3.99216
Epoch:    10 | D_loss: 0.23204 | G_loss: 5.14459
Epoch:    15 | D_loss: 0.29811 | G_loss: 3.26552
Epoch:    20 | D_loss: 0.44037 | G_loss: 1.65701
Epoch:    25 | D_loss: 0.39387 | G_loss: 1.53519
Epoch:    30 | D_loss: 0.44636 | G_loss: 1.36021
Epoch:    35 | D_loss: 0.56285 | G_loss: 1.41897
Epoch:    40 | D_loss: 0.59337 | G_loss: 1.53987
Epoch:    45 | D_loss: 0.64477 | G_loss: 1.14708
Epoch:    50 | D_loss: 0.66756 | G_loss: 1.25807
Epoch:    55 | D_loss: 0.58304 | G_loss: 0.98072
Epoch:    60 | D_loss: 0.59256 | G_loss: 1.11923
Epoch:    65 | D_loss: 0.49924 | G_loss: 1.24293
Epoch:    70 | D_loss: 0.66185 | G_loss: 1.01618
Epoch:    75 | D_loss: 0.49263 | G_loss: 1.32328
Epoch:    80 | D_loss: 0.58178 | G_loss: 0.94945
Epoch:    85 | D_loss: 0.55581 | G_loss: 1.24072
Epoch:    90 | D_loss: 0.59582 | G_loss: 0.93969
Epoch:    95 | D_loss: 0.76758 | G_loss: 1.23837
Epoch:   100 | D_los

KeyboardInterrupt: 