# Making a Pix2Pix GAN to make Sattelite images from Google Maps images
### I'll maybe try to replace the training process with the training of a WGAN bc I really hate the way 'Vanilla' DCGANS converge
### I try to follow the paper's instructions as close as possible but since I find some things unclear I sometimes improvise

In [1]:
import torch 
import torch.nn as nn 
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import sklearn.model_selection as ms
from tqdm import tqdm,trange
import albumentations as A
from torchvision import transforms
import torch.optim as optim
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display
import os 
import cv2

In [2]:
class Hyperparams:
    num_epochs = 20
    discriminator_lr = 5e-4
    generator_lr = 5e-4
    num_latent_features = 100

    l1_lambda = 100
    

    img_shape = (256,256,3)
    
    discriminator_steps = 1
    grad_penalty_lambda = 10 
    
    normalise_transform = transforms.Compose([
        transforms.Normalize(mean=(0.5), std=(0.5))
        ])

    augment_transform = A.Compose([
        A.HorizontalFlip(p =0.4),
        A.VerticalFlip(p =0.4),
        A.RandomRotate90(p =0.4),
        ], p=0.6)

In [3]:
class SatelliteDataset(Dataset):
    def __init__(self, metadata_df, normalise_transform = Hyperparams.normalise_transform, train = True):
        self.metadata_df = metadata_df
        self.normalise_transform = normalise_transform
        self.resize = A.Resize(Hyperparams.img_shape[0], Hyperparams.img_shape[1],interpolation=cv2.INTER_LANCZOS4, always_apply=True)
        self.train = train

    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self,idx):
        idx = int(idx)

        sat_path = self.metadata_df.loc[idx, 'sat_path']
        map_path = self.metadata_df.loc[idx, 'map_path']

        sat = np.load(sat_path) / 255
        map = np.load(map_path) / 255

        
        resized = self.resize(image=sat, mask=map)
        sat = resized["image"]
        map = resized["mask"]

        sat = torch.tensor(sat).permute(2,0,1)
        map = torch.tensor(map).permute(2,0,1)
        
        if self.normalise_transform:
            sat = self.normalise_transform(sat.float())
            map = self.normalise_transform(map.float())

        image = map
        label = sat

        return image, label

In [4]:
metadata = pd.read_csv('data/SATELLITE/metadata.csv')
metadata = metadata.iloc[5:].reset_index(drop=True)
fixed_validate_dataset = metadata.iloc[:5]


train_metadata, valid_metadata = ms.train_test_split(metadata, test_size=0.2, train_size=0.8, random_state=19, shuffle=True)

In [5]:
train_dataset = SatelliteDataset(train_metadata.reset_index(),train = True)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=64)

valid_dataset = SatelliteDataset(valid_metadata.reset_index(), train = False)
valid_loader = DataLoader(valid_dataset, shuffle=False, batch_size=64)

In [6]:
class PrintShape(nn.Module):
    def forward(self, x):
        print(x.shape)
        return x


#The implementation follows the appendix of the Pix2pix paper without the final down and first up conv bc of batch norm with 1x1 inputs (I could use Instance Norm instead)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.down_conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2)
        )
        self.down_conv2 = self.make_down_conv(64, 128)
        self.down_conv3 = self.make_down_conv(128, 256)
        self.down_conv4 = self.make_down_conv(256, 512)
        self.down_conv5 = self.make_down_conv(512, 512)
        self.down_conv6 = self.make_down_conv(512, 512)
        self.down_conv7 = self.make_down_conv(512, 512)

        self.up_conv7 = self.make_up_conv(512, 512, dropout=False)
        self.up_conv6 = self.make_up_conv(1024, 512)
        self.up_conv5 = self.make_up_conv(1024, 512)
        self.up_conv4 = self.make_up_conv(1024, 256, dropout=False)
        self.up_conv3 = self.make_up_conv(512, 128, dropout=False)
        self.up_conv2 = self.make_up_conv(256, 64, dropout=False)
        self.up_conv1 = self.make_up_conv(128, 64, dropout=False)

        self.final_conv = nn.Conv2d(64, 3, kernel_size=1)

    def make_down_conv(self, in_channels, out_channels):
        conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.BatchNorm2d(out_channels)
        )
        return conv

    def make_up_conv(self, in_channels, out_channels, dropout=True):
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        ]
        if dropout:
            layers.append(nn.Dropout(0.5))
        return nn.Sequential(*layers)

    def forward(self, x):
        x1 = self.down_conv1(x)
        x2 = self.down_conv2(x1)
        x3 = self.down_conv3(x2)
        x4 = self.down_conv4(x3)
        x5 = self.down_conv5(x4)
        x6 = self.down_conv6(x5)
        x7 = self.down_conv7(x6)

        x7 = self.up_conv7(x7)
        x7 = torch.cat([x7, x6], 1)

        x6 = self.up_conv6(x7)
        x6 = torch.cat([x6, x5], 1)

        x5 = self.up_conv5(x6)
        x5 = torch.cat([x5, x4], 1)

        x4 = self.up_conv4(x5)
        x4 = torch.cat([x4, x3], 1)

        x3 = self.up_conv3(x4)
        x3 = torch.cat([x3, x2], 1)

        x2 = self.up_conv2(x3)
        x2 = torch.cat([x2, x1], 1)

        x1 = self.up_conv1(x2)

        out = self.final_conv(x1)
        return out



