In [None]:
!pip install mlflow
!pip install pyngrok
!pip install optuna

Collecting mlflow
  Downloading mlflow-2.21.3-py3-none-any.whl.metadata (30 kB)
Collecting mlflow-skinny==2.21.3 (from mlflow)
  Downloading mlflow_skinny-2.21.3-py3-none-any.whl.metadata (31 kB)
Collecting alembic!=1.10.0,<2 (from mlflow)
  Downloading alembic-1.15.2-py3-none-any.whl.metadata (7.3 kB)
Collecting docker<8,>=4.0.0 (from mlflow)
  Downloading docker-7.1.0-py3-none-any.whl.metadata (3.8 kB)
Collecting graphene<4 (from mlflow)
  Downloading graphene-3.4.3-py2.py3-none-any.whl.metadata (6.9 kB)
Collecting gunicorn<24 (from mlflow)
  Downloading gunicorn-23.0.0-py3-none-any.whl.metadata (4.4 kB)
Collecting databricks-sdk<1,>=0.20.0 (from mlflow-skinny==2.21.3->mlflow)
  Downloading databricks_sdk-0.50.0-py3-none-any.whl.metadata (38 kB)
Collecting fastapi<1 (from mlflow-skinny==2.21.3->mlflow)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting uvicorn<1 (from mlflow-skinny==2.21.3->mlflow)
  Downloading uvicorn-0.34.1-py3-none-any.whl.metadata (6.5 k

In [None]:
# This notebook is used to train and tune the model via Colab
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch
from torchvision.models import vit_b_16
import matplotlib.pyplot as plt
import numpy as np
import optuna
import mlflow
import mlflow.pytorch
from mlflow.tracking import MlflowClient

In [None]:
def load_data(batch_size=32):
    # Define the transformation (Imagnet mean and std)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Load the dataset
    dataset = datasets.EuroSAT(root='./data', download=True, transform=transform)

    # Split the dataset into train, validation, test sets
    train_size = int(0.7 * len(dataset))
    val_size = int(0.15 * len(dataset))
    test_size = len(dataset) - train_size - val_size

    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size, shuffle=False)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=False)

    return dataset, train_dataloader, val_dataloader, test_dataloader

In [None]:
# Load the pretrained vit model
def load_vit(num_classes=10, unfreeze=5):
    model = vit_b_16(weights='DEFAULT')
    for param in model.parameters():
        param.requires_grad = False

    # Unfreeze the last few layers
    if unfreeze > 0:
        encoder_layers = model.encoder.layers
        number_of_layers = len(encoder_layers)

        for i in range(number_of_layers - unfreeze, number_of_layers):
            for param in encoder_layers[i].parameters():
                param.requires_grad = True

    # replace the classifier head
    num_features = model.heads.head.in_features
    model.heads.head = torch.nn.Linear(num_features, num_classes)

    return model

In [None]:
# Train the model
def train(model, dataloader, criterion, optimizer, device):
    model.train()

    total_loss = 0
    correct = 0
    total_size = 0

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += preds.eq(labels).sum().item()
        total_size += labels.size(0)
    # Calculate average loss and accuracy
    return total_loss / total_size, 100 * correct / total_size

In [None]:
# Evaluate the model
def eval(model, dataloader, criterion, device):
    model.eval()

    total_loss = 0
    correct = 0
    total_size = 0

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += preds.eq(labels).sum().item()
            total_size += labels.size(0)
    # Calculate average loss and accuracy
    return total_loss / total_size, 100 * correct / total_size

