In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

img_size = 64

transform = torchvision.transforms.Compose([torchvision.transforms.Resize(img_size),
                                            torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Normalize((0.5,),(0.5,))])
batch_size = 100

# Replace MNIST with KMNIST if so desired.
trainset = torchvision.datasets.MNIST('./data', train=True, transform=transform, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

In [None]:
# designed for use on Google Colab
!pip install ipython-autotime
%load_ext autotime

In [None]:
# constants and flags
noise_dim = 100
gen_f = 64
dis_f = 64
nc = 1

num_epochs = 2

add_noise = False
annealing = 0.99

mini_disc = False
md_dim = 1
md_k = 2

stddev_layer = False

label_smooth = False

unroll_steps = 0

In [None]:
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

img, labels = next(iter(trainloader))
img = torchvision.utils.make_grid(img, normalize=True)
plt.figure(figsize=(5, 5))
plt.imshow(np.transpose(img.numpy(),(1,2,0)))
plt.show()

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
class Minibatch_Disc(nn.Module):
    def __init__(self, d_in, d_out, k_size):
        super(Minibatch_Disc, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.k_size = k_size
    
        self.T = nn.Parameter(torch.Tensor(d_in, d_out, k_size))
        nn.init.normal_(self.T, 0, 1)
    
    def forward(self, x):
        x_vector = x.view(-1, self.d_in)
        mat = x_vector.mm(self.T.view(self.d_in, -1))
        mat = mat.view(-1, self.d_out, self.k_size).unsqueeze(0)
        norm = torch.abs(mat - mat.permute(1,0,2,3)).sum(3)
        norm = torch.exp(-norm)
        return torch.cat([x_vector, norm.sum(0) - 1], 1).view(x.size()[0], -1, x.size()[2], x.size()[3])

class Stddev(nn.Module):
    def __init__(self):
        super(Stddev, self).__init__()
    
    def forward(self, x):
        stddev = x.std(dim=0).mean()
        return torch.cat([x, stddev.expand(x.size()[0], 1, x.size()[2], x.size()[3])], dim=1)
        

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.decon1 = nn.ConvTranspose2d(noise_dim, gen_f*8, 4, 1, 0)
        self.norm1 = nn.BatchNorm2d(gen_f*8)
        
        self.decon2 = nn.ConvTranspose2d(gen_f*8, gen_f*4, 4, 2, 1)
        self.norm2 = nn.BatchNorm2d(gen_f*4)
        
        self.decon3 = nn.ConvTranspose2d(gen_f*4, gen_f*2, 4, 2, 1)
        self.norm3 = nn.BatchNorm2d(gen_f*2)
        
        self.decon4 = nn.ConvTranspose2d(gen_f*2, gen_f, 4, 2, 1)
        self.norm4 = nn.BatchNorm2d(gen_f)
        
        self.decon_final = nn.ConvTranspose2d(gen_f, nc, 4, 2, 1)
        
    def forward(self, input):
        x = F.leaky_relu(self.norm1(self.decon1(input)), 0.2)
        x = F.leaky_relu(self.norm2(self.decon2(x)), 0.2)
        x = F.leaky_relu(self.norm3(self.decon3(x)), 0.2)
        x = F.leaky_relu(self.norm4(self.decon4(x)), 0.2)
        x = torch.tanh(self.decon_final(x))
        return x
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.std = 0.3
        
        self.conv1 = nn.Conv2d(nc, dis_f, 4, 2, 1)
        
        self.conv2 = nn.Conv2d(dis_f, dis_f*2, 4, 2, 1)
        self.norm2 = nn.BatchNorm2d(dis_f*2)
        
        self.conv3 = nn.Conv2d(dis_f*2, dis_f*4, 4, 2, 1)
        self.norm3 = nn.BatchNorm2d(dis_f*4)
        
        self.conv4 = nn.Conv2d(dis_f*4, dis_f*8, 4, 2, 1)
        self.norm4 = nn.BatchNorm2d(dis_f*8)
        
        if mini_disc:
            self.conv_final = nn.Conv2d(dis_f*8 + md_dim, 1, 4, 1, 0)
            self.mbd = Minibatch_Disc(dis_f*8, md_dim, md_k)
        elif stddev_layer:
            self.stddev = Stddev()
            self.conv_final = nn.Conv2d(dis_f*8 + 1, 1, 4, 1, 0)
        else:
            self.conv_final = nn.Conv2d(dis_f*8, 1, 4, 1, 0)
        
    def add_noise(self, x, mean, std):
        noise = torch.normal(torch.zeros(x.shape, device=device), torch.full(x.shape, std, dtype=torch.float, device=device))
        return x + noise

    def forward(self, input):
        if add_noise:
            x = self.add_noise(input, 0, self.std)
            self.std *= annealing
        else:
            x = input

        x = F.leaky_relu(self.conv1(x), 0.2)
        x = F.leaky_relu(self.norm2(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.norm3(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.norm4(self.conv4(x)), 0.2)
        
        if mini_disc:
            x = self.mbd(x)
        elif stddev_layer:
            x = self.stddev(x)
        x = torch.sigmoid(self.conv_final(x))
        return x

In [None]:
from copy import deepcopy

class Trainer():
    def __init__(self):
        self.netG = Generator().to(device)
        self.netG.apply(weights_init)

        self.netD = Discriminator().to(device)
        self.netD.apply(weights_init)

        self.optimizerD = optim.Adam(self.netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
        self.optimizerG = optim.Adam(self.netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
        
        self.criterion = nn.BCELoss()

        self.real_label = torch.full((batch_size,), 1, dtype=torch.float, device=device)
        if label_smooth:
            self.real_label += torch.normal(torch.zeros(self.real_label.shape, device=device),
                              torch.full(self.real_label.shape, 0.1, device=device))
        self.fake_label = torch.full((batch_size,), 0, dtype=torch.float, device=device)
        # label smoothing is not usually applied to negative labels for mathematical reasons
        # by uncommenting the lines below, one may observe why
        #if label_smooth:
        #    noise = torch.normal(torch.zeros(self.real_label.shape, device=device),
        #                      torch.full(self.real_label.shape, 0.2, dtype=torch.float, device=device))
        #    self.fake_label += torch.where(noise > 0, noise, torch.zeros(self.real_label.shape, device=device))
        
    def unroll_loop(self):
        self.netD.zero_grad()
    
        data = next(iter(trainloader))
        real = data[0].to(device)
        b_size = real.size(0)

        output = self.netD(real).view(-1)
        errD_real = self.criterion(output, self.real_label)
        errD_real.backward()
        D_real = output.mean().item()

        noise = torch.randn(batch_size, noise_dim, 1, 1, device=device)
        fake = self.netG(noise)
        
        output = self.netD(fake.detach()).view(-1)
        errD_fake = self.criterion(output, self.fake_label)
        errD_fake.backward()
        D_fake = output.mean().item()
        errD = errD_real + errD_fake
        self.optimizerD.step()
    
        return (D_real, D_fake)
    
    def train(self):
        G_losses = []
        D_losses = []
        stddev_z = []
        stddev_x = []

        for epoch in range(num_epochs):
            for i, data in enumerate(trainloader, 0):
        
                # Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                self.netD.zero_grad()

                real = data[0].to(device)
                b_size = real.size(0)

                output = self.netD(real).view(-1)
                errD_real = self.criterion(output, self.real_label)
                errD_real.backward()
                D_real = output.mean().item()

                noise = torch.randn(batch_size, noise_dim, 1, 1, device=device)
                fake = self.netG(noise)
                stddev_fake = fake.std(dim=0).mean().item()
                stddev_real = real.std(dim=0).mean().item()
        
                output = self.netD(fake.detach()).view(-1)
                errD_fake = self.criterion(output, self.fake_label)
                errD_fake.backward()
                D_fake1 = output.mean().item()
                errD = errD_real + errD_fake
                self.optimizerD.step()

                # Update G network: maximize log(D(G(z)))
                self.netG.zero_grad()
        
                if unroll_steps > 0:
                    backup_dict = deepcopy(self.netD).state_dict()
                    for step in range(unroll_steps):
                        self.unroll_loop()

                output = self.netD(fake).view(-1)
                # fake labels treated as real for generator cost
                errG = self.criterion(output, self.real_label) 
                errG.backward()
                D_fake2 = output.mean().item()
                self.optimizerG.step()
        
                if unroll_steps > 0:
                    self.netD.load_state_dict(backup_dict)
                    del backup_dict
                
                D_fake = (D_fake1 + D_fake2) / 2
                # Output training stats
                if i % 10 == 0:
                    print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f   std_z: %.4f / std_x: %.4f'
                    % (epoch, num_epochs, i, len(trainloader), errD.item(), errG.item(), D_real, D_fake, stddev_fake, stddev_real))

                # Save Losses for plotting later
                G_losses.append(errG.item())
                D_losses.append(errD.item())
                stddev_z.append(stddev_fake)
                stddev_x.append(stddev_real)
        return (G_losses, D_losses, stddev_z, stddev_x)
    
    def gen_sample(self, noise=None):
        if noise == None:
            noise = torch.randn(batch_size, noise_dim, 1, 1, device=device)
        return self.netG(noise)

In [None]:
def plot(G_l, D_l, std_z, std_x):
    plt.figure()
    plt.plot(G_l,label="G")
    plt.plot(D_l,label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()

    plt.figure()
    plt.plot(std_z, label="G(z)")
    plt.plot(std_x, label="x")
    plt.xlabel("iterations")
    plt.ylabel("stddev")
    plt.legend()

def show_sample(img):
    img = torchvision.utils.make_grid(img, normalize=True, padding=3)
    plt.imshow(np.transpose(img.detach().cpu().numpy(),(1,2,0)))

In [None]:
G_losses, D_losses, stddevs_f, stddevs_r = [],[],[],[]

In [None]:
add_noise = False
annealing = 0.99

mini_disc = False
md_dim = 1
md_k = 2

stddev_layer = True

label_smooth = False

unroll_steps = 0

for _ in range(25):
    vanilla_trainer = Trainer()
    v_G_loss, v_D_loss, v_stddev_f, v_stddev_r = vanilla_trainer.train()
    plot(v_G_loss, v_D_loss, v_stddev_f, v_stddev_r)

    G_losses.append(v_G_loss)
    D_losses.append(v_D_loss)
    stddevs_f.append(v_stddev_f)
    stddevs_r.append(v_stddev_r)

    plt.figure(figsize=(5, 5))
    show_sample(vanilla_trainer.gen_sample())
    
plt.show()

In [None]:
# remove suspicious / contaminated data from overall results
for i in range(np.asarray(D_losses).shape[0]):
    if (np.asarray(D_losses)[i][-1] >= 10 or np.asarray(G_losses)[i][-1] >= 10):
        G_l = np.delete(G_losses, i, 0)
        D_l = np.delete(D_losses, i, 0)
        std_f = np.delete(stddevs_f, i, 0)
        std_r = np.delete(stddevs_r, i, 0)

G_losses, D_losses, stddevs_f, stddevs_r = G_l, D_l, std_f, std_r

In [None]:
print(np.asarray(G_losses).shape)

plot(np.average(G_losses, axis=0), np.average(D_losses, axis=0), np.average(stddevs_f, axis=0), np.average(stddevs_r, axis=0))
plt.show()

In [None]:
# Do train the classifier before using this - it's at the bottom of the page
# for reasons of cleaner organisation
classifier = torch.load('./classifier')
# instert name of trainer to use here
trainer = vanilla_trainer

# show the first one
sample = trainer.gen_sample()
show_sample(sample)
plt.show()
pred = torch.max(classifier(sample), 1)
print(pred[1].cpu().numpy())

# evaluate 100
scores, diff = [], []
for _ in range(100):
    sample = trainer.gen_sample()
    pred = torch.max(classifier(sample), 1)
    distr = torch.unique(pred[1], return_counts=True)[1].cpu().numpy()
    diff.append(np.mean(abs((distr / np.sum(distr)) - 0.1)))
    scores.append(np.mean(pred[0].detach().cpu().numpy()))

print("Avg. score: " + str(np.mean(scores)))
print("Avg. dev. from uniform: " + str(np.mean(diff)))

## Different tests

In [None]:
add_noise = False
annealing = 0.99

mini_disc = False
md_dim = 1
md_k = 2

stddev_layer = True

label_smooth = False

unroll_steps = 0

for i in range(25):
    stddev_trainer = Trainer()
    s_G_loss, s_D_loss, s_stddev_f, s_stddev_r = stddev_trainer.train()

    G_losses.append(s_G_loss)
    D_losses.append(s_D_loss)
    stddevs_f.append(s_stddev_f)
    stddevs_r.append(s_stddev_r)

    plot(s_G_loss, s_D_loss, s_stddev_f, s_stddev_r)
    plt.figure(figsize=(5, 5))
    show_sample(stddev_trainer.gen_sample())
plt.show()

In [None]:
add_noise = True
annealing = 0.99

mini_disc = False
md_dim = 1
md_k = 2

stddev_layer = False

label_smooth = False

unroll_steps = 0

for _ in range(25):
    noise_trainer = Trainer()
    n_G_loss, n_D_loss, n_stddev_f, n_stddev_r = noise_trainer.train()

    G_losses.append(n_G_loss)
    D_losses.append(n_D_loss)
    stddevs_f.append(n_stddev_f)
    stddevs_r.append(n_stddev_r)

    plot(n_G_loss, n_D_loss, n_stddev_f, n_stddev_r)
    plt.figure(figsize=(5, 5))
    show_sample(noise_trainer.gen_sample())
plt.show()

In [None]:
add_noise = False
annealing = 0.99

mini_disc = True
md_dim = 2
md_k = 4
stddev_layer = False

label_smooth = False

unroll_steps = 0

for _ in range(25):
    
    mini_trainer = Trainer()
    m_G_loss, m_D_loss, m_stddev_f, m_stddev_r = mini_trainer.train()

    G_losses.append(m_G_loss)
    D_losses.append(m_D_loss)
    stddevs_f.append(m_stddev_f)
    stddevs_r.append(m_stddev_r)

    plot(m_G_loss, m_D_loss, m_stddev_f, m_stddev_r)
    plt.figure(figsize=(5, 5))
    show_sample(mini_trainer.gen_sample())

plt.show()

In [None]:
add_noise = False
annealing = 0.99

mini_disc = False
md_dim = 1
md_k = 2

stddev_layer = False

label_smooth = True

unroll_steps = 0

for _ in range(25):
    smooth_trainer = Trainer()
    sm_G_loss, sm_D_loss, sm_stddev_f, sm_stddev_r = smooth_trainer.train()

    G_losses.append(sm_G_loss)
    D_losses.append(sm_D_loss)
    stddevs_f.append(sm_stddev_f)
    stddevs_r.append(sm_stddev_r)

    plot(sm_G_loss, sm_D_loss, sm_stddev_f, sm_stddev_r)
    plt.figure(figsize=(5, 5))
    show_sample(smooth_trainer.gen_sample())
plt.show()

In [None]:
add_noise = False
annealing = 0.99

mini_disc = False
md_dim = 1
md_k = 2

stddev_layer = False

label_smooth = False

unroll_steps = 2

# Warning: This one takes a long time. Consider making a sandwich.
for _ in range(25):
    unroll_trainer = Trainer()
    u_G_loss, u_D_loss, u_stddev_f, u_stddev_r = unroll_trainer.train()

    G_losses.append(u_G_loss)
    D_losses.append(u_D_loss)
    stddevs_f.append(u_stddev_f)
    stddevs_r.append(u_stddev_r)

    plot(u_G_loss, u_D_loss, u_stddev_f, u_stddev_r)
    plt.figure(figsize=(5, 5))
    show_sample(unroll_trainer.gen_sample())
plt.show()

In [None]:
add_noise = True
annealing = 0.99

mini_disc = True
md_dim = 1
md_k = 2

stddev_layer = False

label_smooth = True

unroll_steps = 1

# experimental trainer throwing all of the methods above on top of each other
# rarely works as well as it should
for _ in range(5):
    all_trainer = Trainer()
    a_G_loss, a_D_loss, a_stddev_f, a_stddev_r = all_trainer.train()

    G_losses.append(a_G_loss)
    D_losses.append(a_D_loss)
    stddevs_f.append(a_stddev_f)
    stddevs_r.append(a_stddev_r)

    plot(a_G_loss, a_D_loss, a_stddev_f, a_stddev_r)
    plt.figure(figsize=(5, 5))
    show_sample(all_trainer.gen_sample())
plt.show()

## Training Scoring Classifier

In [None]:
class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        self.con = nn.Conv2d(nc, nc * 4, 4)
        self.pool = nn.MaxPool2d(2, 2)
        
        self.con2 = nn.ConvTranspose2d(nc * 4, nc * 16, 4)
        self.flat = nn.Linear(nc * 16 * 256, 100)

        self.lin1 = nn.Linear(100, 75)
        self.lin2 = nn.Linear(75, 10)
        
    def forward(self, input):
        x = self.pool(self.con(input))
        x = self.pool(self.con2(x))
        x = x.view(-1, 16 * 256)
        x = F.relu(self.flat(x))
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return x

In [None]:
net = MNISTClassifier().to(device)
crit = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(2):

    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = crit(outputs, labels)
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print('[%d, %d] loss: %.3f' %
                  (epoch, i, loss.item()))

print('Finished Training')

In [None]:
# This is designed for local use in Colab
# If trying to use his elsewhere, do provide a more suitable method of saving.
torch.save(net, './classifier')

In [None]:
scores, diff = [], []

# Testing the classifier on some training data
for _ in range(100):
    img, _ = next(iter(trainloader))

    pred = torch.max(net(img.to(device)), 1)
    distr = torch.unique(pred[1], return_counts=True)[1].cpu().numpy()
    diff.append(abs((distr / np.sum(distr)) - 0.1))
    scores.append(np.mean(pred[0].detach().cpu().numpy()))

# It generally works!
print("Avg. score: " + str(np.mean(scores)))
print("Avg. deviation from uniform: " + str(np.mean(diff)))