In [1]:
import torch
import torchvision.transforms as transforms
import torchvision
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [2]:
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [3]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)

Files already downloaded and verified


In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(110,768),
            nn.ReLU())
        self.tconv1 = nn.Sequential(
            nn.ConvTranspose2d(768, 384, 5, 2),
            nn.BatchNorm2d(384),
            nn.ReLU())
        self.tconv2 = nn.Sequential(
            nn.ConvTranspose2d(384, 192, 5, 2),
            nn.BatchNorm2d(192),
            nn.ReLU())
        self.tconv3 = nn.Sequential(
            nn.ConvTranspose2d(192, 96, 4, 2),
            nn.BatchNorm2d(96),
            nn.ReLU())
        self.tconv4 = nn.Sequential(
            nn.ConvTranspose2d(96, 48, 4, 1),
            nn.BatchNorm2d(48),
            nn.ReLU())
        self.tconv5 = nn.Sequential(
            nn.ConvTranspose2d(48, 3, 4, 2),
            nn.Tanh())
        self._initialize_weights()
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.tconv1(out[:,:,None,None])
        out = self.tconv2(out)
        out = self.tconv3(out)
        out = self.tconv4(out)
        out = self.tconv5(out)
        return out
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

In [5]:
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        self.upsample = nn.Upsample(scale_factor = 2)
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, 3, 2, 1),
            nn.LeakyReLU(0.2))
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2))
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, 3, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2))
        self.conv4 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2))
        self.conv5 = nn.Sequential(
            nn.Conv2d(128, 256, 3, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2))
        self.conv6 = nn.Sequential(
            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2))
        self.dropout = nn.Dropout(0.5)
        self.activation_noise_sd = 0.1
        self.fc = nn.Linear(8*8*512, 11)
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(1)
        self._initialize_weights()

    def forward(self,x):
        if x.shape[-1] == 32:
            x = self.upsample(x)
        out = self.conv1(x)
        out = self.dropout(out + self.activation_noise_sd*torch.rand(out.shape).to(device))
        out = self.conv2(out)
        out = self.dropout(out + self.activation_noise_sd*torch.rand(out.shape).to(device))
        out = self.conv3(out)
        out = self.dropout(out + self.activation_noise_sd*torch.rand(out.shape).to(device))
        out = self.conv4(out)
        out = self.dropout(out + self.activation_noise_sd*torch.rand(out.shape).to(device))
        out = self.conv5(out)
        out = self.dropout(out + self.activation_noise_sd*torch.rand(out.shape).to(device))
        out = self.conv6(out)
        out = self.dropout(out + self.activation_noise_sd*torch.rand(out.shape).to(device))
        out = self.fc(out.reshape(-1,8*8*512))
        return self.sigmoid(out[:,0]), self.softmax(out[:,1:])
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
generator = Generator()
generator = generator.to(device)
discriminator = Discriminator()
discriminator = discriminator.to(device)

