based on:
https://arxiv.org/abs/1511.06434

## Sources

Fractionally strided convolutions (*not* the same as deconvolutions):
- https://datascience.stackexchange.com/questions/49299/what-is-fractionally-strided-convolution-layer 
- https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html

In [3]:
import time
import torch
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt

def mnist_std_mean():
    dataset = datasets.MNIST(root="./datasets", train=True, download=True)
    return torch.std_mean(dataset.data / 255.0)

std, mean = mnist_std_mean()

t = transforms.Compose([transforms.ToTensor(),
                        transforms.Normalize(mean=mean, std=std)])

torch.manual_seed(777)
mnist_train = datasets.MNIST(root="./datasets", train=True, download=True, transform=t)
mnist_test = datasets.MNIST(root="./datasets", train=False, download=True, transform=t)

In [4]:
class DCGenerator(torch.nn.Module):
    def __init__(self, nodim=100):
        super(DCGenerator, self).__init__()
        self.nodim = nodim
        self.optim = None
        
        # noise vector needs to be reshaped to fit into the next layer
        
        # TODO: Design this network such that it outputs (1, 28, 28) MNIST images
        self.model = torch.nn.Sequential(
            # input shape: [batch_size, nodim, 1, 1] = [64, 100, 1, 1]
            torch.nn.ConvTranspose2d(nodim, 512, 5, stride=2),
            torch.nn.BatchNorm2d(512),
            torch.nn.ReLU(),
            # output shape: [64, 512, 5, 5]
            
            
            torch.nn.ConvTranspose2d(512, 256, 5, stride=2),
            torch.nn.BatchNorm2d(256),
            torch.nn.ReLU(),
            # output shape: [64, 256, 13, 13]
            
#             torch.nn.ConvTranspose2d(512, 256, 5, stride=2),
#             torch.nn.BatchNorm2d(256),
#             torch.nn.ReLU(),
            
            torch.nn.ConvTranspose2d(256, 128, 5, stride=2),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            # output shape: [64, 256, 29, 29]
            
            torch.nn.ConvTranspose2d(128, 1, 5, stride=2),
            torch.nn.Tanh()
            # output shape: [64, 1, 61, 61]
        )
    
    def set_optim(self, optim):
        self.optim = optim
    
    def forward(self, z):
        return self.model(z)
    
    def update(self, dis, z):
        self.optim.zero_grad()
        x = self(z)
        loss = dis.loss(x, torch.ones(x.shape[0], 1))
        loss.backward()
        self.optim.step()
        return loss

In [13]:
class DCDiscriminator(torch.nn.Module):
    def __init__(self, criterion):
        super(DCDiscriminator, self).__init__()
        self.optim = None
        self.criterion = criterion
        self.model = torch.nn.Sequential(
            torch.nn.Conv2d(1, 28, 3, 1),
            torch.nn.LeakyReLU(0.2),
            
#             torch.nn.Conv2d(28, 64, 3, 1),
#             torch.nn.BatchNorm2d(64),
#             torch.nn.LeakyReLU(0.2),
            
#             torch.nn.Conv2d(64, 64, 3, 1),
#             torch.nn.BatchNorm2d(64),
#             torch.nn.LeakyReLU(0.2),
            
            torch.nn.Flatten(),
            torch.nn.Linear(18928, 1),
            torch.nn.Sigmoid()
        )
    
    def set_optim(self, optim):
        self.optim = optim
    
    def forward(self, x):
        return self.model(x)
    
    def loss(self, x, y):
        return self.criterion(x, y)
    
    def update(self, gen, x, z):
        self.optim.zero_grad()
        print("update1")
        true_loss = self.loss(self(x), torch.ones(x.shape[0], 1))
        print("update2")
        x2 = gen(z)
        print(x2.shape, x.shape)
        prob = self(x2)
        print("update3")
        false_loss = self.loss(prob, torch.zeros(z.shape[0], 1))
        true_loss.backward()
        false_loss.backward()
        self.optim.step()
        return true_loss + false_loss, prob

In [14]:
def gen_noise(batch_size, nodim):
    return torch.rand(batch_size, nodim, 1, 1)

def train(dis, gen, dataloader, epochs, k=1):
    stats = {'dis': {
        'losses': [],
        'probs': []
    }, 'gen': {
        'losses': [] 
    }}
    gen.model.train()
    dis.model.train() 
    
    data = iter(dataloader)
    disloss, probs = 0, 0
    print("test")
    for i in range(epochs):
        print(f"epoch: {i}/{epochs}")
        start = time.time()
        for _ in range(k):
            x, _ = data.next()
            noise = gen_noise(dataloader.batch_size, gen.nodim)
            print("test2")
            loss, p = dis.update(gen, x, noise)
            print("test3")
            disloss += loss
            probs += p
            
        noise = gen_noise(dataloader.batch_size, gen.nodim)
        genloss = gen.update(dis, noise)
        end = time.time()
        
        stats['dis']['probs'].append(float(torch.mean(probs.data) / k))
        stats['dis']['loss'].append(float(torch.mean(disloss.data) / k))
        stats['gen']['loss'].append(float(torch.mean(genloss.data)))
        print(f"epoch: {i}/{epochs}; prob: {round(stats['disc']['probs'][-1], 2)}, genloss: {round(stats['gen']['losses'][-1], 2)} " + 
        f"discloss: {round(stats['disc']['losses'][-1], 2)}, epoch_time: {round(end - start, 2)}""", end="\r")
    return stats

In [15]:
dataloader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
gen = DCGenerator(nodim=100)
dis = DCDiscriminator(criterion=torch.nn.BCELoss())

LR = 0.0002
BETAS = (0.5, 0.999)
gen.set_optim(torch.optim.Adam(gen.parameters(), lr=LR, betas=BETAS))
dis.set_optim(torch.optim.Adam(dis.parameters(), lr=LR, betas=BETAS))

In [16]:
EPOCHS=10
stats = train(dis, gen, dataloader, epochs=EPOCHS, k=1)

test
epoch: 0/10
test2
update1
update2
torch.Size([64, 1, 61, 61]) torch.Size([64, 1, 28, 28])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x97468 and 18928x1)

In [None]:
fig, (p1, p2, p3) = plt.subplots(3, 1, figsize=(16,26))
p1.plot(range(EPOCHS), stats["dis"]["losses"], color="blue")
p1.set_title("Discriminator Loss")
p2.plot(range(EPOCHS), stats["gen"]["losses"], color="green")
p2.set_title("Generator Loss")
p3.plot(range(EPOCHS), stats["dis"]["probs"], color="red")
p3.set_title("Discriminator Output")
plt.xlabel("epoch")
plt.ylabel("loss")