In [None]:
import torch
from torch import nn

## Generator

In [None]:
class Generatror(nn.Module):
    def __init__(self, z_dim: int, img_dim: int, hidden_dim: int) -> None:
        super().__init__()

        self.gen_block1 = nn.Sequential(
            nn.Linear(z_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
        )

        self.gen_block2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.BatchNorm1d(hidden_dim * 2),
            nn.ReLU(inplace=True),
        )

        self.gen_block3 = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim * 4),
            nn.BatchNorm1d(hidden_dim * 4),
            nn.ReLU(inplace=True),
        )

        self.gen_block4 = nn.Sequential(
            nn.Linear(hidden_dim * 4, hidden_dim * 8),
            nn.BatchNorm1d(hidden_dim * 8),
            nn.ReLU(inplace=True),
        )

        self.last_block = nn.Sequential(
            nn.Linear(hidden_dim * 8, img_dim), nn.Sigmoid()
        )

    def forward(self, x):
        x = self.gen_block1(x)
        x = self.gen_block2(x)
        x = self.gen_block3(x)
        x = self.gen_block4(x)
        x = self.last_block(x)
        return x
        

In [None]:
gen = Generatror(
    z_dim=10,
    img_dim = 784,
    hidden_dim = 64).to('cuda')
next(gen.parameters()).is_cuda

In [None]:
try:
    gen(torch.randn(1, 10, device="cuda"))
    #! Expected more than 1 value per channel when training, got input size torch.Size([1, 128])
except:
    print("input more than 1 sample")
# raise ValueError('input more than 1 sample')

In [None]:
generated_img = gen(torch.randn(2,10,device='cuda'))[1]

from matplotlib import pyplot as plt
generated_img = generated_img.to('cpu').detach().numpy()
plt.imshow(generated_img.reshape(28,28), cmap='gray')


## Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, hidden_dim, img_dim) -> None:
        super().__init__()
        
        self.disc_block1 = nn.Sequential(
            nn.Linear(img_dim, hidden_dim*8),
            nn.LeakyReLU(0.2)
        )
        
        self.disc_block2 = nn.Sequential(
            nn.Linear(hidden_dim*8, hidden_dim*4),
            nn.LeakyReLU(0.2)
        )
        
        self.disc_block3 = nn.Sequential(
            nn.Linear(hidden_dim*4, hidden_dim*2),
            nn.LeakyReLU(0.2)
        )
        
        self.disc_block4 = nn.Sequential(
            nn.Linear(hidden_dim*2, hidden_dim),
            nn.LeakyReLU(0.2)
        )
        self.linear_classifier = nn.Linear(hidden_dim, 1)
        
    def forward(self, x):
        x = self.disc_block1(x)
        x = self.disc_block2(x)
        x = self.disc_block3(x)
        x = self.disc_block4(x)
        x = self.linear_classifier(x)
        
        return x

In [None]:
disc = Discriminator(
    hidden_dim=64,
    img_dim=784
).to('cuda')

In [None]:
rand_img = torch.rand(1,28*28,device = 'cuda')

In [None]:
disc(rand_img)

## Loss, optimizer


In [None]:
# Set your parameters
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 10
img_dim = 28*28
hidden_dim = 64
display_step = 500
batch_size = 64
lr = 0.00001

from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms

# Load MNIST dataset as tensors
dataloader = DataLoader(
    MNIST('.', download=False, transform=transforms.ToTensor()),
    batch_size=batch_size,
    shuffle=True)

### DO NOT EDIT ###
device = 'cuda'

gen = Generatror(z_dim= z_dim, img_dim=img_dim, hidden_dim=hidden_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)

disc =Discriminator(img_dim=img_dim, hidden_dim=hidden_dim).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr= lr)

In [None]:
torch.cuda.manual_seed_all(1000)
noise = torch.randn(batch_size, z_dim, device= device)
real_img = torch.randn(batch_size, img_dim, device= device)
real_img.shape


In [None]:
next(disc.parameters()).is_cuda

#### disc loss

$$ L_{disc} = - \frac{1}{N} \sum_{i=1}^N \left[ \log D(x_i) + \log (1 - D(G(z_i))) \right] $$

In [None]:

