In [None]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch
from torchvision.models import vit_b_16
import matplotlib.pyplot as plt
import numpy as np
import optuna
import mlflow

In [None]:
def load_data(batch_size=32):
    # Define the transformation (Imagnet mean and std)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Load the dataset
    dataset = datasets.EuroSAT(root='./data', download=True, transform=transform)

    # Split the dataset into train, validation, test sets
    train_size = int(0.01 * len(dataset))
    val_size = int(0.01 * len(dataset))
    test_size = len(dataset) - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
    
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size, shuffle=False)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=False)
    
    return dataset, train_dataloader, val_dataloader, test_dataloader   

In [None]:
# Load the pretrained vit model
def load_vit(num_classes=10, unfreeze=5):
    model = vit_b_16(weights='DEFAULT')
    for param in model.parameters():
        param.requires_grad = False
    
    # Unfreeze the last few layers
    if unfreeze > 0:
        encoder_layers = model.encoder.layers
        number_of_layers = len(encoder_layers)
        
        for i in range(number_of_layers - unfreeze, number_of_layers):
            for param in encoder_layers[i].parameters():
                param.requires_grad = True
                
    # replace the classifier head
    num_features = model.heads.head.in_features
    model.heads.head = torch.nn.Linear(num_features, num_classes)
    
    return model

In [None]:
# Train the model
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    
    total_loss = 0
    correct = 0
    total_size = 0
    
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += preds.eq(labels).sum().item()
        total_size += labels.size(0)
        
    return total_loss / total_size, 100 * correct / total_size

In [None]:
# Evaluate the model
def eval(model, dataloader, criterion, device):
    model.eval()
    
    total_loss = 0
    correct = 0
    total_size = 0
    
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += preds.eq(labels).sum().item()
            total_size += labels.size(0)
        
    return total_loss / total_size, 100 * correct / total_size

In [None]:
# Tune the model
def objective(trial):
    lr = trial.suggest_float('lr', 1e-5, 5e-5, log=True)
    batch_size = trial.suggest_categorical('batch_size', [32, 64, 128])
    unfreeze = trial.suggest_categorical('unfreeze', [0, 6, 12])
    weight_decay = trial.suggest_float('weight_decay', 1e-6, 0.01, log=True)
    
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
    
    # Load the data
    dataset, train_dataloader, val_dataloader, test_dataloader = load_data(batch_size)
    
    # Load the model
    model = load_vit(num_classes=len(dataset.classes), unfreeze=unfreeze).to(device)
    
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    with mlflow.start_run(run_name=f"trial_{trial.number}"):
        mlflow.log_param("lr", lr)
        mlflow.log_param("batch_size", batch_size)
        mlflow.log_param("unfreeze", unfreeze)
        mlflow.log_param("weight_decay", weight_decay)
        
        patience = 2
        patience_cnt = 0
        best_val_acc = 0
        best_epoch = 0
        epochs = 5
        
        for epoch in range(epochs):
            
            train_loss, train_acc = train(model, train_dataloader, criterion, optimizer, device)
            val_loss, val_acc = eval(model, val_dataloader, criterion, device)
            
            mlflow.log_metric("train_loss", train_loss, step=epoch)
            mlflow.log_metric("train_acc", train_acc, step=epoch)
            mlflow.log_metric("val_loss", val_loss, step=epoch)
            mlflow.log_metric("val_acc", val_acc, step=epoch)
            
            trial.report(val_acc, epoch)
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_epoch = epoch
                patience_cnt = 0
            else:
                patience_cnt += 1
                
                if patience_cnt >= patience:
                    break
            
            if trial.should_prune():
                raise optuna.TrialPruned()
            
    return best_val_acc, best_epoch

In [None]:
# visualize the predictions by running the model on the test set
def visualize_predictions(model, dataloader, device, dataset, n_samples=25):
    model.eval()
    
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            
            # visualize the first n_samples images
            fig, axes = plt.subplots(5, 5, figsize=(30, 20))
            for i in range(n_samples):
                ax = axes[i // 5, i % 5]
                # unnormalize the image
                image = images[i].cpu().permute(1, 2, 0)
                image = image * torch.tensor([0.229, 0.224, 0.225]) + torch.tensor([0.485, 0.456, 0.406])
                image = image.clamp(0, 1)
                ax.imshow(image)
                # set the title to the predicted and true label not number
                ax.set_title(f"Pred: {dataset.classes[preds[i].item()]}, True: {dataset.classes[labels[i].item()]}")
                ax.axis('off')
            plt.show()
            
            break
           

In [None]:
import subprocess
from pyngrok import ngrok, conf
import getpass

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

MLFLOW_TRACKING_URI = "file:./mlruns"
EXPERIMENT_NAME = "EuroSAT_ViT_Classification"

subprocess.Popen(["mlflow", "ui", "--backend-store-uri", MLFLOW_TRACKING_URI, "--port", "8080"])

mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
mlflow.set_experiment(EXPERIMENT_NAME)

In [None]:
study = optuna.create_study(direction="maximize", study_name=EXPERIMENT_NAME)
study.optimize(objective, n_trials=3)

In [None]:
dataset, train_dataloader, val_dataloader, test_dataloader = load_data()
print(len(train_dataloader.dataset), len(val_dataloader.dataset), len(test_dataloader.dataset))

In [None]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
print(device)

In [None]:
model = load_vit()
model = model.to(device)

In [None]:
# print the number of parameters in the model in terms of millions
num_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"Number of parameters in the model: {num_params:.2f}M")

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train the model
num_epochs = 2
for epoch in range(num_epochs):
    train_loss, train_acc = train(model, train_dataloader, criterion, optimizer, device)
    val_loss, val_acc = eval(model, val_dataloader, criterion, device)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    
#test_loss, test_acc = eval(model, test_dataloader, criterion, device)
#print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")

In [None]:
visualize_predictions(model, test_dataloader, device, dataset)