<a href="https://colab.research.google.com/github/AlexanderLontke/ssl-remote-sensing/blob/gan%2Fpipeline/notebooks/Classification_Downstream_Task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Classification Downstream Task 

In [1]:
!pip install ssl_remote_sensing@git+https://github.com/AlexanderLontke/ssl-remote-sensing.git@gan/pipeline

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ssl_remote_sensing@ git+https://github.com/AlexanderLontke/ssl-remote-sensing.git@gan/pipeline
  Cloning https://github.com/AlexanderLontke/ssl-remote-sensing.git (to revision gan/pipeline) to /tmp/pip-install-fx82imqt/ssl-remote-sensing_5f23d22ea442468c8cc209071be0b5d8
  Running command git clone -q https://github.com/AlexanderLontke/ssl-remote-sensing.git /tmp/pip-install-fx82imqt/ssl-remote-sensing_5f23d22ea442468c8cc209071be0b5d8
  Running command git checkout -b gan/pipeline --track origin/gan/pipeline
  Switched to a new branch 'gan/pipeline'
  Branch 'gan/pipeline' set up to track remote branch 'gan/pipeline' from 'origin'.
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Building wheels for collected pack

In [2]:
# Log in to your W&B account
import wandb

wandb.login()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmarccgrau[0m ([33munisg-ds-nlp[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
from datetime import datetime
import torch
import os
import numpy as np
from tqdm import tqdm

from torch import nn
import torchvision.transforms as T
from sklearn.metrics import classification_report

from ssl_remote_sensing.downstream_tasks.classification.model import (
    DownstreamClassificationNet,
)
from ssl_remote_sensing.constants import RANDOM_INITIALIZATION
from ssl_remote_sensing.pretext_tasks.utils import (
    load_encoder_checkpoint_from_pretext_model,
)
from ssl_remote_sensing.data.get_eurosat import (
    get_eurosat_normalizer,
    get_eurosat_dataloader,
)

## Dataset Loading ##

In [None]:
class RunConfig:
    def __init__(self):
        self.num_epochs = 30  # number of training epochs
        self.seed = 1234  # randomness seed
        self.save = "./saved_models/"  # save checkpoint
        self.batch_size = 256
        self.learning_rate = 1e-3
        self.embedding_size = 128  # papers value is 128
        self.test_split_ratio = 0.2
        self.checkpoint_name = None


config = RunConfig()

In [None]:
# Setup data loading
eurosat_normalizer = get_eurosat_normalizer()
train_dl, test_dl = get_eurosat_dataloader(
    root="./",
    transform=T.Compose([T.ToTensor(), eurosat_normalizer]),
    batchsize=config.batch_size,
    numworkers=os.cpu_count(),
    split=(config.test_split_ratio == 0.2),
)

## Model Training ##


In [None]:
# First of all, let's verify if a GPU is available on our compute machine. If not, the cpu will be used instead.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device used: {}".format(device))

# define the optimization criterion / loss function
loss_criterion = nn.CrossEntropyLoss().to(device)

## Setup Checkpoint Loading ##

In [None]:
from google.colab import drive

drive.mount("/content/drive")
g_drive_path = "/content/drive/MyDrive/deep_learning_checkpoints"
check_point_paths = os.listdir(g_drive_path)
check_point_paths += [RANDOM_INITIALIZATION]
check_point_paths = [g_drive_path + "/" + x for x in check_point_paths]

In [None]:
check_point_paths = check_point_paths[3:6]

In [None]:
for filename in check_point_paths:
    # Update checkpoint name
    config.checkpoint_name = filename
    # Load Encoder from different pre-text architectures
    encoder = load_encoder_checkpoint_from_pretext_model(
        path_to_checkpoint=filename,
    )
    wandb.init(
        project="ssl-remote-sensing-classification",
        name=filename,
        config=config.__dict__,
    )
    # Model Setup
    if "bigan" in config.checkpoint_name.lower():
        model = DownstreamClassificationNet(
            input_dim=100, encoder=encoder, gan_encoder=True
        ).to(device)
    else:
        model = DownstreamClassificationNet(input_dim=512, encoder=encoder).to(device)
    # define learning rate and optimization strategy
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    model.train()
    train_epoch_loss = np.NaN
    validation_epoch_loss = np.NaN

    with tqdm(range(config.num_epochs)) as tq:
        for epoch in tq:
            # print epoch loss
            now = datetime.utcnow().strftime("%Y%m%d-%H:%M:%S")
            tq.desc = f"[{now}] epoch: {epoch+1} train-loss: {train_epoch_loss} validation-loss: {validation_epoch_loss}"
            # init collection of mini-batch losses
            train_mini_batch_losses = []

            # iterate over all-mini batches
            for i, (images, labels) in enumerate(train_dl):

                # push mini-batch data to computation device
                images = images.to(device)
                labels = labels.to(device)

                # forward + backward + optimize
                optimizer.zero_grad()
                out = model(images)
                loss = loss_criterion(out, labels)
                loss.backward()
                optimizer.step()

                # collect mini-batch reconstruction loss
                train_mini_batch_losses.append(loss.data.item())
                wandb.log(
                    {
                        "step/training_loss": loss.data.item(),
                    }
                )

            # determine mean min-batch loss of epoch
            train_epoch_loss = np.mean(train_mini_batch_losses)

            # Specify you are in evaluation mode
            model.eval()
            with torch.no_grad():
                validation_mini_batch_losses = []
                for (images, labels) in test_dl:
                    images = images.to(device)
                    labels = labels.to(device)
                    # calculate outputs by running images through the network
                    outputs = model(images)
                    # the class with the highest energy is what we choose as prediction
                    validation_epoch_loss = loss_criterion(outputs, labels)
                    # collect mini-batch reconstruction loss
                    validation_mini_batch_losses.append(
                        validation_epoch_loss.data.item()
                    )
                validation_epoch_loss = np.mean(validation_mini_batch_losses)

            wandb.log(
                {
                    "epoch/training_loss": train_epoch_loss,
                    "epoch/validation_loss": validation_epoch_loss,
                }
            )
    # Store classification report
    y_pred = []
    y_true = []
    # iterate over test data
    model.eval()
    with torch.no_grad():
        for (images, labels) in tqdm(test_dl, desc="Predict labels"):
            images = images.to(device)

            outputs = model(images)  # Feed Network
            _, predicted = torch.max(outputs, 1)

            y_pred.extend(predicted.cpu().numpy())  # Save Prediction
            y_true.extend(labels.numpy())  # Save Truth
    # log results
    wandb.log(
        {
            "classification_report": classification_report(
                y_true, y_pred, output_dict=True
            )
        }
    )
    print(classification_report(y_true, y_pred))