In [1]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import pandas as pd
import numpy as np
import cv2
import os
import torch
import albumentations as A
import imageio.v2 as imageio
import matplotlib.pyplot as plt
from dataset import *
from model import *
from torchsummary import summary
from torch_snippets import *
DF_PATH = "metadata.csv"
DEVICE= torch.device("cuda" if torch.cuda.is_available() else "cpu")
df = pd.read_csv(DF_PATH)

In [2]:
df_train = df[:3680]
len(df_train)

3680

In [3]:
train_dataset = SAR2OpticalDataset(df_train, train=True, optical_rgb=True , device=DEVICE)
train_dl = DataLoader(train_dataset, batch_size=16, shuffle=True, drop_last=False)

In [4]:
generator = GeneratorUNet(1,3).to(DEVICE)
discriminator = Discriminator(1).to(DEVICE)
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
lambda_pixel = 200

### Load Saved Models

In [5]:
# generator = GeneratorUNet(1,1).to(DEVICE)
# discriminator = Discriminator(1).to(DEVICE)
# epochs_performed = 2
# model_path = f'models/generator_model_{epochs_performed}_epochs.pth'
# check_point = torch.load(model_path)
# generator.load_state_dict(check_point['model_state_dict'])
# g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [5]:
def discriminator_train_step(real_src, real_trg, fake_trg):
    d_optimizer.zero_grad()

    prediction_real = discriminator(real_trg, real_src)
    error_real = criterion_GAN(prediction_real, torch.ones(len(real_src), 1, 32, 32).cuda())
    error_real.backward()

    prediction_fake = discriminator(fake_trg.detach(), real_src)
    error_fake = criterion_GAN(prediction_fake, torch.zeros(len(real_src), 1, 32, 32).cuda())
    error_fake.backward()

    d_optimizer.step()

    return error_real + error_fake

def generator_train_step(real_src, fake_trg):
    g_optimizer.zero_grad()
    prediction = discriminator(fake_trg, real_src)

    loss_GAN = criterion_GAN(prediction, torch.ones(len(real_src), 1, 32, 32).cuda())
    loss_pixel = criterion_pixelwise(fake_trg, real_trg)
    loss_G = loss_GAN + lambda_pixel * loss_pixel

    loss_G.backward()
    g_optimizer.step()
    return loss_G


In [6]:
epochs = 1
log = Report(epochs)

for epoch in range(epochs):
    N = len(train_dl)
    for bx, batch in enumerate(train_dl):
        real_src, real_trg = batch
        fake_trg = generator(real_src)
        
        errD = discriminator_train_step(real_src, real_trg, fake_trg)
        errG = generator_train_step(real_src, fake_trg)
        log.record(pos=epoch+(1+bx)/N, errD=errD.item(), errG=errG.item(), end='\r')

    log.report_avgs(epoch+1)

RuntimeError: Given groups=1, weight of size [64, 2, 4, 4], expected input[16, 4, 512, 512] to have 2 channels, but got 4 channels instead

### Save Generator Model

In [None]:
epochs_performed = 
model_path = f'models/generator_model_{epochs_performed}_epochs.pth'
torch.save({
            'model_state_dict': generator.state_dict(),
            'optimizer_state_dict': g_optimizer.state_dict(),
            'loss': errG,
            }, model_path)