In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.functional as F
import torch.nn as nn
# from torchvision.utils /import make_grid, save_image
from torchvision.utils import save_image, make_grid
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Available device is: {device}')
torch.manual_seed(42)

Available device is: cuda


<torch._C.Generator at 0x7981c2fcb4b0>

In [11]:
batch_size = 128
dz = 128
z_dim = dz
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(False)
z_fixed= 128
epochs = 20
lr= 5e-4
n_critic=1 
clip_value= 0.01
img_size= 28
channels= 1
img_shape = (channels, img_size, img_size)
import os
os.makedirs('wgan_dcgan', exist_ok= True)

In [12]:
## DataLoader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_data = MNIST(root='.', train= True, download= True, transform= transform)
train_loader = DataLoader(train_data, batch_size= batch_size, shuffle= True)

100%|██████████| 9.91M/9.91M [00:00<00:00, 34.3MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 988kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.92MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 11.1MB/s]


In [17]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.init_size = img_size//4
        self.fc = nn.Linear(z_dim, 128*self.init_size*self.init_size)
        self.conv_block = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 64, kernel_size= 4, stride= 2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, kernel_size= 4, stride= 2, padding= 1),
            nn.Tanh()
        )
        return

    def forward(self, z):
        x = self.fc(z)
        x = x.view(z.size(0), 128, self.init_size, self.init_size)
        return self.conv_block(x)

In [18]:
### critric
class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(128*7*7, 1)
        )
        return

    def forward(self, x):
        return self.model(x)

In [20]:
generator = Generator().to(device)
critic = Critic().to(device)

## optimizers 
g_optim = optim.RMSprop(generator.parameters(), lr= lr)
c_optim = optim.RMSprop(critic.parameters(), lr= lr)


In [27]:
## tranining loop
## most critical for every loop
for epoch in range(1, epochs+1):
    for i, (real_img, _) in enumerate(train_loader):
        real_imgs = real_img.to(device)
        b_size = real_imgs.size(0)

        ## training critic
        for _ in range(n_critic):
            z= torch.randn(b_size, z_dim, device= device)
            fake_imgs = generator(z)

            loss_c = -torch.mean(critic(real_imgs)) + torch.mean(critic(fake_imgs.detach()))

            c_optim.zero_grad()
            loss_c.backward()
            c_optim.step()
            # weight clipping fo rLipschitz constraint
            for p in critic.parameters():
                p.data.clamp_(-clip_value, clip_value)

    ## training the generator
    z = torch.randn(b_size, z_dim, device= device)
    gen_imgs = generator(z)
    loss_g = -torch.mean(critic(gen_imgs))
    g_optim.zero_grad()
    loss_g.backward()
    g_optim.step()

    if epoch % 1 ==0:
        print(f'[Epoch {epoch}/{epochs} [Batch {i}/{len(train_loader)}'
              f'[Critic: {loss_c.item():.4f}] [Gen: {loss_g.item():.4f}]')


    # save samples every epoch
    generator.eval()
    with torch.no_grad():
        z = torch.randn(64, z_dim, device= device)
        samples = generator(z)
        samples = samples * 0.5 + 0.5
        save_image(samples, f'wgan_dcgan/epoch_{epoch}.png', nrow= 8)
    generator.train()


