In [2]:
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch
import os
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import zipfile
from io import BytesIO
from PIL import Image
import numpy as np
from torchvision.utils import save_image

In [3]:
batch_size = 16
latent_space = 100
_channels = 3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_epochs = 50



In [11]:
transform = transforms.Compose([
    transforms.Resize(64),  
    transforms.CenterCrop(64),  
    transforms.ToTensor(),  
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  
])

dataset = datasets.CelebA(root='./CELEBA', split='train', transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)


SSLError: HTTPSConnectionPool(host='drive.google.com', port=443): Max retries exceeded with url: /uc?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM (Caused by SSLError(SSLCertVerificationError(1, '[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:992)')))

In [None]:
def show_images(images):
    grid = torchvision.utils.make_grid(images, nrow=8, padding=2, normalize=True)
    plt.figure(figsize=(10, 10))
    plt.imshow(grid.permute(1, 2, 0))
    plt.axis('off')
    plt.show()

# Get a batch of images
data_iter = iter(dataloader)
images, _ = data_iter.next()

# Show images
show_images(images)


In [20]:
dataset = datasets.CelebA(root='path/to/data', split='train', transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)


## DESCRIMINATOR

In [23]:
class Critic(nn.Module):
    def __init__(self, input_channels, label_channels):
        super(Critic, self).__init__()

        self.input_channels = input_channels
        self.label_channels = label_channels

        self.convNEW = nn.Conv2d(input_channels + label_channels, 64, kernel_size=3, stride=2, padding=1)
        self.activationNEW = nn.LeakyReLU()

        self.conv1 = nn.Conv2d(64, 32, kernel_size=3, stride=2, padding=1)
        self.activation1 = nn.LeakyReLU()

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.activation2 = nn.LeakyReLU()

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.activation3 = nn.LeakyReLU()


        self.flatten = nn.Flatten(1,-1)

        self.linear1 = nn.Linear(2048, 512)
        self.activation4 = nn.LeakyReLU()
        
        self.linear2 = nn.Linear(512,1)

    def forward(self,x,label_input):
        x = torch.cat((x, label_input), dim=1)

        x = self.convNEW(x)
        x = self.activationNEW(x)

        x = self.conv1(x)
        x = self.activation1(x)
        
        x = self.conv2(x)
        x = self.activation2(x)

        x = self.conv3(x)
        x = self.activation3(x)
        
        x = self.flatten(x)
        
        x = self.linear1(x)
        x = self.activation4(x)
        
        x = self.linear2(x)

        return x

    

In [24]:
critic_input = torch.randn(32, 3, 64, 64)  # Batch of 32 images with 3 channels
label_input = torch.randn(32, 2, 64, 64)   # Batch of 32 labels with 2 channels

critic = Critic(input_channels=3, label_channels=2)
critic_output = critic(critic_input, label_input)
print(critic_output.shape)

torch.Size([32, 1])


In [38]:
# x.shape

# GENERATOR

In [76]:
import torch.nn as nn

# output_shape = ( input - 1 ) * stride + output_padding - 2 * padding + kernel_size

class Generator(nn.Module):
    def __init__(self, input_size, output_channels,label_dim):
        super(Generator, self).__init__()
        self.input_size = input_size
        self.channels = output_channels
        self.label_dim = label_dim

        self.conv2 = nn.ConvTranspose2d(input_size + label_dim, 512, 4, 2, 0, bias=False)
        self.batchnorm2 = nn.BatchNorm2d(512)
        self.activation2 = nn.ReLU(True)

        self.conv3 = nn.ConvTranspose2d(512, 256, 3, 2, 1,output_padding = 1, bias=False)
        self.batchnorm3 = nn.BatchNorm2d(256)
        self.activation3 = nn.ReLU(True)

        self.conv4 = nn.ConvTranspose2d(256, 128, 3, 2, 1,output_padding = 1, bias=False)
        self.batchnorm4 = nn.BatchNorm2d(128)
        self.activation4 = nn.ReLU(True)

        self.conv5 = nn.ConvTranspose2d(128, 64, 3, 2, 1, output_padding = 1,bias=False)
        self.batchnorm5 = nn.BatchNorm2d(64)
        self.activation5 = nn.ReLU(True)

        self.conv6 = nn.ConvTranspose2d(64, 3, 3, 2, 1, output_padding = 1,bias=False)
        self.batchnorm6 = nn.BatchNorm2d(self.channels)
        self.activation6 = nn.Tanh()

    def forward(self, x, label_input):
        x = torch.cat((x, label_input), dim=1)

        batch_size = x.size(0)
        x = x.view(batch_size, x.shape[1], 1, 1)


        x = self.conv2(x)
        x = self.batchnorm2(x)
        x = self.activation2(x)

        x = self.conv3(x)
        x = self.batchnorm3(x)
        x = self.activation3(x)

        x = self.conv4(x)
        x = self.batchnorm4(x)
        x = self.activation4(x)

        x = self.conv5(x)
        x = self.batchnorm5(x)
        x = self.activation5(x)

        x = self.conv6(x)
        x = self.batchnorm6(x)
        x = self.activation6(x)

       
        return x
    

