In [None]:
import torch
import torch.nn as nn

import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import os
import torch
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch
from torchvision.utils import save_image
from skimage import transform

In [None]:
class Load_data(Dataset):

    def __init__(self,root):
        self.root = root
        list_files = os.listdir(self.root)
        self.n_samples = list_files
        self.img_conversion = A.Compose(
            [
                A.Resize(width=256, height=256),
                ToTensorV2()
            ]
        )

    def __len__(self):
        return len(self.n_samples)

    def __getitem__(self,idx):
        # try:
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image_name = self.n_samples[idx]
        #print(self.n_samples)
        SAR_image_path = os.path.join(self.root,image_name)
        RGB_image_path = SAR_image_path.replace('/s1', '/s2').replace("_s1_","_s2_")
        SAR_image = np.asarray(Image.open(SAR_image_path).convert('RGB'))
        RGB_image = np.asarray(Image.open(RGB_image_path).convert('RGB'))
        height, width,_ = SAR_image.shape
        dtype=torch.float64

        # img_res = (256,256)

        # SAR_image = torch.from_numpy(transform.resize(SAR_image, img_res)).permute(2, 0, 1).double()
        # RGB_image = torch.from_numpy(transform.resize(RGB_image, img_res)).permute(2, 0, 1).double()


        SAR_image = self.img_conversion(image = SAR_image)["image"].float()
        RGB_image = self.img_conversion(image = RGB_image)["image"].float()

        if np.random.random() < 0.5:
            SAR_image = torch.from_numpy(np.fliplr(SAR_image).copy())
            RGB_image = torch.from_numpy(np.fliplr(RGB_image).copy())

        return (SAR_image, RGB_image)
        # except:
        #     if torch.is_tensor(idx):
        #         idx = idx.tolist()
        #     image_name = self.n_samples[idx]
        #     #print(self.n_samples)
        #     image_path = os.path.join(self.root,image_name)
        #     print('Exception in main:', image_path)
        #     pass

In [None]:
def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(DEVICE), y.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    gen.train()


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=1, bias=False, padding_mode="reflect"),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=1, bias=False, padding_mode="reflect"),
            nn.ReLU(),
            nn.Dropout(0.5)
        )

    def forward(self,x):
        x = self.conv(x)
        return x



class Generator(nn.Module):
    def __init__(self, in_channels=3, nb_filter=[8, 16, 32, 64, 128]):
        super().__init__()

