In [1]:
import os
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as tt
import torch
import torch.nn as nn
from torch import Tensor
import cv2
from tqdm.notebook import tqdm
import torch.nn.functional as F
from torchvision.utils import save_image
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from copy import deepcopy
from math import prod

# %matplotlib inline
torch.cuda.set_device(0)
torch.cuda.empty_cache()

In [2]:
DATA_DIR = "../../data/dss/"
CHAR_DATA_DIR = DATA_DIR + "monkbrill/"

In [3]:
image_size = 64
batch_size = 256
latent_size = image_size ** 2
stats = (0.5,), (0.5,)

train_ds = ImageFolder(CHAR_DATA_DIR, transform=tt.Compose([tt.Grayscale(num_output_channels=1),
                                                            tt.RandomInvert(p=1),
                                                            tt.Resize(image_size),
                                                            tt.CenterCrop(image_size),
                                                            tt.ToTensor(),
                                                            tt.Normalize(*stats)]))

train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=0, pin_memory=True)

In [4]:
# for x in iter(DataLoader(train_ds, 1, shuffle=True, num_workers=3, pin_memory=True)):
#   print(x[0][0][0])
#   print([y.shape for y in x])
#   plt.imshow(x[0][0][0])
#   break

class CorruptCharGen():
  
  def __init__(self, *args, max_iter=2048, **kwargs):
    self.dl_args = args
    self.dl_kwargs = {**kwargs, "batch_size": 1}
    self.n_iter = 0
    self.max_iter = max_iter
    self.data_loader = None
  
  def __iter__(self):
    self.n_iter = 0
    self.data_loader = None
    return self
  
  def __next__(self):
    if self.n_iter > self.max_iter:
      raise StopIteration
    
    if self.data_loader is None:
      self.data_loader = iter(DataLoader(*self.dl_args, **self.dl_kwargs))
    
    try:
      base_img = next(self.data_loader)
      base_img_lab = base_img[1]
      base_img = base_img[0][0][0]
      subtr_img = next(self.data_loader)[0][0][0]
      crpt_img = base_img - (subtr_img + 1)
      crpt_img = torch.maximum(crpt_img, -torch.ones(*crpt_img.shape))
      # temp = subtr_img + 1
      # print((torch.min(temp), torch.max(temp)))
      
      # fig, axes = plt.subplots(1, 3, figsize=(20, 60))
      # axes[0].imshow(base_img)
      # axes[1].imshow(subtr_img)
      # axes[2].imshow(crpt_img)
      
      # plt.show()
      
      self.n_iter += 1
      
      return crpt_img.reshape((latent_size, 1, 1)), base_img.reshape((latent_size, 1, 1)), base_img_lab
      
    except StopIteration:
      self.data_loader = None
      return next(self)
    
  def gen_chars(self, num=1):
    crpt_imgs = [next(self) for _ in range(num)]
    return tuple([torch.stack([img[i] for img in crpt_imgs]) for i in range(len(crpt_imgs[0]))])
  
  # torch.stack([img[0] for img in crpt_imgs]), \
  #          torch.stack([])
  #          torch.stack([img[1] for img in crpt_imgs])
      
    
ccg = CorruptCharGen(deepcopy(train_ds), shuffle=True, num_workers=0, pin_memory=True)

In [5]:
# next(ccg)

In [6]:
class Reshape(nn.Module):
  
  def __init__(self, *shape):
    super(Reshape, self).__init__()
    
    if isinstance(shape[0], tuple):
      shape = shape[0]
      
    self.shape = shape
    
  def forward(self, input):
    relative_dim = lambda x: input.shape[int(x)]
    shape = tuple([prod(map(relative_dim, d.split("*"))) if isinstance(d, str) else d for d in self.shape])
    return input.reshape(shape)


generator = nn.Sequential(
    # in: latent_size x 1 x 1

    nn.ConvTranspose2d(latent_size, 512, kernel_size=4, stride=1, padding=0, bias=False),
    nn.BatchNorm2d(512),
    nn.ReLU(True),
    # out: 512 x 4 x 4

    nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.ReLU(True),
    # out: 256 x 8 x 8

    nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.ReLU(True),
    # out: 128 x 16 x 16

    nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU(True),
    # out: 64 x 32 x 32

    nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1, bias=False),
    nn.Tanh(),
    # out: 1 x 64 x 64
    
    # Reshape(1, 1, '2*3')
)
c, b, l = ccg.gen_chars(1)

output = generator(c)

In [7]:
print(c.shape)
print(output.shape)

torch.Size([1, 4096, 1, 1])
torch.Size([1, 1, 64, 64])


In [8]:
discriminator = nn.Sequential(
    # in: 1 x 64 x 64

    nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 64 x 32 x 32

    nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 128 x 16 x 16

    nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 256 x 8 x 8

    nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 512 x 4 x 4

    nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
    # out: 1 x 1 x 1

    nn.Flatten(),
    nn.Sigmoid()
)

In [9]:
denorm = lambda x: x

sample_dir = 'generated'
os.makedirs(sample_dir, exist_ok=True)

def save_samples(index, latent_tensors, show=True):
    fake_images = generator(latent_tensors)
    fake_fname = 'generated-images-{0:0=4d}.png'.format(index)
    save_image(denorm(fake_images), os.path.join(sample_dir, fake_fname), nrow=8)
    print('Saving', fake_fname)
    if show:
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))

In [10]:
torch.cuda.get_device_name(0)

