In [10]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import statistics
from tqdm import tqdm
import pickle
from PIL import Image

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")

# Clear GPU cache
torch.cuda.empty_cache()

# Define Gray transform
class Gray(object):
    def __call__(self, img):
        gray = img.convert('L')
        return gray

# Custom dataset
class CustomDataset(Dataset):
    def __init__(self, root_s1, root_s2, transform_s1=None, transform_s2=None):
        self.root_s1 = root_s1
        self.root_s2 = root_s2
        self.transform_s1 = transform_s1
        self.transform_s2 = transform_s2
        self.s1_images = sorted(os.listdir(root_s1))
        self.s2_images = sorted(os.listdir(root_s2))
        assert len(self.s1_images) == len(self.s2_images), "Mismatched number of images in s1 and s2 directories"

    def __len__(self):
        return len(self.s1_images)

    def __getitem__(self, idx):
        img_s1_path = os.path.join(self.root_s1, self.s1_images[idx])
        img_s2_path = os.path.join(self.root_s2, self.s2_images[idx])
        img_s1 = Image.open(img_s1_path).convert("RGB")
        img_s2 = Image.open(img_s2_path).convert("RGB")

        if self.transform_s1:
            img_s1 = self.transform_s1(img_s1)
        if self.transform_s2:
            img_s2 = self.transform_s2(img_s2)

        return img_s1, img_s2

