# Use This, if you wanna train all three Generator, Discriminator & Classifier models.

In [None]:
# Necessary imports
import os
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from torch.cuda.amp import GradScaler, autocast

# Kaggle paths for dataset and checkpoint
base_path = r"/kaggle/input/sentinel12-image-pairs-segregated-by-terrain/v_2"
checkpoint_dir = r"/kaggle/working/SAR_Model"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

# Dataset class to handle both SAR (grayscale) and optical (RGB) image pairs
class SARImageDataset(Dataset):
    def __init__(self, base_path, transform=None):
        self.s1_paths = []
        self.s2_paths = []
        self.labels = []

        categories = ['agri', 'barrenland', 'grassland', 'urban']
        category_to_label = {cat: i for i, cat in enumerate(categories)}
        
        for category in categories:
            s1_folder = os.path.join(base_path, category, 's1')
            s2_folder = os.path.join(base_path, category, 's2')
            self.s1_paths += sorted([os.path.join(s1_folder, img) for img in os.listdir(s1_folder)])
            self.s2_paths += sorted([os.path.join(s2_folder, img) for img in os.listdir(s2_folder)])
            self.labels += [category_to_label[category]] * len(os.listdir(s1_folder))
        
        self.transform = transform

    def __len__(self):
        return len(self.s1_paths)

    def __getitem__(self, idx):
        s1_image = Image.open(self.s1_paths[idx]).convert('L')  # Grayscale SAR image
        s2_image = Image.open(self.s2_paths[idx]).convert('RGB')  # Color Optical image
        label = self.labels[idx]  # Numeric label

        if self.transform:
            s1_image = self.transform(s1_image)
            s2_image = color_transform(s2_image)

        return s1_image, s2_image, label  # Return label with images

# Define data transformations with normalization
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize to 256x256
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize grayscale (SAR) images
])

color_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize color (Optical) images
])

# Load the dataset and define the DataLoader
train_dataset = SARImageDataset(base_path, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)  # Reduced batch size to 8 for memory efficiency

# Define DCT Residual Block
class DCTResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DCTResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.dct = nn.Conv2d(in_channels, out_channels, kernel_size=1)  # Placeholder for DCT
        
    def forward(self, x):
        dct_features = self.dct(x)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return self.relu(x + dct_features)

# Define Light-ASPP
class LightASPP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(LightASPP, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.atrous_block1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, dilation=1)
        self.atrous_block2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=6, dilation=6)
        self.atrous_block3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=12, dilation=12)
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.atrous_block1(x)
        x3 = self.atrous_block2(x)
        x4 = self.atrous_block3(x)
        x5 = self.global_avg_pool(x)
        x5 = self.conv2(x5)
        return x1 + x2 + x3 + x4 + x5

# Define CCMB
class CCMB(nn.Module):
    def __init__(self, in_channels):
        super(CCMB, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(in_channels, 3)  # Mapping to 3 color channels (RGB)

    def forward(self, x):
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        color_info = self.fc(x)
        return color_info.view(-1, 3, 1, 1)

# Define Generator with skip connections (U-Net style)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Encoder
        self.encoder1 = DCTResidualBlock(1, 32)
        self.encoder2 = DCTResidualBlock(32, 64)
        self.encoder3 = LightASPP(64, 128)
        self.encoder4 = DCTResidualBlock(128, 256)

        # Decoder with skip connections
        self.decoder1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.decoder2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.decoder3 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
        self.final_layer = nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1)  # Final layer to output RGB

    def forward(self, x):
        # Encode
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        encoded = self.encoder4(e3)
        
        # Decoder
        d1 = F.relu(self.decoder1(encoded))
        d2 = F.relu(self.decoder2(d1))
        d3 = F.relu(self.decoder3(d2))
        decoded = self.final_layer(d3)

        # Resize to 256x256 if needed
        decoded = F.interpolate(decoded, size=(256, 256), mode='bilinear', align_corners=True)

        return decoded

# Define Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)

# Simple Classifier for image categorization
class Classifier(nn.Module):
    def __init__(self, num_classes=4):
        super(Classifier, self).__init__()
        
        # First Convolutional block
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(2, 2)  # Reduces size to 128x128
        
        # Second Convolutional block
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2, 2)  # Reduces size to 64x64
        
        # Third Convolutional block
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d(2, 2)  # Reduces size to 32x32
        
        # Fully Connected Layers
        self.fc1 = nn.Linear(128 * 32 * 32, 256)
        self.fc2 = nn.Linear(256, num_classes)
        
        # Dropout to prevent overfitting
        self.dropout = nn.Dropout(0.4)

    def forward(self, x):
        # Apply convolutional blocks
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool3(x)
        
        # Flatten for fully connected layers
        x = x.view(x.size(0), -1)  # Flatten the tensor
        
        # Fully connected layers with dropout
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

