In [None]:
import torch  # PyTorch library for tensor computations and deep learning
import torch.nn as nn  # Neural network modules and layers
import torch.optim as optim  # Optimization algorithms for training
import torchvision  # PyTorch's computer vision library
from torchvision import datasets, transforms as T  # Datasets and image transformations
from torch.utils.data import DataLoader, SubsetRandomSampler  # Data loading utilities
import numpy as np  # Numerical operations on arrays
import timm  # PyTorch Image Models library with pretrained models, including Vision Transformers
import os  # Operating system interface for file and directory operations
import matplotlib.pyplot as plt  # Plotting library for visualizations
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report

In [None]:
image_size = 224

transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomRotation(15),
    T.ColorJitter(brightness=0.1, contrast=0.1, saturation = 0.1),
    T.Resize((image_size, image_size)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
data_dir = '/kaggle/input/eurosat10-classes/EuroSAT_RGB/'  # Path to the dataset directory

# Create a dataset from the images in the specified directory, applying the defined transformations
dataset = datasets.ImageFolder(root=data_dir, transform=transform)

classes = dataset.classes
print("Classes:", classes)

batch_size = 32
valid_size = 0.2

num_data = len(dataset)
indices = list(range(num_data))
np.random.shuffle(indices)

split = int(np.floor(valid_size * num_data))

train_idx, valid_idx = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_idx) # Randomly samples elements from the training indices
valid_sampler = SubsetRandomSampler(valid_idx)

In [None]:
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
valid_loader = DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler)

model = timm.create_model('vit_base_patch16_224', pretrained=True)
num_classes = len(classes)

for param in model.parameters():
    param.requires_grad = False

model.head = nn.Linear(model.head.in_features, num_classes)

for param in model.blocks[-1].parameters():
    param.requires_grad = True

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.AdamW([
    {'params': model.head.parameters(), 'lr': 1e-3},
    {'params' : model.blocks[-1].parameters(), 'lr' : 1e-4}
], weight_decay=1e-4)

In [None]:
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max= 10)

num_epochs = 5

train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

best_val_acc = 0.0
best_model_path = 'best_vit_eurosat.pth'

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
    epoch_loss = running_loss / len(train_idx) #avg epoch loss
    epoch_acc = 100. * correct / total
    train_loss.append(epoch_loss)
    train_accuracies.append(epoch_acc)
    
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for images, labels in valid_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels)sum().item()
    
    val_epoch_loss = val_loss / len(valid_idx)
    val_epoch_acc = 100. * val_correct / val_total
    val_losses.append(val_epoch_loss)
    val_accuracies/append(val_epoch_acc)
    
    scheduler.step()
    
    print("Epoch")
    
    if val_epoch_acc > best_val_acc:
        best_val_acc = bal_epoch_acc
        torch.save(model.stage_dict(), best_model_path)
        print("Saved Best Model")