In [None]:
# Tune the model
def objective(trial):
    lr = trial.suggest_float('lr', 1e-5, 5e-5, log=True)
    batch_size = trial.suggest_categorical('batch_size', [32, 64])
    unfreeze = trial.suggest_categorical('unfreeze', [0, 3])
    weight_decay = trial.suggest_categorical('weight_decay', [0.0, 0.01])

    device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

    # Load the data
    dataset, train_dataloader, val_dataloader, test_dataloader = load_data(batch_size)

    # Load the model
    model = load_vit(num_classes=len(dataset.classes), unfreeze=unfreeze).to(device)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    with mlflow.start_run(run_name=f"trial_{trial.number}"):
        # Log parameters
        mlflow.log_param("lr", lr)
        mlflow.log_param("batch_size", batch_size)
        mlflow.log_param("unfreeze", unfreeze)
        mlflow.log_param("weight_decay", weight_decay)

        patience = 2
        patience_cnt = 0
        best_val_acc = 0
        epochs = 5

        for epoch in range(epochs):

            train_loss, train_acc = train(model, train_dataloader, criterion, optimizer, device)
            val_loss, val_acc = eval(model, val_dataloader, criterion, device)
            
            # Log metrics
            
            mlflow.log_metric("train_loss", train_loss, step=epoch)
            mlflow.log_metric("train_acc", train_acc, step=epoch)
            mlflow.log_metric("val_loss", val_loss, step=epoch)
            mlflow.log_metric("val_acc", val_acc, step=epoch)

            print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

            trial.report(val_acc, epoch)

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                patience_cnt = 0
            else:
                patience_cnt += 1

                if patience_cnt >= patience:
                    break

            if trial.should_prune():
                raise optuna.TrialPruned()

    return best_val_acc

In [None]:
# visualize the predictions by running the model on the test set
def visualize_predictions(model, dataloader, device, dataset, n_samples=25):
    model.eval()

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            _, preds = torch.max(outputs, 1)

            # visualize the first n_samples images
            fig, axes = plt.subplots(5, 5, figsize=(30, 20))
            for i in range(n_samples):
                ax = axes[i // 5, i % 5]
                # unnormalize the image
                image = images[i].cpu().permute(1, 2, 0)
                image = image * torch.tensor([0.229, 0.224, 0.225]) + torch.tensor([0.485, 0.456, 0.406])
                image = image.clamp(0, 1)
                ax.imshow(image)
                # set the title to the predicted and true label not number
                ax.set_title(f"Pred: {dataset.classes[preds[i].item()]}, True: {dataset.classes[labels[i].item()]}")
                ax.axis('off')
            plt.show()

            break


In [None]:
import subprocess
from pyngrok import ngrok, conf
import getpass

# Fix the seed
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Set up MLflow tracking server
MLFLOW_TRACKING_URI = "sqlite:///mlflow.db"
EXPERIMENT_NAME = "EuroSAT_ViT_Classification"

subprocess.Popen(["mlflow", "ui", "--backend-store-uri", MLFLOW_TRACKING_URI, "--port", "5000"])

mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
mlflow.set_experiment(EXPERIMENT_NAME)

2025/04/14 18:29:21 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2025/04/14 18:29:21 INFO mlflow.store.db.utils: Updating database tables
INFO  [alembic.runtime.migration] Context impl SQLiteImpl.
INFO  [alembic.runtime.migration] Will assume non-transactional DDL.
INFO  [alembic.runtime.migration] Running upgrade  -> 451aebb31d03, add metric step
INFO  [alembic.runtime.migration] Running upgrade 451aebb31d03 -> 90e64c465722, migrate user column to tags
INFO  [alembic.runtime.migration] Running upgrade 90e64c465722 -> 181f10493468, allow nulls for metric values
INFO  [alembic.runtime.migration] Running upgrade 181f10493468 -> df50e92ffc5e, Add Experiment Tags Table
INFO  [alembic.runtime.migration] Running upgrade df50e92ffc5e -> 7ac759974ad8, Update run tags with larger limit
INFO  [alembic.runtime.migration] Running upgrade 7ac759974ad8 -> 89d4b8295536, create latest metrics table
INFO  [89d4b8295536_create_latest_metrics_table_py] Migration complete!
INFO  

<Experiment: artifact_location='/content/mlruns/1', creation_time=1744655362619, experiment_id='1', last_update_time=1744655362619, lifecycle_stage='active', name='EuroSAT_ViT_Classification', tags={}>

In [None]:
# Set up ngrok to expose the MLflow UI
print("Enter your authtoken, which can be copied from https://dashboard.ngrok.com/get-started/your-authtoken")
conf.get_default().auth_token = getpass.getpass()

Enter your authtoken, which can be copied from https://dashboard.ngrok.com/get-started/your-authtoken
··········


In [None]:
port = 5000
public_url = ngrok.connect(port).public_url
print(f' * ngrok tunnel "{public_url}" -> "http://127.0.0.1:{port}"')

 * ngrok tunnel "https://299f-34-142-227-62.ngrok-free.app" -> "http://127.0.0.1:5000"


In [None]:
study = optuna.create_study(direction="maximize", study_name=EXPERIMENT_NAME)
study.optimize(objective, n_trials=10)

[I 2025-04-14 18:29:46,780] A new study created in memory with name: EuroSAT_ViT_Classification
100%|██████████| 94.3M/94.3M [00:00<00:00, 387MB/s]
Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:05<00:00, 67.3MB/s]