# Perceptual Loss (VGG)
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features[:16].eval().cuda()
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg

    def forward(self, x, y):
        x_vgg = self.vgg(x)
        y_vgg = self.vgg(y)
        return F.mse_loss(x_vgg, y_vgg)
        
# Define the unnormalize function
def unnormalize(tensor, mean, std):
    """Unnormalize a tensor image to [0, 1]."""
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)  # Inverse of normalization
    return tensor.clamp(0, 1)

# Loss functions
criterion_GAN = nn.BCEWithLogitsLoss()
criterion_pixelwise = nn.L1Loss()
criterion_classifier = nn.CrossEntropyLoss()  # For classification task

# Instantiate models
generator = Generator().cuda()
discriminator = Discriminator().cuda()
classifier = Classifier().cuda()
vgg_loss = VGGPerceptualLoss().cuda()

# Optimizers for generator, discriminator, and classifier
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_C = optim.Adam(classifier.parameters(), lr=0.0002)

# Mixed precision training with GradScaler
scaler = GradScaler()

# Load pre-trained generator and discriminator if resuming
start_epoch = 36
checkpoint_path = os.path.join(r'/kaggle/input/dct/pytorch/default/1', f'model_epoch_{start_epoch}.pth')
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    generator.load_state_dict(checkpoint['generator'])
    discriminator.load_state_dict(checkpoint['discriminator'])
    classifier.load_state_dict(checkpoint['classifier'])
    print("Loaded pretrained models from previous epoch checkpoint :)")

# Training loop
epochs = 100
sample_interval = 500

for epoch in range(epochs):
    for i, (s1, s2, label) in enumerate(train_loader):
        s1, s2, label = s1.cuda(), s2.cuda(), label.cuda()

        # ------------------ Train Generator ------------------
        optimizer_G.zero_grad()
        with torch.amp.autocast(device_type='cuda'):
            fake_s2 = generator(s1)
            pred_fake = discriminator(fake_s2)
            gan_loss = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
            pixel_loss = criterion_pixelwise(fake_s2, s2)
            perceptual_loss = vgg_loss(fake_s2, s2)
            g_loss = gan_loss + pixel_loss + perceptual_loss
            
        scaler.scale(g_loss).backward()
        scaler.step(optimizer_G)
        scaler.update()

        # ------------------- Train Discriminator -------------------
        optimizer_D.zero_grad()
        with torch.amp.autocast(device_type='cuda'):
            pred_real = discriminator(s2)
            pred_fake = discriminator(fake_s2.detach())
            real_loss = criterion_GAN(pred_real, torch.ones_like(pred_real))
            fake_loss = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))
            d_loss = (real_loss + fake_loss) / 2
        scaler.scale(d_loss).backward()
        
        scaler.step(optimizer_D)
        scaler.update()

        # ------------------- Train Classifier -------------------
        optimizer_C.zero_grad()
        with torch.amp.autocast(device_type='cuda'):  # Specify device_type:
            pred_class = classifier(fake_s2.detach())  # Detach here
            class_loss = criterion_classifier(pred_class, label)

        scaler.scale(class_loss).backward()
        scaler.step(optimizer_C)
        scaler.update()

        # Log progress and visualize images
        if i % sample_interval == 0:
            print(f"[Epoch {start_epoch}/{epochs}] [Batch {i}/{len(train_loader)}] [G loss: {g_loss.item():.4f}] [D loss: {d_loss.item():.4f}] [Class loss: {class_loss.item():.4f}]")

            # Show sample images with predicted and actual classes as titles
            fig, axs = plt.subplots(1, 3, figsize=(12, 4))
            axs[0].imshow(s1[0].cpu().permute(1, 2, 0), cmap='gray')
            axs[0].set_title('Grayscale SAR Image')
            generated_image_unnormalized = unnormalize(fake_s2[0].cpu().detach(), 
                                                        mean=[0.5, 0.5, 0.5], 
                                                        std=[0.5, 0.5, 0.5])
            axs[1].imshow(generated_image_unnormalized.permute(1, 2, 0))
            pred_label = torch.argmax(pred_class[0]).item()
            axs[1].set_title(f'Predicted: {pred_label}, Actual: {label[0].item()}')
            s2_unnormalized = unnormalize(s2[0].cpu().detach(), 
                                          mean=[0.5, 0.5, 0.5], 
                                          std=[0.5, 0.5, 0.5])
            axs[2].imshow(s2_unnormalized.permute(1, 2, 0))
            axs[2].set_title('Ground Truth')
            plt.show()

    # Save model checkpoints
    checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{start_epoch}.pth')
    torch.save({
        'generator': generator.state_dict(),
        'discriminator': discriminator.state_dict(),
        'classifier': classifier.state_dict(),
    }, checkpoint_path)
    print(f"Checkpoint saved for epoch {start_epoch}.")
    start_epoch+=1
    
