In [1]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import pandas as pd
import numpy as np
import cv2
import os
import torch
import albumentations as A
import imageio.v2 as imageio
import matplotlib.pyplot as plt
from dataset import *
from model import *
from torchsummary import summary
from torch_snippets import *
DF_PATH = "metadata.csv"
DEVICE= torch.device("cuda" if torch.cuda.is_available() else "cpu")
df = pd.read_csv(DF_PATH)

In [2]:
df_train = df[:3680]
len(df_train)

3680

In [3]:
train_dataset = SAR2OpticalDataset(df_train, train=True, optical_rgb=False , device=DEVICE)
train_dl = DataLoader(train_dataset, batch_size=16, shuffle=True, drop_last=False)

In [4]:
generator = GeneratorUNet(1,1).to(DEVICE)
discriminator = Discriminator(1).to(DEVICE)
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
lambda_pixel = 200

### Load Saved Models

In [5]:
# generator = GeneratorUNet(1,1).to(DEVICE)
# discriminator = Discriminator(1).to(DEVICE)
# epochs_performed = 2
# model_path = f'models/generator_model_{epochs_performed}_epochs.pth'
# check_point = torch.load(model_path)
# generator.load_state_dict(check_point['model_state_dict'])
# g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [6]:
def discriminator_train_step(real_src, real_trg, fake_trg):
    d_optimizer.zero_grad()

    prediction_real = discriminator(real_trg, real_src)
    error_real = criterion_GAN(prediction_real, torch.ones(len(real_src), 1, 32, 32).cuda())
    error_real.backward()

    prediction_fake = discriminator(fake_trg.detach(), real_src)
    error_fake = criterion_GAN(prediction_fake, torch.zeros(len(real_src), 1, 32, 32).cuda())
    error_fake.backward()

    d_optimizer.step()

    return error_real + error_fake

def generator_train_step(real_src, fake_trg):
    g_optimizer.zero_grad()
    prediction = discriminator(fake_trg, real_src)

    loss_GAN = criterion_GAN(prediction, torch.ones(len(real_src), 1, 32, 32).cuda())
    loss_pixel = criterion_pixelwise(fake_trg, real_trg)
    loss_G = loss_GAN + lambda_pixel * loss_pixel

    loss_G.backward()
    g_optimizer.step()
    return loss_G


In [7]:
epochs = 10
log = Report(epochs)

for epoch in range(epochs):
    N = len(train_dl)
    for bx, batch in enumerate(train_dl):
        real_src, real_trg = batch
        fake_trg = generator(real_src)
        
        errD = discriminator_train_step(real_src, real_trg, fake_trg)
        errG = generator_train_step(real_src, fake_trg)
        log.record(pos=epoch+(1+bx)/N, errD=errD.item(), errG=errG.item(), end='\r')

    log.report_avgs(epoch+1)

EPOCH: 1.000  errD: 0.824  errG: 114.631  (2679.96s - 24119.68s remaining)
EPOCH: 2.000  errD: 0.213  errG: 107.382  (5193.87s - 20775.48s remaining)
EPOCH: 3.000  errD: 0.189  errG: 100.081  (7688.35s - 17939.49s remaining)
EPOCH: 4.000  errD: 0.130  errG: 94.628  (10183.74s - 15275.61s remaining))
EPOCH: 5.000  errD: 0.107  errG: 91.355  (12634.36s - 12634.36s remaining))
EPOCH: 6.000  errD: 0.122  errG: 89.074  (15060.54s - 10040.36s remaining))
EPOCH: 7.000  errD: 0.085  errG: 87.061  (17493.01s - 7497.00s remaining)))
EPOCH: 8.000  errD: 0.079  errG: 85.609  (19917.15s - 4979.29s remaining))
EPOCH: 9.000  errD: 0.065  errG: 84.434  (22326.58s - 2480.73s remaining))
EPOCH: 10.000  errD: 0.109  errG: 83.141  (24744.53s - 0.00s remaining))))


### Save Generator Model

In [9]:
epochs_performed =10
model_path = f'models/generator_model_{epochs_performed}_epochs.pth'
torch.save({
            'model_state_dict': generator.state_dict(),
            'optimizer_state_dict': g_optimizer.state_dict(),
            'loss': errG,
            }, model_path)