#         self.initial_down = nn.Sequential(
#             nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
#             nn.LeakyReLU(0.2)
#         )

        #X11 block
        self.conv1_1 = Block(in_channels, nb_filter[0])
        self.pool1_1 = nn.MaxPool2d(kernel_size=2, stride=2)

        #X21 block
        self.conv2_1 = Block(nb_filter[0], nb_filter[1])
        self.pool2_1 = nn.MaxPool2d(kernel_size=2, stride=2)

        #X12 block
        self.up1_2 = nn.ConvTranspose2d(nb_filter[1], nb_filter[0], kernel_size=(2, 2), stride=(2, 2), bias=False)
        self.conv1_2 = Block(nb_filter[0]*2, nb_filter[0])

        #X31 block
        self.conv3_1 = Block(nb_filter[1], nb_filter[2])
        self.pool3_1 = nn.MaxPool2d(kernel_size=2, stride=2)

        #X22 block
        self.up2_2 = nn.ConvTranspose2d(nb_filter[2], nb_filter[1], kernel_size=(2, 2), stride=(2, 2), bias=False)
        self.conv2_2 = Block(nb_filter[1]*2, nb_filter[1])

        #X13 block
        self.up1_3 = nn.ConvTranspose2d(nb_filter[1], nb_filter[0], kernel_size=(2, 2), stride=(2, 2), bias=False)
        self.conv1_3 = Block(nb_filter[0]*3, nb_filter[0])

        #X41 block
        self.conv4_1 = Block(nb_filter[2], nb_filter[3])
        self.pool4_1 = nn.MaxPool2d(kernel_size=2, stride=2)

        #X32 block
        self.up3_2 = nn.ConvTranspose2d(nb_filter[3], nb_filter[2], kernel_size=(2, 2), stride=(2, 2), bias=False)
        self.conv3_2 = Block(nb_filter[2]*2, nb_filter[2])

        #X23 block
        self.up2_3 = nn.ConvTranspose2d(nb_filter[2], nb_filter[1], kernel_size=(2, 2), stride=(2, 2), bias=False)
        self.conv2_3 = Block(nb_filter[1]*3, nb_filter[1])

        #X14 block
        self.up1_4 = nn.ConvTranspose2d(nb_filter[1], nb_filter[0], kernel_size=(2, 2), stride=(2, 2), bias=False)
        self.conv1_4 = Block(nb_filter[0]*4, nb_filter[0])

        #X51 block
        self.conv5_1 = Block(nb_filter[3], nb_filter[4])

        #X42 block
        self.up4_2 = nn.ConvTranspose2d(nb_filter[4], nb_filter[3], kernel_size=(2, 2), stride=(2, 2), bias=False)
        self.conv4_2 = Block(nb_filter[3]*2, nb_filter[3])

        #X33 block
        self.up3_3 = nn.ConvTranspose2d(nb_filter[3], nb_filter[2], kernel_size=(2, 2), stride=(2, 2), bias=False)
        self.conv3_3 = Block(nb_filter[2]*3, nb_filter[2])

        #X24 block
        self.up2_4 = nn.ConvTranspose2d(nb_filter[2], nb_filter[1], kernel_size=(2, 2), stride=(2, 2), bias=False)
        self.conv2_4 = Block(nb_filter[1]*4, nb_filter[1])

        #X15 block
        self.up1_5 = nn.ConvTranspose2d(nb_filter[1], nb_filter[0], kernel_size=(2, 2), stride=(2, 2), bias=False)
        self.conv1_5 = Block(nb_filter[0]*5, nb_filter[0])

        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(nb_filter[0], in_channels, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )



    def forward(self,x):

        conv1_1 = self.conv1_1(x)
        pool1_1 = self.pool1_1(conv1_1)
        conv2_1 = self.conv2_1(pool1_1)
        pool2_1 = self.pool2_1(conv2_1)
        up1_2 = self.up1_2(conv2_1)
        conv1_2 = self.conv1_2(torch.cat([up1_2, conv1_1], 1))
        conv3_1 = self.conv3_1(pool2_1)
        pool3_1 = self.pool3_1(conv3_1)
        up2_2 = self.up2_2(conv3_1)
        conv2_2 = self.conv2_2(torch.cat([up2_2, conv2_1], 1))
        up1_3 = self.up1_3(conv2_2)
        conv1_3 = self.conv1_3(torch.cat([up1_3, conv1_1, conv1_2], 1))
        conv4_1 = self.conv4_1(pool3_1)
        pool4_1 = self.pool4_1(conv4_1)
        up3_2 = self.up3_2(conv4_1)
        conv3_2 = self.conv3_2(torch.cat([up3_2, conv3_1], 1))
        up2_3 = self.up2_3(conv3_2)
        conv2_3 = self.conv2_3(torch.cat([up2_3, conv2_1, conv2_2], 1))
        up1_4 = self.up1_4(conv2_3)
        conv1_4 = self.conv1_4(torch.cat([up1_4, conv1_1, conv1_2, conv1_3], 1))
        conv5_1 = self.conv5_1(pool4_1)
        up4_2 = self.up4_2(conv5_1)
        conv4_2 = self.conv4_2(torch.cat([up4_2, conv4_1], 1))
        up3_3 = self.up3_3(conv4_2)
        conv3_3 = self.conv3_3(torch.cat([up3_3, conv3_1, conv3_2], 1))
        up2_4 = self.up2_4(conv3_3)
        conv2_4 = self.conv2_4(torch.cat([up2_4, conv2_1, conv2_2, conv2_3], 1))
        up1_5 = self.up1_5(conv2_4)
        conv1_5 = self.conv1_5(torch.cat([up1_5, conv1_1, conv1_2, conv1_3, conv1_4], 1))

        return self.final_up(conv1_5)


In [None]:
class CNN_Block(nn.Module):
    def __init__(self,in_channels,out_channels,stride=2):
        super().__init__()
        self.conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 4, stride, bias=False, padding_mode="reflect"),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2)
            )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features = [64,128,256,512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels*2, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2)
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                CNN_Block(in_channels, feature, stride=1 if feature==features[-1] else 2)

            )
            in_channels = feature
        layers.append(
            nn.Conv2d(
                in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            )
        )

        self.model = nn.Sequential(*layers)

    def forward(self,x,y):
        ### X = Correct Satellite Image
        ### Y = Correct/Fake Image

        x = torch.cat([x,y],dim=1)
        x = self.initial(x)
        return self.model(x)

