In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

2024-06-26 10:30:46.786380: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-26 10:30:46.786504: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-26 10:30:46.967526: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
class Critic(nn.Module):
    def __init__(self, channels, features):
        super().__init__()
        self.disc = nn.Sequential(
            #Input 64x64
            nn.Conv2d(channels, features, kernel_size = 4, stride = 2, padding = 1),
            nn.LeakyReLU(0.2),
            # 32x32
            self._block(features, features*2, 4, 2, 1),# 16x16
            self._block(features*2, features*4, 4, 2, 1),# 8x8
            self._block(features*4, features*8, 4, 2, 1), # 4x4
            nn.Conv2d(features*8, 1, kernel_size = 4, stride = 2, padding = 0), # 1x1
        )
        
    def _block(self, in_channels, out_channels, kernel, stride, padding):
        return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, stride, padding),
        nn.LeakyReLU(0.2),
        #nn.BatchNorm2d(out_channels)
        )
    
    def forward(self, x):
        return self.disc(x)
    
    
class Generator(nn.Module):
    def __init__(self, z_dim, channels, features):
        super().__init__()
        self.net = nn.Sequential(
            # Input 1x1
            self._block(z_dim, features*16, 4, 1, 0),# 4x4
            self._block(features*16, features*8, 4, 2, 1), # 8x8
            self._block(features*8, features*4, 4, 2, 1), # 16x16
            self._block(features*4, features*2, 4, 2, 1), # 32x32
            nn.ConvTranspose2d(features*2, # 64x64
                              channels,
                              4,
                              2,
                              1),
            nn.Sigmoid()
        )
        
    def _block(self, in_channels, out_channels, kernel, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, 
                              out_channels,
                              kernel,
                              stride,
                              padding),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )
    
    def forward(self, x):
        return self.net(x)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LR = 2e-4
BS = 128
IMGS = 64
CHANNELS = 1
ZDIM = 16
EPOCHS = 100
FEATURES_DISC = 64
FEATURES_GEN = 64
CRITIC_ITER = 5
WEIGHT_CLIP = 0.01

In [4]:
transform = transforms.Compose(
        
                [   transforms.Resize(IMGS),
                    transforms.ToTensor()]
                )

# dataset = datasets.CIFAR10('../data', train=True, download=True,
#                        transform=transform)
dataset = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)

# num = (5, 1)
# dataset = [x for x in dataset if x[1] in num]
train_loader = torch.utils.data.DataLoader(dataset, batch_size = BS)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 12880761.74it/s]


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 364964.38it/s]


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 3370908.83it/s]


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1450916.13it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw






In [5]:
gen = Generator(ZDIM, CHANNELS, FEATURES_GEN)
critic = Critic(CHANNELS, FEATURES_DISC)

gen.to(device)
critic.to(device)

loss_fn = nn.BCELoss()

optim_gen = torch.optim.RMSprop(gen.parameters(), lr = LR)
optim_critic = torch.optim.RMSprop(critic.parameters(), lr = LR)

writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

In [None]:
for epoch in tqdm(range(EPOCHS)):
    for index, (image, label) in enumerate(train_loader):
        gen.train()
        critic.train()
        
        for  _ in range(CRITIC_ITER):
            sample = torch.randn(BS, ZDIM, 1, 1).to(device)
            fake = gen(sample)
            
            fake_output = critic(fake.detach()).reshape(-1)
            real_output = critic(image.to(device)).reshape(-1)
            loss_critic = -(torch.mean(real_output) - torch.mean(fake_output))
            critic.zero_grad()
            loss_critic.backward()
            optim_critic.step()
            
            for p in critic.parameters():
                p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)
        
        fake_output = critic(fake).reshape(-1)
        gen_loss = -torch.mean(fake_output)
        gen.zero_grad()
        gen_loss.backward()
        optim_gen.step()
        
        if index % 100 == 0:
            print(
            f"Epoch [{epoch}/{EPOCHS}] Batch {index}/{len(train_loader)} \
            Loss D: {dis_loss.item():.4f}, Loss G: {gen_loss.item():.4f}"
            )
            
            with torch.no_grad():
                gen.eval()
                sample = torch.randn(BS, ZDIM, 1, 1).to(device)
                fake = gen(sample)
                img_grid_real = torchvision.utils.make_grid(
                    image[:32], normalize = True)
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize = True)
                
                writer_real.add_image("Real", img_grid_real, global_step = step)
                writer_fake.add_image("Real", img_grid_fake, global_step = step)
            step += 1
            
    plt.figure(figsize=(20, 8))

    for i in range(1, 6):
        plt.subplot(2,5,i)
        plt.imshow(image[i - 1].cpu().detach().permute(1,2,0).numpy())
        plt.axis('off')
    for i in range(6, 11):
        plt.subplot(2,5, i)
        plt.imshow(fake[i - 6].cpu().detach().permute(1,2,0).numpy())
        plt.axis('off')
    plt.show()
            