In [None]:
!nvidia-smi

Fri Jul 22 13:11:34 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
device= 'cuda' if torch.cuda.is_available() else 'cpu'

class Generator(nn.Module):
    def __init__(self, noise_dim, z_dim, img_channels):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.ConvTranspose2d(in_channels=noise_dim, out_channels=z_dim*8, kernel_size=(4, 4), stride=(1, 1), padding=(0, 0), bias=False),
            nn.BatchNorm2d(z_dim*8),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(in_channels=z_dim*8, out_channels=z_dim * 4, kernel_size=(4, 4), stride=(2, 2),
                               padding=(1, 1), bias=False),
            nn.BatchNorm2d(z_dim * 4),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(in_channels=z_dim * 4, out_channels=z_dim * 2, kernel_size=(4, 4), stride=(2, 2),
                               padding=(1, 1), bias=False),
            nn.BatchNorm2d(z_dim * 2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(in_channels=z_dim * 2, out_channels=z_dim, kernel_size=(4, 4), stride=(2, 2),
                               padding=(1, 1), bias=False),
            nn.BatchNorm2d(z_dim),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(in_channels=z_dim, out_channels=img_channels, kernel_size=(4, 4), stride=(2, 2),
                               padding=(1, 1), bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.network(x)


class Critic(nn.Module):
    def __init__(self, img_channels, z_dim):
        super(Critic, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(in_channels=img_channels, out_channels=z_dim, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1),
                      bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels=z_dim, out_channels=z_dim * 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1),
                      bias=False),
            nn.BatchNorm2d(z_dim * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels=z_dim * 2, out_channels=z_dim * 4, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1),
                      bias=False),
            nn.BatchNorm2d(z_dim * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels=z_dim * 4, out_channels=z_dim * 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1),
                      bias=False),
            nn.BatchNorm2d(z_dim * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels=z_dim * 8, out_channels=1, kernel_size=(4, 4), stride=(1, 1), padding=(0, 0),
                      bias=False),

        )

    def forward(self, x):
        return self.network(x)


gen = Generator(noise_dim=100, z_dim=128, img_channels=3).to(device)
critic = Critic(img_channels=3, z_dim=128).to(device)

print(gen.__class__.__name__.find('Conv'))


def weight_init(m):
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if class_name.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)


gen.apply(weight_init)
critic.apply(weight_init)

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Resize([64, 64]),
                                transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                                transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
                                transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)])
batch_size = 64
num_epochs = 10
lr = 3e-4
opt_gen = optim.RMSprop(gen.parameters(), lr=lr)
opt_critic = optim.RMSprop(critic.parameters(), lr=lr)
data_set = MNIST(root='dataset/', transform=transform, train=True, download=True)

data_load = DataLoader(dataset=data_set, batch_size=batch_size, shuffle=True)



-1
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw



In [None]:
# print(gen(torch.rand(batch_size, 100, 1, 1)).shape)
print(len(data_load))
for epoch in range(num_epochs):
    print(f'Epoch: {epoch}')
    for i, (data, _) in enumerate(data_load):
        data = data.to(device)
        # rand = torch.rand(batch_size, 100, 1, 1).to(device)
        # fake = gen(rand).detach()

        if i >= 300 and i % 300 == 0:
          print(f"{i}/{len(data_load)}")

        for _ in range(5):
            noise = torch.randn(batch_size, 100, 1, 1).to(device)
            fake = gen(noise).to(device)
            critic.zero_grad()
            critic_fake = critic(fake)
            critic_real = critic(data)
            critic_real = critic_real.reshape(-1)
            critic_fake = critic_fake.reshape(-1)
            critic_loss = -(torch.mean(critic_real) - torch.mean(critic_fake))
            critic_loss.backward(retain_graph=True)
            opt_critic.step()

            for p in critic.parameters():
                p.data.clamp_(-0.01, 0.01)

        # opt_gen.zero_grad()
        output_gen = critic(fake).reshape(-1)
        gen_loss = -torch.mean(output_gen)
        gen.zero_grad()
        gen_loss.backward()
        opt_gen.step()
    random_img = torch.randn((batch_size, 100, 1, 1)).to(device)
    img = gen(random_img)
    print(img.shape)
    grid = make_grid(img)
    print(grid.shape)
    name = str(epoch//4)+"Wgan.jpg"
    save_image(grid, name)

938
Epoch: 0
300/938
600/938
900/938
torch.Size([64, 3, 64, 64])
torch.Size([3, 530, 530])
Epoch: 1
300/938
600/938
900/938
torch.Size([64, 3, 64, 64])
torch.Size([3, 530, 530])
Epoch: 2
300/938
600/938


In [None]:
random_img = torch.randn((batch_size, 100, 1, 1)).to(device)
img = gen(random_img)
print(img.shape)
grid = make_grid(img)
print(grid.shape)
save_image(grid, "/content/drive/MyDrive/Colab Notebooks/WGAN.jpg")

torch.Size([64, 1, 64, 64])
torch.Size([3, 530, 530])


In [None]:
5//3