In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip /content/drive/MyDrive/sardata_small.zip -d /content/

In [5]:
import os
# import cv2
import torch

import pandas as pd
# import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt

from torch import manual_seed
# from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
# from torchvision.io import decode_image
from PIL import Image
from torchvision import transforms
from torch import nn
from torchvision.models import vgg19, VGG19_Weights
# from tqdm import tqdm

In [6]:
CSV_FILE_PATH = './sardata_small.csv'
IMAGE_DIR_SAR = './sardata_small/s1'
IMAGE_DIR_COL = './sardata_small/s2'
IMAGE_SIZE = (256, 256)
BATCH_SIZE = 16
SEED = 42
manual_seed(SEED)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else "cpu")
print(f"device: {DEVICE}")

device: cuda


In [7]:
data_df = pd.read_csv(CSV_FILE_PATH)

# train_df, test_df = train_test_split(data_df, test_size=0.15, random_state=SEED, shuffle=True, stratify=data_df['type'])
# print(train_df.groupby('type').count())
# print(test_df.groupby('type').count())

In [8]:
class SarColorDataset(Dataset):
    def __init__(self, data_df, image_dir_sar, image_dir_col, transform_sar=None, transform_col=None):
        self.data_df = data_df
        self.image_dir_sar = image_dir_sar
        self.image_dir_col = image_dir_col
        self.transform_sar = transform_sar
        self.transform_col = transform_col

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, index):
        row = self.data_df.iloc[index]
        label = row['type']

        image_path_sar = os.path.join(self.image_dir_sar, row['s1_image'])
        image_path_col = os.path.join(self.image_dir_col, row['s2_image'])

        # image_sar = decode_image(image_path_sar, mode='GRAY')
        # image_col = decode_image(image_path_col, mode='RGB')

        # image_sar = cv2.imread(image_path_sar, cv2.IMREAD_GRAYSCALE)
        # image_col = cv2.imread(image_path_col, cv2.IMREAD_COLOR_RGB)

        image_sar = Image.open(image_path_sar).convert('L')
        image_col = Image.open(image_path_col).convert('RGB')

        if self.transform_sar:
            image_sar = self.transform_sar(image_sar)

        if self.transform_col:
            image_col = self.transform_col(image_col)

        return image_sar, image_col, label

In [9]:
transform_sar = transforms.Compose([
    transforms.Resize(size=IMAGE_SIZE, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,), inplace=False),
])

transform_col = transforms.Compose([
    transforms.Resize(size=IMAGE_SIZE, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), inplace=False),
])

In [10]:
dataset = SarColorDataset(
    data_df=data_df,
    image_dir_sar=IMAGE_DIR_SAR,
    image_dir_col=IMAGE_DIR_COL,
    transform_sar=transform_sar,
    transform_col=transform_col,
)

dataloader = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

In [11]:
len(dataloader)

125

In [12]:
class DownSample(nn.Module):
    def __init__(self, inp_c, out_c, kernel_size=4, stride=2, padding=1, use_bias=True, normalization='batch'):
        super(DownSample, self).__init__()

        self.down = nn.Sequential(
            nn.Conv2d(in_channels=inp_c, out_channels=out_c, kernel_size=kernel_size, stride=stride, padding=padding, bias=(not normalization) and use_bias),
        )

        if (normalization == 'batch'):
            self.down.append(nn.BatchNorm2d(num_features=out_c, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True))
        elif (normalization == 'instance'):
            self.down.append(nn.InstanceNorm2d(num_features=out_c, eps=1e-5, momentum=0.1, affine=False, track_running_stats=False))

        self.down.append(nn.LeakyReLU(negative_slope=0.2, inplace=False))

    def forward(self, x):
        x = self.down(x)
        return x

class UpSample(nn.Module):
    # def __init__(self, inp_c, out_c, kernel_size=4, stride=1, padding=0, use_bias = True, normalization='batch', apply_dropout=False, dropout_rate=0.5):
    def __init__(self, inp_c, out_c, kernel_size=4, stride=2, padding=1, use_bias = True, normalization='batch', apply_dropout=False, dropout_rate=0.5):
        super(UpSample, self).__init__()

        # self.up = nn.Sequential(
        #     nn.Upsample(scale_factor=2, mode='bilinear'),
        #     nn.ZeroPad2d((2,1,2,1)),
        #     nn.Conv2d(in_channels=inp_c, out_channels=out_c, kernel_size=kernel_size, stride=stride, padding=padding, bias=(not normalization) and use_bias),
        # )

        self.up = nn.Sequential(
            nn.ConvTranspose2d(in_channels=inp_c, out_channels=out_c, kernel_size=kernel_size, stride=stride, padding=padding, bias=(not normalization) and use_bias),
        )

        if (normalization == 'batch'):
            self.up.append(nn.BatchNorm2d(num_features=out_c, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True))
        elif (normalization == 'isinstance'):
            self.up.append(nn.InstanceNorm2d(num_features=out_c, eps=1e-5, momentum=0.1, affine=False, track_running_stats=False))

        if apply_dropout:
            self.up.append(nn.Dropout(p=dropout_rate, inplace=False))

        self.up.append(nn.ReLU(inplace=False))

    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        return x

