In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


In [1]:
!rm -rf ./logs/

In [2]:
class Discriminator(nn.Module):
  def __init__(self, in_channel, features):
    super().__init__()
    self.disc= nn.Sequential(                                                                   #[N, 3, 64, 64]
        nn.Conv2d(in_channel, features, kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(0.2),                                                                      #[N, 8, 32, 32]
        nn.Conv2d(features, features*2, kernel_size=4, stride=2, padding=1, bias=False,),
        nn.LeakyReLU(0.2),                                                                      #[N, 16,  16,  16]
        nn.Conv2d(features*2, features*4, kernel_size=4, stride=2, padding=1, bias=False,),
        nn.LeakyReLU(0.2),                                                                      #[N, 32, 8, 8]
        nn.Conv2d(features*4, features*8, kernel_size=4, stride=2, padding=1,bias=False,),
        nn.LeakyReLU(0.2),                                                                      #[N, 64, 4, 4]
        nn.Conv2d(features*8, 1, kernel_size=4, stride=2, padding=0),
        nn.Sigmoid(),                                                                           #[N, 1, 1, 1]

    )
  def forward(self, x):
    return self.disc(x)


In [3]:
class Generator(nn.Module):
  def __init__(self, noise_dim, img_channel, features):
    super().__init__()
    self.gen= nn.Sequential(                                                                          #[N, noise_dim, 1, 1]
        nn.ConvTranspose2d(noise_dim, features*16, kernel_size=4, stride=1, padding=0, bias=False),
        nn.ReLU(),                                                                                    #[N, 128, 4, 4]
        nn.ConvTranspose2d(features*16, features*8, kernel_size=4, stride=2, padding=1, bias=False),
        nn.ReLU(),                                                                                    #[N, 64, 8, 8]
        nn.ConvTranspose2d(features*8, features*4, kernel_size=4, stride=2, padding=1, bias=False),
        nn.ReLU(),                                                                                    #[N, 32, 16, 16]
        nn.ConvTranspose2d(features*4, features*2, kernel_size=4, stride=2, padding=1, bias=False),
        nn.ReLU(),                                                                                    #[N, 16, 32, 32]
        nn.ConvTranspose2d(features*2, img_channel, kernel_size=4, stride=2, padding=1, bias=False),
        nn.Tanh(),                                                                                    #[N, 3, 64, 64]

    )
  def forward(self, x):
        return self.gen(x)



In [4]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
LEARNING_RATE = 2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64  #1x64x64
CHANNELS_IMG = 1
NOISE_DIM = 100
NUM_EPOCHS = 1
FEATURES_DISC = 64
FEATURES_GEN = 64

In [7]:
transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE), #1x64x64
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for i in range(CHANNELS_IMG)], [0.5 for i in range(CHANNELS_IMG)]
        ),
    ]
)

In [8]:
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms, download=True)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 104124098.39it/s]


Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 29588591.55it/s]


Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 26539358.37it/s]


Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1173857.22it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw






In [9]:
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)

In [10]:
initialize_weights(gen)
initialize_weights(disc)

In [11]:
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE)
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE)
criterion = nn.BCELoss()

In [12]:
fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)

In [13]:
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")

In [14]:
step = 0
gen.train()
disc.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (5): LeakyReLU(negative_slope=0.2)
    (6): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
    (9): Sigmoid()
  )
)

In [None]:

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, _) in enumerate(dataloader):

        real = real.to(device)
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))

        noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
        fake = gen(noise)
        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()


        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

Epoch [0/1] Batch 0/469                   Loss D: 0.6905, loss G: 0.6857


In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir logs/fake

In [None]:
#Discriminator
input = torch.randn(32,3,64,64) #image shape 3*64*64 and batch size=32
model = Discriminator(3, 8)
print(model(input).shape)

torch.Size([32, 1, 1, 1])


In [None]:

x = torch.randn(32,3,64,64)
print(x.shape)
features = 8
m= nn.Conv2d(3, features, kernel_size=4, stride=2, padding=1)
n = nn.LeakyReLU(0.2)
x= m(x)
x = n(x)
print(x.shape)

m= nn.Conv2d(features, features*2, kernel_size=4, stride=2, padding=1, bias=False,)
n = nn.LeakyReLU(0.2)
x= m(x)
x = n(x)
print(x.shape)

m= nn.Conv2d(features*2, features*4, kernel_size=4, stride=2, padding=1, bias=False,)
n = nn.LeakyReLU(0.2)
x= m(x)
x = n(x)
print(x.shape)

m= nn.Conv2d(features*4, features*8, kernel_size=4, stride=2, padding=1, bias=False,)
n = nn.LeakyReLU(0.2)
x= m(x)
x = n(x)
print(x.shape)

m= nn.Conv2d(features*8, 1, kernel_size=4, stride=2, padding=0)
n = nn.Sigmoid()
x= m(x)
x = n(x)
print(x.shape)



torch.Size([32, 3, 64, 64])
torch.Size([32, 8, 32, 32])
torch.Size([32, 16, 16, 16])
torch.Size([32, 32, 8, 8])
torch.Size([32, 64, 4, 4])
torch.Size([32, 1, 1, 1])


In [None]:
input = torch.randn(32,100,1,1)
model = Generator(100,3,8)
print(model(input).shape)

torch.Size([32, 3, 64, 64])


In [None]:
noise_dim = 100
batch_size = 32
features = 8
img_channel = 3
x = torch.randn(batch_size ,noise_dim,1,1)  #batch_size x noise_dim x 1 x 1
m= nn.ConvTranspose2d(noise_dim, features*16, kernel_size=4, stride=1, padding=0, bias=False)
n = nn.ReLU()
x = m(x)
x= n(x)
print(x.shape) #[32, 128, 4, 4]

m= nn.ConvTranspose2d(features*16, features*8, kernel_size=4, stride=2, padding=1, bias=False)
n = nn.ReLU()
x = m(x)
x= n(x)
print(x.shape) #[32, 64, 8, 8]

m= nn.ConvTranspose2d(features*8, features*4, kernel_size=4, stride=2, padding=1, bias=False)
n = nn.ReLU()
x = m(x)
x= n(x)
print(x.shape) #[32, 32, 16, 16]

m= nn.ConvTranspose2d(features*4, features*2, kernel_size=4, stride=2, padding=1, bias=False)
n = nn.ReLU()
x = m(x)
x= n(x)
print(x.shape) #[32, 16, 32, 32]

m= nn.ConvTranspose2d(features*2, img_channel, kernel_size=4, stride=2, padding=1, bias=False)
n = nn.Tanh()
x = m(x)
x= n(x)
print(x.shape) #[32, 3, 64, 64]

torch.Size([32, 128, 4, 4])
torch.Size([32, 64, 8, 8])
torch.Size([32, 32, 16, 16])
torch.Size([32, 16, 32, 32])
torch.Size([32, 3, 64, 64])
