In [1]:
import torch
import torchvision.transforms as transforms
import torchvision
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from resnet20 import ResNetCIFAR

In [2]:
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [3]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)

Files already downloaded and verified


In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(110,384),
            nn.ReLU())
        self.tconv1 = nn.Sequential(
            nn.ConvTranspose2d(384, 192, 6, 2),
            nn.BatchNorm2d(192),
            nn.ReLU())
        self.tconv2 = nn.Sequential(
            nn.ConvTranspose2d(192, 96, 5, 2),
            nn.BatchNorm2d(96),
            nn.ReLU())
        self.tconv3 = nn.Sequential(
            nn.ConvTranspose2d(96, 3, 4, 2),
            nn.Tanh())
        self._initialize_weights()
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.tconv1(out[:,:,None,None])
        out = self.tconv2(out)
        out = self.tconv3(out)
        return out
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

In [5]:
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, 3, 2, 1),
            nn.LeakyReLU(0.2))
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2))
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, 3, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2))
        self.conv4 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2))
        self.conv5 = nn.Sequential(
            nn.Conv2d(128, 256, 3, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2))
        self.conv6 = nn.Sequential(
            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2))
        self.dropout = nn.Dropout(0.5)
        self.activation_noise_sd = 0.1
        self.fc = nn.Linear(4*4*512, 11)
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(1)
        self._initialize_weights()

    def forward(self,x):
        out = self.conv1(x)
        out = self.dropout(out + self.activation_noise_sd*torch.rand(out.shape).to(device))
        out = self.conv2(out)
        out = self.dropout(out + self.activation_noise_sd*torch.rand(out.shape).to(device))
        out = self.conv3(out)
        out = self.dropout(out + self.activation_noise_sd*torch.rand(out.shape).to(device))
        out = self.conv4(out)
        out = self.dropout(out + self.activation_noise_sd*torch.rand(out.shape).to(device))
        out = self.conv5(out)
        out = self.dropout(out + self.activation_noise_sd*torch.rand(out.shape).to(device))
        out = self.conv6(out)
        out = self.dropout(out + self.activation_noise_sd*torch.rand(out.shape).to(device))
        out = self.fc(out.reshape(-1,4*4*512))
        return self.sigmoid(out[:,0]), self.softmax(out[:,1:])
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
generator = Generator()
generator = generator.to(device)
discriminator = Discriminator()
discriminator = discriminator.to(device)

In [7]:
optimizer_d = optim.Adam(discriminator.parameters(), lr = 0.0001, betas = (0.5, 0.999))
optimizer_g = optim.Adam(generator.parameters(), lr = 0.0001, betas = (0.5, 0.999))
criterion_source = nn.BCELoss()
criterion_class = nn.NLLLoss()

In [8]:
generator.load_state_dict(torch.load("generator100_original.pt"))
discriminator.load_state_dict(torch.load("discriminator100_original.pt"))

<All keys matched successfully>

In [9]:
def reverse_norm_fake_image(img):
    return img*0.5 + 0.5

def norm_orig_image(img):
    return (img - torch.tensor([0.4914, 0.4822, 0.4465])[None,:,None,None])/torch.tensor([0.2023, 0.1994, 0.2010])[None,:,None,None]

In [10]:
net = ResNetCIFAR(num_layers=20)
net = net.to(device)
net.load_state_dict(torch.load("net_before_pruning.pt"))

<All keys matched successfully>