print('Training Completed, Sir!')

# Use This, if you only wanna train Generator & Discriminator.

In [None]:
# Necessary imports
import os
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from torch.amp import GradScaler, autocast

# Kaggle paths for dataset and checkpoint
base_path = r"/kaggle/input/sentinel12-image-pairs-segregated-by-terrain/v_2"
checkpoint_dir = r"/kaggle/working/SAR_Model"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

# Dataset class to handle both SAR (grayscale) and optical (RGB) image pairs
class SARImageDataset(Dataset):
    def __init__(self, base_path, transform=None):
        self.s1_paths = []
        self.s2_paths = []
        self.labels = []

        categories = ['agri', 'barrenland', 'grassland', 'urban']
        category_to_label = {cat: i for i, cat in enumerate(categories)}
        
        for category in categories:
            s1_folder = os.path.join(base_path, category, 's1')
            s2_folder = os.path.join(base_path, category, 's2')
            self.s1_paths += sorted([os.path.join(s1_folder, img) for img in os.listdir(s1_folder)])
            self.s2_paths += sorted([os.path.join(s2_folder, img) for img in os.listdir(s2_folder)])
            self.labels += [category_to_label[category]] * len(os.listdir(s1_folder))
        
        self.transform = transform

    def __len__(self):
        return len(self.s1_paths)

    def __getitem__(self, idx):
        s1_image = Image.open(self.s1_paths[idx]).convert('L')  # Grayscale SAR image
        s2_image = Image.open(self.s2_paths[idx]).convert('RGB')  # Color Optical image
        label = self.labels[idx]  # Numeric label

        if self.transform:
            s1_image = self.transform(s1_image)
            s2_image = color_transform(s2_image)

        return s1_image, s2_image, label  # Return label with images

# Define data transformations with normalization
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize to 256x256
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize grayscale (SAR) images
])

color_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize color (Optical) images
])

# Load the dataset and define the DataLoader
train_dataset = SARImageDataset(base_path, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)  # Reduced batch size to 8 for memory efficiency

# Define DCT Residual Block
class DCTResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DCTResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.dct = nn.Conv2d(in_channels, out_channels, kernel_size=1)  # Placeholder for DCT
        
    def forward(self, x):
        dct_features = self.dct(x)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return self.relu(x + dct_features)

# Define Light-ASPP
class LightASPP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(LightASPP, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.atrous_block1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, dilation=1)
        self.atrous_block2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=6, dilation=6)
        self.atrous_block3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=12, dilation=12)
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.atrous_block1(x)
        x3 = self.atrous_block2(x)
        x4 = self.atrous_block3(x)
        x5 = self.global_avg_pool(x)
        x5 = self.conv2(x5)
        return x1 + x2 + x3 + x4 + x5

# Define CCMB
class CCMB(nn.Module):
    def __init__(self, in_channels):
        super(CCMB, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(in_channels, 3)  # Mapping to 3 color channels (RGB)

    def forward(self, x):
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        color_info = self.fc(x)
        return color_info.view(-1, 3, 1, 1)

# Define Generator with skip connections (U-Net style)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Encoder
        self.encoder1 = DCTResidualBlock(1, 32)
        self.encoder2 = DCTResidualBlock(32, 64)
        self.encoder3 = LightASPP(64, 128)
        self.encoder4 = DCTResidualBlock(128, 256)

        # Decoder with skip connections
        self.decoder1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.decoder2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.decoder3 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
        self.final_layer = nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1)  # Final layer to output RGB

    def forward(self, x):
        # Encode
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        encoded = self.encoder4(e3)
        
        # Decoder
        d1 = F.relu(self.decoder1(encoded))
        d2 = F.relu(self.decoder2(d1))
        d3 = F.relu(self.decoder3(d2))
        decoded = self.final_layer(d3)

        # Resize to 256x256 if needed
        decoded = F.interpolate(decoded, size=(256, 256), mode='bilinear', align_corners=True)

        return decoded

# Define Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)

# Simple Classifier for image categorization
class Classifier(nn.Module):
    def __init__(self, num_classes=4):
        super(Classifier, self).__init__()
        
        # First Convolutional block
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(2, 2)  # Reduces size to 128x128
        
        # Second Convolutional block
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2, 2)  # Reduces size to 64x64
        
        # Third Convolutional block
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d(2, 2)  # Reduces size to 32x32
        
        # Fully Connected Layers
        self.fc1 = nn.Linear(128 * 32 * 32, 256)
        self.fc2 = nn.Linear(256, num_classes)
        
        # Dropout to prevent overfitting
        self.dropout = nn.Dropout(0.4)

    def forward(self, x):
        # Apply convolutional blocks
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool3(x)
        
        # Flatten for fully connected layers
        x = x.view(x.size(0), -1)  # Flatten the tensor
        
        # Fully connected layers with dropout
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