class Discriminator(nn.Module):
    
    def make_down_conv(self, in_channels, out_channels, apply_batchnorm=True):
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2)
        ]
        
        if apply_batchnorm:
            layers.append(nn.BatchNorm2d(out_channels))
        
        return nn.Sequential(*layers)
    
    def __init__(self):
        super(Discriminator, self).__init__()

        self.down_conv1 = self.make_down_conv(6, 64, apply_batchnorm=False)  # 6 channels because we concatenate generator input and fake images
        self.down_conv2 = self.make_down_conv(64, 128)
        self.down_conv3 = self.make_down_conv(128, 256)
        self.down_conv4 = self.make_down_conv(256, 512)
        
        self.last_conv = nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1)

    def forward(self, real_image, fake_image):
        x = torch.cat([real_image, fake_image], dim=1)
        
        x = self.down_conv1(x)
        x = self.down_conv2(x)
        x = self.down_conv3(x)
        x = self.down_conv4(x)
        
        x = self.last_conv(x)
        
        return torch.sigmoid(x)



#Init all weights by a value sampled from a (0,0.02) Normal distrib (Very important for Vanilla GAN convergence)
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1 or classname.find('Instance') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [7]:
def train(train_loader, generator, discriminator, device, g_optimizer, d_optimizer,criterion,epoch):
    generator = generator.to(device)
    discriminator = discriminator.to(device)

    d_losses = []
    g_losses = []
        
    for index, batch in enumerate(tqdm(train_loader, total = len(train_loader))):

        #First Step -> Discriminator training
        discriminator.train()
        generator.eval()
        
        images, labels = batch
        images, real_labels = images.to(device), labels.to(device)

        batch_size = images.shape[0]

        for _ in range(Hyperparams.discriminator_steps):
            with torch.no_grad():
                fake_labels = generator(images)

            discriminator_real_score_matrix = discriminator(torch.cat([images, real_labels]))
            discriminator_fake_score_matrix = discriminator(torch.cat([images, fake_labels]))

            discriminator_real_loss = criterion(discriminator_real_score_matrix, torch.ones((batch_size, 1 , 15, 15)).to(device))
            discriminator_fake_loss = criterion(discriminator_fake_score_matrix, torch.zeros((batch_size, 1 , 15, 15)).to(device))
            discriminator_loss = (discriminator_real_loss + discriminator_fake_loss) / 2

            d_losses.append(discriminator_loss.item())
            d_optimizer.zero_grad()
            discriminator_loss.backward()
            d_optimizer.step()


        #Second Step -> Generator training
        discriminator.eval()
        generator.train()
        
        fake_labels = generator(images)
        
        l1criterion = nn.L1Loss()
        l1loss = l1criterion(fake_labels,real_labels)
        
        credibility = discriminator(fake_labels)
        expected_credibility = torch.ones_like(credibility)
        
        generator_loss = criterion(credibility,expected_credibility) + Hyperparams.l1_lambda * l1loss
        
        g_losses.append(generator_loss.item())
        g_optimizer.zero_grad()
        generator_loss.backward()
        g_optimizer.step()
    
    mean_d_loss = np.mean(d_losses)
    mean_g_loss = np.mean(g_losses)
    print(f'Training Discriminator loss : {mean_d_loss} | Training Generator loss : {mean_g_loss}')

