In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import torchvision.utils as vutils


In [25]:
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print("using device: ", device)

using device:  mps


In [26]:
# Hyperparameters

batch_size = 128
image_size=28
nz = 100
num_epochs=50
learning_rate = 0.0002
beta1 = 0.5

# data transformation

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

#load the MNIST dataset

dataset = torchvision.datasets.MNIST(root='./data',
                                     train=True,
                                     transform=transform,
                                     download=True)

dataloader = DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)

In [27]:
# Let's now defien the network

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.main=nn.Sequential(
            # Input size is 1x28x28
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=4, stride=2, padding=1) , # 64x14x14
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            
            #seocond layer
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), # 128x7x7
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            
            #third layer
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), # 256x4x4
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            
            #fourth layer
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), # 512x4x4
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            
            #output layer
            nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=0), # 1x1x1
            nn.Sigmoid()
        )
        
    def forward(self, input):
        output=self.main(input)
        return output.view(-1)
    
# Generator network

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.main=nn.Sequential(
            #Input is nz - going into convolution
            nn.Linear(in_features=nz, out_features=256*7*7),
            nn.BatchNorm1d(256*7*7),
            nn.ReLU(inplace=True), 
            
            # reshape
            nn.Unflatten(dim=1, unflattened_size=(256,7,7)),
            
            # first convose transpose
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), # 128x14x14
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            # second convose transpose
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), # 64x28x28
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            # Third convse layer
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),  # Output: (batch_size, 32, 28, 28)
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            
            # Output layer
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1),  # Output: (batch_size, 1, 28, 28)
            nn.Tanh()  # Output values in [-1, 1]
        )
        
    def forward(self, input):
        output = self.main(input)
        return output
    

# Initialize the networks

netD = Discriminator().to(device)
netG = Generator().to(device)


In [28]:
# Initialize weights
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netD.apply(weights_init)
netG.apply(weights_init)

# Note that BCE is the minmax loss function under the hood

criterion = nn.BCELoss()
# fixed_noise = torch.randn(64, nz, device=device)

real_label = 1
fake_label = 0

optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta1, 0.999))

In [29]:
# Ensure the directory exists for saving images
os.makedirs('mid_run_samples', exist_ok=True)

# Fixed noise for generating samples
fixed_noise = torch.randn(12, nz, device=device)

# Training loop
for epoch in range(num_epochs):
    # Save generated images at the start of each epoch
    with torch.no_grad():
        fake_images = netG(fixed_noise).detach().cpu()
    vutils.save_image(
        fake_images,
        f'mid_run_samples/output_epoch_{epoch}.png',
        normalize=True,
        nrow=4
    )

    # Progress bar for batches
    for i, data in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):
        ############################
        # (1) Update D network
        ############################
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full(
            (b_size,), real_label, dtype=torch.float, device=device
        )
        output = netD(real_cpu)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # Train with all-fake batch
        noise = torch.randn(b_size, nz, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network
        ############################
        for _ in range(3):  # Perform 3 generator iterations so generator and discriminator are more balanced
            netG.zero_grad()
            label.fill_(real_label)  # Fake labels are real for generator cost
            output = netD(fake)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()
            # Generate new fake data for the next generator iteration
            noise = torch.randn(b_size, nz, device=device)
            fake = netG(noise)

        # Print training stats
        if i % 200 == 0:
            print(
                f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}] '
                f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} '
                f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}'
            )

Epoch 1/50:   0%|          | 1/469 [00:00<02:21,  3.30it/s]

[0/50][0/469] Loss_D: 1.7309 Loss_G: 0.1947 D(x): 0.4639 D(G(z)): 0.4994/0.8380


Epoch 1/50:   3%|▎         | 14/469 [00:03<01:48,  4.18it/s]


KeyboardInterrupt: 

In [None]:
# Alright now using the discriminator as a feature extractor and than doing classification on it

class FeatureExtractor(nn.Module):
    def __init__(self, discriminator):
        super(FeatureExtractor, self).__init__()
        # copy all but the sigmoid layer 
        self.features=nn.Sequential(*list(discriminator.main.children())[:-2])
        

In [36]:
class Animal:
    def __init__(self,name,species):
        self.name = name
        self.species = species
        
dog = Animal("Buddy", "Dog")
print(dog.name, "\n", dog.species)

Buddy 
 Dog
