In [23]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

In [24]:
class Discriminator(nn.Module):
    def __init__(self, image_size):
        super().__init__()
        self.discriminator = nn.Sequential(
            nn.Linear(in_features=image_size, out_features=128),
            nn.LeakyReLU(0.1),
            nn.Linear(in_features=128, out_features=1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.discriminator(x)

In [25]:
class Generator(nn.Module):
    def __init__(self, z_dim, image_size):
        super().__init__()
        self.generator = nn.Sequential(
            nn.Linear(in_features=z_dim, out_features= 256),
            nn.LeakyReLU(0.1),
            nn.Linear(in_features=256, out_features=image_size),
            nn.Tanh()
        )
    def forward(self, x):
        return self.generator(x)


In [26]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [27]:
lr = 3e-4
z_dim = 64
image_size = 28 * 28 #MNIST dimention
batch_size = 64
num_epochs = 50

In [28]:
discriminator = Discriminator(image_size=image_size).to(device)
generator = Generator(z_dim=z_dim, image_size=image_size).to(device)

In [29]:
transforms = transforms.Compose(
    (
        transforms.ToTensor(),
        transforms.Normalize(0.5, 0.5)
    )
)

In [30]:
dataset = torchvision.datasets.MNIST(root='dataset/', transform=transforms, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST\raw\train-images-idx3-ubyte.gz


31.0%IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100.0%


Extracting dataset/MNIST\raw\train-images-idx3-ubyte.gz to dataset/MNIST\raw


3.5%


Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST\raw\train-labels-idx1-ubyte.gz


102.8%


Extracting dataset/MNIST\raw\train-labels-idx1-ubyte.gz to dataset/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST\raw\t10k-images-idx3-ubyte.gz


93.5%IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting dataset/MNIST\raw\t10k-labels-idx1-ubyte.gz to dataset/MNIST\raw