def get_disc_loss(disc, gen, criterion, real_img, batch_size, z_dim, device):
    ## claculated Gen loss
    #* noise or Z
    noise = torch.randn(batch_size, z_dim, device= device)
    #real_img = torch.randn(batch_size, img_dim, device= device)

    #*1 generate fake nosie 
    gen_fake = gen(noise)
    #*2 pass the generated img to the disc
        ## detaching to make the generated fake img, un-removed when doing the back prop
    disc_fake = disc(gen_fake.detach())
    #*3 calculated losses 
        ## fake images disc loss
        ## real images disc loss
        
        #*3.1 fake disc loss
        #* disc fake loss compare disc fake with torch of zeros like 
        #* D(G(z))
        
    disc_fake_loss = criterion(disc_fake, torch.zeros_like(disc_fake))
    disc_real = disc(real_img)

        #*3.2 real disc loss
        #* disc real loss compare disc real with torch of ones like
        #* D(x)
        
    disc_real_loss = criterion(disc_real, torch.ones_like(disc_real))
    Disc_LOSS = (disc_fake_loss+disc_real_loss)/2

    return Disc_LOSS


In [None]:
get_disc_loss(
    device=device,
    disc=disc,
    gen=gen,
    criterion=criterion,
    real_img= real_img,
    z_dim=z_dim,
    batch_size=batch_size,
)

## generator loss

$$ L_{gen} = - \frac{1}{N} \sum_{i=1}^N \log D(G(z_i)) $$

In [None]:
def get_gen_loss(gen, disc, criterion, batch_size, z_dim, device):

    noise = torch.randn(batch_size, z_dim, device=device)
    generated_img = gen(noise)
    disc_generated = disc(generated_img)
    Gen_LOSS = criterion(disc_generated, torch.ones_like(disc_generated))

    return Gen_LOSS

In [None]:
get_gen_loss(
    gen=gen,
    disc= disc,
    criterion=criterion,
    batch_size=batch_size,
    z_dim=z_dim,
    device=device
)

## Training the Gan

In [None]:
images, labels = next(iter(dataloader))
images[0].shape

In [None]:
plt.imshow(images[0].squeeze(), cmap='gray')

In [None]:
bts = len(images) # 64 
images.view(-1).shape, images.view(bts, -1).shape, images.shape

## (torch.Size([50176]), torch.Size([64, 784]), torch.Size([64, 1, 28, 28]))

In [None]:
from tqdm.autonotebook import tqdm

In [None]:
current_step = 0
mean_gen_loss = 0
mean_disc_loss = 0
device = "cuda"
n_epochs = 200
display_step = 5

test_generato = True
generator_loss = False
discriminator = False

for epoch in range(n_epochs):
    for real, _ in tqdm(dataloader):

        current_batch_size = len(real)

        ## reshaping
        # * reshape tensor from (64,1,28,28) to (64, 784)
        real = real.view(current_batch_size, -1).to(device)

        disc_opt.zero_grad()
        disc_loss = get_disc_loss(
            gen=gen,
            disc=disc,
            criterion=criterion,
            real_img=real,
            batch_size=current_batch_size,
            z_dim=z_dim,
            device=device,
        )

        disc_loss.backward(retain_graph=True)
        disc_opt.step()
        # print(disc_loss)

        ## Tracking generator weights
        # * could work if i have only forward once
        # * but since i have more than 1 block i can't track this way
        # * sol is to track each block alone
        # old_generator_weights = gen.gen[0][0].weight.detach().clone()
        # old_generator_weights = gen.gen[0][0].weight.detach().clone()
        # old_generator_weights = gen.gen[0][0].weight.detach().clone()

        ## Updating the Generator
        gen_opt.zero_grad()
        gen_loss = get_gen_loss(
            gen=gen,
            disc=disc,
            criterion=criterion,
            batch_size=current_batch_size,
            z_dim=z_dim,
            device=device,
        )
        gen_loss.backward()
        gen_opt.step()
        
        ## 
        mean_disc_loss += disc_loss.item()
        mean_gen_loss += gen_loss.item()
        
    mean_disc_loss = mean_disc_loss / current_batch_size
    mean_gen_loss = mean_gen_loss / current_batch_size
    
    
    gen.eval()
    gen_fake = gen(noise)
    print(f"Epoch {epoch} : Generator Loss : {mean_gen_loss/len(dataloader)} Discriminator Loss : {mean_disc_loss/len(dataloader)}")
        
        
        

In [None]:
generated_imgs = gen(noise)
generated_imgs.shape

In [None]:
# plot 12 images from the generated images with matplotlib
# fig, axes = plt.subplots(nrows=2, ncols=6, sharex=True, sharey=True, figsize=(24,6))
# for ax, img in zip(axes.flatten(), generated_imgs):
#     img = img.detach().cpu().numpy()
#     ax.xaxis.set_visible(False)
#     ax.yaxis.set_visible(False)
#     im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')