# Perceptual Loss (VGG)
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features[:16].eval().cuda()
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg

    def forward(self, x, y):
        x_vgg = self.vgg(x)
        y_vgg = self.vgg(y)
        return F.mse_loss(x_vgg, y_vgg)
        
# Define the unnormalize function
def unnormalize(tensor, mean, std):
    """Unnormalize a tensor image to [0, 1]."""
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor.clamp(0, 1) 

# Loss functions
criterion_GAN = nn.BCEWithLogitsLoss()
criterion_pixelwise = nn.L1Loss()

# Instantiate models
generator = Generator().cuda()
discriminator = Discriminator().cuda()
classifier = Classifier().cuda()
vgg_loss = VGGPerceptualLoss().cuda()

# Optimizers for generator, discriminator, and classifier
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Mixed precision training with GradScaler
scaler = GradScaler()

# Load pre-trained generator and discriminator if resuming
start_epoch = 36
checkpoint_path = os.path.join(r'/kaggle/input/dct/pytorch/default/1', f'model_epoch_{start_epoch-1}.pth')
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path,weights_only=True)
    generator.load_state_dict(checkpoint['generator'])
    discriminator.load_state_dict(checkpoint['discriminator'])
    classifier.load_state_dict(checkpoint['classifier'])
    print(f"Loaded pretrained models from previous epoch {start_epoch-1} checkpoint :)")

# Training loop
epochs = 100
sample_interval = 10

for epoch in range(epochs):
    for i, (s1, s2, label) in enumerate(train_loader):
        s1, s2, label = s1.cuda(), s2.cuda(), label.cuda()

        # ------------------ Train Generator ------------------
        optimizer_G.zero_grad()
        with autocast('cuda'):
            fake_s2 = generator(s1)
            pred_fake = discriminator(fake_s2)
            gan_loss = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
            pixel_loss = criterion_pixelwise(fake_s2, s2)
            perceptual_loss = vgg_loss(fake_s2, s2)
            g_loss = gan_loss + 100 * pixel_loss + 0.1 * perceptual_loss
            
        scaler.scale(g_loss).backward()
        scaler.step(optimizer_G)
        scaler.update()

        # ------------------- Train Discriminator -------------------
        optimizer_D.zero_grad()
        with autocast('cuda'):
            pred_real = discriminator(s2)
            pred_fake = discriminator(fake_s2.detach())
            real_loss = criterion_GAN(pred_real, torch.ones_like(pred_real))
            fake_loss = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))
            d_loss = (real_loss + fake_loss) / 2
            
        scaler.scale(d_loss).backward()
        scaler.step(optimizer_D)
        scaler.update()

        pred_class = classifier(fake_s2.detach())
        
        # Log progress and visualize images
        if i % sample_interval == 0:
            print(f"[Epoch {start_epoch}/{epochs}] [Batch {i}/{len(train_loader)}] [G loss: {g_loss.item():.4f}] [D loss: {d_loss.item():.4f}]")

            # Show sample images with predicted and actual classes as titles
            fig, axs = plt.subplots(1, 3, figsize=(12, 4))
            axs[0].imshow(s1[0].cpu().permute(1, 2, 0), cmap='gray')
            axs[0].set_title('Grayscale SAR Image')
            generated_image_unnormalized = unnormalize(fake_s2[0].cpu().detach(), 
                                                        mean=[0.5, 0.5, 0.5], 
                                                        std=[0.5, 0.5, 0.5])
            axs[1].imshow(generated_image_unnormalized.permute(1, 2, 0))
            pred_label = torch.argmax(pred_class[0]).item()
            axs[1].set_title(f'Predicted: {pred_label}, Actual: {label[0].item()}')
            s2_unnormalized = unnormalize(s2[0].cpu().detach(), 
                                          mean=[0.5, 0.5, 0.5], 
                                          std=[0.5, 0.5, 0.5])
            axs[2].imshow(s2_unnormalized.permute(1, 2, 0))
            axs[2].set_title('Ground Truth')
            plt.show()

    # Save model checkpoints
    checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{start_epoch}.pth')
    torch.save({
        'generator': generator.state_dict(),
        'discriminator': discriminator.state_dict(),
        'classifier': classifier.state_dict(),
    }, checkpoint_path)
    print(f"Checkpoint saved for epoch {start_epoch}.")
    start_epoch+=1
    
print('Training Completed, Sir!')