[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rslab-ntua/MSc_GBDA/2022/Lab2_simple_cnn.ipynb)

In [None]:
!wget http://madm.dfki.de/files/sentinel/EuroSATallBands.zip
!unzip EuroSATallBands.zip
!rm EuroSATallBands.zip

In [None]:
!pip install rasterio
!pip install pytorch-lightning

%load_ext tensorboard

In [None]:
# Read data
from torch.utils.data import Dataset, DataLoader, random_split
from glob import glob
import os
import rasterio
from typing import Callable, List
import numpy as np

DATA_ROOT = "ds/images/remote_sensing/otherDatasets/sentinel_2/tif/"

TransformFun = Callable[[dict], dict]

class EuroSAT(Dataset):
    def __init__(self, data_root, transforms: List[TransformFun] = []):
        super().__init__()
        self._build_db(data_root)
        self.transforms = transforms
        
    def _build_db(self, data_root) -> None:
        sample_urls = sorted(glob(os.path.join(data_root, "**/*.tif"), recursive=True))
        
        def parse_category(url):
            return os.path.basename(os.path.dirname(url))
        
        # Get unique category names in alphabetical order
        categories = sorted(list(set([parse_category(url) for url in sample_urls])))
        self.categories = {c_name: idx for idx, c_name in enumerate(categories)}
        
        self.db = []
        for s_url in sample_urls:
            self.db.append({
                "url": s_url,
                "category_name": parse_category(s_url),
                "category_id": self.categories[parse_category(s_url)]
            })
    
    @property
    def num_categories(self):
        return(len(self.categories))
    
    def __getitem__(self, index):
        sample =  self.db[index]
        
        for T in self.transforms:
            sample = T(sample)
            
        return sample
    
    def __len__(self):
        return len(self.db)
    
        

def load_data():
    def apply(x:dict) -> dict:
        assert "url" in x
        with rasterio.open(x["url"]) as dataset:
            x.update({"data": dataset.read()})
        return x
    return apply

def normalize(factor=10000):
    def apply(x:dict) -> dict:
        assert "data" in x
        x["data"] = x["data"].astype(np.float32) / factor
        return x
    return apply


In [None]:
# Split train/val/test set
dset = EuroSAT(data_root=DATA_ROOT, transforms=[load_data(), normalize()])

train_dset, val_dset, test_dset = random_split(dset, 
        lengths=[
            int(0.7*len(dset)),
            int(0.2*len(dset)),
            len(dset) - int(0.7*len(dset)) - int(0.2*len(dset))
        ]
    )

train_dloader =DataLoader(train_dset, batch_size=128, shuffle=True, num_workers=2)
val_dloader =DataLoader(val_dset, batch_size=128, shuffle=False, num_workers=2)
test_dloader =DataLoader(test_dset, batch_size=128, shuffle=False, num_workers=2)

In [None]:
# Train a simple CNN
from torch import nn
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from torchmetrics import Accuracy, ConfusionMatrix
from torchsummary import summary

class CNN(pl.LightningModule):
    def __init__(self, channels_in, num_classes, lr=1e-3):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(channels_in, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(output_size=(4, 4))
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(4*4*128, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

        self.lr = lr

        self.train_accuracy = Accuracy()

        self.val_accuracy = Accuracy()
        self.val_confusion_matrix = ConfusionMatrix(num_classes)
        
        self.save_hyperparameters()

    def forward(self, x):
        x = self.encoder(x)
        return self.classifier(x)

    def training_step(self, batch, batch_idx):
        X = batch["data"]
        y = batch["category_id"]

        logits = self(X)

        loss = F.nll_loss(torch.log_softmax(logits, dim=-1), y)
        self.log("loss/train", loss, on_epoch=True, on_step=False)

        self.train_accuracy(logits, y)
        self.log("accuracy/train", self.train_accuracy, on_epoch=True, on_step=False)

        return loss
    
    def validation_step(self, batch, batch_idx):
        X = batch["data"]
        y = batch["category_id"]

        logits = self(X)

        loss = F.nll_loss(torch.log_softmax(logits, dim=-1), y)
        self.log("loss/val", loss, on_epoch=True, on_step=False)

        self.val_accuracy(logits, y)
        self.log("accuracy/val", self.val_accuracy, on_epoch=True, on_step=False)

        self.val_confusion_matrix(logits, y)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)
    
print(summary(CNN(13, dset.num_categories), input_size=(13, 64, 64), device="cpu"))

In [None]:
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

callbacks = [
    EarlyStopping(monitor="accuracy/val", mode="max", patience=3),
    ModelCheckpoint(monitor="accuracy/val", mode="max", save_last=True)
]

model = CNN(13, dset.num_categories)
trainer = pl.Trainer(
    accelerator="gpu", 
    devices=1,
    max_epochs=20,
    callbacks=callbacks,
    default_root_dir="simple_cnn"
)

trainer.fit(model, train_dataloaders=train_dloader, val_dataloaders=val_dloader)


In [None]:
%tensorboard --logdir simple_cnn/lightning_logs

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay
from matplotlib import pyplot as plt

best_model = CNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

trainer.validate(model, dataloaders=test_dloader)

cm = model.val_confusion_matrix.compute().cpu().numpy()

disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                               display_labels=dset.categories.keys())
plt.figure(figsize=(20,20), dpi=100)
ax = plt.axes()

disp.plot(ax=ax)
plt.xticks(rotation=90)
plt.show()