'NVIDIA GeForce RTX 2070'

In [11]:
device = 0

def train_discriminator(real_images, opt_d):
    # Clear discriminator gradients
    opt_d.zero_grad()

    # Pass real images through discriminator
    real_preds = discriminator(real_images)
    real_targets = torch.ones(real_images.size(0), 1)
    # real_loss = F.binary_cross_entropy(real_preds, real_targets)
    real_loss = F.mse_loss(real_preds, real_targets)
    real_score = torch.mean(real_preds).item()
    
    # Generate fake images
    # latent = torch.randn(batch_size, latent_size, 1, 1)
    corrupted, base, labels = ccg.gen_chars(batch_size)
    fake_images = generator(corrupted)

    # Pass fake images through discriminator
    fake_targets = torch.zeros(fake_images.size(0), 1)
    fake_preds = discriminator(fake_images)
    # fake_loss = F.binary_cross_entropy(fake_preds, fake_targets)
    fake_loss = F.mse_loss(fake_preds, fake_targets)
    fake_score = torch.mean(fake_preds).item()
    
    # if real_loss < 0.1:
    #   print(f"\nreal:\n{real_preds[:10]}\n{real_targets[:10]}")
    # if fake_loss < 0.1:
    #   print(f"\nfake:\n{fake_preds[:10]}\n{fake_targets[:10]}")
    
    print(f"discriminator losses: {(real_loss.item(), fake_loss.item())}")

    # Update discriminator weights
    loss = (real_loss + fake_loss) / 2
    loss.backward()
    opt_d.step()
    return loss.item(), real_score, fake_score

In [12]:
def mae(pred, true):
  return F.l1_loss(pred, true) / latent_size

def train_generator(opt_g):
    # Clear generator gradients
    opt_g.zero_grad()
    
    # Generate fake images
    # latent = torch.randn(batch_size, latent_size, 1, 1)
    corrupted, base, labels = ccg.gen_chars(batch_size)
    print(f"fixable images shape: {corrupted.shape}")
    fake_images = generator(corrupted)
    
    # Try to fool the discriminator
    fool_preds = discriminator(fake_images)
    fool_targets = torch.ones(batch_size, 1)
    fool_loss = F.mse_loss(fool_preds, fool_targets)
    
    sim_loss = F.mse_loss((fake_images / 2).reshape((batch_size, latent_size, 1, 1)), base / 2)
    
    # Update generator weights
    loss = 2.5 * fool_loss + 150 * sim_loss
    loss.backward()
    opt_g.step()
    
    print(f"generator loss: {(torch.mean(loss).item())}")
    
    return loss.item()

In [13]:
# fixed_latent = torch.randn(batch_size, latent_size, 1, 1)
fixed_corrupted, fixed_base, fixed_labels = ccg.gen_chars(64)

def fit(epochs, lr, start_idx=1):
    torch.cuda.empty_cache()
    
    # Losses & scores
    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []
    
    # Create optimizers
    opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    
    for epoch in range(epochs):
        for real_images, _ in tqdm(train_dl):
            iter(ccg)
            # Train discriminator
            loss_d, real_score, fake_score = train_discriminator(real_images, opt_d)
            # Train generator
            loss_g = train_generator(opt_g)
            
        # Record losses & scores
        losses_g.append(loss_g)
        losses_d.append(loss_d)
        real_scores.append(real_score)
        fake_scores.append(fake_score)
        
        # Log losses & scores (last batch)
        print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
            epoch+1, epochs, loss_g, loss_d, real_score, fake_score))
    
        # Save generated images
        save_samples(epoch+start_idx, fixed_corrupted, show=False)
    
    return losses_g, losses_d, real_scores, fake_scores

In [14]:
# Image('./generated/generated-images-0060.png')
save_image(denorm(fixed_base.reshape(64, 1, image_size, image_size)), os.path.join('generated', "base.png"), nrow=8)
save_image(denorm(fixed_corrupted.reshape(64, 1, image_size, image_size)), os.path.join('generated', "corrupted.png"), nrow=8)
print(fixed_labels)

tensor([[17],
        [ 4],
        [ 7],
        [11],
        [ 3],
        [ 5],
        [18],
        [ 7],
        [ 4],
        [13],
        [ 7],
        [19],
        [ 0],
        [ 3],
        [ 2],
        [ 6],
        [18],
        [23],
        [18],
        [13],
        [10],
        [19],
        [26],
        [12],
        [20],
        [ 0],
        [22],
        [13],
        [ 6],
        [13],
        [ 2],
        [ 8],
        [ 1],
        [11],
        [11],
        [ 7],
        [13],
        [ 9],
        [13],
        [23],
        [13],
        [23],
        [21],
        [11],
        [ 6],
        [19],
        [ 1],
        [ 6],
        [15],
        [10],
        [20],
        [20],
        [ 2],
        [ 5],
        [20],
        [10],
        [ 5],
        [ 5],
        [ 0],
        [20],
        [22],
        [10],
        [13],
        [21]])


In [15]:
history = fit(100, 1e-2)

  0%|          | 0/22 [00:00<?, ?it/s]

discriminator losses: (0.26326775550842285, 0.23197446763515472)
fixable images shape: torch.Size([256, 4096, 1, 1])


KeyboardInterrupt: 

In [None]:
# torch.save(discriminator.state_dict(), "./trained/char/discriminator")
# torch.save(generator.state_dict(), "./trained/char/generator")

NameError: name 'opt_d' is not defined