<a href="https://colab.research.google.com/github/AlexanderLontke/ssl-remote-sensing/blob/main/notebooks/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GAN Implementation Remote Sensing

## Import necessary packages

In [1]:
!pip install rasterio
!pip install wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting rasterio
  Downloading rasterio-1.2.10-cp37-cp37m-manylinux1_x86_64.whl (19.3 MB)
[K     |████████████████████████████████| 19.3 MB 503 kB/s 
Collecting click-plugins
  Downloading click_plugins-1.1.1-py2.py3-none-any.whl (7.5 kB)
Collecting affine
  Downloading affine-2.3.1-py2.py3-none-any.whl (16 kB)
Collecting snuggs>=1.4.1
  Downloading snuggs-1.4.7-py3-none-any.whl (5.4 kB)
Collecting cligj>=0.5
  Downloading cligj-0.7.2-py3-none-any.whl (7.1 kB)
Installing collected packages: snuggs, cligj, click-plugins, affine, rasterio
Successfully installed affine-2.3.1 click-plugins-1.1.1 cligj-0.7.2 rasterio-1.2.10 snuggs-1.4.7
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wandb
  Downloading wandb-0.13.4-py2.py3-none-any.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 27.2 MB/s 
Collecting setproctitle

In [52]:
# import standard python libraries
import os
import datetime as time
from datetime import datetime
import numpy as np
import pickle as pkl
import glob
from pathlib import Path
from math import floor
import random

# import data reader, logging and transforms
from torchvision import transforms
import rasterio as rio
from rasterio.plot import reshape_as_image
import wandb

# import the PyTorch deep learning library
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# import matplotlib and enabling notebook inline plotting:
import matplotlib.pyplot as plt
%matplotlib inline

Mount Google Drive Directories for data access

In [4]:
# import the Google Colab GDrive connector
from google.colab import drive

# mount GDrive inside the Colab notebook
drive.mount('/content/drive')

Mounted at /content/drive


In [21]:
# create Colab Notebooks directory
notebook_directory = Path('/content/drive/MyDrive/Projects/DeepLearning')
if not os.path.exists(notebook_directory): os.makedirs(notebook_directory)

 # create data sub-directory inside the Colab Notebooks directory
data_directory = Path('/content/drive/MyDrive/Projects/DeepLearning/data')
if not os.path.exists(data_directory): os.makedirs(data_directory)

 # create models sub-directory inside the Colab Notebooks directory
models_directory = Path('/content/drive/MyDrive/Projects/DeepLearning/models')
if not os.path.exists(models_directory): os.makedirs(models_directory)

## Helper Functions

In [29]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu').type

print('[LOG] notebook with {} computation enabled'.format(str(device)))

[LOG] notebook with cuda computation enabled


In [34]:
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)
if device == "cuda":
  torch.cuda.manual_seed(SEED)

## Data

Define the directory on your drive to reproduce results. Downloaded data from https://madm.dfki.de/files/sentinel/EuroSATallBands.zip should be within the data folder created in previous steps. 

In [26]:
eurosat_dir = data_directory.joinpath("ds/images/remote_sensing/otherDatasets/sentinel_2/tif")
samples = glob.glob(os.path.join(eurosat_dir, "*", "*.tif"))
len(samples)

27000

In [39]:
classes = {
    0: "AnnualCrop",
    1: "Forest",
    2: "HerbaceousVegetation",
    3: "Highway",
    4: "Industrial",
    5: "Pasture",
    6: "PermanentCrop",
    7: "Residential",
    8: "River",
    9: "SeaLake"
}

In [41]:
class_to_idx = {value:key for key,value in classes.items()}

In [42]:
print(class_to_idx)

{'AnnualCrop': 0, 'Forest': 1, 'HerbaceousVegetation': 2, 'Highway': 3, 'Industrial': 4, 'Pasture': 5, 'PermanentCrop': 6, 'Residential': 7, 'River': 8, 'SeaLake': 9}


Create a user-defined data loader for the EuroSAT data to adjust specifically for a GAN model. This process includes transformations if wanted. 

In [35]:
def get_training_and_testing_sets(file_list):
    split = 0.8
    random.Random(SEED).shuffle(file_list)
    split_index = floor(len(file_list) * split)
    training = file_list[:split_index]
    testing = file_list[split_index:]
    return training, testing

In [37]:
train_df, test_df = get_training_and_testing_sets(samples)

In [43]:
transformer_train = transforms.Compose([
                                  transforms.ToTensor(),
                                  transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                                       std = [0.229, 0.224, 0.225])
])

