In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils
from PIL import Image
from typing import Any, Dict
import matplotlib.pyplot as plt

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# Custom dataset class for loading the images
class WasteImageDataset(Dataset):
    def __init__(self, directory, transform=None):
        """Initializes the dataset object
        Args:
            directory (str): Directory path of the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.directory = directory
        self.transform = transform
        self.images = [os.path.join(directory, x) for x in os.listdir(directory) if x.lower().endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        """Returns the number of images in the dataset"""
        return len(self.images)

    def __getitem__(self, idx):
        """Fetches the image at index `idx` and applies transformations if any"""
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

# Transformations applied to each image
transform = transforms.Compose([
    transforms.Resize((64, 64)), 
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load dataset
dataset = WasteImageDataset('./processed_images_cropped', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Create directory for output image
os.makedirs('output_images_18', exist_ok=True)

In [None]:
# Function to initialize model weights
def weights_init(model):
     """Applies initial weights to certain layers in a model.
    The weights are initialized to a normal distribution.
    """
    classname = model.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(model.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(model.weight.data, 1.0, 0.02)
        nn.init.constant_(model.bias.data, 0)

In [None]:
# Generator model
class Generator(nn.Module):
    """Maps a latent space vector (z) to data-space."""

    def __init__(self, **kwargs: Dict[str, Any]):
        """Creates a new instance of Generator.

        Args:
            **kwargs: An optional set of key/value pair arguments:
                * latent_vector_size: The size of the Normally-distributed
                latent-vector input. Defaults to 100.
                * num_features: The feature size of the output data. Defaults
                to 64 (e.g. 64x64 images).
                * num_channels: The number of channels of the output data.
                Defaults to 3 (e.g. RGB images).
        """

        super(Generator, self).__init__()

        self.latent_vector_size = kwargs.get('latent_vector_size', 100)
        self.num_features = kwargs.get('num_features', 64)
        self.num_channels = kwargs.get('num_channels', 3)

        self.main = nn.Sequential(
            # layer-1 100(1x1) -> 512(4x4)
            nn.ConvTranspose2d(
                in_channels=self.latent_vector_size,
                out_channels=(self.num_features * 8),
                kernel_size=4,
                stride=1,
                padding=0,
                bias=False,
            ),
            nn.BatchNorm2d(num_features=self.num_features*8),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            # layer-2  512x(4x4) -> 256x(8x8)
            nn.ConvTranspose2d(
                in_channels=(self.num_features * 8),
                out_channels=(self.num_features * 8),
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(num_features=self.num_features * 8),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            # layer-3  512x(8x8) -> 512x(16x16)
            nn.ConvTranspose2d(
                in_channels=(self.num_features * 8),
                out_channels=(self.num_features * 8),
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(num_features=self.num_features * 8),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            # layer-4  512x(16x16) -> 256x(32x32)
            nn.ConvTranspose2d(
                in_channels=(self.num_features * 8),
                out_channels=(self.num_features*4),
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(num_features=self.num_features*4),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),


            # layer-5  256x(32x32) -> 128x(64x64)
            nn.ConvTranspose2d(
                in_channels=(self.num_features * 4),
                out_channels=(self.num_features*2),
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(num_features=self.num_features*2),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            # layer-5  128x(64x64) -> 3x(64x64)
            nn.ConvTranspose2d(
                in_channels=self.num_features*2,
                out_channels=self.num_channels,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.Tanh(),
        )

    # forward propagation
    def forward(self, input):
        return self.main(input)

In [None]:
# Discriminator model
class Discriminator(nn.Module):
    """
    A binary classification network that takes an image as input and outputs
    a scalar probability that the input image is real (as opposed to fake).
    """

    def __init__(self, **kwargs: Dict[str, Any]):
        """Creates a new instance of Discriminator.

        Args:
            **kwargs: An optional set of key/value pair arguments:
                * num_features: The feature size of the input data. Defaults
                to 64 (e.g. 64x64 images).
                * num_channels: The number of channels of the input data.
                Defaults to 3 (e.g. RGB images).
        """
        super(Discriminator, self).__init__()

        self.num_features = kwargs.get('num_features', 64)
        self.num_channels = kwargs.get('num_channels', 3)

        # descriminator network
        self.main = nn.Sequential(
            # layer-1 3x(64x64) -> 64x(32x32)
            nn.Conv2d(
                in_channels=self.num_channels,
                out_channels=self.num_features,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            # layer-2 64x(32x32) -> 128x(16x16)
            nn.Conv2d(
                in_channels=self.num_features,
                out_channels=(self.num_features * 2),
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(num_features=self.num_features * 2),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            # layer-3 128x(16x16) -> 256x(8x8)
            nn.Conv2d(
                in_channels=(self.num_features * 2),
                out_channels=(self.num_features * 4),
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(num_features=self.num_features * 4),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            # layer-4 256x(8x8) -> 512x(4x4)
            nn.Conv2d(
                in_channels=(self.num_features * 4),
                out_channels=(self.num_features * 8),
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(num_features=self.num_features * 8),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            #  layer-5 512x(4x4) -> 1x(1x1)
            nn.Conv2d(
                in_channels=(self.num_features * 8),
                out_channels=1,
                kernel_size=4,
                stride=1,
                padding=0,
                bias=False,
            ),
            nn.Sigmoid(),
        )

    def forward(self, input):
        return self.main(input)

In [None]:
# Initialize and setup the Generator and Discriminator
netG = Generator(latent_vector_size=100, num_features=64, num_channels=3).to(device)
netD = Discriminator(num_features=64, num_channels=3).to(device)

# Initialize weights
netG.apply(weights_init)
netD.apply(weights_init)

In [None]:
# Learning rates for the Generator and Discriminator
lr_generator = 0.0002
lr_discriminator = 0.0002

# Loss function
criterion = nn.BCELoss()
optimizerG = optim.Adam(netG.parameters(), lr=lr_generator, betas=(0.5, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=lr_discriminator, betas=(0.5, 0.999))

In [None]:
# Function to show generated images
def show_generated_img():
    """Generates and displays an image using the Generator"""
    with torch.no_grad():
        noise = torch.randn(1, 100, 1, 1, device=device)
        fake = netG(noise).detach().cpu()
    plt.imshow(np.transpose(utils.make_grid(fake, padding=2, normalize=True).cpu(),(1,2,0)))
    plt.show()

In [None]:
# Variables to store losses
G_losses = []
D_losses = []

# Main training loop
num_epochs = 9000

for epoch in range(num_epochs): 
    for i, real_images in enumerate(dataloader):
        # Update Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
        netD.zero_grad()
        real_data = real_images.to(device)
        batch_size = real_data.size(0)
        labels = torch.full((batch_size,), 1, dtype=torch.float, device=device)

        # Forward pass real batch through D
        output = netD(real_data).view(-1)
        errD_real = criterion(output, labels)
        errD_real.backward()

        # Generate fake image batch with G
        noise = torch.randn(batch_size, 100, 1, 1, device=device)
        fake_images = netG(noise)
        labels.fill_(0)

        # Classify all fake batch with D
        output = netD(fake_images.detach()).view(-1)
        errD_fake = criterion(output, labels)
        errD_fake.backward()
        errD = errD_real + errD_fake
        optimizerD.step()

        # Update G: maximize log(D(G(z)))
        netG.zero_grad()
        labels.fill_(1)  # Fake labels are real for generator cost
        output = netD(fake_images).view(-1)
        errG = criterion(output, labels)
        errG.backward()
        optimizerG.step()

        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        if i % 50 == 0:
            print(f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}] Loss_D: {errD.item()} Loss_G: {errG.item()}')

        if i % 250 == 0:  
            show_generated_img()

        if (i % 500 == 0) or ((epoch == num_epochs - 1) and (i == len(dataloader) - 1)):
            with torch.no_grad():
                fake = netG(noise).detach().cpu()
            utils.save_image(fake, f'output_images_18/fake_samples_epoch_{epoch}_iter_{i}.png', normalize=True)


In [None]:
# Plotting the training losses
import matplotlib.pyplot as plt

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="Generator Loss")
plt.plot(D_losses,label="Discriminator Loss")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()