## Import the Lib

In [1]:
import os
import sys
import yaml
import torch
import argparse
import unittest
import torch.nn as nn
from torchview import draw_graph

## Utility functions

In [2]:
def config():
    with open("../config.yml", "r") as file:
        return yaml.safe_load(file)

## CoupledGenerator

In [None]:
class CoupledGenerators(nn.Module):
    def __init__(
        self,
        latent_space: int = 100,
        constant: int = 128,
        image_size: int = 32,
    ):
        super(CoupledGenerators, self).__init__()
        self.latent_space = latent_space
        self.constant = constant
        self.image_size = image_size

        self.kernel_size = 3
        self.stride_size = 1
        self.padding_size = 1

        self.negative_slope = 0.2
        self.scale_factor = 2

        self.netG1Layers = []
        self.netG2Layers = []

        self.fullyConnectedLayer = nn.Linear(
            in_features=self.latent_space,
            out_features=self.constant * self.image_size // 4 * self.image_size // 4,
        )

        self.sharedConvolution = nn.Sequential(
            nn.BatchNorm2d(self.constant),
            nn.Upsample(scale_factor=self.scale_factor),
            nn.Conv2d(
                in_channels=self.constant,
                out_channels=self.constant,
                kernel_size=self.kernel_size,
                stride=self.stride_size,
                padding=self.padding_size,
            ),
            nn.LeakyReLU(negative_slope=self.negative_slope, inplace=True),
            nn.BatchNorm2d(self.constant),
            nn.Upsample(scale_factor=self.scale_factor),
        )

        for index in range(2):
            self.netG1Layers.append(
                nn.Conv2d(
                    in_channels=self.constant if index == 0 else self.constant // 2,
                    out_channels=self.constant // 2 if index == 0 else 3,
                    kernel_size=self.kernel_size,
                    stride=self.stride_size,
                    padding=self.padding_size,
                )
            )
            (
                self.netG1Layers.append(
                    nn.Sequential(
                        nn.BatchNorm2d(num_features=self.constant // 2),
                        nn.LeakyReLU(negative_slope=self.negative_slope, inplace=True),
                    )
                )
                if index == 0
                else nn.Tanh()
            )

        for index in range(2):
            self.netG2Layers.append(
                nn.Conv2d(
                    in_channels=self.constant if index == 0 else self.constant // 2,
                    out_channels=self.constant // 2 if index == 0 else 3,
                    kernel_size=self.kernel_size,
                    stride=self.stride_size,
                    padding=self.padding_size,
                )
            )
            (
                self.netG2Layers.append(
                    nn.Sequential(
                        nn.BatchNorm2d(num_features=self.constant // 2),
                        nn.LeakyReLU(negative_slope=self.negative_slope, inplace=True),
                    )
                )
                if index == 0
                else nn.Tanh()
            )

        self.generator1 = nn.Sequential(*self.netG1Layers)
        self.generator2 = nn.Sequential(*self.netG2Layers)

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            x = self.fullyConnectedLayer(x)
            x = x.view(
                x.size(0), self.constant, self.image_size // 4, self.image_size // 4
            )

            shared = self.sharedConvolution(x)

            image1 = self.generator1(shared)
            image2 = self.generator2(shared)

            return image1, image2

        else:
            raise ValueError("Input should be a torch.Tensor".capitalize())


if __name__ == "__main__":
    netG = CoupledGenerators(
        latent_space=config()["netG"]["latent_space"],
        constant=config()["netG"]["constant"],
        image_size=config()["dataloader"]["image_size"],
    )

    image1, image2 = netG(
        torch.randn(
            config()["dataloader"]["batch_size"], config()["netG"]["latent_space"]
        )
    )

    assert (
        image1.size() == image2.size()
    ), "Image1 and Image2 must be the same size".capitalize()
    
    print("Imgae1 size # {}".format(image1.size()))
    print("Image2 size # {}".format(image2.size()))
    
    try:
        draw_graph(model=netG, input_data=torch.randn(config()["dataloader"]["batch_size"], config()["netG"]["latent_space"])).visual_graph.render(
            filename=os.path.join("../artifacts/files/", "coupleGenerator"), format="png"
        )
        print("Graph saved in ../artifacts/files/coupleGenerator.png")
    except Exception as e:
        print(f"Error during graph rendering: {e}")

## Couple Discriminators

In [None]:
class CoupledDiscriminators(nn.Module):
    def __init__(
        self,
        channels: int = 3,
        image_size: int = 32,
        constant: int = 128,
    ):
        super(CoupledDiscriminators, self).__init__()

        self.in_channels = channels
        self.out_channels = self.in_channels * 5 + 1
        self.image_size = image_size
        self.constant = constant

        self.kernel_size = 3
        self.stride_size = 2
        self.padding_size = 1

        self.layers = []

        for index in range(4):
            self.layers += [
                nn.Conv2d(
                    in_channels=self.in_channels,
                    out_channels=self.out_channels,
                    kernel_size=self.kernel_size,
                    stride=self.stride_size,
                    padding=self.padding_size,
                ),
            ]

            if index != 0:
                self.layers += [nn.BatchNorm2d(num_features=self.out_channels)]

            self.layers += [nn.LeakyReLU(negative_slope=0.2, inplace=True)]

            self.in_channels = self.out_channels
            self.out_channels = self.in_channels * 2

        self.sharedConvolution = nn.Sequential(*self.layers)

        self.discriminator1 = nn.Linear(
            in_features=self.constant * (self.image_size // 2**4) ** 2,
            out_features=self.in_channels // self.in_channels,
        )

        self.discriminator2 = nn.Linear(
            in_features=self.constant * (self.image_size // 2**4) ** 2,
            out_features=self.in_channels // self.in_channels,
        )

    def forward(self, image1: torch.Tensor, image2: torch.Tensor):
        if isinstance(image1, torch.Tensor) and isinstance(image2, torch.Tensor):
            shared = self.sharedConvolution(image1)

            shared = shared.view(shared.size(0), -1)

            validity1 = self.discriminator1(shared)
            validity2 = self.discriminator2(shared)

            return validity1, validity2

        else:
            raise ValueError("Both inputs must be PyTorch tensors".capitalize())


if __name__ == "__main__":
    batch_size = config()["dataloader"]["batch_size"]
    image_size = config()["dataloader"]["image_size"]
    constant = config()["netG"]["constant"]
    channels = 3

    netD = CoupledDiscriminators()
    validity1, validity2 = netD(
        image1=torch.randn((batch_size, channels, image_size, image_size)),
        image2=torch.randn((batch_size, channels, image_size, image_size)),
    )

    assert (
        validity1.size() == validity2.size()
    ), "Validity1 and Validity2 must be the same size".capitalize()

    print("Validity1 size # {}".format(validity1.size()))
    print("Validity2 size # {}".format(validity2.size()))

    try:
        input_data1 = torch.randn((batch_size, channels, image_size, image_size))
        input_data2 = torch.randn((batch_size, channels, image_size, image_size))

        draw_graph(
            model=netD, input_data=(input_data1, input_data2)
        ).visual_graph.render(
            filename=os.path.join("./artifacts/files/", "coupleDiscriminator"),
            format="png",
        )
        print("Graph saved in ./artifacts/files/coupleDiscriminator.png")
    except Exception as e:
        print(f"Error during graph rendering: {e}")


## Adversarial Loss

In [None]:
class AdversarialLoss(nn.Module):
    def __init__(
        self, name: str = "Adversarial Loss".capitalize(), reduction: str = "mean"
    ):
        super(AdversarialLoss, self).__init__()

        self.name = name
        self.reduction = reduction

        self.MSEloss = nn.MSELoss(reduction=self.reduction)

    def forward(self, pred: torch.Tensor, actual: torch.Tensor):
        if isinstance(pred, torch.Tensor) and isinstance(actual, torch.Tensor):
            loss = self.MSEloss(pred, actual)

            return loss
        else:
            raise ValueError(
                f"Both 'pred' and 'actual' should be torch.Tensor.".capitalize()
            )
            
            
if __name__ == "__main__":
    loss = AdversarialLoss()
    
    actual = torch.Tensor([1.0, 0.0, 1.0, 1.0, 0.0])
    predicted = torch.Tensor([1.0, 0.0, 1.0, 1.0, 1.0])
    
    print(f"Adversarial Loss: {loss(predicted, actual):.4f}")


## Unit Tests

In [None]:
class UnitTest(unittest.TestCase):
    def setUp(self):
        self.image_size = config()["dataloader"]["image_size"]
        self.batch_size = config()["dataloader"]["batch_size"]

        self.latent_space = config()["netG"]["latent_space"]
        self.constant = config()["netG"]["constant"]

        self.netG = CoupledGenerators(
            latent_space=self.latent_space,
            constant=self.constant,
            image_size=self.image_size,
        )

        self.netD = CoupledDiscriminators(
            channels=3,
            constant=self.constant,
            image_size=self.image_size,
        )

    def test_coupleGenerator(self):
        self.Z = torch.randn(self.batch_size, self.latent_space)

        image1, image2 = self.netG(self.Z)

        self.assertEqual(
            image1.size(), image2.size(), "Image1 and Image2 must be the same size"
        )

    def test_coupledDiscriminator(self):
        self.Z = torch.randn(self.batch_size, self.latent_space)

        image1, image2 = self.netG(self.Z)

        validity1, validity2 = self.netD(image1=image1, image2=image2)

        self.assertEqual(
            validity1.size(),
            validity2.size(),
            "Validity1 and Validity2 must be the same size",
        )


if __name__ == "__main__":
    unittest.main()