In [None]:
import os
import tempfile

import kornia.augmentation as K
from torchvision import transforms

import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from torchgeo.datasets import EuroSAT
from torchgeo.models import ResNet18_Weights, resnet18

from tqdm import tqdm #for time counting

torch.manual_seed(0)

In [None]:
root = 'torchgeo_data'
dataset = EuroSAT(root, download=True, checksum=True)


In [None]:
train_dataset = EuroSAT(root, split='train')
val_dataset = EuroSAT(root, split='val')
test_dataset = EuroSAT(root, split='test')

In [None]:
batch_size = 32

train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size, shuffle=False)

#print the sizes of the datasets
print(f"Number of samples in the training set: {len(train_dataloader)}")
print(f"Number of samples in the validation set: {len(val_dataloader)}")
print(f"Number of samples in the testing set: {len(test_dataloader)}")

In [None]:
# Preprocessing and augmentation
preprocess = K.Normalize(0, 10000)
augment = K.ImageSequential(K.RandomHorizontalFlip(), K.RandomVerticalFlip(), 
                            K.RandomRotation(degrees=10))


In [None]:
# model and device setup
model = resnet18(ResNet18_Weights.SENTINEL2_ALL_MOCO)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
model.eval()

In [None]:
# training function with progress tracking
def train(dataloader):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Training")
    for i, batch in progress_bar:
        x = batch['image'].to(device)
        y = batch['label'].to(device)

        # Apply augmentation and preprocessing
        x = preprocess(augment(x))

        # Forward pass
        y_hat = model(x)
        loss = loss_fn(y_hat, y)
        total_loss += loss.item() * x.size(0)

        # Backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # accuracy calculation
        _, predicted = y_hat.max(1)
        correct += predicted.eq(y).sum().item()
        total += y.size(0)

        progress_bar.set_postfix(loss=total_loss / (i + 1), accuracy=100. * correct / total)

In [None]:
# evaluation function
def evaluate(dataloader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in dataloader:
            x = batch['image'].to(device)
            y = batch['label'].to(device)

            # Apply preprocessing
            x = preprocess(x)

            # Forward pass
            y_hat = model(x)
            _, predicted = y_hat.max(1)
            correct += predicted.eq(y).sum().item()
            total += y.size(0)

    accuracy = correct / total
    print(f'Accuracy: {accuracy:.2%}')

In [None]:
# Training loop
epochs = 10
for epoch in range(epochs):
    print(f'Epoch: {epoch +1}')
    train(train_dataloader)
    evaluate(val_dataloader)

In [None]:
evaluate(test_dataloader)

In [None]:
# saving the model
torch.save(model, 'torchgeo_data/models/torchgeo_resnet18.pth')
torch.save(model.state_dict(), 'torchgeo_data/models/torchgeo_resnet18_weights.pth')