class Generator(nn.Module):
    def __init__(self, in_channels=1, out_channels=3):
        super(Generator, self).__init__()

        self.down_stack = nn.ModuleList([
            DownSample(inp_c=in_channels, out_c=64, normalization=None),
            DownSample(inp_c= 64, out_c=128),
            DownSample(inp_c=128, out_c=256),
            DownSample(inp_c=256, out_c=512),
            DownSample(inp_c=512, out_c=512),
            DownSample(inp_c=512, out_c=512),
            DownSample(inp_c=512, out_c=512),
            DownSample(inp_c=512, out_c=512, normalization=None),
            ])

        self.up_stack = nn.ModuleList([
            UpSample(inp_c= 512, out_c=512), # removed dropout layers
            UpSample(inp_c=1024, out_c=512),
            UpSample(inp_c=1024, out_c=512),
            UpSample(inp_c=1024, out_c=512),
            UpSample(inp_c=1024, out_c=256),
            UpSample(inp_c= 512, out_c=128),
            UpSample(inp_c= 256, out_c= 64),
            ])

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.ZeroPad2d((2,1,2,1)),
            nn.Conv2d(in_channels=128, out_channels=out_channels, kernel_size=4, stride=1, padding=0, bias=True),
            nn.Tanh()
        )

    def forward(self, x):

        skips = []
        for layer in self.down_stack:
            x = layer(x)
            skips.append(x)

        skips.pop()
        skips = skips[::-1]

        for layer, skip in zip(self.up_stack, skips):
            x = layer(x, skip)

        x = self.final(x)
        return x

class Discriminator(nn.Module):
    def __init__(self, in_channels=4):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            DownSample(inp_c=in_channels, out_c=64, normalization=None),
            DownSample(inp_c= 64, out_c=128),
            DownSample(inp_c=128, out_c=256),
            DownSample(inp_c=256, out_c=512, stride=1, padding=1),
            nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=True),
        )

        # self.model = nn.Sequential(
        #     DownSample(inp_c=in_channels, out_c=64, normalization=None),
        #     DownSample(inp_c= 64, out_c=128),
        #     DownSample(inp_c=128, out_c=256, padding=0),
        #     nn.Conv2d(in_channels=256, out_channels=1, kernel_size=4, stride=1, padding=1, bias=True),
        # )

    def forward(self, sar_img, col_img):
        x = torch.cat([sar_img, col_img], dim=1)
        x = self.model(x)
        return x

In [13]:
generator = Generator(in_channels=1, out_channels=3).to(device=DEVICE)
discriminator = Discriminator(in_channels=4).to(device=DEVICE)