In [None]:
torch.backends.cudnn.benchmark = True
Gen_loss = []
Dis_loss = []

In [None]:
def train(netG, netD, train_dl, OptimizerG, OptimizerD, L1_Loss, BCE_Loss):
    loop = tqdm(train_dl)
    for idx, (x,y) in enumerate(loop):
        x = x.cuda()
        y = y.cuda()
        ############## Train Discriminator ##############
        #with torch.cuda.amp.autocast():
        y_fake = netG(x)
        D_real = netD(x,y)
        D_real_loss = BCE_Loss(D_real, torch.ones_like(D_real))
        D_fake = netD(x,y_fake.detach())
        D_fake_loss = BCE_Loss(D_fake, torch.zeros_like(D_fake))
        D_loss = (D_real_loss + D_fake_loss)/2

        netD.zero_grad()
        Dis_loss.append(D_loss.item())
        D_loss.backward()
        #D_Scaler.scale(D_loss).backward()
        OptimizerD.step()
        #D_Scaler.step(OptimizerD)
        #D_Scaler.update()

        ############## Train Generator ##############
        #with torch.cuda.amp.autocast():
        D_fake = netD(x, y_fake)
        G_fake_loss = BCE_Loss(D_fake, torch.ones_like(D_fake))
        L1 = L1_Loss(y_fake,y) * L1_LAMBDA
        G_loss = G_fake_loss + L1

        OptimizerG.zero_grad()
        Gen_loss.append(G_loss.item())
        G_loss.backward()
        #G_Scaler.scale(G_loss).backward()
        #G_Scaler.step(OptimizerG)
        OptimizerG.step()
        #G_Scaler.update()

        if idx % 10 == 0:
            loop.set_postfix(
                D_real=torch.sigmoid(D_real).mean().item(),
                D_fake=torch.sigmoid(D_fake).mean().item(),
            )

In [None]:
def main():
    netD = Discriminator(in_channels=3).cuda()
    netG = Generator(in_channels=3).cuda()
    OptimizerD = torch.optim.Adam(netD.parameters(),lr=LEARNING_RATE,betas=(BETA1,0.999))
    OptimizerG = torch.optim.Adam(netG.parameters(),lr=LEARNING_RATE,betas=(BETA1,0.999))
    BCE_Loss = nn.BCEWithLogitsLoss()
    L1_Loss = nn.L1Loss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN,netG,OptimizerG,LEARNING_RATE
        )
        load_checkpoint(
            CHECKPOINT_DISC,netD,OptimizerD,LEARNING_RATE
        )

    train_dataset = Load_data(root=TRAIN_DIR)
    train_dl = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=NUM_WORKERS,pin_memory=True)
    # G_Scaler = torch.cuda.amp.GradScaler()
    # D_Scaler = torch.cuda.amp.GradScaler()
    val_dataset = Load_data(root=VAL_DIR)
    val_dl = DataLoader(val_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=NUM_WORKERS,pin_memory=True)


    for epoch in range(NUM_EPOCHS):
        train(
            netG, netD, train_dl, OptimizerG, OptimizerD, L1_Loss, BCE_Loss
        )
        #Generator_loss.append(g_loss.item())
        #Discriminator_loss.append(d_loss.item())
        if SAVE_MODEL and epoch%50==0:
            save_checkpoint(netG, OptimizerG, filename=CHECKPOINT_GEN)
            save_checkpoint(netD, OptimizerD, filename=CHECKPOINT_DISC)
        if epoch%2==0:
            save_some_examples(netG, val_dl, epoch, folder="evaluation")

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = './Datasets/v_2/agri/s1'
VAL_DIR = './Datasets/v_2/agri/s1'
LEARNING_RATE = 2e-4
BETA1 = 0.5
BATCH_SIZE = 15
NUM_WORKERS = 2
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
LAMBDA_GP = 10
NUM_EPOCHS = 800
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"

In [None]:
main()

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(Gen_loss,label="Generator")
plt.plot(Dis_loss,label="Discriminator")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()