Epoch [1/5], Train Loss: 1.6406, Train Acc: 58.49%, Val Loss: 1.1266, Val Acc: 80.54%
Epoch [2/5], Train Loss: 0.8918, Train Acc: 84.54%, Val Loss: 0.7143, Val Acc: 87.16%
Epoch [3/5], Train Loss: 0.6185, Train Acc: 88.16%, Val Loss: 0.5335, Val Acc: 89.98%
Epoch [4/5], Train Loss: 0.4847, Train Acc: 90.02%, Val Loss: 0.4336, Val Acc: 91.04%


[I 2025-04-14 18:55:15,866] Trial 0 finished with value: 91.92592592592592 and parameters: {'lr': 3.172390142984219e-05, 'batch_size': 32, 'unfreeze': 0, 'weight_decay': 0.0}. Best is trial 0 with value: 91.92592592592592.


Epoch [5/5], Train Loss: 0.4055, Train Acc: 91.13%, Val Loss: 0.3708, Val Acc: 91.93%
Epoch [1/5], Train Loss: 0.6183, Train Acc: 85.43%, Val Loss: 0.1858, Val Acc: 94.77%
Epoch [2/5], Train Loss: 0.1202, Train Acc: 96.68%, Val Loss: 0.1158, Val Acc: 96.57%
Epoch [3/5], Train Loss: 0.0719, Train Acc: 98.09%, Val Loss: 0.0935, Val Acc: 97.23%
Epoch [4/5], Train Loss: 0.0475, Train Acc: 98.78%, Val Loss: 0.0890, Val Acc: 97.38%


[I 2025-04-14 19:28:57,082] Trial 1 finished with value: 97.38271604938272 and parameters: {'lr': 1.1703764483606674e-05, 'batch_size': 64, 'unfreeze': 3, 'weight_decay': 0.0}. Best is trial 1 with value: 97.38271604938272.


Epoch [5/5], Train Loss: 0.0311, Train Acc: 99.25%, Val Loss: 0.0978, Val Acc: 97.23%
Epoch [1/5], Train Loss: 0.3035, Train Acc: 91.95%, Val Loss: 0.1062, Val Acc: 96.79%
Epoch [2/5], Train Loss: 0.0705, Train Acc: 97.85%, Val Loss: 0.0851, Val Acc: 97.09%
Epoch [3/5], Train Loss: 0.0377, Train Acc: 98.90%, Val Loss: 0.0781, Val Acc: 97.65%
Epoch [4/5], Train Loss: 0.0191, Train Acc: 99.53%, Val Loss: 0.0861, Val Acc: 97.26%


[I 2025-04-14 20:02:32,053] Trial 2 finished with value: 97.65432098765432 and parameters: {'lr': 2.2992169210769673e-05, 'batch_size': 32, 'unfreeze': 3, 'weight_decay': 0.0}. Best is trial 2 with value: 97.65432098765432.