In [7]:
optimizer_d = optim.Adam(discriminator.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizer_g = optim.Adam(generator.parameters(), lr = 0.0002, betas = (0.5, 0.999))
criterion_source = nn.BCELoss()
criterion_class = nn.NLLLoss()
print("We shall train for {} epochs".format(50000/len(trainloader)))

We shall train for 100.0 epochs


In [8]:
d_real_source_accuracies_avg = []
d_real_class_accuracies_avg = []
d_real_source_losses_avg = []
d_real_class_losses_avg = []
d_fake_source_accuracies_avg = []
d_fake_class_accuracies_avg = []
d_fake_source_losses_avg = []
d_fake_class_losses_avg = []
g_source_accuracies_avg = []
g_class_accuracies_avg = []
g_source_losses_avg = []
g_class_losses_avg = []

for epoch in range(100):
    print("Epoch {}:".format(epoch))
    
    d_real_source_accuracies = []
    d_real_class_accuracies = []
    d_real_source_losses = []
    d_real_class_losses = []
    d_fake_source_accuracies = []
    d_fake_class_accuracies = []
    d_fake_source_losses = []
    d_fake_class_losses = []
    g_source_accuracies = []
    g_class_accuracies = []
    g_source_losses = []
    g_class_losses = []
    
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if batch_idx % 100 == 0:
            print("Batch Index: {}".format(batch_idx))
        # Train discriminator with real data
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer_d.zero_grad()
        sources, classes = discriminator(inputs)
        source_targets = torch.ones(sources.shape).to(device)
        source_loss = criterion_source(sources, source_targets)
        class_loss = 2*criterion_class(classes, targets)
        d_real_loss = (source_loss + class_loss)/2
        d_real_loss.backward()
        d_real_source_accuracy = torch.mean((torch.where(sources > 0.5, torch.tensor(1.).to(device), torch.tensor(0.).to(device)) == source_targets).float()).item()*100
        d_real_class_accuracy = torch.mean((classes.max(1)[1] == targets).float()).item()*100
        d_real_source_accuracies.append(d_real_source_accuracy)
        d_real_class_accuracies.append(d_real_class_accuracy)
        d_real_source_losses.append(source_loss.item())
        d_real_class_losses.append(class_loss.item())
        
        # Train discriminator with fake data
        fake_targets = np.random.randint(0, 10, inputs.shape[0])
        gen_input_np = np.c_[np.random.randn(inputs.shape[0], 100), np.eye(10)[fake_targets]]
        fake_targets = torch.from_numpy(fake_targets).long().to(device)
        gen_input = torch.from_numpy(gen_input_np).float().to(device)
        fake_inputs = generator(gen_input)
        sources, classes = discriminator(fake_inputs.detach())
        source_targets = torch.zeros(sources.shape).to(device)
        source_loss = criterion_source(sources, source_targets) 
        class_loss = criterion_class(classes, fake_targets)
        d_fake_loss = (source_loss + class_loss)/2
        d_fake_loss.backward()
        optimizer_d.step()
        d_fake_source_accuracy = np.round(torch.mean((torch.where(sources > 0.5, torch.tensor(1.).to(device), torch.tensor(0.).to(device)) == source_targets).float()).item()*100)
        d_fake_class_accuracy = np.round(torch.mean((classes.max(1)[1] == fake_targets).float()).item()*100)
        d_fake_source_accuracies.append(d_fake_source_accuracy)
        d_fake_class_accuracies.append(d_fake_class_accuracy)
        d_fake_source_losses.append(source_loss.item())
        d_fake_class_losses.append(class_loss.item())
        
        # Train generator
        optimizer_g.zero_grad()
        sources, classes = discriminator(fake_inputs)
        source_targets = torch.ones(sources.shape).to(device)
        source_loss = criterion_source(sources, source_targets) 
        class_loss = criterion_class(classes, fake_targets)
        g_loss = source_loss + class_loss
        g_loss.backward()
        optimizer_g.step()
        g_source_accuracy = np.round(torch.mean((torch.where(sources > 0.5, torch.tensor(1.).to(device), torch.tensor(0.).to(device)) == source_targets).float()).item()*100)
        g_class_accuracy = np.round(torch.mean((classes.max(1)[1] == fake_targets).float()).item()*100)
        g_source_accuracies.append(g_source_accuracy)
        g_class_accuracies.append(g_class_accuracy)
        g_source_losses.append(source_loss.item())
        g_class_losses.append(class_loss.item())

    torch.save(generator.state_dict(), "generator100_large.pt")
    torch.save(discriminator.state_dict(), "discriminator100_large.pt")

    print("Discriminator source accuracy and loss on real data is {:.2f}% and {:.3f}.".format(np.mean(d_real_source_accuracies), np.mean(d_real_source_losses)))
    print("Discriminator class accuracy and loss on real data is {:.2f}% and {:.3f}.".format(np.mean(d_real_class_accuracies), np.mean(d_real_class_losses)))
    print("Discriminator source accuracy and loss on fake data is {:.2f}% and {:.3f}.".format(np.mean(d_fake_source_accuracies), np.mean(d_fake_source_losses)))
    print("Discriminator class accuracy and loss on fake data is {:.2f}% and {:.3f}.".format(np.mean(d_fake_class_accuracies), np.mean(d_fake_class_losses)))
    print("Generator source accuracy and loss is {:.2f}% and {:.3f}.".format(np.mean(g_source_accuracies), np.mean(g_source_losses)))
    print("Generator class accuracy and loss is {:.2f}% and {:.3f}.\n".format(np.mean(g_class_accuracies), np.mean(g_class_losses)))
    
    d_real_source_accuracies_avg.append(np.mean(d_real_source_accuracies))
    d_real_class_accuracies_avg.append(np.mean(d_real_class_accuracies))
    d_real_source_losses_avg.append(np.mean(d_real_source_losses))
    d_real_class_losses_avg.append(np.mean(d_real_class_losses))
    d_fake_source_accuracies_avg.append(np.mean(d_fake_source_accuracies))
    d_fake_class_accuracies_avg.append(np.mean(d_fake_class_accuracies))
    d_fake_source_losses_avg.append(np.mean(d_fake_source_losses))
    d_fake_class_losses_avg.append(np.mean(d_fake_class_losses))
    g_source_accuracies_avg.append(np.mean(g_source_accuracies))
    g_class_accuracies_avg.append(np.mean(g_class_accuracies))
    g_source_losses_avg.append(np.mean(g_source_losses))
    g_class_losses_avg.append(np.mean(g_class_losses))

Epoch 0:
Batch Index: 0
Batch Index: 100
Batch Index: 200
Batch Index: 300
Batch Index: 400
Discriminator source accuracy and loss on real data is 55.97% and 1.001.
Discriminator class accuracy and loss on real data is 16.79% and -0.332.
Discriminator source accuracy and loss on fake data is 56.61% and 0.878.
Discriminator class accuracy and loss on fake data is 11.01% and -0.110.
Generator source accuracy and loss is 27.10% and 1.827.
Generator class accuracy and loss is 11.24% and -0.112.

Epoch 1:
Batch Index: 0
Batch Index: 100
Batch Index: 200
Batch Index: 300
Batch Index: 400
Discriminator source accuracy and loss on real data is 66.05% and 0.769.
Discriminator class accuracy and loss on real data is 20.57% and -0.409.
Discriminator source accuracy and loss on fake data is 68.16% and 0.637.
Discriminator class accuracy and loss on fake data is 11.67% and -0.117.
Generator source accuracy and loss is 9.79% and 2.216.
Generator class accuracy and loss is 12.10% and -0.121.

Epoch 2

In [9]:
data = np.r_[np.array(d_real_source_accuracies_avg),
      np.array(d_real_class_accuracies_avg),
      np.array(d_real_source_losses_avg),
      np.array(d_real_class_losses_avg),
      np.array(d_fake_source_accuracies_avg),
      np.array(d_fake_class_accuracies_avg),
      np.array(d_fake_source_losses_avg),
      np.array(d_fake_class_losses_avg),
      np.array(g_source_accuracies_avg),
      np.array(g_class_accuracies_avg),
      np.array(g_source_losses_avg),
      np.array(g_class_losses_avg)]

In [10]:
pd.DataFrame(data).to_csv("data_large.csv")