In [11]:
df_val = df[3680:]
val_dataset = SAR2OpticalDataset(df_val, train=True, optical_rgb=False , device=DEVICE)
val_dl = DataLoader(val_dataset, batch_size=20, shuffle=True, drop_last=False)

In [13]:
batch = next(iter(val_dl))
real_src, real_trg = batch
generator.eval()
fake_trg = generator(real_src)
sar_ds_mean = 12.29
sar_ds_st = 5.27
b8_ds_mean = 132.99
b8_ds_st = 37.32

for i in range(20):
    sar_image = real_src[i].detach().cpu().permute(1,2,0).numpy()
    real_b8 = real_trg[i].detach().cpu().permute(1,2,0).numpy()
    fake_b8 = fake_trg[i].detach().cpu().permute(1,2,0).numpy()
    denorm_sar_image = (sar_image * sar_ds_st) + sar_ds_mean
    denorm_real_b8 = (real_b8 * b8_ds_st) + b8_ds_mean
    denorm_fake_b8 = (fake_b8 * b8_ds_st) + b8_ds_mean
    plt.figure(figsize=(20,60))
    plt.subplot(1,3,1)
    plt.title("SAR Image")
    plt.imshow(denorm_sar_image , cmap='gray')
    plt.subplot(1,3,2)
    plt.title("Real Optical B8")
    plt.imshow(denorm_real_b8 , cmap='gray')
    plt.subplot(1,3,3)
    plt.title("Generated Optical B8 ")
    plt.imshow(denorm_fake_b8 , cmap='gray')
    plt.show()

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.50 GiB (GPU 0; 6.00 GiB total capacity; 10.58 GiB already allocated; 0 bytes free; 11.05 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [14]:
generator.state_dict()

OrderedDict([('down1.model.0.weight',
              tensor([[[[-2.8490e-03, -2.6163e-02, -5.3768e-02,  2.6750e-02],
                        [-1.9640e-02,  4.4427e-02,  3.1913e-02,  1.5955e-02],
                        [ 1.8465e-02, -9.3049e-03,  7.8714e-03, -1.1878e-02],
                        [-1.0525e-03,  1.5362e-02,  1.1701e-03, -1.2856e-02]]],
              
              
                      [[[ 1.4695e-02, -3.5305e-02,  6.6357e-03,  2.0663e-02],
                        [-3.7254e-02, -9.1089e-03, -1.1082e-03,  2.2159e-02],
                        [-7.5486e-03, -3.5186e-02,  1.3894e-02,  3.7486e-02],
                        [ 3.9548e-02,  3.3676e-03,  7.8851e-03, -1.2043e-02]]],
              
              
                      [[[ 9.9287e-04, -2.5859e-02, -9.6056e-03,  3.9714e-02],
                        [ 1.9939e-02,  3.9842e-02,  4.1848e-02, -8.4238e-03],
                        [ 1.3302e-02,  1.9267e-02,  1.0345e-02,  1.9936e-02],
                        [ 1.0001e-02,  9

In [15]:
g_optimizer.state_dict()

{'state': {0: {'step': tensor(2300.),
   'exp_avg': tensor([[[[-1.6371, -1.3110, -1.2229, -1.0374],
             [-1.2427, -1.1632, -1.3485, -1.1360],
             [-1.1421, -0.9387, -1.0483, -1.1450],
             [-1.1873, -0.8928, -1.0473, -1.1351]]],
   
   
           [[[-0.2939, -0.5281, -0.5134, -0.5272],
             [-0.2611, -0.3875, -0.5041, -0.4135],
             [-0.3274, -0.3423, -0.5538, -0.4449],
             [-0.3458, -0.3674, -0.4095, -0.2833]]],
   
   
           [[[-0.3720, -0.4002, -0.3568, -0.4193],
             [-0.2950, -0.3566, -0.3082, -0.2610],
             [-0.2143, -0.2229, -0.2540, -0.2133],
             [-0.2861, -0.2928, -0.2902, -0.4119]]],
   
   
           ...,
   
   
           [[[-0.4731, -0.3813, -0.6071, -0.5842],
             [-0.4151, -0.3608, -0.6133, -0.5970],
             [-0.2881, -0.1835, -0.6706, -0.5774],
             [-0.4174, -0.3395, -0.7085, -0.5533]]],
   
   
           [[[-2.3841, -2.0171, -2.1829, -2.1210],
             [-2.295