In [8]:
def validate_with_labels(valid_loader, generator, discriminator, device, fixed_input_images, fixed_actual_labels, epoch):
    discriminator.eval()
    generator.eval()
    all_preds = torch.tensor([]).to(device)
    all_expected = torch.tensor([]).to(device)

    with torch.no_grad():
        for index, batch in enumerate(tqdm(valid_loader, total=len(valid_loader))):
    
            # First Step -> Validate Discriminator
            images, real_labels = batch
            images, real_labels = images.to(device), real_labels.to(device)

            fake_labels = generator(images)
            
            combined_images = torch.cat([images, images], 0)
            combined_labels = torch.cat([real_labels, fake_labels], 0)
            
            discriminator_outputs = discriminator(combined_images, combined_labels).view(-1)
            
            ones = torch.ones(real_labels.shape[0]).to(device)
            zeros = torch.zeros(fake_labels.shape[0]).to(device)
            expected_credibility = torch.cat([ones, zeros], 0)
            
            all_preds = torch.cat([all_preds, discriminator_outputs], 0)
            all_expected = torch.cat([all_expected, expected_credibility], 0)
    
        accuracy = ((all_preds > 0.5) == all_expected).float().mean().item()
        print(f'Valid Discriminator accuracy : {accuracy}')
    
        # Second Step -> Validate Generator
        fixed_input_images = fixed_input_images.to(device)
        fake_labels = generator(fixed_input_images)
        
        fig, axes = plt.subplots(len(fixed_input_images), 3, figsize=(12, 3 * len(fixed_input_images)))
        
        for idx, (real_img, actual_label, fake_img) in enumerate(zip(fixed_input_images, fixed_actual_labels, fake_labels)):
            axes[idx, 0].imshow(real_img.permute(1, 2, 0).cpu().numpy())
            axes[idx, 0].set_title("Input")
            axes[idx, 0].axis("off")
            
            axes[idx, 1].imshow(actual_label.permute(1, 2, 0).cpu().numpy())
            axes[idx, 1].set_title("Actual")
            axes[idx, 1].axis("off")
            
            axes[idx, 2].imshow(fake_img.permute(1, 2, 0).cpu().numpy())
            axes[idx, 2].set_title("Generated")
            axes[idx, 2].axis("off")
        
        save_dir = "saved_plots/SATELLITE_PIX2PIX"
        os.makedirs(save_dir, exist_ok=True)
    
        plot_filename = os.path.join(save_dir, f"plot_epoch_{epoch}.png")
        plt.tight_layout()
        plt.savefig(plot_filename)
        plt.show()
    
    return accuracy, fig

In [9]:
fixed_input_images = torch.tensor([])
fixed_actual_labels = torch.tensor([])

resize =  A.Resize(Hyperparams.img_shape[0], Hyperparams.img_shape[1],interpolation=cv2.INTER_LANCZOS4, always_apply=True)
for index,row in fixed_validate_dataset.iterrows():
    sat_path = row['sat_path']
    map_path = row['map_path']


    sat = np.load(sat_path) / 255
    map = np.load(map_path) / 255

    
    resized = resize(image=sat, mask=map)
    sat = resized["image"]
    map = resized["mask"]

    sat = torch.tensor(sat).permute(2,0,1)
    map = torch.tensor(map).permute(2,0,1)

    sat = Hyperparams.normalise_transform(sat.float()).unsqueeze(0)
    map = Hyperparams.normalise_transform(map.float()).unsqueeze(0)

    fixed_input_images = torch.cat([fixed_input_images, map], 0)
    fixed_actual_labels = torch.cat([fixed_actual_labels, sat], 0)

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

generator = Generator().to(device)
generator.apply(weights_init)
discriminator = Discriminator().to(device)
discriminator.apply(weights_init)

criterion = nn.BCELoss()

fixed_input_images = fixed_input_images.to(device)
fixed_actual_labels = fixed_actual_labels.to(device)

d_optimizer = optim.Adam(discriminator.parameters(), lr=Hyperparams.discriminator_lr, betas=(0.5, 0.9))
g_optimizer = optim.Adam(generator.parameters(), lr=Hyperparams.generator_lr, betas=(0.5, 0.9))

In [None]:
for epoch in range(Hyperparams.num_epochs):
    print(f'Epoch {epoch +1}/{Hyperparams.num_epochs}')
    train(train_loader, generator, discriminator, device, g_optimizer, d_optimizer,criterion,epoch)
    accuracy = validate(valid_loader, generator, discriminator, device, fixed_input_images, epoch)
    print(f'-' *100)

Epoch 1/20


  0%|          | 0/110 [00:00<?, ?it/s]