# Classification Downstream Task 

In [18]:
!pip install ssl_remote_sensing@git+https://github.com/AlexanderLontke/ssl-remote-sensing.git@feature/pipeline

In [None]:
# Log in to your W&B account
import wandb
wandb.login()

In [2]:
from datetime import datetime
import torch
import wandb
import numpy as np
from tqdm import tqdm

from torch import nn
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision.datasets import EuroSAT
from sklearn.metrics import classification_report

from ssl_remote_sensing.downstream_tasks.classification.model import DownstreamClassificationNet
from ssl_remote_sensing.downstream_tasks.classification.util import get_subset_samplers_for_train_test_split

## Dataset Loading ##

In [9]:
class RunConfig:
    def __init__(self):
        self.num_epochs = 10  # number of training epochs
        self.seed = 1234  # randomness seed
        self.save = "./saved_models/"  # save checkpoint
        self.batch_size = 64
        self.learning_rate = 1e-3
        self.embedding_size = 128  # papers value is 128
        self.test_split_ratio = 0.2
config = RunConfig()
wandb.init(
    project="ssl-remote-sensing-classification",
    config=config.__dict__,
)

In [10]:
# Setup data loading
# TODO add normalization
eurosat_ds = EuroSAT(root="./", download=True, transform=T.ToTensor())
dataset_size = len(eurosat_ds)

train_sampler, test_sampler = get_subset_samplers_for_train_test_split(
    dataset_size, test_split_ratio=config.test_split_ratio
)
train_dl = DataLoader(
    dataset=eurosat_ds,
    batch_size=128,
    sampler=train_sampler,
)

test_dl = DataLoader(
    dataset=eurosat_ds,
    batch_size=128,
    sampler=test_sampler,
)

## Model Training ##


In [11]:
# First of all, let's verify if a GPU is available on our compute machine. If not, the cpu will be used instead.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device used: {}".format(device))

model = DownstreamClassificationNet(input_dim=72 * 13 * 13).to(device)

# define the optimization criterion / loss function
loss_criterion = nn.CrossEntropyLoss().to(device)

# define learning rate and optimization strategy
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

Device used: cuda:0


In [14]:
model.train()

for epoch in tqdm(range(config.num_epochs)):
    # init collection of mini-batch losses
    train_mini_batch_losses = []

    # iterate over all-mini batches
    for i, (images, labels) in enumerate(train_dl):

        # push mini-batch data to computation device
        images = images.to(device)
        labels = labels.to(device)

        # forward + backward + optimize
        optimizer.zero_grad()
        out = model(images)
        loss = loss_criterion(out, labels)
        loss.backward()
        optimizer.step()

        # collect mini-batch reconstruction loss
        train_mini_batch_losses.append(loss.data.item())
        wandb.log({
            "step/training_loss": loss.data.item(),
        })

    # determine mean min-batch loss of epoch
    train_epoch_loss = np.mean(train_mini_batch_losses)

    # Specify you are in evaluation mode
    model.eval()
    with torch.no_grad():
        validation_mini_batch_losses = []
        for (images, labels) in test_dl:
            images = images.to(device)
            labels = labels.to(device)
            # calculate outputs by running images through the network
            outputs = model(images)
            # the class with the highest energy is what we choose as prediction
            validation_epoch_loss = loss_criterion(outputs, labels)
            # collect mini-batch reconstruction loss
            validation_mini_batch_losses.append(validation_epoch_loss.data.item())
        validation_epoch_loss = np.mean(validation_mini_batch_losses)

    wandb.log({
        "epoch/training_loss": train_epoch_loss,
        "epoch/validation_loss": validation_epoch_loss,
    })

    # print epoch loss
    now = datetime.utcnow().strftime("%Y%m%d-%H:%M:%S")
    print(
        f"[LOG {now}] epoch: {epoch+1} train-loss: {train_epoch_loss} validation-loss: {validation_epoch_loss}"
    )

100%|██████████| 169/169 [00:15<00:00, 11.11it/s]


[LOG 20221027-13:33:44] epoch: 1 train-loss: 1.480928829435766 validation-loss: 1.2402197122573853


100%|██████████| 169/169 [00:12<00:00, 13.25it/s]


[LOG 20221027-13:33:59] epoch: 2 train-loss: 1.0071332916705567 validation-loss: 0.8630091547966003


100%|██████████| 169/169 [00:13<00:00, 12.65it/s]


[LOG 20221027-13:34:14] epoch: 3 train-loss: 0.7849056743658506 validation-loss: 0.7523846626281738


100%|██████████| 169/169 [00:12<00:00, 13.11it/s]


[LOG 20221027-13:34:30] epoch: 4 train-loss: 0.6741639231083661 validation-loss: 0.7110217213630676


100%|██████████| 169/169 [00:13<00:00, 12.81it/s]


[LOG 20221027-13:34:45] epoch: 5 train-loss: 0.6494043100867751 validation-loss: 0.5024420619010925


100%|██████████| 169/169 [00:12<00:00, 13.26it/s]


[LOG 20221027-13:35:00] epoch: 6 train-loss: 0.5615866159546304 validation-loss: 0.5328946113586426


100%|██████████| 169/169 [00:12<00:00, 13.33it/s]


[LOG 20221027-13:35:16] epoch: 7 train-loss: 0.5410281814767058 validation-loss: 0.7607574462890625


100%|██████████| 169/169 [00:12<00:00, 13.21it/s]


[LOG 20221027-13:35:31] epoch: 8 train-loss: 0.479756369745943 validation-loss: 0.3574117422103882


100%|██████████| 169/169 [00:16<00:00, 10.32it/s]


[LOG 20221027-13:35:51] epoch: 9 train-loss: 0.44355890243011115 validation-loss: 0.48850345611572266


100%|██████████| 169/169 [00:13<00:00, 12.76it/s]


[LOG 20221027-13:36:07] epoch: 10 train-loss: 0.40241120023840277 validation-loss: 0.3943125307559967


## Model Evaluation ##

In [22]:
y_pred = []
y_true = []
# iterate over test data
model.eval()
with torch.no_grad():
    for (images, labels) in tqdm(test_dl, desc="Predict labels"):
        images = images.to(device)

        outputs = model(images)  # Feed Network
        _, predicted = torch.max(outputs, 1)

        y_pred.extend(predicted.cpu().numpy())  # Save Prediction
        y_true.extend(labels.numpy())  # Save Truth

Predict labels: 100%|██████████| 43/43 [00:02<00:00, 17.45it/s]


In [23]:
print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

           0       0.92      0.77      0.84       595
           1       0.84      0.92      0.88       606
           2       0.80      0.72      0.76       602
           3       0.69      0.56      0.62       515
           4       0.83      0.97      0.89       493
           5       0.75      0.75      0.75       423
           6       0.65      0.73      0.69       473
           7       0.91      0.96      0.93       597
           8       0.61      0.79      0.69       490
           9       0.98      0.78      0.87       606

    accuracy                           0.80      5400
   macro avg       0.80      0.79      0.79      5400
weighted avg       0.81      0.80      0.80      5400

