In [2]:
# --- Import packages and setup device ---
import os
import yaml
import torch
import random
import matplotlib.pyplot as plt
from torchvision import datasets, models
from torch.utils.data import DataLoader, random_split

# custom modules
from datafactory import EuroSAT, get_transforms
from engine import train
from utils import *

# Load Configuration
with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)


# Set Device & Seeds
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Select the best available hardware accelerator (cuda gpu or cpu)
set_seeds(config['seed'])
print(f"Device: {device}")

In [1]:
# --- Load EuroSAT data and inspect samples ---
print("Loading EuroSAT Data...")

# Load raw data
raw_dataset = datasets.ImageFolder(config['data_dir'])
class_names = raw_dataset.classes
print(f"Classes found: {class_names}")


def plot_samples(dataset, class_names, num_images=20):
    plt.figure(figsize=(15, 12))
    indices = random.sample(range(len(dataset)), num_images)
    
    for i, idx in enumerate(indices):
        image, label = dataset[idx]
        plt.subplot(4, 5, i + 1)
        plt.imshow(image)
        plt.title(f"{class_names[label]}", fontsize=9)
        plt.axis('off')

    plt.tight_layout()
    plt.show()

plot_samples(raw_dataset, class_names)

In [4]:
# --- Data processing (splitting and transform) ---

# Create a stratified subset: 30% was used to speed up training
indices = list(range(len(raw_dataset)))
subset_indices = sample_subset_per_class(torch.utils.data.Subset(raw_dataset, indices), 
                                         percentage=config['percentage_per_class'])

# Split into Train/Val
val_size = int(len(subset_indices) * config['val_ratio'])
train_size = len(subset_indices) - val_size
train_subset, val_subset = random_split(subset_indices, [train_size, val_size])

# Apply Transforms using the Wrapper
train_data = EuroSAT(torch.utils.data.Subset(raw_dataset, train_subset.indices), get_transforms('train'))
val_data = EuroSAT(torch.utils.data.Subset(raw_dataset, val_subset.indices), get_transforms('val'))

print(f"Training Samples: {len(train_data)}")
print(f"Validation Samples: {len(val_data)}")

# DataLoaders
train_loader = DataLoader(train_data, batch_size=config['batch_size'], shuffle=True, num_workers=0)
val_loader = DataLoader(val_data, batch_size=config['batch_size'], shuffle=False, num_workers=0)

In [None]:
# --- Model setup ---
print("Initializing ViT Model...")
weights = models.ViT_B_16_Weights.DEFAULT
model = models.vit_b_16(weights=weights)

# Modify head for 10 Classes
model.heads.head = torch.nn.Linear(model.heads.head.in_features, config['num_classes'])
model = model.to(device)

# Optimizer & Loss
optimizer = get_optimizer(model, config)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
# --- Training loop ---
save_path = os.path.join(config['save_path'], "best_model.pth")

results, best_model = train(
    model=model,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    optimizer=optimizer,
    loss_fn=criterion,
    device=device,
    epochs=config['num_epochs'],
    model_save_path=save_path,
    use_scheduler=config['use_scheduler']
)

# Save results & plots
save_training_results(results, os.path.join(config['save_path'], "train_results.json"))
plot_curves(results, os.path.join(config['save_path'], "training_curves.png"))

print("Model saved.")