<a href="https://colab.research.google.com/github/AntonYermilov/sound-visualization/blob/master/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision import transforms
import numpy as np
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
bs = 100

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor()])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


9920512it [00:01, 9014970.73it/s]                            


Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw


  0%|          | 0/28881 [00:00<?, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


32768it [00:00, 137026.50it/s]           
  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


1654784it [00:00, 2334121.66it/s]                            
0it [00:00, ?it/s]

Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


8192it [00:00, 52475.08it/s]            

Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw
Processing...
Done!





In [0]:
def gray_image_target_eval(x):
    return x.mean(1) * 2 - 0.28

In [0]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim, target_evaluator):
        super(Generator, self).__init__()
        self.target_evaluator = target_evaluator       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
    
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = self.fc4(x)
        target = self.target_evaluator(x)
        return torch.tanh(x), target

In [0]:
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [118]:
z_dim = 32
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)
G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim, target_evaluator=gray_image_target_eval).to(device)
D = Discriminator(mnist_dim).to(device)



In [103]:
G

Generator(
  (fc1): Linear(in_features=32, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=1024, bias=True)
  (fc4): Linear(in_features=1024, out_features=784, bias=True)
)

In [69]:
D

Discriminator(
  (fc1): Linear(in_features=784, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=1, bias=True)
)

In [0]:

criterion = nn.BCELoss() 

lr = 0.0002 
G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)

In [0]:
def D_train(x):
    D.zero_grad()

    # train discriminator on real
    x_real, y_real = x.view(-1, mnist_dim), torch.ones(bs, 1)
    x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))

    D_output = D(x_real)
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output

    # train discriminator on facke
    z = Variable(torch.randn(bs, z_dim).to(device))
    x_fake, _ = G(z)
    y_fake = Variable(torch.zeros(bs, 1).to(device))

    D_output = D(x_fake)
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output

    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()
        
    return  D_loss.data.item()

In [0]:
def G_train(x):
    G.zero_grad()

    z = Variable(torch.randn(bs, z_dim).to(device))
    y = Variable(torch.ones(bs, 1).to(device))

    G_output, _ = G(z)
    D_output = D(G_output)
    G_loss = criterion(D_output, y)

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()
        
    return G_loss.data.item()

In [121]:
n_epoch = 100
for epoch in range(1, n_epoch+1):           
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))

[1/100]: loss_d: 1.135, loss_g: 2.474
[2/100]: loss_d: 0.776, loss_g: 2.611
[3/100]: loss_d: 0.724, loss_g: 2.494
[4/100]: loss_d: 0.795, loss_g: 2.124
[5/100]: loss_d: 0.711, loss_g: 2.320
[6/100]: loss_d: 0.717, loss_g: 2.261
[7/100]: loss_d: 0.820, loss_g: 2.016
[8/100]: loss_d: 0.753, loss_g: 2.096
[9/100]: loss_d: 0.764, loss_g: 1.923
[10/100]: loss_d: 0.762, loss_g: 1.969
[11/100]: loss_d: 0.756, loss_g: 2.017
[12/100]: loss_d: 0.780, loss_g: 1.959
[13/100]: loss_d: 0.746, loss_g: 2.041
[14/100]: loss_d: 0.775, loss_g: 1.995
[15/100]: loss_d: 0.747, loss_g: 2.132
[16/100]: loss_d: 0.677, loss_g: 2.289
[17/100]: loss_d: 0.670, loss_g: 2.305
[18/100]: loss_d: 0.675, loss_g: 2.195
[19/100]: loss_d: 0.697, loss_g: 2.124
[20/100]: loss_d: 0.693, loss_g: 2.103
[21/100]: loss_d: 0.712, loss_g: 2.079
[22/100]: loss_d: 0.723, loss_g: 2.177
[23/100]: loss_d: 0.662, loss_g: 2.222
[24/100]: loss_d: 0.624, loss_g: 2.475
[25/100]: loss_d: 0.604, loss_g: 2.418
[26/100]: loss_d: 0.600, loss_g: 2

In [0]:
with torch.no_grad():
    test_z = Variable(torch.randn(bs, z_dim).to(device))
    generated, _ = G(test_z)

    save_image(generated.view(generated.size(0), 1, 28, 28), './sample_' + '.png')

In [0]:
torch.save(G.state_dict(), "G100.pt")