transformer_test = transforms.Compose([
                                  transforms.ToTensor(),
                                  transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                                       std = [0.229, 0.224, 0.225])
])

In [88]:
class TrainData(Dataset):

    def __init__(self, directories, transform=None):
        self.directories = directories
        self.transform = transform

    def __len__(self):
      return len(self.directories)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        input = self.directories[idx]
        label = input.split('/')[-1].split('_')[0]
        label = class_to_idx[label]
        with rio.open(input, "r") as d:
          image = d.read([4,3,2]).astype(int)
          image = reshape_as_image(image)

        if self.transform:
            image = self.transform(image.astype(float))
            image = image.type(torch.FloatTensor)

        return image, label

In [89]:
trainData = TrainData(directories = train_df, transform = transformer_train)

In [56]:
class TestData(Dataset):
    def __init__(self, directories, transform=None):
        self.directories = directories
        self.transform = transform

    def __len__(self):
      return len(self.directories)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        input = self.directories[idx]
        label = input.split('/')[-1].split('_')[0]
        label = class_to_idx[label]
        with rio.open(input, "r") as d:
          image = d.read([4,3,2]).astype(int)
          image = reshape_as_image(image)

        if self.transform:
            image = self.transform(image.astype(float))

        return image, label

In [57]:
testData = TestData(directories = test_df, transform = transformer_test)

In [60]:
testData[0][0].shape

torch.Size([3, 64, 64])

## Setup GAN 
The GAN architecture is composed of the generative model $G$ and the discriminative model $D$.

The discriminator $D$ is a binary classifier trying to determine whether the input sample $X$ is real or fake. Real pictures come from the EuroSAT dataset whereas fake inputs are generated by generator $G$. Thus, $D$ outputs a scalar which is then transformed to a probability measure using the sigmoid function. 

- 1 - Sample is part of the real dataset
- 0 - Sample is a fake generated by $G$

### Discriminator
Due to the binary classification case of the discriminator we will use the BCE with logits loss because cross entropy not only punishes incorrect but confident predictions but also correct but less confident predictions. Including the logits there is no need to apply the sigmoid activation function in the network. 

In [61]:
# implement the Discriminator network architecture
class Discriminator(nn.Module):

    # define the class constructor
    def __init__(self):

        # call super class constructor
        super(Discriminator, self).__init__()
        
        # specify convoluted layer 1: in 64*64 and three layers
        self.conv1 = nn.Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
        self.bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.activation1 = nn.LeakyReLU(0.2, inplace=True) # the non-linearity
        
        # specify convoluted layer 2: in 68*68*5, out 32*32*1
        self.conv2 = nn.Conv2d(5, 64, kernel_size=(1, 1), stride=(2, 2), padding=(0, 0), bias=False)
        self.activation2 = nn.LeakyReLU(0.2, inplace=True) # the non-linearity

        # specify fc layer 3: in 32*32, out 64
        self.fc3 = nn.Linear(32*32, 64) # the linearity W*x+b
        self.activation3 = nn.LeakyReLU(0.2, inplace=True) # the non-linearity
        
        # specify fc layer 4: in 64, out 1
        self.fc4 = nn.Linear(64, 1) # the linearity W*x+b

        # dropout layer
        self.dropout = nn.Dropout(0.3)
        
    # define network forward pass
    def forward(self, x):

        # define fc layer 1 forward pass and add dropout
        x = self.activation1(self.conv1(x))
        x = self.dropout(x)

        # define fc layer 2 forward pass and add dropout
        x = self.activation2(self.conv2(x))
        x = self.dropout(x)

        # flatten image
        x = x.view(-1, 32*32)

        # define fc layer 3 forward pass and add dropout
        x = self.activation3(self.fc3(x))
        x = self.dropout(x)
        
        # define fc layer 4 forward pass
        out = self.fc4(x)

        # return forward pass result
        return out

### Generator
As the generator is not a discriminative model its aim is to generate data. 
Thus, we draw the latent variable $z \in \mathbb{R}^d$ from a random distribution such as a Gaussian or a uniform distribution. 
Accordingly, $G$ produces the following output: $X' = G(z)$. 
As we want to fool $D$ to fail in distinguishing fake and real remote sensing data we aspire $D(G(z)) ≈ 1$. Hence, the goal is to maximize cross-entropy loss in case $y=0$ (fake data). 