# Load datasets
def load_datasets():
    SAR_transform = transforms.Compose([
        Gray(),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,), std=(0.5,))
    ])
    
    opt_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    
    dataset = CustomDataset(
        root_s1='/kaggle/input/sentinel12-image-pairs-segregated-by-terrain/v_2/agri/s1',
        root_s2='/kaggle/input/sentinel12-image-pairs-segregated-by-terrain/v_2/agri/s2',
        transform_s1=SAR_transform,
        transform_s2=opt_transform
    )
    
    train_loader = DataLoader(
        dataset,
        batch_size=32,  # Increased batch size
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    return train_loader

# Define Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.enc1 = self.conv_bn_relu(1, 64, kernel_size=5)
        self.enc2 = self.conv_bn_relu(64, 128, kernel_size=3, pool_kernel=4)
        self.enc3 = self.conv_bn_relu(128, 256, kernel_size=3, pool_kernel=2)
        self.enc4 = self.conv_bn_relu(256, 512, kernel_size=3, pool_kernel=2)

        self.dec1 = self.conv_bn_relu(512, 256, kernel_size=3, pool_kernel=-2, flag=True, enc=False)
        self.dec2 = self.conv_bn_relu(256+256, 128, kernel_size=3, pool_kernel=-2, flag=True, enc=False)
        self.dec3 = self.conv_bn_relu(128+128, 64, kernel_size=3, pool_kernel=-4, enc=False)
        self.dec4 = nn.Sequential(
            nn.Conv2d(64 + 64, 3, kernel_size=5, padding=2),
            nn.Tanh()
        )
  
    def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None, flag=None, enc=True):
        layers = []
        if pool_kernel is not None:
            if pool_kernel > 0:
                layers.append(nn.AvgPool2d(pool_kernel))
            elif pool_kernel < 0:
                layers.append(nn.UpsamplingNearest2d(scale_factor=-pool_kernel))
        layers.append(nn.Conv2d(in_ch, out_ch, kernel_size, padding=(kernel_size - 1) // 2))
        layers.append(nn.BatchNorm2d(out_ch))
        if flag is not None:
            layers.append(nn.Dropout2d(0.5))
        if enc:
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        else:
            layers.append(nn.ReLU(inplace=True))
        return nn.Sequential(*layers)
  
    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(x1)
        x3 = self.enc3(x2)
        x4 = self.enc4(x3)
        out = self.dec1(x4)
        out = self.dec2(torch.cat([out, x3], dim=1))
        out = self.dec3(torch.cat([out, x2], dim=1))
        out = self.dec4(torch.cat([out, x1], dim=1))
        return out

# Define Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = self.conv_bn_relu(4, 16, kernel_size=5, reps=1)
        self.conv2 = self.conv_bn_relu(16, 32, pool_kernel=4)
        self.conv3 = self.conv_bn_relu(32, 64, pool_kernel=2)
        self.out_patch = nn.Conv2d(64, 1, kernel_size=1)

    def conv_bn_relu(self, in_ch, out_ch, kernel_size=3, pool_kernel=None, reps=2):
        layers = []
        for i in range(reps):
            if i == 0 and pool_kernel is not None:
                layers.append(nn.AvgPool2d(pool_kernel))
            layers.append(nn.Conv2d(in_ch if i == 0 else out_ch,
                                  out_ch, kernel_size, padding=(kernel_size - 1) // 2))
            layers.append(nn.BatchNorm2d(out_ch))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv3(self.conv2(self.conv1(x)))
        return self.out_patch(out)

# Training function
def train():
    torch.backends.cudnn.benchmark = True

    model_G, model_D = Generator(), Discriminator()
    model_G, model_D = nn.DataParallel(model_G), nn.DataParallel(model_D)
    model_G, model_D = model_G.to(device), model_D.to(device)

    params_G = torch.optim.Adam(model_G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    params_D = torch.optim.Adam(model_D.parameters(), lr=0.0002, betas=(0.5, 0.999))

    ones = torch.ones(32, 1, 32, 32).to(device)  # Match batch size
    zeros = torch.zeros(32, 1, 32, 32).to(device)  # Match batch size
    bce_loss = nn.BCEWithLogitsLoss()
    mae_loss = nn.L1Loss()

    result = {"log_loss_G_sum": [], "log_loss_G_bce": [], "log_loss_G_mae": [], "log_loss_D": []}
    
    dataset = load_datasets()
    
    for epoch in range(100):
        log_loss_G_sum, log_loss_G_bce, log_loss_G_mae, log_loss_D = [], [], [], []

        with tqdm(total=len(dataset), desc=f"Epoch {epoch + 1}/100") as pbar:
            for input_gray, real_color in dataset:
                batch_len = len(real_color)
                real_color, input_gray = real_color.to(device), input_gray.to(device)

                fake_color = model_G(input_gray)
                fake_color_tensor = fake_color.detach()
                LAMBD = 100.0 
                out = model_D(torch.cat([fake_color, input_gray], dim=1))
                loss_G_bce = bce_loss(out, ones[:batch_len])
                loss_G_mae = LAMBD * mae_loss(fake_color, real_color)
                loss_G_sum = loss_G_bce + loss_G_mae
                log_loss_G_bce.append(loss_G_bce.item())
                log_loss_G_mae.append(loss_G_mae.item())
                log_loss_G_sum.append(loss_G_sum.item())
              
                params_D.zero_grad()
                params_G.zero_grad()
                loss_G_sum.backward()
                params_G.step()

                ### Discriminator
                real_out = model_D(torch.cat([real_color, input_gray], dim=1))
                loss_D_real = bce_loss(real_out, ones[:batch_len])
                
                fake_out = model_D(torch.cat([fake_color_tensor, input_gray], dim=1))
                loss_D_fake = bce_loss(fake_out, zeros[:batch_len])
                
                loss_D = loss_D_real + loss_D_fake
                log_loss_D.append(loss_D.item())
                
                params_D.zero_grad()
                params_G.zero_grad()
                loss_D.backward()
                params_D.step()
                
                pbar.update(1)
                pbar.set_postfix({
                    "loss_G_sum": statistics.mean(log_loss_G_sum),
                    "loss_G_bce": statistics.mean(log_loss_G_bce),
                    "loss_G_mae": statistics.mean(log_loss_G_mae),
                    "loss_D": statistics.mean(log_loss_D)
                })

        result["log_loss_G_sum"].append(statistics.mean(log_loss_G_sum))
        result["log_loss_G_bce"].append(statistics.mean(log_loss_G_bce))
        result["log_loss_G_mae"].append(statistics.mean(log_loss_G_mae))
        result["log_loss_D"].append(statistics.mean(log_loss_D))
        
        if not os.path.exists("SARtoOpt"):
            os.mkdir("SARtoOpt")
        
        torchvision.utils.save_image(input_gray[:min(batch_len, 100)],
                                f"SARtoOpt/gray_epoch_{epoch:03}.png",
                                normalize=True)
        torchvision.utils.save_image(fake_color_tensor[:min(batch_len, 100)],
                                f"SARtoOpt/fake_epoch_{epoch:03}.png",
                                normalize=True)
        torchvision.utils.save_image(real_color[:min(batch_len, 100)],
                                f"SARtoOpt/real_epoch_{epoch:03}.png",
                                normalize=True)

        if not os.path.exists("SARtoOpt/models"):
            os.mkdir("SARtoOpt/models")
        if epoch % 10 == 0 or epoch == 99:
            torch.save(model_G.state_dict(), f"SARtoOpt/models/gen_{epoch:03}.pt")                        
            torch.save(model_D.state_dict(), f"SARtoOpt/models/dis_{epoch:03}.pt")                        
        
    with open("SARtoOpt/logs.pkl", "wb") as fp:
        pickle.dump(result, fp)
    
    plt.plot(result["log_loss_G_sum"], color="red")
    plt.plot(result["log_loss_G_bce"], color="blue")
    plt.plot(result["log_loss_G_mae"], color="green")
    plt.plot(result["log_loss_D"], color="black")
    plt.show()

if __name__ == "__main__":
    train()


Let's use 2 GPUs!


Epoch 2/100:  93%|█████████▎| 116/125 [09:16<00:43,  4.81s/it, loss_G_sum=32, loss_G_bce=1.09, loss_G_mae=30.9, loss_D=0.871]  