In [None]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device Used:", device)

batch_num = 100

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

# Data Loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_num, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_num, shuffle=False)


class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features * 2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features * 2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)

    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))


class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features // 2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features // 2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)

    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

# build network
h = 100
input_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)

G = Generator(g_input_dim = h, g_output_dim = input_dim).to(device)
D = Discriminator(input_dim).to(device)

# loss
criterion = nn.BCELoss()

# optimizer
lr = 0.0002
G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)


def D_train(x):

    # Train the discriminator
    D.zero_grad()   
    x_real, y_real = x.view(-1, input_dim), torch.ones(batch_num, 1)
    x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))
    D_output = D(x_real)
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output  
    z = Variable(torch.randn(batch_num, h).to(device))
    x_fake, y_fake = G(z), Variable(torch.zeros(batch_num, 1).to(device))
    D_output = D(x_fake)
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()

    return D_loss.data.item()


def G_train(x):
    # Train the generator without k
    G.zero_grad()
    z = Variable(torch.randn(batch_num, h).to(device))
    y = Variable(torch.ones(batch_num, 1).to(device))
    G_output = G(z)
    D_output = D(G_output)
    G_loss = criterion(D_output, y)
    G_loss.backward()
    G_optimizer.step()

    return G_loss.data.item()

def G_traink(x,k):
    # Train the generator with k
    G.zero_grad()
    z = Variable(torch.randn(batch_num, h).to(device))
    G_output = G(z)
    D_output = D(G_output)
    output = torch.reshape(D_output, (-1,))
    # print("Output after reshaping", output)
    if k > 0:
        output = torch.sort(output, descending=True)
    else:
        output = torch.sort(output, descending=False)

    output = torch.reshape(output.values, (output.values.shape[0], 1))
    # print("Output after Sorting", output)
    # print("Output", output[0].type)
    output = output[:k]
    y = torch.ones_like(output)
    G_loss = criterion(output, y)
    G_loss.backward()
    G_optimizer.step()

    return G_loss.data.item()



g_loss = []
d_loss = []
n_epoch = 100
decay_rate = 0.99
min_k = int(0.75*batch_num)
print("Min k val", min_k)
for epoch in range(1, n_epoch+1):
    k = batch_num
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(D_train(x))
        # G_losses.append(G_train(x))
        G_losses.append(G_traink(x,k))

        k = int(k * decay_rate)
        k = max(min_k, k)
        # print("updated k val", k)
    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
    g_loss.append(torch.mean(torch.FloatTensor(G_losses)).item())
    d_loss.append(torch.mean(torch.FloatTensor(D_losses)).item())
print("Gen Losses", g_loss)
print("Dis Losses", d_loss)

In [None]:
import matplotlib.pyplot as plt

epochval = [i for i in range(1, n_epoch + 1)]
plt.figure(2)
plt.plot(epochval, g_loss)
# naming the x axis
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Generator Loss')
plt.show()
plt.figure(3)
plt.plot(epochval, d_loss)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Discriminator Loss')
plt.show()