In [11]:
def calculate_inception_score(preds, splits = 10):
    preds = np.exp(preds)
    preds /= preds.sum(axis = 1)[:, None]
    scores = []
    for i in range(splits):
        split = preds[i*preds.shape[0]//splits : (i+1)*preds.shape[0]//splits, :]
        kl = split*(np.log(split) - np.log(np.mean(split, axis = 0)))
        scores.append(np.exp(np.mean(np.sum(kl, 1))))
    return np.mean(scores), np.std(scores)

In [12]:
max_inception_score = 0
for epoch in range(100, 160):
    print("Epoch {}:".format(epoch))
    
    d_real_source_accuracies = []
    d_real_class_accuracies = []
    d_real_source_losses = []
    d_real_class_losses = []
    d_fake_source_accuracies = []
    d_fake_class_accuracies = []
    d_fake_source_losses = [] 
    d_fake_class_losses = []
    g_source_accuracies = []
    g_class_accuracies = []
    g_source_losses = []
    g_class_losses = []
    
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if batch_idx % 100 == 0:
            print("Batch Index: {}".format(batch_idx))
        # Train discriminator with real data
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer_d.zero_grad()
        sources, classes = discriminator(inputs)
        source_targets = torch.ones(sources.shape).to(device)
        source_loss = criterion_source(sources, source_targets)
        class_loss = 10*criterion_class(classes, targets)
        d_real_loss = (source_loss + class_loss)/2
        d_real_loss.backward()
        # optimizer_d.step()
        d_real_source_accuracy = torch.mean((torch.where(sources > 0.5, torch.tensor(1.).to(device), torch.tensor(0.).to(device)) == source_targets).float()).item()*100
        d_real_class_accuracy = torch.mean((classes.max(1)[1] == targets).float()).item()*100
        d_real_source_accuracies.append(d_real_source_accuracy)
        d_real_class_accuracies.append(d_real_class_accuracy)
        d_real_source_losses.append(source_loss.item())
        d_real_class_losses.append(class_loss.item())
        
        # Train discriminator with fake data
        fake_targets = np.random.randint(0, 10, inputs.shape[0])
        gen_input_np = np.c_[np.random.randn(inputs.shape[0], 100), np.eye(10)[fake_targets]]
        fake_targets = torch.from_numpy(fake_targets).long().to(device)
        gen_input = torch.from_numpy(gen_input_np).float().to(device)
        fake_inputs = generator(gen_input)
        # optimizer_d.zero_grad()
        sources, classes = discriminator(fake_inputs.detach())
        source_targets = torch.zeros(sources.shape).to(device)
        source_loss = criterion_source(sources, source_targets) 
        class_loss = criterion_class(classes, fake_targets)
        d_fake_loss = (source_loss + class_loss)/2
        d_fake_loss.backward()
        optimizer_d.step()
        d_fake_source_accuracy = np.round(torch.mean((torch.where(sources > 0.5, torch.tensor(1.).to(device), torch.tensor(0.).to(device)) == source_targets).float()).item()*100)
        d_fake_class_accuracy = np.round(torch.mean((classes.max(1)[1] == fake_targets).float()).item()*100)
        d_fake_source_accuracies.append(d_fake_source_accuracy)
        d_fake_class_accuracies.append(d_fake_class_accuracy)
        d_fake_source_losses.append(source_loss.item())
        d_fake_class_losses.append(class_loss.item())
        
        # Train generator
        optimizer_g.zero_grad()
        sources, classes = discriminator(fake_inputs)
        source_targets = torch.ones(sources.shape).to(device)
        source_loss = criterion_source(sources, source_targets) 
        class_loss = criterion_class(classes, fake_targets)
        g_loss = source_loss + class_loss
        g_loss.backward()
        optimizer_g.step()
        g_source_accuracy = np.round(torch.mean((torch.where(sources > 0.5, torch.tensor(1.).to(device), torch.tensor(0.).to(device)) == source_targets).float()).item()*100)
        g_class_accuracy = np.round(torch.mean((classes.max(1)[1] == fake_targets).float()).item()*100)
        g_source_accuracies.append(g_source_accuracy)
        g_class_accuracies.append(g_class_accuracy)
        g_source_losses.append(source_loss.item())
        g_class_losses.append(class_loss.item())

    print("Discriminator source accuracy and loss on real data is {:.2f}% and {:.3f}.".format(np.mean(d_real_source_accuracies), np.mean(d_real_source_losses)))
    print("Discriminator class accuracy and loss on real data is {:.2f}% and {:.3f}.".format(np.mean(d_real_class_accuracies), np.mean(d_real_class_losses)))
    print("Discriminator source accuracy and loss on fake data is {:.2f}% and {:.3f}.".format(np.mean(d_fake_source_accuracies), np.mean(d_fake_source_losses)))
    print("Discriminator class accuracy and loss on fake data is {:.2f}% and {:.3f}.".format(np.mean(d_fake_class_accuracies), np.mean(d_fake_class_losses)))
    print("Generator source accuracy and loss is {:.2f}% and {:.3f}.".format(np.mean(g_source_accuracies), np.mean(g_source_losses)))
    print("Generator class accuracy and loss is {:.2f}% and {:.3f}.".format(np.mean(g_class_accuracies), np.mean(g_class_losses)))

    fakes = []
    for i in range(100):
        fake_targets = np.random.randint(0, 10, 100)
        gen_input_np = np.c_[np.random.randn(100, 100), np.eye(10)[fake_targets]]
        gen_input = torch.from_numpy(gen_input_np).float().to(device)
        fake_inputs = generator(gen_input)
        fakes.append((reverse_norm_fake_image(fake_inputs.detach().cpu()), fake_targets))

    net.eval()
    preds = np.empty((0,10))
    all_targets = np.empty((0,100))
    for batch_idx, (inputs, targets) in enumerate(fakes):
        inputs = norm_orig_image(inputs).to(device)
        output = net(inputs)
        preds = np.r_[preds, output.detach().cpu().numpy()]
        all_targets = np.r_[all_targets, targets[None, :]]

    inception_score = calculate_inception_score(preds)[0]

    print("Resnet Accuracy: {}".format(np.mean(np.argmax(preds, axis = 1) == all_targets.flatten())))
    print("Inception Score: {}\n".format(inception_score))

    if inception_score > max_inception_score:
        max_inception_score = inception_score
        torch.save(generator.state_dict(), "generator160_small.pt")
        torch.save(discriminator.state_dict(), "discriminator160_small.pt")

Epoch 100:
Batch Index: 0
Batch Index: 100
Batch Index: 200
Batch Index: 300
Batch Index: 400
Discriminator source accuracy and loss on real data is 64.26% and 0.638.
Discriminator class accuracy and loss on real data is 58.62% and -5.843.
Discriminator source accuracy and loss on fake data is 63.16% and 0.642.
Discriminator class accuracy and loss on fake data is 95.06% and -0.948.
Generator source accuracy and loss is 36.75% and 0.958.
Generator class accuracy and loss is 95.05% and -0.948.
Resnet Accuracy: 0.4963
Inception Score: 4.247519170711181

Epoch 101:
Batch Index: 0
Batch Index: 100
Batch Index: 200
Batch Index: 300
Batch Index: 400
Discriminator source accuracy and loss on real data is 60.48% and 0.670.
Discriminator class accuracy and loss on real data is 59.13% and -5.886.
Discriminator source accuracy and loss on fake data is 59.93% and 0.671.
Discriminator class accuracy and loss on fake data is 95.52% and -0.952.
Generator source accuracy and loss is 39.91% and 0.895.