In [77]:
# generator_input = torch.randn(8, 100)  # Batch size of 64, input size of 100
# generator = Generator(input_size=100,output_channels=3)
# output = generator(generator_input)
# print(output.shape)
generator_input = torch.randn(32, 100)  # Batch of 32 latent vectors
label_input_gen = torch.randn(32, 2)   # Batch of 32 labels with 2 dimensions

generator = Generator(input_size=100, output_channels=3,label_dim=2)
generator_output = generator(generator_input, label_input_gen)
print(generator_output.shape)

torch.Size([32, 3, 64, 64])


In [78]:
class DCGAN(nn.Module):
    def __init__(self, critic,generator, latent_dim, device, gp_weight=10, lr_gen=0.0002, lr_disc=0.0002, betas=(0.5, 0.999),ratio = 3):
        super(DCGAN, self).__init__()
        self.generator = generator.to(device)
        self.critic = critic.to(device)
        self.latent_dim = latent_dim
        self.device = device
        self.gp_weight = gp_weight

        self.optim_gen = optim.Adam(self.generator.parameters(), lr=lr_gen, betas=betas)
        self.optim_crit = optim.Adam(self.critic.parameters(), lr=lr_disc, betas=betas)
        
        self.criterion = nn.BCELoss()
        self.loss_fn = nn.BCELoss()

        self.ratio = ratio

        self.c_loss_metric = []
        self.c_wass_loss_metric = []
        self.c_gp_metric = []
        self.g_loss_metric = []
    

    def gradient_penalty(self, batch_size, real_images, fake_images):
        alpha = torch.rand(batch_size, 1, 1, 1, device=self.device)
        interpolated = alpha * real_images + (1 - alpha) * fake_images
        interpolated.requires_grad_(True)

        interpolated_predictions = self.critic(interpolated)

        gradients = torch.autograd.grad(
            outputs=interpolated_predictions,
            inputs=interpolated,
            grad_outputs=torch.ones_like(interpolated_predictions, device=self.device),
            create_graph=True,
            retain_graph=True,
        )[0]

        gradients = gradients.view(batch_size, -1)
        gradient_norm = gradients.norm(2, dim=1)
        gradient_penalty = ((gradient_norm - 1) ** 2).mean()
        return gradient_penalty
    


    def train_step(self, real_images):
        batch_size = real_images.shape[0]
        real_images = real_images.to(self.device)
        one_hot_labels = one_hot_labels.to(self.device)

        image_one_hot_labels = one_hot_labels.view(batch_size, one_hot_labels.size(1), 1, 1)
        image_one_hot_labels = image_one_hot_labels.repeat(1, 1, 64, 64)


        for _ in range(self.critic_steps):
            random_latent_vectors = torch.randn(batch_size, self.latent_dim, device=self.device)

            fake_images = self.generator(random_latent_vectors, one_hot_labels)
            fake_images = torch.cat((fake_images, image_one_hot_labels), dim=1)

            fake_predictions = self.critic(fake_images)
            real_predictions = self.critic(torch.cat((real_images, image_one_hot_labels), dim=1))

            c_wass_loss = fake_predictions.mean() - real_predictions.mean()
            c_gp = self.gradient_penalty(real_images, fake_images, image_one_hot_labels)
            c_loss = c_wass_loss + self.gp_weight * c_gp

            self.optim_crit.zero_grad()
            c_loss.backward()
            self.optim_crit.step()


        random_latent_vectors = torch.randn(batch_size, self.latent_dim, device=self.device)
        fake_images = self.generator(random_latent_vectors, one_hot_labels)
        fake_images = torch.cat((fake_images, image_one_hot_labels), dim=1)

        fake_predictions = self.critic(fake_images)
        g_loss = -fake_predictions.mean()

        self.optim_gen.zero_grad()
        g_loss.backward()
        self.optim_gen.step()

        return {
            'c_loss': c_loss.item(),
            'c_wass_loss': c_wass_loss.item(),
            'c_gp': c_gp.item(),
            'g_loss': g_loss.item()
        }

    

    def generate_images(self, num_images=1, save_path=None,epoch=None):
        self.generator.eval()
        if epoch is None:
            epoch = "_"	
            
        if os.path.exists(save_path) == False:
            os.makedirs(save_path)

        with torch.no_grad():
            noise = torch.randn(num_images, self.latent_dim).to(self.device)
            generated_images = self.generator(noise)
            
            if save_path is not None:
                for i, image in enumerate(generated_images):
                    if os.path.exists(save_path) == False:
                        os.makedirs(save_path)
                    save_image(image, f"{save_path}/EPOCH_{epoch}_image_{i}.png")
            
            return generated_images

