In [None]:
import os
import copy
import random
import pandas as pd
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
import torch.optim as optim
from torchvision.io import read_image, ImageReadMode
import torchvision.transforms.functional as TF
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.utils import make_grid
import torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import LabelEncoder
from os import listdir
from pathlib import Path

# Function to load images and masks from directory
def load_dir(dir):
    images = []
    masks = []
    labels = []
    resize = T.Resize((128, 128))
    for subfolder in ['infected', 'notinfected']:
        subdir_path = os.path.join(dir, subfolder)
        for file in listdir(subdir_path):
            if file.endswith(".png"):
                image_path = os.path.join(subdir_path, file)
                image = read_image(image_path, ImageReadMode.GRAY)
                image = resize(image)
                image = TF.convert_image_dtype(image, dtype=torch.float32)

                mask_file = file.replace('.png', '_mask.png')
                mask_path = os.path.join(subdir_path, mask_file)
                mask = read_image(mask_path, ImageReadMode.GRAY)
                mask = resize(mask)
                mask = TF.convert_image_dtype(mask, dtype=torch.float32)

                images.append(image)
                masks.append(mask)
                labels.append(subfolder)

    return images, masks, labels

# Custom Dataset class
class PCauseDataset(Dataset):
    def __init__(self, path, transform=None):
        self.images, self.masks, labels = load_dir(path)
        self.transform = transform

        encLabel = LabelEncoder().fit_transform(labels)
        self.labels = torch.from_numpy(encLabel)

    def __getitem__(self, index):
        image = self.images[index]
        mask = self.masks[index]
        label = self.labels[index]

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask, label

    def __len__(self):
        return len(self.images)

# Path to the train and test directories
train_path = "grayscale_dataset/train"
test_path = "grayscale_dataset/test"

# Transformation for data augmentation
transform_augmentation = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.RandomRotation(90),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
])

# Create datasets
train_dataset = PCauseDataset(train_path, transform=transform_augmentation)
test_dataset = PCauseDataset(test_path)

# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=2)

# Define the model architecture
class Generator(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super(Generator, self).__init__()

        self.inBlock = GenBlock(in_channels, mid_channels, kernel_size=4, stride=1, padding=0)  # 4x4
        self.net = nn.Sequential(
            GenBlock(mid_channels + 1, mid_channels // 2, kernel_size=4, stride=2, padding=1),  # 8x8
            GenBlock(mid_channels // 2, mid_channels // 4, kernel_size=4, stride=2, padding=1),  # 16x16
            GenBlock(mid_channels // 4, mid_channels // 8, kernel_size=4, stride=2, padding=1),  # 32x32
            GenBlock(mid_channels // 8, mid_channels // 16, kernel_size=4, stride=2, padding=1),  # 64x64
            nn.ConvTranspose2d(mid_channels // 16, out_channels, kernel_size=4, stride=2, padding=1),  # 128x128
            nn.Tanh())

        self.embedding = nn.Embedding(3, 50)
        self.linear = nn.Linear(50, 4 * 4)

    def forward(self, x, label):
        y = self.embedding(label)
        y = self.linear(y).view(-1, 1, 4, 4)
        x = self.inBlock(x)
        x = torch.cat((x, y), dim=1)
        return self.net(x)

class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(GenBlock, self).__init__()

        self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.norm = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)

        return self.dropout(self.relu(x))

class Discriminator(nn.Module):
    def __init__(self, in_channels, mid_channels):
        super(Discriminator, self).__init__()

        self.net = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5),
            DisBlock(mid_channels, mid_channels * 2, 4, 2, 1),
            DisBlock(mid_channels * 2, mid_channels * 4, 4, 2, 1),
            DisBlock(mid_channels * 4, mid_channels * 8, 4, 2, 1),
            DisBlock(mid_channels * 8, mid_channels * 16, 4, 2, 1)
        )

        self.Slinear = nn.Linear(mid_channels * 16 * 4 * 4, 1)
        self.Clinear = nn.Linear(mid_channels * 16 * 4 * 4, 3)
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.net(x)
        x = torch.flatten(x, start_dim=1)
        y1 = self.sigmoid(self.Slinear(x)).view(-1)
        y2 = self.softmax(self.Clinear(x))
        return y1, y2

class DisBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(DisBlock, self).__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.norm = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)