Epoch [5/5], Train Loss: 0.0123, Train Acc: 99.71%, Val Loss: 0.0824, Val Acc: 97.63%
Epoch [1/5], Train Loss: 0.2718, Train Acc: 92.78%, Val Loss: 0.1090, Val Acc: 96.79%
Epoch [2/5], Train Loss: 0.0569, Train Acc: 98.26%, Val Loss: 0.0736, Val Acc: 97.75%
Epoch [3/5], Train Loss: 0.0292, Train Acc: 99.14%, Val Loss: 0.0839, Val Acc: 97.28%


[I 2025-04-14 20:29:28,103] Trial 3 finished with value: 97.75308641975309 and parameters: {'lr': 4.721786288863328e-05, 'batch_size': 64, 'unfreeze': 3, 'weight_decay': 0.01}. Best is trial 3 with value: 97.75308641975309.


Epoch [4/5], Train Loss: 0.0170, Train Acc: 99.54%, Val Loss: 0.0995, Val Acc: 97.21%
Epoch [1/5], Train Loss: 0.6402, Train Acc: 85.11%, Val Loss: 0.1996, Val Acc: 94.96%
Epoch [2/5], Train Loss: 0.1342, Train Acc: 96.29%, Val Loss: 0.1294, Val Acc: 96.05%
Epoch [3/5], Train Loss: 0.0816, Train Acc: 97.61%, Val Loss: 0.1114, Val Acc: 96.49%
Epoch [4/5], Train Loss: 0.0561, Train Acc: 98.44%, Val Loss: 0.0982, Val Acc: 96.86%


[I 2025-04-14 21:03:05,792] Trial 4 finished with value: 96.88888888888889 and parameters: {'lr': 1.022305809475511e-05, 'batch_size': 64, 'unfreeze': 3, 'weight_decay': 0.0}. Best is trial 3 with value: 97.75308641975309.


Epoch [5/5], Train Loss: 0.0388, Train Acc: 99.15%, Val Loss: 0.0910, Val Acc: 96.89%
Epoch [1/5], Train Loss: 0.3781, Train Acc: 90.34%, Val Loss: 0.1107, Val Acc: 96.96%
Epoch [2/5], Train Loss: 0.0822, Train Acc: 97.56%, Val Loss: 0.0790, Val Acc: 97.51%
Epoch [3/5], Train Loss: 0.0442, Train Acc: 98.76%, Val Loss: 0.0668, Val Acc: 97.80%
Epoch [4/5], Train Loss: 0.0252, Train Acc: 99.33%, Val Loss: 0.1137, Val Acc: 96.64%


[I 2025-04-14 21:36:42,358] Trial 5 finished with value: 97.87654320987654 and parameters: {'lr': 2.5342001417793284e-05, 'batch_size': 64, 'unfreeze': 3, 'weight_decay': 0.01}. Best is trial 5 with value: 97.87654320987654.


Epoch [5/5], Train Loss: 0.0175, Train Acc: 99.58%, Val Loss: 0.0653, Val Acc: 97.88%


[I 2025-04-14 21:41:35,654] Trial 6 pruned. 


Epoch [1/5], Train Loss: 2.0401, Train Acc: 32.80%, Val Loss: 1.7613, Val Acc: 53.31%


[I 2025-04-14 21:46:38,890] Trial 7 pruned. 


Epoch [1/5], Train Loss: 2.1119, Train Acc: 29.31%, Val Loss: 1.8266, Val Acc: 51.85%


[I 2025-04-14 21:51:42,476] Trial 8 pruned. 


Epoch [1/5], Train Loss: 1.4306, Train Acc: 64.98%, Val Loss: 0.8820, Val Acc: 83.90%


[I 2025-04-14 21:56:35,526] Trial 9 pruned. 


Epoch [1/5], Train Loss: 2.2477, Train Acc: 20.22%, Val Loss: 2.0563, Val Acc: 31.83%


In [None]:
client = MlflowClient()
experiment = client.get_experiment_by_name(EXPERIMENT_NAME)
best_run = client.search_runs(experiment_ids=experiment.experiment_id, order_by=["metrics.val_acc DESC"], max_results=3)[0]