[Epoch 1/20 [Batch 468/469[Critic: -0.4514] [Gen: 0.0808]
[Epoch 2/20 [Batch 468/469[Critic: -0.4302] [Gen: 0.0458]
[Epoch 3/20 [Batch 468/469[Critic: -0.4241] [Gen: -0.0419]
[Epoch 4/20 [Batch 468/469[Critic: -0.4668] [Gen: 0.0656]
[Epoch 5/20 [Batch 468/469[Critic: -0.5180] [Gen: 0.1051]
[Epoch 6/20 [Batch 468/469[Critic: -0.4410] [Gen: 0.0640]
[Epoch 7/20 [Batch 468/469[Critic: -0.4582] [Gen: 0.0368]
[Epoch 8/20 [Batch 468/469[Critic: -0.5003] [Gen: 0.0978]
[Epoch 9/20 [Batch 468/469[Critic: -0.3846] [Gen: 0.0895]
[Epoch 10/20 [Batch 468/469[Critic: -0.3881] [Gen: 0.0827]
[Epoch 11/20 [Batch 468/469[Critic: -0.4477] [Gen: 0.0117]
[Epoch 12/20 [Batch 468/469[Critic: -0.4478] [Gen: 0.0343]
[Epoch 13/20 [Batch 468/469[Critic: -0.4490] [Gen: 0.0415]
[Epoch 14/20 [Batch 468/469[Critic: -0.4025] [Gen: 0.0375]
[Epoch 15/20 [Batch 468/469[Critic: -0.4466] [Gen: 0.0518]
[Epoch 16/20 [Batch 468/469[Critic: -0.3661] [Gen: 0.0504]
[Epoch 17/20 [Batch 468/469[Critic: -0.3853] [Gen: 0.0372]
[Epoc

Question 1: Dataset size and Imbalance

In [28]:
dict_class = {int(x.split(' ')[0]): 0 for x in train_data.classes}
print(dict_class)
for _, y in train_data:
    dict_class[y] += 1
print(dict_class)
max_class = (-1, -1)
min_class = (-1, len(train_data))


{0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0}
{0: 5923, 1: 6742, 2: 5958, 3: 6131, 4: 5842, 5: 5421, 6: 5918, 7: 6265, 8: 5851, 9: 5949}


In [36]:
## printing post transform training batch
for (real_imgs, _) in train_loader:
    # print(real_imgs)
    x = real_imgs
    mu = x.mean()
    sigma = x.std()
    mi = x.min()
    mx = x.min()
    p = (x.abs() <= 0.5).float().mean().item()
    # print('mean', torch.mean(real_imgs))
    print(f'mu {mu} sigma {sigma} min {mi} max {mx} p {p}')
    # print('signma', torch.)
    break

mu -0.7486591935157776 sigma 0.603895902633667 min -1.0 max -1.0 p 0.0515485480427742


In [29]:
for k, v in dict_class.items():
    if v > max_class[1]:
        max_class = (k, v)
    if v < min_class[1]:
        min_class = (k, v)
print(max_class,min_class)

total = sum(dict_class.values())
print('total items are:', total)
delta = max_class[1] - min_class[1]
print(delta)
imbalance = (delta/(0.1 * total) ) * 100
print('Total imbalance is: ', imbalance)

(1, 6742) (5, 5421)
total items are: 60000
1321
Total imbalance is:  22.01666666666667


Question 3: Model size and memory

In [40]:
# computing total number of learnable parameters
g_param_count = sum(p.numel() for p in generator.parameters() if p.requires_grad)
c_param_count = sum(p.numel() for p in critic.parameters() if p.requires_grad)
print(f'Total parameters in g: {g_param_count} and critic: {c_param_count}')
MB = 4*(g_param_count + c_param_count) / 1e6
print(MB)

Total parameters in g: 941633 and critic: 138817
4.3218


Question 4: One training cycle delta

In [45]:
def get_params_vector(model):
    return torch.cat([p.detach().view(-1) for p in model.parameters()])

In [44]:
## one training cycle rewritting
q = 5
for epoch in range(1, epochs +1):
    for i, (real_imgs, _) in enumerate(train_loader):
        real_imgs = real_imgs.to(device)
        b_size = real_imgs.size(0)

        ## training the critic
        ld = 0
        critic_param_before = get_params_vector
        for _ in range(q):
            z = torch.randn(b_size, z_dim, device= device)
            fake_imgs = generator(z)
            loss_c = -torch.mean(critic(real_imgs)) +  torch.mean(critic(fake_imgs.detach()))
            ld += loss_c.item()
            c_optim.zero_grad()
            loss_c.backward()
            c_optim.step()
            print(f'ld {ld} loss: {loss_c.item()}')

        ## generator thing
        z = torch.randn(b_size, z_dim, device= device)
        fake_imgs = generator(z)
        loss_g = torch.mean(critic(fake_imgs))

        g_optim.zero_grad()
        loss_g.backward()
        g_optim.step()
            # break
        # break
    break
                                 
        