In [62]:
# implement the Generator network architecture
class Generator(nn.Module):

    # define the class constructor
    def __init__(self):

        # call super class constructor
        super(Generator, self).__init__()
        
        # specify fc layer 1: in 100 (from latent space z), out 128
        self.fc1 = nn.Linear(100, 128) # the linearity W*x+b
        self.activation1 = nn.LeakyReLU(0.2, inplace=True) # the non-linearity

        # specify fc layer 2: in 32, out 64
        self.fc2 = nn.Linear(128, 256) # the linearity W*x+b
        self.activation2 = nn.LeakyReLU(0.2, inplace=True) # the non-linearity

        # specify fc layer 3: in 64, out 128
        self.fc3 = nn.Linear(256, 512) # the linearity W*x+b
        self.activation3 = nn.LeakyReLU(0.2, inplace=True) # the non-linearity
        
        # specify fc layer 4: in 128, out 28*28
        self.fc4 = nn.Linear(512, 64*64) # the linearity W*x+b
       
        # dropout layer 
        self.dropout = nn.Dropout(0.3)

    # define network forward pass
    def forward(self, x):

        # define fc layer 1 forward pass and add dropout
        x = self.activation1(self.fc1(x))
        x = self.dropout(x)

        # define fc layer 2 forward pass and add dropout
        x = self.activation2(self.fc2(x))
        x = self.dropout(x)

        # define fc layer 3 forward pass and add dropout
        x = self.activation3(self.fc3(x))
        x = self.dropout(x)

        # define fc layer 4 with tanh applied
        out = self.fc4(x).tanh()

        # return forward pass result
        return out

### MinMax-Game
We thus define the following value function $V$ by combining both the targets of $G$ and $D$:

$$min_{G}max_{D} V(D, G) = \mathbb{E}_{x∼Data} [logD(x)] + \mathbb{E}_{z∼Noise} [log(1-D(G(z)))]$$


In [63]:
# instantiate both D and G
D = Discriminator()
G = Generator()

In [64]:
# push to device (GPU)
D = D.to(device)
G = G.to(device)

In [70]:
# check whether model is loaded to GPU
!nvidia-smi

Tue Oct 11 16:07:26 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   50C    P0    26W /  70W |    632MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [66]:
print('[LOG] Discriminator architecture:\n\n{}\n'.format(D))

[LOG] Discriminator architecture:

Discriminator(
  (conv1): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (activation1): LeakyReLU(negative_slope=0.2, inplace=True)
  (conv2): Conv2d(5, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
  (activation2): LeakyReLU(negative_slope=0.2, inplace=True)
  (fc3): Linear(in_features=1024, out_features=64, bias=True)
  (activation3): LeakyReLU(negative_slope=0.2, inplace=True)
  (fc4): Linear(in_features=64, out_features=1, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)



In [67]:
print('[LOG] Generator architecture:\n\n{}\n'.format(G))

[LOG] Generator architecture:

Generator(
  (fc1): Linear(in_features=100, out_features=128, bias=True)
  (activation1): LeakyReLU(negative_slope=0.2, inplace=True)
  (fc2): Linear(in_features=128, out_features=256, bias=True)
  (activation2): LeakyReLU(negative_slope=0.2, inplace=True)
  (fc3): Linear(in_features=256, out_features=512, bias=True)
  (activation3): LeakyReLU(negative_slope=0.2, inplace=True)
  (fc4): Linear(in_features=512, out_features=4096, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)



In [69]:
# define the loss for the discriminator and push to device
criterion = nn.BCEWithLogitsLoss()
criterion = criterion.to(device)

For further training tips of GANs look at: [GAN hacks](https://https://github.com/soumith/ganhacks)

In [71]:
# set learning rate
lr = 0.002

# create optimizers for the discriminator and generator
d_optimizer = optim.SGD(D.parameters(), 0.02) 
g_optimizer = optim.Adam(G.parameters(), 0.002) 

## Training

In [82]:
# specify the training parameters
num_epochs = 20 # number of training epochs
mini_batch_size=64 # size of the mini-batches

In [90]:
train_loader = torch.utils.data.DataLoader(trainData, 
                                           batch_size=mini_batch_size,
                                           shuffle=True
                                           )

In [84]:
len(train_loader)

338

In [85]:
# establish convention for real and fake labels during training
real_label = 1
fake_label = 0

In [86]:
# define size of latent vector
z_size = 100

# define sample size
sample_size = 4

# uniformly distribute data of size z_size over an interval of -1; 1
fixed_noise = np.random.normal(0, 1, size=(sample_size, z_size))

# create numpy array into tensor, and convert data to float
fixed_noise = torch.from_numpy(fixed_noise).float()

# push the fixed vector to the device that's enabled
fixed_noise = fixed_noise.to(device)

In [92]:
# initialize list of the generated (fake) images
fake_images = []

# initialize collection of batch losses
D_batch_losses = []
G_batch_losses = []

# initialize collection of epoch losses
D_epoch_losses = []
G_epoch_losses = []

# set networks to training mode
D.train()
G.train()

# define time right before training
start = time.datetime.now()

# train the GANs
for epoch in range(num_epochs):

    # iterate over mini batches
    for i, data in enumerate(train_loader, 0):

        # define real images and push to computation device
        real_images = data[0].to(device)

        # define batch size as size of the images to make sure the loader is emptied completely
        batch_size = real_images.size(0)

        # --------------------------------------------------------------------------
        # (1) Update Discriminator network

        #### train with real images

        # create tensor of same size as mini-batch and filled with 1's (real_label)
        label = torch.full((batch_size,), real_label, dtype=torch.float, device=device)

        # rescaling input images from [0,1) to [-1, 1), which is needed for network
        real_images = real_images*2 - 1

        # run forward pass through Discriminator
        output = D(real_images) #.view(-1)

        # reset graph gradients
        D.zero_grad()

        # determine loss on Discriminator
        errD_real = criterion(output, label)

        # run backward pass
        errD_real.backward()
    
        #### train with fake images

        # generate batch of latent vectors
        z = np.random.normal(0, 1, size=(batch_size, z_size))
        z = torch.from_numpy(z).float()
        z = z.to(device)

        # generate fake image batch with Generator
        fake = G(z)

        # fills label tensor with 0's (fake_label)
        label.fill_(fake_label)

        # classify all fake batch with Discriminator
        output = D(fake.detach()).view(-1)

        # get discriminator loss on the fake batch
        errD_fake = criterion(output, label)

        # run backward pass
        errD_fake.backward()

        # compute error of Discriminator as sum of loss over the fake and the real batches
        errD = errD_fake + errD_real

        # update Discriminator parameters
        d_optimizer.step()


        # --------------------------------------------------------------------------
        # (2) Update Generator network

        # reset graph gradients
        G.zero_grad()

        # fake labels are real for generator
        label.fill_(real_label)

        # since we just updated D, perform another forward pass of fake batch through the Discriminator
        output = D(fake).view(-1)

        # get Generator loss based on this output
        errG = criterion(output, label)

        # run backward pass
        errG.backward()

        # update Generator paramaters
        g_optimizer.step()

        # --------------------------------------------------------------------------

        # each 250 iterations (4x per epoch), print losses
        if i % 500 == 0:
          now = datetime.utcnow().strftime("%H:%M:%S")
          print('[LOG {}] Epoch [{}/{}] \t[{}/{}] \t d_loss: {} \t g_loss: {}'.format(
              now, epoch+1, num_epochs,i, len(train_loader), errD.item(), errG.item()))
          
        # save losses for plotting later
        D_batch_losses.append(errD.item())
        G_batch_losses.append(errG.item())

        # set Generator to eval mode for generating samples (equivalent to 'testing' the model)
        G.eval() 

        # make Generator generate samples from the fixed noise ditribution
        samples = G(fixed_noise.float())

        # if you are using a GPU, copy tensor to host memory (cpu) - needed for later operations
        if device == 'cuda':
          samples = samples.cpu()

        # append generated fixed samples to the fake_images list
        fake_images.append(samples)

        # set Generator back to train mode
        G.train()

    # determine mean min-batch loss of epoch
    D_epoch_loss = np.mean(D_batch_losses)

    D_epoch_losses.append(D_epoch_loss)

    # determine mean min-batch loss of epoch
    G_epoch_loss = np.mean(G_batch_losses)

    G_epoch_losses.append(G_epoch_loss)

    # set filename of actual model
    d_model_name = 'gan_d_model_epoch_{}.pth'.format(str(epoch+1))

    # set filename of actual model
    g_model_name = 'gan_g_model_epoch_{}.pth'.format(str(epoch+1))

    # save current model to GDrive models directory
    torch.save(D.state_dict(), os.path.join(models_directory, d_model_name))

    # save current model to GDrive models directory
    torch.save(G.state_dict(), os.path.join(models_directory, g_model_name))

# save generated samples with pickle
with open('fake_images.pkl', 'wb') as f:
  pkl.dump(fake_images, f)

# print total training time
print('\nTotal training time:', time.datetime.now() - start)

RuntimeError: ignored