<a href="https://colab.research.google.com/github/Olivia-Feldman/NUGAN-DISTGAN/blob/Olivia/DIST_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import numpy as np




In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=False)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)

In [4]:
print(train_dataset.data.shape)

torch.Size([60000, 28, 28])


In [5]:
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder,self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True), nn.Linear(64, 12), nn.ReLU(True), nn.Linear(12, 3))
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(True),
            nn.Linear(12, 64),
            nn.ReLU(True),
            nn.Linear(64, 128),
            nn.ReLU(True), nn.Linear(128, 28 * 28), nn.Tanh())

    def forward(self, x):
      
        x = self.encoder(x)
        x = self.decoder(x)
        x = x.view(x.size(0),-1)
        return x

In [6]:
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)

     # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [7]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
    
    # forward method
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))

In [8]:
lr = 0.0002 
z_dim = 100
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)
print(mnist_dim)


G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)

# loss
criterion = nn.BCELoss() 

# optimizer
lr = 0.0002 
G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)

784




In [10]:
encoder = autoencoder()
#discriminator = Discriminator()
#generator = Generator()



In [11]:
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
discriminator = Discriminator(784)

In [12]:
adversarial_loss = torch.nn.BCELoss()

In [15]:
num_epochs = 30
for epoch in range(num_epochs):
    for  data in train_loader:
        imgs, _ = data
        imgs = imgs.view(imgs.size(0), -1)
        print(imgs.shape)
       
  
  
        # Adversarial ground truths
        valid = Variable(Tensor( imgs.shape[0], 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

        # Configure real images 
        real_imgs = Variable(imgs).type(Tensor)
      

        ## Train Generator##
        
        G_optimizer.zero_grad()

      # encoder to create encoded imgs for generator 
        encoded_imgs = encoder(real_imgs)
        print(encoded_imgs.shape)
        print(real_imgs.shape)
      

        # Loss measures generator's ability to fool the discriminator
        g_loss = 0.001 * adversarial_loss(discriminator(encoded_imgs), valid) + 0.999 * torch.nn.L1Loss()(
            encoded_imgs, real_imgs
        )

        g_loss.backward()
        G_optimizer.step()

       
        #  Train Discriminator ###
     

        D_optimizer.zero_grad()

        # Sample noise as discriminator ground truth
        z = Variable(transforms.Tensor(np.random.normal(0, 1, (imgs.shape[0], 784))))

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(z), valid)
        fake_loss = adversarial_loss(discriminator(encoded_imgs.detach()), fake)
        d_loss = 0.5 * (real_loss + fake_loss)

        d_loss.backward()
        D_optimizer.step()

        print(
            "[Epoch %d/%d], [D loss: %f] ,[G loss: %f]"%(epoch,num_epochs, d_loss.item(), g_loss.item())
        )

       
        

torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
[Epoch 0/30], [D loss: 0.693908] ,[G loss: 0.943003]
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
[Epoch 0/30], [D loss: 0.694470] ,[G loss: 0.941437]
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
[Epoch 0/30], [D loss: 0.694241] ,[G loss: 0.940748]
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
[Epoch 0/30], [D loss: 0.695587] ,[G loss: 0.945243]
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
[Epoch 0/30], [D loss: 0.696926] ,[G loss: 0.943164]
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
[Epoch 0/30], [D loss: 0.693272] ,[G loss: 0.944012]
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
[Epoch 0/30], [D loss: 0.694642] ,[G loss: 0.943688]
torch.Size([100, 784])
torch.Size([100, 784])
torch.Size([100, 784])
[Epoch 0/30], [D loss: 0.693112] ,[G loss: 0.943503]
torch.Size([100, 784])
t

KeyboardInterrupt: ignored