In [1]:
# Code adapted from: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/ac5dcd03a40a08a8af7e1a67ade37f28cf88db43/ML/Pytorch/GANs/2.%20DCGAN/train.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as tfms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import os, math
import random
import numpy as np

In [2]:
DATASET = 'mnist'   # Data at benchmark/{DATASET}/data/
RUN_ID = 't1'       # Checkpoints at benchmark/{DATASET}/checkpoints/{RUN_ID}/

LOAD = False
SAVE = True
curr_epoch = 0

GEN_PATH = f'./benchmark/{DATASET}/checkpoints/{RUN_ID}/generator/'
DISC_PATH = f'./benchmark/{DATASET}/checkpoints/{RUN_ID}/critic/'
LOAD_EPOCH = 0  # The epoch checkpoint to load

BATCH_SIZE = 50

transforms

In [3]:
class AddAffine(nn.Module):
    def __init__(self):
        super().__init__()
    
    # Assumes input is a 3D matrix (C,H,W)
    def forward(self, image):
        max = (image.shape[2]-1, image.shape[1]-1)
        affine = torch.zeros_like(image[:1,:,:])

        # Draw lines between the corners of the affine matrix
        x = 0
        b1, b2 = (0,max[1])
        m1 = max[1] / max[0]
        m2 = -max[1] / max[0]
        while (x <= max[0]):
            y1 = m1*x + b1
            y2 = m2*x + b2
            y1 = round(y1)
            y2 = round(y2)
            affine[0,y1,x] = 1
            affine[0,y2,x] = 1
            x += 1

        y = 0
        b1, b2 = (0,max[0])
        m1 = max[0] / max[1]
        m2 = -max[0] / max[1]
        while (y <= max[1]):
            x1 = m1*y + b1
            x2 = m2*y + b2
            x1 = round(x1)
            x2 = round(x2)
            affine[0,y,x1] = 1
            affine[0,y,x2] = 1
            y += 1

        # draw line along the top
        x = 0
        while (x <= max[0]):
            affine[0,0,x] = 1
            x += 1

        # Append affine to image along channels axis
        out = torch.concat((image, affine), axis=0)
        return out

In [4]:
tf = tfms.Compose(
    [
        tfms.ToTensor(),
        nn.Sequential(
            AddAffine(),
            tfms.RandomAffine(degrees=180, translate=(0.0,0.0), scale=(0.75,1.0), shear=None)
        )
    ]
)

load dataset

In [5]:
data_train = datasets.MNIST(root='benchmark/datasets/', train=True, download=True, transform=tf)
dataloader = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

writer_test1 = SummaryWriter(f'logs/test1')
writer_test2 = SummaryWriter(f'logs/test2')
writer_test3 = SummaryWriter(f'logs/test3')

Test

In [6]:
step = 0
for batch_idx, (imgs,labels) in enumerate(dataloader):
    NOISE_SHAPE = (imgs.shape[0], 1, imgs.shape[2], imgs.shape[3])

    labels = labels[:,None,None,None]
    labels = labels.expand(NOISE_SHAPE)

    data = torch.concat((imgs, labels), axis=1)

    img_grid_test1 = torchvision.utils.make_grid(data[:,0:1,:,:])
    img_grid_test2 = torchvision.utils.make_grid(data[:,1:2,:,:])
    img_grid_test3 = torchvision.utils.make_grid(data[:,2:3,:,:])
    writer_test1.add_image('test', img_grid_test1, global_step=step)
    writer_test2.add_image('test', img_grid_test2, global_step=step)
    writer_test3.add_image('test', img_grid_test3, global_step=step)
    step += 1