In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
from PIL import Image
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
class Down_ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.down_conv = nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False, padding_mode='reflect')  
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.LeakyReLU(negative_slope=0.2) 

    def forward(self, x):
        out = self.down_conv(x)
        out = self.bn(out)
        out = self.act(out) 

        return out
    

class Up_ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.up_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU()
        self.do = nn.Dropout2d(p=0.5)

    def forward(self, x):
        out = self.up_conv(x)
        out = self.bn(out)
        out = self.act(out)
        out = self.do(out)
        
        return out
    

class UNet_Generator(nn.Module):
    def __init__(self, input_channels=3, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.first_down = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1, bias=False, padding_mode='reflect'), 
            nn.LeakyReLU(negative_slope=0.2)
        )
        
        self.down2 = Down_ConvBlock(64, 128)
        self.down3 = Down_ConvBlock(128, 256)
        self.down4 = Down_ConvBlock(256, 512)
        self.down5 = Down_ConvBlock(512, 512)
        self.down6 = Down_ConvBlock(512, 512)
        self.down7 = Down_ConvBlock(512, 512)

        self.last_down = Down_ConvBlock(512, 512)

        self.first_up = Up_ConvBlock(512, 512)

        self.up2 = Up_ConvBlock(512*2, 512)
        self.up3 = Up_ConvBlock(512*2, 512)
        self.up4 = Up_ConvBlock(512*2, 512)
        self.up5 = Up_ConvBlock(512*2, 256)
        self.up6 = Up_ConvBlock(256*2, 128)
        self.up7 = Up_ConvBlock(128*2, 64)

        self.last_up = nn.Sequential(
            nn.ConvTranspose2d(64*2, input_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
    
    def forward(self, x):
        down1 = self.first_down(x)

        down2 = self.down2(down1)
        down3 = self.down3(down2)
        down4 = self.down4(down3)
        down5 = self.down5(down4)
        down6 = self.down6(down5)
        down7 = self.down7(down6)

        bottleneck = self.last_down(down7)
        
        up1 = self.first_up(bottleneck)

        up2 = self.up2(torch.cat([down7, up1], dim=1))
        up3 = self.up3(torch.cat([down6, up2], dim=1))
        up4 = self.up4(torch.cat([down5, up3], dim=1))
        up5 = self.up5(torch.cat([down4, up4], dim=1))
        up6 = self.up6(torch.cat([down3, up5], dim=1))
        up7 = self.up7(torch.cat([down2, up6], dim=1))

        out = self.last_up(torch.cat([down1, up7], dim=1))
        
        return out

x = torch.zeros(8, 3, 256, 256).to(device)
model = UNet_Generator().to(device)
output = model(x)
print(output.shape)

torch.Size([8, 3, 256, 256])


In [4]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1, bias=False, padding_mode='reflect') 
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.LeakyReLU(negative_slope=0.2)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.act(out)

        return out
    

class Patch_Discriminator(nn.Module):
    def __init__(self, input_channels=3, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.first_conv = nn.Sequential(
            nn.Conv2d(input_channels*2, 64, kernel_size=4, stride=2, padding=1, bias=False, padding_mode='reflect'),
            nn.ReLU(),
        )
        self.conv2 = ConvBlock(64, 128, 2)
        self.conv3 = ConvBlock(128, 256, 2)
        self.conv4 = ConvBlock(256, 512, 1)
        self.last_conv = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1, bias=False, padding_mode='reflect')


    def forward(self, x, y):
        out = self.first_conv(torch.cat([x, y], dim=1))
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.last_conv(out)

        return out
    
x = torch.zeros(8, 3, 256, 256).to(device)
y = torch.zeros(8, 3, 256, 256).to(device)
model = Patch_Discriminator().to(device)
output = model(x, y)
output.shape

torch.Size([8, 1, 30, 30])

In [5]:
def init_weights(model):
    std = 2e-2
    for m in model.modules():
        if type(m) in {
            nn.Conv2d,
            nn.ConvTranspose2d,
            nn.Linear,
        }:
            nn.init.normal_(m.weight.data, mean=0.0, std=std)
        if type(m) in {
            nn.BatchNorm2d,
        }:
            nn.init.normal_(m.weight.data, mean=1.0, std=std)
            nn.init.constant_(m.bias.data, val=0.0)

In [6]:
model_d = Patch_Discriminator().to(device)
model_g = UNet_Generator().to(device)
init_weights(model_d), init_weights(model_g)

lr = 2e-4
beta1 = 0.5
beta2 = 0.999
optimizer_d = optim.Adam(model_d.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_g = optim.Adam(model_g.parameters(), lr=lr, betas=(beta1, beta2))

scaler_d = torch.cuda.amp.GradScaler()
scaler_g = torch.cuda.amp.GradScaler()

loss_fns = [nn.BCEWithLogitsLoss(), nn.L1Loss()]

In [7]:
img_size = 256
num_channels = 3
root_path = '/mnt/c/Users/121js/OneDrive/Desktop/TorchImages/maps/'

common_transform = A.Compose(
    [
        A.Resize(height=img_size, width=img_size),
        A.VerticalFlip(p=0.5),
        A.HorizontalFlip(p=0.5),
    ],
    additional_targets={'real_map': 'image'}
)

aerial_transform = A.Compose(
    [
        A.ColorJitter(p=0.5),
        A.Normalize(
            mean=[0.5 for _ in range(num_channels)],
            std=[0.5 for _ in range(num_channels)],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ]
)

map_transform = A.Compose(
    [
        A.Normalize(
            mean=[0.5 for _ in range(num_channels)],
            std=[0.5 for _ in range(num_channels)],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ]
)

class Aerial2Map(Dataset):
    def __init__(self, data_path, common_transform, aerial_transform, map_transform):
        super().__init__()
        self.data_path = data_path
        self.list_filenames = os.listdir(data_path)
        self.common_transform = common_transform
        self.aerial_transform = aerial_transform
        self.map_transform = map_transform
    
    def __len__(self):
        return len(self.list_filenames)
    
    def __getitem__(self, index):
        image_filename = self.list_filenames[index]
        image_path = os.path.join(self.data_path, image_filename)

        image = Image.open(image_path)
        image_arr = np.array(image)

        width = image_arr.shape[1]
        split_at = width // 2
        aerial, real_map = image_arr[:, :split_at, :], image_arr[:, split_at:, :]

        common = self.common_transform(image=aerial, real_map=real_map)
        aerial, real_map = common['image'], common['real_map']
        aerial = self.aerial_transform(image=aerial)['image']
        real_map = self.map_transform(image=real_map)['image']

        return aerial, real_map
    
train_data = Aerial2Map(root_path+'train', common_transform, aerial_transform, map_transform)
val_data = Aerial2Map(root_path+'val', common_transform, aerial_transform, map_transform)

In [8]:
class Trainer:
    def __init__(
            self,
            optimizer_d,
            optimizer_g,
            scaler_d,
            scaler_g,
            model_d,
            model_g,
            loss_fns,
            lambda_coeff,
            device = device
    ):
        self.optimizer_d = optimizer_d
        self.optimizer_g = optimizer_g
        self.scaler_d = scaler_d
        self.scaler_g = scaler_g
        self.model_d = model_d
        self.model_g = model_g
        self.bce_loss_fn, self.l1_loss_fn = loss_fns
        self.lambda_coeff = lambda_coeff
        self.device = device


    def calc_disc_loss(self, aerial, real_map, is_train):
        with torch.set_grad_enabled(is_train):
            real_args = self.model_d(aerial, real_map)
            loss_real = self.bce_loss_fn(real_args, torch.ones_like(real_args))

            fake_map = self.model_g(aerial)
            fake_args = self.model_d(aerial, fake_map)
            loss_fake = self.bce_loss_fn(fake_args, torch.zeros_like(fake_args))

            loss_d = (loss_real + loss_fake) / 2

        return loss_d
    

    def calc_gen_loss(self, aerial, real_map, is_train):
        with torch.set_grad_enabled(is_train):
            fake_map = self.model_g(aerial)
            fake_args = self.model_d(aerial, fake_map)
            loss_fake = self.bce_loss_fn(fake_args, torch.ones_like(fake_args))

            loss_l1 = self.lambda_coeff * self.l1_loss_fn(fake_map, real_map)

            loss_g = loss_fake + loss_l1
        
        return loss_g
    
    
    def calc_metrics(self, metrics_dict, train_loader, val_loader):
        self.model_d.eval(), self.model_g.eval()

        final_str = ''
        loaders_list = [('Train', train_loader), ('Val', val_loader)]
            
        if metrics_dict == None:
            metrics_dict = {'Train': {'DiscLoss': [], 'GenLoss': []}, 'Val': {'DiscLoss': [], 'GenLoss': []}}

        for name, loader in loaders_list:
            len_data = 0
            total_loss_d, total_loss_g = 0, 0

            for aerial, real_map in loader:
                aerial, real_map = aerial.to(device), real_map.to(device)
                batch_size = aerial.shape[0]
                len_data += batch_size

                with torch.cuda.amp.autocast():
                    loss_d = self.calc_disc_loss(aerial, real_map, is_train=False)
                total_loss_d += loss_d

                with torch.cuda.amp.autocast():
                    loss_g = self.calc_gen_loss(aerial, real_map, is_train=False)
                total_loss_g += loss_g
                
            disc_loss = total_loss_d/len_data
            gen_loss = total_loss_g/len_data

            final_str += ' -- {} Disc Loss: {:.5f} -- {} Gen Loss: {:.5f}'.format(name, disc_loss, name, gen_loss)
                
            metrics_dict[name]['DiscLoss'].append(disc_loss.item())
            metrics_dict[name]['GenLoss'].append(gen_loss.item()) 

        self.model_d.train(), self.model_g.train()

        return final_str, metrics_dict
    

    def visualize_tensorboard(self, aerial, real_map, writer_pix2pix, steps):
        self.model_g.eval()

        with torch.no_grad():
            fake_map = self.model_g(aerial)
            combined_grid = torch.cat([
                aerial[:4]*0.5 + 0.5,
                real_map[:4]*0.5 + 0.5,
                fake_map[:4]*0.5 + 0.5
            ], dim=0
            )
            image_grid = make_grid(combined_grid, nrow=4, normalize=False)
            writer_pix2pix.add_image('Fake', image_grid, global_step=steps)

        self.model_g.train()

        return None
    
    
    def fit(self, n_epochs, train_loader, val_loader, writer_pix2pix):
        self.model_d.train(), self.model_g.train()

        metrics_dict = None
        steps = 1
        for epoch in range(1, n_epochs+1):
            for batch_idx, (aerial, real_map) in enumerate(train_loader):
                aerial, real_map = aerial.to(device), real_map.to(device)
                
                with torch.cuda.amp.autocast():
                    loss_d = self.calc_disc_loss(aerial, real_map, is_train=True)
                    
                self.optimizer_d.zero_grad()
                self.scaler_d.scale(loss_d).backward()
                self.scaler_d.step(self.optimizer_d)
                self.scaler_d.update()

                with torch.cuda.amp.autocast():
                    loss_g = self.calc_gen_loss(aerial, real_map, is_train=True)

                self.optimizer_g.zero_grad()
                self.scaler_g.scale(loss_g).backward()
                self.scaler_g.step(self.optimizer_g)
                self.scaler_g.update()

                if batch_idx % 5 == 0:
                    print(f'Epoch: {epoch:2d}/{n_epochs} -- Batch: {batch_idx+1:3d}/{len(train_loader)}' + f' -- Train Disc Loss: {loss_d:.4f} -- Train Gen Loss: {loss_g:.4f}')
                    self.visualize_tensorboard(aerial, real_map, writer_pix2pix, steps)
                    steps += 1
            
            final_str, metrics_dict = self.calc_metrics(metrics_dict, train_loader, val_loader)
            print('Epoch: {:2d}'.format(epoch) + final_str)
            
        self.metrics_dict = metrics_dict
    
        return None

In [9]:
batch_size = 64
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)

In [10]:
lambda_coeff = 100
trainer = Trainer(optimizer_d, optimizer_g, scaler_d, scaler_g, model_d, model_g, loss_fns, lambda_coeff, device)

n_epochs = 50
writer_pix2pix = SummaryWriter('logs/pix2pix')
trainer.fit(n_epochs, train_loader, val_loader, writer_pix2pix)

Epoch:  1/50 -- Batch:   1/18 -- Train Disc Loss: 0.8361 -- Train Gen Loss: 80.7823
Epoch:  1/50 -- Batch:   6/18 -- Train Disc Loss: 0.7163 -- Train Gen Loss: 80.8421
Epoch:  1/50 -- Batch:  11/18 -- Train Disc Loss: 0.6510 -- Train Gen Loss: 70.3852
Epoch:  1/50 -- Batch:  16/18 -- Train Disc Loss: 0.5390 -- Train Gen Loss: 59.2353