: 

In [49]:

descriminator = Critic(input_channels=_channels)
generator = Generator(input_size=latent_space,output_channels=_channels)

DCGAN = DCGAN(descriminator, generator, latent_space, device,gp_weight=10, lr_gen=2e-3, lr_disc=2e-3, betas=(0.5, 0.999))

In [50]:
DCGAN.train_step(_dataloader.__next__())

penatly tensor(0.9682, grad_fn=<MeanBackward0>)
penatly tensor(0.1044, grad_fn=<MeanBackward0>)
penatly tensor(5.4209, grad_fn=<MeanBackward0>)


{'c_loss': -134.77979882558188,
 'c_wass_loss': -156.42507188611975,
 'c_gp': 2.1645278483629227,
 'g_loss': 1093.273193359375}

In [13]:
def weights_init_glorot(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)


descriminator = Critic(input_channels=_channels).apply(weights_init_glorot)
generator = Generator(input_size=latent_space,output_channels=_channels).apply(weights_init_glorot)

DCGAN = DCGAN(descriminator, generator, latent_space, device, lr_gen=2e-3, lr_disc=2e-3, betas=(0.5, 0.999))

In [31]:
_dataloader = image_batch_generator(zip_file_path, 4, transform=transform)

In [32]:
DCGAN.train_step(_dataloader.__next__())


(5.06100606918335, -0.13016650080680847, 0.5191172361373901)

In [None]:
# import shutil
# import os

# # Specify the path to the folder you want to delete
# folder_path = '/kaggle/working/generated_images'

# # Check if the folder exists
# if os.path.exists(folder_path):
#     # Remove the folder and all its contents
#     shutil.rmtree(folder_path)

In [16]:
for i in range(num_epochs):
    x=0
    _dataloader = image_batch_generator(zip_file_path, 32, transform=transform)
    for idx,j in enumerate(_dataloader):
        x = DCGAN.train_step(j)
        print(f"epoch :{i} batch: {idx}",end ="\r")
    print(f"epoch {i}: {x}")
    if i%15 == 0:
        torch.save(DCGAN.state_dict(), 'model_state_dict.pth')
    DCGAN.generate_images(num_images=1, save_path="generated_images",epoch=i)

  images.append(torch.tensor(img))


epoch 0: {'d_loss': 1.3815033435821533, 'g_loss': 6.308757305145264}
epoch 1: {'d_loss': 1.3254845142364502, 'g_loss': 3.7382590770721436}
epoch 2: {'d_loss': 1.0476481914520264, 'g_loss': 4.312037944793701}
epoch 3: {'d_loss': 0.5357766151428223, 'g_loss': 4.355941295623779}
epoch :4 batch: 77

KeyboardInterrupt: 

In [None]:
# Save only the model's state dict
# torch.save(DCGAN.state_dict(), 'model_state_dict.pth')
# os.remove("/kaggle/working/model_state_dict.pth")
# torch.save(DCGAN.state_dict(), 'model_state_dict.pth')

In [None]:
DCGAN.generate_images(num_images=1, save_path="generated_images",epoch=4001)

In [None]:
sum(i.numel() for i in DCGAN.parameters())