IMPORTING THE LIBRARIES

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import sys
import matplotlib.pyplot as plt
import numpy as np

DATASET CLASS INITIALIASATION

In [2]:
class SatelliteDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.gray_dir = os.path.join(root_dir, 'gray')
        self.color_dir = os.path.join(root_dir, 'color')
        
        self.image_names = list(set(
            [f for f in os.listdir(self.gray_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        ) & set(
            [f for f in os.listdir(self.color_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        ))
        
        self.transform = transform if transform else transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        try:
            image_name = self.image_names[idx]
            gray_path = os.path.join(self.gray_dir, image_name)
            color_path = os.path.join(self.color_dir, image_name)
            
            gray_image = Image.open(gray_path).convert('RGB')
            color_image = Image.open(color_path).convert('RGB')
            
            gray_tensor = self.transform(gray_image)
            color_tensor = self.transform(color_image)
            
            return gray_tensor, color_tensor
        except Exception as e:
            print(f"Error loading image {image_name}: {e}")
            dummy = torch.zeros(3, 256, 256)
            return dummy, dummy

PRINT PROGRESS FUNCTION

In [3]:
def print_progress(current, total, prefix='', suffix='', decimals=1, length=50, fill='█', print_end="\r"):
    
    percent = ("{0:." + str(decimals) + "f}").format(100 * (current / float(total)))
    filled_length = int(length * current // total)
    bar = fill * filled_length + '-' * (length - filled_length)
    print(f'\r{prefix} |{bar}| {percent}% {suffix}', end=print_end)
    sys.stdout.flush()
    
    # Print New Line on Complete
    if current == total: 
        print()

DOWNSAMPLE CLASS INITIALISATION

In [4]:
class DownSample(nn.Module):
    def __init__(self, Input_Channels, Output_Channels):
        super(DownSample, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(Input_Channels, Output_Channels, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2)
        )
    
    def forward(self, x):
        return self.model(x)

UPSAMPLE CLASS INITIALISATION

In [5]:
class Upsample(nn.Module):
    def __init__(self, Input_Channels, Output_Channels):
        super(Upsample, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(Input_Channels, Output_Channels, 4, 2, 1, bias=False),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)
        return x

CREATION OF GENERATOR

In [6]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(Generator, self).__init__()

        # Downsampling path
        self.down1 = DownSample(in_channels, 64)
        self.down2 = DownSample(64, 128)
        self.down3 = DownSample(128, 256)
        self.down4 = DownSample(256, 512)
        self.down5 = DownSample(512, 512)
        self.down6 = DownSample(512, 512)
        self.down7 = DownSample(512, 512)
        self.down8 = DownSample(512, 512)
        
        # Upsampling path
        self.up1 = Upsample(512, 512)
        self.up2 = Upsample(1024, 512)
        self.up3 = Upsample(1024, 512)
        self.up4 = Upsample(1024, 512)
        self.up5 = Upsample(1024, 256)
        self.up6 = Upsample(512, 128)
        self.up7 = Upsample(256, 64)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ZeroPad2d((1,0,1,0)),
            nn.Conv2d(128, 3, 4, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        # Downsampling
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        # Upsampling with skip connections
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)
        u8 = self.final(u7)
        
        return u8

CREATION OF DISCRIMINATOR

In [7]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.ZeroPad2d((1,0,1,0)),
            nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )
        
    def forward(self, img_A, img_B):
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

PRININT PROGRESS FUNCTION

In [8]:
def print_progress(current, total, prefix='', suffix='', decimals=1, length=50, fill='█', print_end="\r"):
    percent = ("{0:." + str(decimals) + "f}").format(100 * (current / float(total)))
    filled_length = int(length * current // total)
    bar = fill * filled_length + '-' * (length - filled_length)
    print(f'\r{prefix} |{bar}| {percent}% {suffix}', end=print_end)
    sys.stdout.flush()
    
    if current == total: 
        print()

DISPLAY OF THE IMAGES

In [9]:
def save_and_display_images(inputs, generated_images, targets, epoch):
    # Create output directory
    os.makedirs('output', exist_ok=True)
    
    inputs = inputs.detach().cpu()
    generated_images = generated_images.detach().cpu()
    targets = targets.detach().cpu()
    
    # Denormalize images from [-1,1] to [0,1]
    inputs = (inputs + 1) / 2.0
    generated_images = (generated_images + 1) / 2.0
    targets = (targets + 1) / 2.0
    
    # Select first 4 images from the batch
    num_images = min(4, inputs.size(0))
    
    # Create a figure with 3 columns per row
    plt.figure(figsize=(15, 5 * num_images))
    
    for i in range(num_images):
        # Input (Grayscale)
        plt.subplot(num_images, 3, 3*i + 1)
        plt.title(f'Epoch {epoch}: Input (Grayscale)')
        plt.imshow(transforms.ToPILImage()(inputs[i]))
        plt.axis('off')
        
        # Generated Image
        plt.subplot(num_images, 3, 3*i + 2)
        plt.title(f'Epoch {epoch}: Generated')
        plt.imshow(transforms.ToPILImage()(generated_images[i]))
        plt.axis('off')
        
        # Target Image
        plt.subplot(num_images, 3, 3*i + 3)
        plt.title(f'Epoch {epoch}: Target')
        plt.imshow(transforms.ToPILImage()(targets[i]))
        plt.axis('off')
    
    plt.tight_layout()
    
    # Save figure
    plt.savefig(f'output/images_epoch_{epoch}.png')
    print(f"Images for epoch {epoch} saved to output/images_epoch_{epoch}.png")
    
    # Show the plot 
    plt.show(block=False)
    plt.pause(2)  
    plt.close('all')

HYPERPARAMETERS

In [None]:
L1_lambda = 10
NUM_EPOCHS = 150  
lr = 0.0001
beta1 = 0.5
beta2 = 0.999
batch_size = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

    

DATASET AND DATALOADER

In [None]:
dataset_path = r"C:\Users\hp\OneDrive\Desktop\SAR_Dataset\Patching\Patching"  
dataset = SatelliteDataset(dataset_path)
dataloader_train = torch.utils.data.DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=0,
        pin_memory=True,
        drop_last=True
    )
total_batches = len(dataloader_train)
print(f"Total batches per epoch: {total_batches}")


INITIALISING GENERATOR AND DISCRIMINATOR

In [12]:
discriminator = Discriminator().to(device)
generator = Generator().to(device)

INITIALIZATION OF LOSS CALCULATION

In [13]:
adversarial_loss = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()

INITIALISING THE OPTIMIZER

In [14]:
discriminator_opt = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))
generator_opt = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
disc_losses = []
gen_losses = []

TRAINING LOOP

In [None]:
disc_losses = []
gen_losses = []

# Training loop (your existing code remains the same)
checkpoint_folder = "checkpoints_10"
os.makedirs(checkpoint_folder, exist_ok=True)

# Training loop
for epoch in range(NUM_EPOCHS + 1):
    # Reset epoch metrics
    total_disc_loss = 0
    total_gen_loss = 0
    
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
    
    # Iterate through batches
    for batch_idx, (inputs, targets) in enumerate(dataloader_train):
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Discriminator training
        discriminator_opt.zero_grad()
        
        # Real images
        real_output = discriminator(inputs, targets)
        real_labels = torch.ones_like(real_output)
        real_loss = adversarial_loss(real_output, real_labels)

        # Fake images
        generated_image = generator(inputs)
        fake_output = discriminator(inputs, generated_image.detach())
        fake_labels = torch.zeros_like(fake_output)
        fake_loss = adversarial_loss(fake_output, fake_labels)

        disc_loss = (real_loss + fake_loss) / 2
        disc_loss.backward()
        discriminator_opt.step()

        # Generator training
        generator_opt.zero_grad()
        fake_output = discriminator(inputs, generated_image)
        gen_labels = torch.ones_like(fake_output)
        
        # Adversarial loss
        adv_loss = adversarial_loss(fake_output, gen_labels)
        
        # L1 loss
        l1_reconstruction_loss = l1_loss(generated_image, targets)
        
        # Total generator loss
        gen_loss = adv_loss + L1_lambda * l1_reconstruction_loss
        gen_loss.backward()
        generator_opt.step()

        # Accumulate losses
        total_disc_loss += disc_loss.item()
        total_gen_loss += gen_loss.item()

        # Update progress bar
        print_progress(
            batch_idx + 1, 
            total_batches, 
            prefix='Training:', 
            suffix=f'Disc Loss: {total_disc_loss/(batch_idx+1):.4f} Gen Loss: {total_gen_loss/(batch_idx+1):.4f}'
        )

    # Periodic visualization and logging
    avg_disc_loss = total_disc_loss / total_batches
    avg_gen_loss = total_gen_loss / total_batches
    
    # Append losses to tracking lists
    disc_losses.append(avg_disc_loss)
    gen_losses.append(avg_gen_loss)
    
    print(f"\nEpoch {epoch + 1} Summary:")
    print(f"Average Discriminator Loss: {avg_disc_loss:.4f}")
    print(f"Average Generator Loss: {avg_gen_loss:.4f}")

    # Save and display images after each epoch
    save_and_display_images(inputs, generated_image, targets, epoch)

    # Save model checkpoints every 5 epochs
    if (epoch + 1) % 10 == 0:
        checkpoint = {
            'epoch': epoch + 1,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'generator_optimizer_state_dict': generator_opt.state_dict(),
            'discriminator_optimizer_state_dict': discriminator_opt.state_dict(),
            'disc_losses': disc_losses,
            'gen_losses': gen_losses
        }
        checkpoint_path = os.path.join(checkpoint_folder, f'checkpoint_epoch_{epoch + 1}.pth')
        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

GEN AND DISC LOSS VISUALIZATION

In [None]:
plt.figure(figsize=(15, 6))

# Discriminator Loss Subplot
plt.subplot(1, 2, 1)
plt.plot(range(1, len(disc_losses) + 1), disc_losses, label='Discriminator Loss', color='blue')
plt.title('Discriminator Loss per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Generator Loss Subplot
plt.subplot(1, 2, 2)
plt.plot(range(1, len(gen_losses) + 1), gen_losses, label='Generator Loss', color='red')
plt.title('Generator Loss per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Adjust layout and save
plt.tight_layout()
plt.savefig('output/training_losses.png')
plt.show()

SAVING THE MODEL

In [None]:
torch.save(generator.state_dict(), 'Colourized_model.pth')
print("Model saved successfully")