# Transfer Learning for BigEarthNet Model with Multispectral Images

This notebook holds the code to use a BigEarthNet classifier originally trained by U Berlin and BIFOLD for transfer learning with a custom classifier. This code is in pytorch lighting, since the original model was trained in pytorch.

[BIFOLD huggingface repo](https://huggingface.co/BIFOLD-BigEarthNetv2-0/convmixer_768_32-s2-v0.2.0)

Resources: pytorch lighting documentation and chatgpt.

To download the model, create a wandb account (academic).





In [None]:
%pip install configilm
%pip install wandb
%wandb login
#Paste api key into terminal and hit enter

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import EuroSAT
from torchmetrics.classification import Accuracy
import pytorch_lightning as pl
from reben_publication.BigEarthNetv2_0_ImageClassifier import BigEarthNetv2_0_ImageClassifier
import matplotlib.pyplot as plt

Uncomment for running on mac (m series)

In [None]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

Uncomment for running on cuda gpu

In [None]:
# device = "cuda" if torch.cuda.is_available() else "cpu"
# print(f"Using device: {device}")

## Load the datasets into torch dataloaders
These will process and supply the model with batches as it trains to conserve memory use. First create a custom dataset class. 

In [None]:
# Transforms data into torch tensors 
transform = transforms.ToTensor()

# Load the full dataset (13 bands) | Set download to False once downloaded
dataset = EuroSAT(root="data", bands="all", download=True, transform=transform)

# Split into train/val/test
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size, test_size])

# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32)
test_dataloader = DataLoader(test_dataset, batch_size=32)

## Visualize data

Take sample from the dataset and visualize it's RGB bands.

In [None]:
# Take a sample from the train set and visualize its RGB bands
sample_img, sample_label = train_dataset[0]  # or any index

# EuroSAT returns images as [C, H, W] tensors; for RGB, use bands 0, 1, 2
rgb_img = sample_img[:3].permute(1, 2, 0)  # [H, W, C] for matplotlib

plt.imshow(rgb_img)
plt.title(f"Label: {sample_label}")
plt.axis('off')
plt.show()

## Build the model

In [None]:
class BigEarthNetClassifier(pl.LightningModule):
    def __init__(self, num_target_classes=5, dropout_rate=0.15):
        super().__init__()
        # Load pretrained model
        self.backbone = BigEarthNetv2_0_ImageClassifier.from_pretrained(
            "BIFOLD-BigEarthNetv2-0/convmixer_768_32-s2-v0.1.1")
        # Freeze backbone classifier remains trainable
        for param in self.backbone.parameters():
            param.requires_grad = False
        
        # Remove classifier level 
        print(self.backbone) # check that is actually has a classifier layer
        num_features = self.backbone.classifier.in_features
        self.backbone.classifier = nn.Identity()
        
        # Create new classifier 
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(num_features, num_target_classes) # Input: num_features Output: num_target_classes
        
        # Define loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # Define accuracy metric  
        self.train_acc = Accuracy(task="multiclass", num_classes=num_target_classes)
        self.val_acc = Accuracy(task="multiclass", num_classes=num_target_classes)
        self.test_acc = Accuracy(task="multiclass", num_classes=num_target_classes)
    
    def forward(self, x):
        features = self.backbone(x)
        x = self.dropout(features)
        x = self.classifier(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        self.log('train_loss', loss, on_step=False, on_epoch=True)        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        self.log('val_loss', loss, on_step=False, on_epoch=True)        
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log('test_loss', loss)
        self.log('test_acc', acc)
        return {"test_loss": loss, "test_acc": acc}
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=0.001)
        return optimizer

## Run training loop for classifier
Once the model is assembled we use the following code to train the classifier level. 

In [None]:
model = BigEarthNetClassifier() # Create an instance of the model
trainer = pl.Trainer() # Create a pytorch lighting trainer
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) # Train the model 

## Test model
Load the model and run it on the test data.

In [None]:
model = BigEarthNetClassifier.load_from_checkpoint(PATH)
model.freeze() # makes the model read-only for inference.
predictions = trainer.predict(model, dataloaders=test_dataloader)