In [1]:
import torch
import torch.nn as nn

In [2]:
#conv, batch_norm, Leaky ReLU

class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
         super(Discriminator, self).__init__()
         self.disc= nn.Sequential(
             #Input: N* channels_img * 64 * 64
             nn.Conv2d(
                 channels_img, features_d, kernel_size=4,stride=2,padding=1
             ),
             nn.LeakyReLU(0.2),
             self._block(features_d, features_d*2, 4,2,1),
             self._block(features_d*2, features_d*4, 4,2,1),
             self._block(features_d*4, features_d*8, 4,2,1),
             nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), #1*1
             nn.Sigmoid(),
         )
        
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
    
    def forward(self,x):
        return self.disc(x)
    

In [10]:
class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # Input: N x channels_noise x 1 x 1
            self._block(channels_noise, features_g * 16, 4, 1, 0),  # img: 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1),  # img: 8x8
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # img: 16x16
            self._block(features_g * 4, features_g * 2, 4, 2, 1),  # img: 32x32
            nn.ConvTranspose2d(
                features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
            ),
            # Output: N x channels_img x 64 x 64
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            # nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)



In [11]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
            
#test if things are going accordig to expectations
def test():
    N, in_channels, H, W = 8, 3, 64, 64
    noise_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
    gen = Generator(noise_dim, in_channels, 8)
    z = torch.randn((N, noise_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
    print("Success")

In [12]:
test()

Success, tests passed!


# train

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [17]:
#Hyper Params
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
#from number in the paper
learning_rate= 2e-4
batch_size= 128
image_size= 64
channels_img= 1
z_dim= 100
num_epochs= 5
features_disc= 64
features_gen= 64

In [18]:
transforms= transforms.Compose(
    [
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5 for _ in range(channels_img)],[0.5 for _ in range(channels_img)])
    ]
)

AttributeError: 'Compose' object has no attribute 'Compose'

In [None]:
dataset= dataset.MNIST(root="dataset/", train=True, transform= transforms,
                      download= True)
loader= DataLoader(dataset, batch_size= batch_size, shuffle=True)
gen= Generator(noise_dim)