In [None]:
# Print information for the single best run
run_id = best_run.info.run_id
metrics = best_run.data.metrics
best_params = best_run.data.params
status = best_run.info.status

print(f"\nRun : {run_id}")
print(f"Status: {status}")
print("Metrics:")
for key, value in metrics.items():
    print(f"  {key}: {value}")
print("Params:")
for key, value in best_params.items():
    print(f"  {key}: {value}")


Run : 2e9cd812af4c4257a4cb7986bb05a39b
Status: FINISHED
Metrics:
  train_loss: 0.017508617192950278
  train_acc: 99.57671957671958
  val_loss: 0.0653166193963477
  val_acc: 97.87654320987654
Params:
  lr: 2.5342001417793284e-05
  batch_size: 64
  unfreeze: 3
  weight_decay: 0.01


In [None]:
# Train the best model
def train_eval_best_model(best_params):
    lr = float(best_params['lr'])
    batch_size = int(best_params['batch_size'])
    unfreeze = int(best_params['unfreeze'])
    weight_decay = float(best_params['weight_decay'])

    device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

    # Load the data
    dataset, train_dataloader, val_dataloader, test_dataloader = load_data(batch_size)

    # Load the model
    model = load_vit(num_classes=len(dataset.classes), unfreeze=unfreeze).to(device)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    with mlflow.start_run(run_name="best_model"):
        # Log the parameters to mlflow
        mlflow.log_param("lr", lr)
        mlflow.log_param("batch_size", batch_size)
        mlflow.log_param("unfreeze", unfreeze)
        mlflow.log_param("weight_decay", weight_decay)

        epochs = 5

        for epoch in range(epochs):

            train_loss, train_acc = train(model, train_dataloader, criterion, optimizer, device)
            val_loss, val_acc = eval(model, val_dataloader, criterion, device)

            mlflow.log_metric("train_loss", train_loss, step=epoch)
            mlflow.log_metric("train_acc", train_acc, step=epoch)
            mlflow.log_metric("val_loss", val_loss, step=epoch)
            mlflow.log_metric("val_acc", val_acc, step=epoch)

            print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Check the test set
        test_loss, test_acc = eval(model, test_dataloader, criterion, device)
        mlflow.log_metric("test_loss", test_loss)
        mlflow.log_metric("test_acc", test_acc)

        # Log the best model
        mlflow.pytorch.log_model(model, "vit_eurosat_best_model")
        torch.save(model.state_dict(), "vit_eurosat_best_model.pth")
        mlflow.log_artifact("vit_eurosat_best_model.pth")

    return model, test_loss, test_acc

In [None]:
train_eval_best_model(best_params)

Epoch [1/5], Train Loss: 0.3763, Train Acc: 90.46%, Val Loss: 0.1282, Val Acc: 95.95%
Epoch [2/5], Train Loss: 0.0784, Train Acc: 97.57%, Val Loss: 0.1061, Val Acc: 96.57%
Epoch [3/5], Train Loss: 0.0439, Train Acc: 98.78%, Val Loss: 0.0849, Val Acc: 97.09%
Epoch [4/5], Train Loss: 0.0261, Train Acc: 99.35%, Val Loss: 0.1040, Val Acc: 96.79%
Epoch [5/5], Train Loss: 0.0147, Train Acc: 99.71%, Val Loss: 0.0800, Val Acc: 97.58%




(VisionTransformer(
   (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
   (encoder): Encoder(
     (dropout): Dropout(p=0.0, inplace=False)
     (layers): Sequential(
       (encoder_layer_0): EncoderBlock(
         (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
         (self_attention): MultiheadAttention(
           (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
         )
         (dropout): Dropout(p=0.0, inplace=False)
         (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
         (mlp): MLPBlock(
           (0): Linear(in_features=768, out_features=3072, bias=True)
           (1): GELU(approximate='none')
           (2): Dropout(p=0.0, inplace=False)
           (3): Linear(in_features=3072, out_features=768, bias=True)
           (4): Dropout(p=0.0, inplace=False)
         )
       )
       (encoder_layer_1): EncoderBlock(
         (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine