In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
import numpy as np
from torchvision.utils import save_image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import cv2
from torchvision import models, transforms
from sklearn.metrics.pairwise import cosine_similarity
from torchvision.models import vgg19
import matplotlib.pyplot as plt


In [None]:
# Constants
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "Train directory"
VAL_DIR = "Validation directory"
TEST_DIR = "Test directory"  # Change to your test images directory
SAVE_FOLDER = "save folder directory"  # Change to your save folder
GRAPH_SAVE_FOLDER = "Graph directory"
LEARNING_RATE = 2e-4
BATCH_SIZE = 3
NUM_WORKERS = 2
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
LAMBDA_GP = 10
NUM_EPOCHS = 250
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_DISC = "iscp.pth.tar"
CHECKPOINT_GEN = "enp.pth.tar"


In [None]:
# Data Augmentation
both_transform = A.Compose(
    [A.Resize(width=256, height=256)], additional_targets={"image0": "image"},
)

transform_only_input = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(p=0.2),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
        ToTensorV2(),
    ]
)

transform_only_mask = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
        ToTensorV2(),
    ]
)


In [None]:
# Dataset classes
class MapDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.list_files = os.listdir(self.root_dir)

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        image = np.array(Image.open(img_path))
        input_image = image[:, :256, :]
        target_image = image[:, 256:, :]

        augmentations = both_transform(image=input_image, image0=target_image)
        input_image = augmentations["image"]
        target_image = augmentations["image0"]

        input_image = transform_only_input(image=input_image)["image"]
        target_image = transform_only_mask(image=target_image)["image"]

        return input_image, target_image

In [None]:
class TestDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.list_files = os.listdir(self.root_dir)

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        image = np.array(Image.open(img_path))
        input_image = image[:, :256, :]
        target_image = image[:, 256:, :]

        augmentations = both_transform(image=input_image, image0=target_image)
        input_image = augmentations["image"]
        target_image = augmentations["image0"]

        input_image = transform_only_input(image=input_image)["image"]
        target_image = transform_only_mask(image=target_image)["image"]

        return input_image, target_image

In [None]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(CNNBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)

In [None]:
def add_noise(tensor, std):
    noise = torch.randn_like(tensor) * std
    return tensor + noise


In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512], noise_std=0.05):
        super().__init__()
        self.noise_std = noise_std  # Standard deviation of the noise
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels * 2, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(CNNBlock(in_channels, feature, stride=1 if feature == features[-1] else 2))
            in_channels = feature

        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))

        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        x = add_noise(x, self.noise_std)  # Add noise to the input image
        y = add_noise(y, self.noise_std)  # Add noise to the target image
        x = torch.cat([x, y], dim=1)
        x = self.initial(x)
        x = self.model(x)
        return x


In [None]:

if __name__ == "__main__":
    # Create a discriminator instance
    D = Discriminator()

    # Generate random inputs (batch_size, channels, height, width)
    x = torch.randn(1, 3, 256, 256)  # Random input image tensor
    y = torch.randn(1, 3, 256, 256)  # Random target image tensor

    # Pass the inputs through the discriminator
    output = D(x, y)

    # Print the output shape
    print(output.shape)  # Should be (batch_size, 1, output_height, output_width)

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
        super().__init__()
        if down:
            self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
        else:
            self.conv = nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
        
        self.batch_norm = nn.BatchNorm2d(out_channels)
        
        if act == "relu":
            self.activation = nn.ReLU()
        elif act == "leaky":
            self.activation = nn.LeakyReLU(0.2)
        
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5) if use_dropout else None

    def forward(self, x):
        x = self.conv(x)
        x = self.batch_norm(x)
        x = self.activation(x)
        if self.use_dropout:
            x = self.dropout(x)
        return x