In [14]:
def initialize_weights(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

In [15]:
generator.apply(initialize_weights)

Generator(
  (down_stack): ModuleList(
    (0): DownSample(
      (down): Sequential(
        (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (1): LeakyReLU(negative_slope=0.2)
      )
    )
    (1): DownSample(
      (down): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (2): DownSample(
      (down): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (3): DownSample(
      (down): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, aff

In [16]:
discriminator.apply(initialize_weights)

Discriminator(
  (model): Sequential(
    (0): DownSample(
      (down): Sequential(
        (0): Conv2d(4, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (1): LeakyReLU(negative_slope=0.2)
      )
    )
    (1): DownSample(
      (down): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (2): DownSample(
      (down): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (3): DownSample(
      (down): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affi

In [None]:
# from torchsummary import summary

# summary(generator, input_size=(1, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           1,088
         LeakyReLU-2         [-1, 64, 128, 128]               0
        DownSample-3         [-1, 64, 128, 128]               0
            Conv2d-4          [-1, 128, 64, 64]         131,072
       BatchNorm2d-5          [-1, 128, 64, 64]             256
         LeakyReLU-6          [-1, 128, 64, 64]               0
        DownSample-7          [-1, 128, 64, 64]               0
            Conv2d-8          [-1, 256, 32, 32]         524,288
       BatchNorm2d-9          [-1, 256, 32, 32]             512
        LeakyReLU-10          [-1, 256, 32, 32]               0
       DownSample-11          [-1, 256, 32, 32]               0
           Conv2d-12          [-1, 512, 16, 16]       2,097,152
      BatchNorm2d-13          [-1, 512, 16, 16]           1,024
        LeakyReLU-14          [-1, 512,

In [None]:
# summary(discriminator, input_size=[(1, 256, 256), (3, 256, 256)])

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           4,160
         LeakyReLU-2         [-1, 64, 128, 128]               0
        DownSample-3         [-1, 64, 128, 128]               0
            Conv2d-4          [-1, 128, 64, 64]         131,072
       BatchNorm2d-5          [-1, 128, 64, 64]             256
         LeakyReLU-6          [-1, 128, 64, 64]               0
        DownSample-7          [-1, 128, 64, 64]               0
            Conv2d-8          [-1, 256, 32, 32]         524,288
       BatchNorm2d-9          [-1, 256, 32, 32]             512
        LeakyReLU-10          [-1, 256, 32, 32]               0
       DownSample-11          [-1, 256, 32, 32]               0
           Conv2d-12          [-1, 512, 31, 31]       2,097,152
      BatchNorm2d-13          [-1, 512, 31, 31]           1,024
        LeakyReLU-14          [-1, 512,

In [17]:
bce_loss = nn.BCEWithLogitsLoss().to(device=DEVICE)
mae_loss = nn.L1Loss().to(device=DEVICE)

feature_extractor = vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features[:27].to(DEVICE).eval()
for param in feature_extractor.parameters():
      param.requires_grad = False

def extract_features(img):
      img = (img + 1) / 2
      img = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))(img)
      return feature_extractor(img)

def generator_loss(generated, target, patch_fake, lambda_image=100.0, lambda_perceptual=20.0):
    gen_f, tar_f = extract_features(generated), extract_features(target)

    bce_loss_G = bce_loss(patch_fake, torch.full_like(patch_fake, fill_value=0.9, device=DEVICE))
    img_loss_G = mae_loss(generated, target)
    per_loss_G = mae_loss(gen_f, tar_f)

    loss_G = bce_loss_G + lambda_image * img_loss_G + lambda_perceptual * per_loss_G

    return loss_G, bce_loss_G.item(), img_loss_G.item(), per_loss_G.item()

def discriminator_loss(patch_fake, patch_valid):
    valid_loss_D = bce_loss(patch_valid, torch.full_like(patch_valid, fill_value=0.9, device=DEVICE))
    fake_loss_D = bce_loss(patch_fake, torch.full_like(patch_fake, fill_value=0.1, device=DEVICE))

    loss_D = 0.5 * (valid_loss_D + fake_loss_D)

    return  loss_D, valid_loss_D.item(), fake_loss_D.item()

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:03<00:00, 174MB/s]


In [18]:
optimizer_G = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-5, betas=(0.5, 0.999))

In [19]:
def visualize_results(sar, fake, col):
    """
    Display SAR (grayscale), FAKE (RGB), and COL (RGB) images side by side.

    Args:
        sar (torch.Tensor): Grayscale image (1, H, W), range [-1, 1].
        fake (torch.Tensor): RGB image (3, H, W), range [-1, 1].
        col (torch.Tensor): RGB image (3, H, W), range [-1, 1].
    """
    def tensor_to_image(tensor, cmap=None):
        tensor = (tensor.clamp(-1, 1) + 1) / 2  # Normalize [-1,1] to [0,1]
        tensor = tensor.cpu().detach().numpy()  # Convert to NumPy
        if tensor.shape[0] == 1:  # Grayscale (1, H, W) -> (H, W)
            return tensor[0], "gray"
        return tensor.transpose(1, 2, 0), None  # RGB (3, H, W) -> (H, W, 3)

    images = [sar, fake, col]
    titles = ["SAR", "Generated", "Optical"]

    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    for ax, img, title in zip(axes, images, titles):
        img, cmap = tensor_to_image(img)
        ax.imshow(img, cmap=cmap)
        ax.set_title(title)

    plt.show()

In [None]:
START_EPOCH = 1
EPOCHS_TO_TRAIN = 100
LAST_EPOCH = START_EPOCH + EPOCHS_TO_TRAIN - 1
STEPS_PER_EPOCH = len(dataloader)
CHECKPOINT_DIR = '/content/drive/MyDrive/SAR_COLOR/checkpoints2'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

loss_Gs = []
loss_Ds = []

for epoch in range(START_EPOCH, LAST_EPOCH + 1):

    generator.train()
    discriminator.train()

    epoch_loss_G = 0.0
    epoch_loss_D = 0.0

    # progress_bar = tqdm(enumerate(dataloader), desc=f'Epoch [{epoch:03}/{LAST_EPOCH:03}]', leave=True)

    print(f'\nEpoch [{epoch:03}/{LAST_EPOCH:03}]')
    print('-' * 52)

    # for i, (sar, col, _) in tqdm:
    for i, (sar, col, _) in enumerate(dataloader):

        sar = sar.to(DEVICE)
        col = col.to(DEVICE)

        fake = generator(sar)

        for param in discriminator.parameters():
            param.requires_grad = True

        optimizer_D.zero_grad()
        patch_valid = discriminator(sar, col)
        patch_fake = discriminator(sar, fake.detach())
        loss_D, valid_loss_D, fake_loss_D = discriminator_loss(patch_fake, patch_valid)
        loss_D.backward()
        # torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
        optimizer_D.step()

        for param in discriminator.parameters():
            param.requires_grad = False

        optimizer_G.zero_grad()
        patch_fake = discriminator(sar, fake)
        loss_G, bce_loss_G, img_loss_G, per_loss_G = generator_loss(fake, col, patch_fake)
        loss_G.backward()
        # torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
        optimizer_G.step()

        # progress_bar.set_postfix({
        #     "G Loss": f'{loss_G.item():.6f}',
        #     "D Loss": f'{loss_D.item():.6f}',
        # })

        # if (i + 1) % 1 == 0:
        if (i + 1) % (STEPS_PER_EPOCH // 5) == 0:
            print(
                f"Batch [{i+1:03}/{STEPS_PER_EPOCH:03}] | "
                f"Loss_D: {loss_D.item():8.6f} Loss_G: {loss_G.item():9.6f} | "
                f"Valid_D: {valid_loss_D:8.6f} Fake_D: {fake_loss_D:8.6f} | "
                f"BCE_G: {bce_loss_G:8.6f} Img_G: {img_loss_G:8.6f} Per_G: {per_loss_G:8.6f}"
            )

        loss_Gs.append(loss_G.item())
        loss_Ds.append(loss_D.item())

        epoch_loss_D += loss_D.item()
        epoch_loss_G += loss_G.item()

    epoch_loss_D /= STEPS_PER_EPOCH
    epoch_loss_G /= STEPS_PER_EPOCH

    print('-' * 52)
    print(
        f"Epoch [{epoch:03}/{LAST_EPOCH:03}] | "
        f"Loss_D: {epoch_loss_D:8.6f} Loss_G: {epoch_loss_G:9.6f} (Epoch Average)\n"
    )

    # if epoch % 3 == 0:
    #     visualize_results(sar[0], fake[0], col[0])

    checkpoint = {
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
        'loss_G': loss_G.item(),
        'loss_D': loss_D.item(),
    }

    if epoch % 5 == 0:
        checkpoint_path = os.path.join(CHECKPOINT_DIR, f'checkpoint_epoch_{epoch:03}.pth')
        torch.save(checkpoint, checkpoint_path)

        if epoch > 29:
            checkpoint_path_del = os.path.join(CHECKPOINT_DIR, f'checkpoint_epoch_{epoch - 25:03}.pth')
            os.remove(checkpoint_path_del)


 Epoch [001/100]
 ----------------------------------------------------
 Batch [025/125] | Loss_D: 0.751097 Loss_G: 40.942726 | Valid_D: 0.765656 Fake_D: 0.736539 | BCE_G: 0.787108 Img_G: 0.327095 Per_G: 0.372304
 Batch [050/125] | Loss_D: 0.714026 Loss_G: 40.251263 | Valid_D: 0.716103 Fake_D: 0.711949 | BCE_G: 0.787265 Img_G: 0.323352 Per_G: 0.356439
 Batch [075/125] | Loss_D: 0.645039 Loss_G: 39.539150 | Valid_D: 0.664948 Fake_D: 0.625130 | BCE_G: 0.866222 Img_G: 0.317622 Per_G: 0.345538
 Batch [100/125] | Loss_D: 0.550900 Loss_G: 38.552498 | Valid_D: 0.538051 Fake_D: 0.563748 | BCE_G: 0.959976 Img_G: 0.308306 Per_G: 0.338096
 Batch [125/125] | Loss_D: 0.562957 Loss_G: 39.643475 | Valid_D: 0.618244 Fake_D: 0.507670 | BCE_G: 1.059304 Img_G: 0.312107 Per_G: 0.368671
 ----------------------------------------------------
 Epoch [001/100] | Loss_D: 0.682885 Loss_G: 40.915904 (Epoch Average)
 

 Epoch [002/100]
 ----------------------------------------------------
 Batch [025/125] | Loss_D

In [None]:
# from torchviz import make_dot

# dot = make_dot(loss_D, params=dict(discriminator.named_parameters()))
# dot.render("gradient_flow_D", format="png")
# dot.view()

# dot = make_dot(loss_G, params=dict(generator.named_parameters()))
# dot.render("gradient_flow_G", format="png")
# dot.view()

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(loss_Gs)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Generator Loss')
# plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(loss_Ds)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Discriminator Loss')
# plt.legend()
plt.show()

In [None]:
from google.colab import runtime
runtime.unassign()