In [None]:
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
import torch
import torch.nn as nn
import torch.optim as optim

from torchvision.utils import make_grid
from torchvision.utils import save_image
from torchvision.io import read_image
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm

import os

#!pip install opendatasets
import opendatasets as od


In [None]:
image_size = 128
batch_size = 128
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
latent_size = 128
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 2e-4
device

'cuda'

In [None]:
od.download('https://www.kaggle.com/datasets/greg115/abstract-art')

Please provide your Kaggle credentials to download this dataset. Learn more: http://bit.ly/kaggle-creds
Your Kaggle username: slenser
Your Kaggle Key: ··········
Downloading abstract-art.zip to ./abstract-art


100%|██████████| 296M/296M [00:03<00:00, 82.6MB/s]





In [None]:
transforms = T.Compose([T.Resize((128,128)),
                        T.CenterCrop(128),
                        T.RandomHorizontalFlip(),
                        T.RandomVerticalFlip(),
                        T.ToTensor(),
                        T.Normalize(*stats)])

In [None]:
def denorm(img_tensor):
    return img_tensor * stats[1][0] + stats[0][0]

In [None]:
train_ds = ImageFolder(root='/content/abstract-art', transform=transforms)
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=0, pin_memory=True)


In [None]:
def show_image(train_dl):
  for images,_ in train_dl:
      fig, ax = plt.subplots(figsize=(8,8))
      ax.set_xticks([]); ax.set_yticks([])
      ax.imshow(make_grid(denorm(images.detach()[:32]), nrow=8).permute(1,2,0))
      break
        

In [None]:
def to_device(data, device):
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        for x in self.dl:
            yield to_device(x, self.device)
            
    def __len__(self):
        return len(self.dl)

In [None]:
train_dl = DeviceDataLoader(train_dl, device)

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.disc = nn.Sequential(
    nn.Conv2d(3, 64, 4, 2, 1, bias=False),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.2, inplace=True),
    
    nn.Conv2d(64, 128,4, 2, 1, bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2, inplace=True),
    
    nn.Conv2d(128, 256,4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2, inplace=True),
    
    nn.Conv2d(256, 512, 4, 2, 1, bias=False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2, inplace=True),
    
    nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
    nn.BatchNorm2d(1024),
    nn.LeakyReLU(0.2, inplace=True),
    
    nn.Conv2d(1024, 1, 4, 1, 0, bias=False),
    
    nn.Flatten(),
    nn.Sigmoid()
    )
  def forward(self, x):
    return self.disc(x)

class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.gen = nn.Sequential(
        
  
    nn.ConvTranspose2d(latent_size, 1024, kernel_size=4, stride=1, padding=0, bias=False),
    nn.BatchNorm2d(1024),
    nn.ReLU(True),
    
    nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(512),
    nn.ReLU(True),
    
    nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.ReLU(True),
    
    nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.ReLU(True),
    
    nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU(True),
    
    nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
    nn.Tanh()
    )
  def forward(self, x):
    return self.gen(x)

In [None]:
D = Discriminator().to(device)
G = Generator().to(device)

In [None]:
def train_discriminator(real_images, opt_d):
    opt_d.zero_grad()
    
    real_preds= D(real_images) 
    real_targets = torch.ones(real_images.size(0), 1, device=device)
    real_loss = F.binary_cross_entropy(real_preds, real_targets) 
    real_score = torch.mean(real_preds).item()
    
    latent = torch.randn(latent_size, latent_size, 1, 1, device=device)
    fake_images = G(latent)
    
    fake_preds= D(fake_images)
    fake_targets = torch.zeros(fake_images.size(0), 1, device=device)
    fake_loss = F.binary_cross_entropy(fake_preds, fake_targets)
    fake_score = torch.mean(fake_preds).item()
    
    loss = real_loss + fake_loss
    loss.backward(),
    opt_d.step()
    
    return loss.item(), real_score, fake_score

In [None]:
def train_generator(opt_g):
    opt_g.zero_grad()
    
    latent = torch.randn(latent_size, latent_size, 1, 1, device=device)
    fake_images = G(latent)
    
    preds = D(fake_images)
    targets = torch.ones(fake_images.size(0), 1, device=device)
    loss = F.binary_cross_entropy(preds, targets)
    
    loss.backward(),
    opt_g.step()
    
    return loss.item()

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
sample_dir = "/content/gdrive/My Drive/genn"

In [None]:
def save_sample(index, fixed_latent, show=True):
    fake_images = G(fixed_latent)
    fake_fname = "img{0}.png".format(index)
    save_image(denorm(fake_images), os.path.join(sample_dir, fake_fname), nrow=8)
    if show:
        fig, ax = plt.subplots(figsize=(8,8))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(fake_images.cpu().detach()[:32], nrow=8).permute(1,2,0))

In [None]:
fixed_latent = torch.randn(128, latent_size, 1, 1, device=device)


In [None]:
def fit(epochs, lr_d, lr_g, start_idx=1):
    torch.cuda.empty_cache()
    
    opt_d = torch.optim.Adam(D.parameters(), lr=lr_d, betas=(0.5, 0.999))
    opt_g = torch.optim.Adam(G.parameters(), lr=lr_g, betas=(0.5, 0.999))
    
    for epoch in range(epochs):
        for real_images,_ in tqdm(train_dl):
            loss_d, real_score, fake_score = train_discriminator(real_images, opt_d)
            loss_g = train_generator(opt_g)
        if (epoch+1)%10==0:
            model_save_name = F'Generaotr{epoch}.pt'
            path = F"/content/gdrive/My Drive/{model_save_name}" 
            torch.save(G.state_dict(), path)

        print("Epoch: [{}/{}], loss_d: {:.4f}, loss_g: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
        epoch+1, epochs, loss_d, loss_g, real_score, fake_score))
        
        save_sample(epoch+start_idx, fixed_latent, show=False)
        

In [None]:
epochs = 300
lr_d = 10e-5
lr_g = 10e-4

In [None]:
fit(epochs, lr_d, lr_g, start_idx=1)