In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, features=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )
        self.down1 = Block(features, features * 2, down=True, act="leaky", use_dropout=False)
        self.down2 = Block(features * 3, features * 4, down=True, act="leaky", use_dropout=False)  # Adjusted input channels
        self.down3 = Block(features * 4, features * 8, down=True, act="leaky", use_dropout=False)
        self.down4 = Block(features * 8, features * 8, down=True, act="leaky", use_dropout=False)
        self.down5 = Block(features * 8, features * 8, down=True, act="leaky", use_dropout=False)
        self.down6 = Block(features * 8, features * 8, down=True, act="leaky", use_dropout=False)
        self.bottleneck = nn.Sequential(nn.Conv2d(features * 8, features * 8, 4, 2, 1), nn.ReLU())

        self.up1 = Block(features * 8, features * 8, down=False, act="relu", use_dropout=True)
        self.up2 = Block(features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True)
        self.up3 = Block(features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True)
        self.up4 = Block(features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False)
        self.up5 = Block(features * 8 * 2, features * 4, down=False, act="relu", use_dropout=False)
        self.up6 = Block(features * 4 * 2, features * 2, down=False, act="relu", use_dropout=False)
        self.up7 = Block(features * 2 * 2, features, down=False, act="relu", use_dropout=False)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features * 2, in_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )
        self.downsample = nn.Conv2d(features, features, kernel_size=3, stride=2, padding=1)  # Adjusted downsampling layer
        

    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d1_downsampled = self.downsample(d1)  # Downsample d1 to match d2
        d2_concat = torch.cat([d2, d1_downsampled], 1)
        d3 = self.down2(d2_concat)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d7], 1))
        up3 = self.up3(torch.cat([up2, d6], 1))
        up4 = self.up4(torch.cat([up3, d5], 1))
        up5 = self.up5(torch.cat([up4, d4], 1))
        up6 = self.up6(torch.cat([up5, d3], 1))
        up7 = self.up7(torch.cat([up6, d2], 1))
        return self.final_up(torch.cat([up7, d1], 1))

In [None]:
model = Generator(in_channels=3, features=64)
preds = model(x)
print(preds.shape)

In [None]:
# Ensure the directory to save graphs exists
os.makedirs(GRAPH_SAVE_FOLDER, exist_ok=True)

# Function to calculate style loss
class StyleLoss(nn.Module):
    def __init__(self):
        super(StyleLoss, self).__init__()
        self.vgg = models.vgg19(pretrained=True).features[:36].eval()  # Use first 4 conv layers of VGG19

    def forward(self, y_fake, y):
        return torch.mean((self.vgg(y_fake) - self.vgg(y)))


def train_fn(disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, style_loss, g_scaler, d_scaler):
    loop = tqdm(loader, leave=True)
    G_losses = []
    D_losses = []

    for idx, (x, y) in enumerate(loop):
        x, y = x.to(DEVICE), y.to(DEVICE)
        
        # Train Discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generator
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake, y) * L1_LAMBDA
            Style = style_loss(y_fake, y)  # Calculate style loss
            G_loss = G_fake_loss + L1 + Style

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        G_losses.append(G_loss.item())
        D_losses.append(D_loss.item())

        if idx % 10 == 0:
            loop.set_postfix(D_real=torch.sigmoid(D_real).mean().item(), D_fake=torch.sigmoid(D_fake).mean().item())

    return G_losses, D_losses

In [None]:
def save_some_examples(gen, test_loader, epoch, folder):
    x, y = next(iter(test_loader))
    x, y = x.to(DEVICE), y.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        if epoch == 1:
            save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    gen.train()

In [None]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

In [None]:
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [None]:
from IPython.display import clear_output
def main():
    disc = Discriminator(in_channels=3).to(DEVICE)
    gen = Generator(in_channels=3, features=64).to(DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    bce = nn.BCEWithLogitsLoss()
    l1_loss = nn.L1Loss()
    style_loss = StyleLoss().to(DEVICE)

    if LOAD_MODEL:
        load_checkpoint(CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE)
        load_checkpoint(CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE)

    train_dataset = MapDataset(root_dir=TRAIN_DIR)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
    val_dataset = MapDataset(root_dir=VAL_DIR)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    test_dataset = TestDataset(root_dir=TEST_DIR)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    G_losses = []
    D_losses = []

    plt.ion()  # Turn on interactive mode for live plotting
    fig, ax = plt.subplots(figsize=(10, 5))

    for epoch in range(NUM_EPOCHS):
        g_loss, d_loss = train_fn(disc, gen, train_loader, opt_disc, opt_gen, l1_loss, bce, style_loss, g_scaler, d_scaler)
        G_losses.extend(g_loss)
        D_losses.extend(d_loss)

        if SAVE_MODEL and epoch % 3 == 0:
            save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
            save_checkpoint(disc, opt_disc, filename=CHECKPOINT_DISC)

        save_some_examples(gen, test_loader, epoch, folder=SAVE_FOLDER)

        # Update live plot
        clear_output(wait=True)
        ax.clear()
        ax.set_title("Generator and Discriminator Loss During Training")
        ax.plot(G_losses, label="G")
        ax.plot(D_losses, label="D")
        ax.set_xlabel("iterations")
        ax.set_ylabel("Loss")
        ax.legend()
        plt.draw()
        
        # Save the current plot as an image
        plt.savefig(os.path.join(GRAPH_SAVE_FOLDER, f"epoch_{epoch}.png"))

    plt.ioff()  # Turn off interactive mode
    plt.show()

if __name__ == "__main__":
    main()