In [15]:
!pip install -U hyperopt



In [2]:
from functools import partial
import os
from os import path
import tempfile
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from ray import tune
from ray import train
from ray.train import Checkpoint, get_checkpoint
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.hyperopt import HyperOptSearch
import ray.cloudpickle as pickle

from torchvision.datasets import ImageFolder
from sklearn.metrics import balanced_accuracy_score

In [3]:
# Constants
EPOCHS = 20
N_TRIALS = 20
CLASSES = 10  # StateFarm has 10 classes

In [4]:
def define_model():
    """
    Defines the pretrained ViT_B_16 model with a modified last linear layer and frozen base layers.
    The model-specific transforms are also obtained from the pretrained weights.

    Returns:
        nn.Module: A Vision Transformer model with a modified last layer for 10 classes.
        Callable: Data transforms specific to the pretrained model.
    """
    pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
    pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights)
    
    # Freeze the base parameters
    for parameter in pretrained_vit.parameters():
        parameter.requires_grad = False

    # Modify the final layer for 10 classes (StateFarm)
    pretrained_vit.heads = nn.Linear(in_features=768, out_features=CLASSES)
    
    # Get the data transforms from the pretrained model
    pretrained_vit_transforms = pretrained_vit_weights.transforms()

    return pretrained_vit, pretrained_vit_transforms

In [5]:
def get_data_loaders(transform):
    """
    Creates the train and validation datasets for the StateFarm dataset using ImageFolder
    and the provided data transforms.

    Args:
        transform: The data transformations to apply to the dataset images.

    Returns:
        Dataloader, Dataloader: Dataloaders for training and validation sets.
    """
    train_dir = "/home/sur06423/wacv_paper/wacv_paper/data/imbalanced_v2/train"
    val_dir = "/home/sur06423/wacv_paper/wacv_paper/data/imbalanced_v2/validation"

    # Apply the model-specific transforms to the datasets
    trainset = ImageFolder(root=train_dir, transform=transform)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=1024, shuffle=True)
    valset = ImageFolder(root=val_dir, transform=transform)
    val_loader = torch.utils.data.DataLoader(valset, batch_size=1024, shuffle=True)

    return train_loader, val_loader

In [6]:
from sklearn.metrics import balanced_accuracy_score

def calculate_balanced_accuracy(y_pred, y_true, num_classes):
    """
    Calculates the balanced accuracy score using PyTorch operations.
    (y_pred == c): Creates a boolean tensor where each element is True 
    if the predicted label equals class c, and False otherwise.

    (y_true == c): Creates another boolean tensor where each element is True 
    if the true label equals class c, and False otherwise.

    &: Performs a logical AND operation between the two boolean tensors. 
    The result is a tensor where each element is True only if both conditions 
    are met: the predicted label is class c, and the true label is also class c. 
    This effectively filters out the true positives for class c.

    .sum(): Sums up the True values in the resultant tensor, which corresponds
    to the count of true positive predictions for class c.

    Args:
        y_pred (torch.Tensor): Tensor of predicted class labels( No Logits & Probabilities, only labels).
        y_true (torch.Tensor): Tensor of true class labels.
        num_classes (int): Number of classes.

    Returns:
        float: The balanced accuracy score.
    """
    correct_per_class = torch.zeros(num_classes, device=y_pred.device)
    total_per_class = torch.zeros(num_classes, device=y_pred.device)

    for c in range(num_classes):
        # The number of true positive predictions for class c. 
        # True positives are instances that are correctly identified as 
        # belonging to class c by the classifier.
        true_positives = ((y_pred == c) & (y_true == c)).sum()
        # Condition Positive: total number of instances that actually belong to class c, 
        # regardless of whether they were correctly identified by the classifier or not.
        condition_positives = (y_true == c).sum()
        
        correct_per_class[c] = true_positives.float()
        total_per_class[c] = condition_positives.float()

    # .clamp(min=1) function ensures that no value in the total_per_class tensor is less than 1
    recall_per_class = correct_per_class / total_per_class.clamp(min=1)
    balanced_accuracy = recall_per_class.mean().item()  # Convert to Python scalar for compatibility

    return balanced_accuracy

# Define the Training & Evaluation Functions
def train(model, optimizer, train_loader, device=None):
    device = device or torch.device("cpu")
    model.train()
    running_loss, num_samples = 0.0, 0
    y_pred_all = []
    y_all = []
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * output.size(0)
        num_samples += output.size(0)
        y_pred_class = torch.argmax(torch.softmax(output, dim=1), dim=1)
        y_pred_all.append(y_pred_class)
        y_all.append(target)

    t_average_loss = running_loss / num_samples
    # t_balanced_accuracy = calculate_balanced_accuracy(torch.cat(y_pred_all), torch.cat(y_all), CLASSES)
    t_balanced_accuracy = balanced_accuracy_score(torch.cat(y_all).cpu().numpy(), 
                                              torch.cat(y_pred_all).cpu().numpy())

    return t_balanced_accuracy, t_average_loss

def test(model, optimizer, train_loader, device=None):
    device = device or torch.device("cpu")
    model.eval()
    running_loss, num_samples = 0.0, 0
    y_pred_all = []
    y_all = []
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = F.cross_entropy(output, target)
            running_loss += loss.item() * output.size(0)
            num_samples += output.size(0)
            y_pred_class = torch.argmax(torch.softmax(output, dim=1), dim=1)
            y_pred_all.append(y_pred_class)
            y_all.append(target)

    e_average_loss = running_loss / num_samples
    # e_balanced_accuracy = calculate_balanced_accuracy(torch.cat(y_pred_all), torch.cat(y_all), CLASSES)
    e_balanced_accuracy = balanced_accuracy_score(torch.cat(y_all).cpu().numpy(), 
                                              torch.cat(y_pred_all).cpu().numpy())
    return e_balanced_accuracy, e_average_loss

In [7]:
# Define the Trainable class for Ray Tune     
class TrainViT(tune.Trainable):
    
    def setup(self, config):
        # detect if cuda is availalbe as ray will assign GPUs if available and configured
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        pretrained_vit, pretrained_vit_transforms = define_model() 
        
        self.train_loader, self.test_loader = get_data_loaders(pretrained_vit_transforms)
        self.model = pretrained_vit.to(self.device)
        
        #setup the optimiser (try Adam instead and change parameters we are tuning)
        self.optimizer = optim.SGD(
            self.model.parameters(),
            lr=config.get("lr", 0.01),
            momentum=config.get("momentum", 0.9))
                
    def step(self):
        train_balanced_accuracy, train_average_loss = train(self.model, self.optimizer, self.train_loader, device=self.device)
        test_balanced_accuracy, test_average_loss = test(self.model, self.test_loader, self.device)  
        return {"train_bal_accuracy": train_balanced_accuracy, "train_loss": train_average_loss, "test_bal_accuracy": test_balanced_accuracy, "test_loss": test_average_loss}
    
    def save_checkpoint(self, checkpoint_dir):
        # checkpoint_path = path.join(checkpoint_dir, "model.pth")
        checkpoint_data = {
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
        }
        torch.save(checkpoint_data, path.join(checkpoint_dir, "model.pth"))
        return checkpoint_dir   
    
    def load_checkpoint(self, checkpoint_dir):
        checkpoint_path = path.join(checkpoint_dir, "model.pth")
        self.model.load_state_dict(torch.load(checkpoint_path))

In [8]:
# Define the scheduler
asha = ASHAScheduler(
        time_attr='training_iteration',
        metric="test_balanced_accuracy",
        mode="max",
        max_t=100,
        grace_period=10,
        reduction_factor=3,
        brackets=1
    )

In [9]:
import ray
ray.shutdown()
ray.init(num_cpus=24, num_gpus=0, include_dashboard=True)

2024-10-21 03:55:13,152	INFO worker.py:1777 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m


0,1
Python version:,3.10.12
Ray version:,2.37.0
Dashboard:,http://127.0.0.1:8265


In [10]:
config={
    "lr": tune.uniform(0.001, 0.1),
    "momentum": tune.uniform(0.1, 0.9),
}

asha = ASHAScheduler(
        time_attr='training_iteration',
        metric="test_balanced_accuracy",
        mode="max",
        max_t=100,
        grace_period=10,
        reduction_factor=3,
        brackets=1
    )

analysis = tune.run(
    TrainViT,
    storage_path="/home/sur06423/wacv_paper/wacv_paper/ray_results",
    resources_per_trial={
        "cpu": 2,
        "gpu": 1
    },
    num_samples=10,
    checkpoint_at_end=True,
    checkpoint_freq=10,
    # keep_checkpoints_num=3,
    scheduler=asha,
#     stop={
#         "mean_accuracy": 0.95,
#         "training_iteration": 100,
#     },
    config=config)

2024-10-21 03:55:48,454	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


0,1
Current time:,2024-10-21 04:15:47
Running for:,00:19:58.79
Memory:,18.3/503.4 GiB

Trial name,# failures,error file
TrainViT_97e30_00000,1,"/tmp/ray/session_2024-10-21_03-55-09_785944_447631/artifacts/2024-10-21_03-55-48/TrainViT_2024-10-21_03-55-48/driver_artifacts/TrainViT_97e30_00000_0_lr=0.0777,momentum=0.5816_2024-10-21_03-55-48/error.txt"
TrainViT_97e30_00001,1,"/tmp/ray/session_2024-10-21_03-55-09_785944_447631/artifacts/2024-10-21_03-55-48/TrainViT_2024-10-21_03-55-48/driver_artifacts/TrainViT_97e30_00001_1_lr=0.0011,momentum=0.1920_2024-10-21_03-55-48/error.txt"
TrainViT_97e30_00002,1,"/tmp/ray/session_2024-10-21_03-55-09_785944_447631/artifacts/2024-10-21_03-55-48/TrainViT_2024-10-21_03-55-48/driver_artifacts/TrainViT_97e30_00002_2_lr=0.0987,momentum=0.1433_2024-10-21_03-55-48/error.txt"
TrainViT_97e30_00003,1,"/tmp/ray/session_2024-10-21_03-55-09_785944_447631/artifacts/2024-10-21_03-55-48/TrainViT_2024-10-21_03-55-48/driver_artifacts/TrainViT_97e30_00003_3_lr=0.0539,momentum=0.3793_2024-10-21_03-55-48/error.txt"
TrainViT_97e30_00004,1,"/tmp/ray/session_2024-10-21_03-55-09_785944_447631/artifacts/2024-10-21_03-55-48/TrainViT_2024-10-21_03-55-48/driver_artifacts/TrainViT_97e30_00004_4_lr=0.0904,momentum=0.6155_2024-10-21_03-55-48/error.txt"
TrainViT_97e30_00005,1,"/tmp/ray/session_2024-10-21_03-55-09_785944_447631/artifacts/2024-10-21_03-55-48/TrainViT_2024-10-21_03-55-48/driver_artifacts/TrainViT_97e30_00005_5_lr=0.0804,momentum=0.3498_2024-10-21_03-55-48/error.txt"
TrainViT_97e30_00006,1,"/tmp/ray/session_2024-10-21_03-55-09_785944_447631/artifacts/2024-10-21_03-55-48/TrainViT_2024-10-21_03-55-48/driver_artifacts/TrainViT_97e30_00006_6_lr=0.0190,momentum=0.7658_2024-10-21_03-55-48/error.txt"
TrainViT_97e30_00007,1,"/tmp/ray/session_2024-10-21_03-55-09_785944_447631/artifacts/2024-10-21_03-55-48/TrainViT_2024-10-21_03-55-48/driver_artifacts/TrainViT_97e30_00007_7_lr=0.0835,momentum=0.1120_2024-10-21_03-55-48/error.txt"
TrainViT_97e30_00008,1,"/tmp/ray/session_2024-10-21_03-55-09_785944_447631/artifacts/2024-10-21_03-55-48/TrainViT_2024-10-21_03-55-48/driver_artifacts/TrainViT_97e30_00008_8_lr=0.0498,momentum=0.5000_2024-10-21_03-55-48/error.txt"
TrainViT_97e30_00009,1,"/tmp/ray/session_2024-10-21_03-55-09_785944_447631/artifacts/2024-10-21_03-55-48/TrainViT_2024-10-21_03-55-48/driver_artifacts/TrainViT_97e30_00009_9_lr=0.0403,momentum=0.1523_2024-10-21_03-55-48/error.txt"

Trial name,status,loc,lr,momentum
TrainViT_97e30_00000,ERROR,10.56.7.46:450450,0.0776726,0.581605
TrainViT_97e30_00001,ERROR,10.56.7.46:450451,0.00108159,0.19198
TrainViT_97e30_00002,ERROR,10.56.7.46:450452,0.0987402,0.143346
TrainViT_97e30_00003,ERROR,10.56.7.46:450455,0.0538677,0.379301
TrainViT_97e30_00004,ERROR,10.56.7.46:450454,0.0904447,0.615472
TrainViT_97e30_00005,ERROR,10.56.7.46:450453,0.0804162,0.349808
TrainViT_97e30_00006,ERROR,10.56.7.46:450456,0.0190002,0.765757
TrainViT_97e30_00007,ERROR,10.56.7.46:450457,0.0835095,0.112044
TrainViT_97e30_00008,ERROR,10.56.7.46:450458,0.0497818,0.499975
TrainViT_97e30_00009,ERROR,10.56.7.46:450459,0.0403021,0.152309


2024-10-21 04:13:55,593	ERROR tune_controller.py:1331 -- Trial task failed for trial TrainViT_97e30_00005
Traceback (most recent call last):
  File "/home/sur06423/miniconda3/envs/deepl/lib/python3.10/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/home/sur06423/miniconda3/envs/deepl/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/home/sur06423/miniconda3/envs/deepl/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/home/sur06423/miniconda3/envs/deepl/lib/python3.10/site-packages/ray/_private/worker.py", line 2691, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/home/sur06423/miniconda3/envs/deepl/lib/python3.10/site-packages/ray/_private/worker.py", line 871, in get_objects
    raise value.as_instanc

Trial name
TrainViT_97e30_00000
TrainViT_97e30_00001
TrainViT_97e30_00002
TrainViT_97e30_00003
TrainViT_97e30_00004
TrainViT_97e30_00005
TrainViT_97e30_00006
TrainViT_97e30_00007
TrainViT_97e30_00008
TrainViT_97e30_00009


2024-10-21 04:14:01,826	ERROR tune_controller.py:1331 -- Trial task failed for trial TrainViT_97e30_00007
Traceback (most recent call last):
  File "/home/sur06423/miniconda3/envs/deepl/lib/python3.10/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/home/sur06423/miniconda3/envs/deepl/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/home/sur06423/miniconda3/envs/deepl/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/home/sur06423/miniconda3/envs/deepl/lib/python3.10/site-packages/ray/_private/worker.py", line 2691, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/home/sur06423/miniconda3/envs/deepl/lib/python3.10/site-packages/ray/_private/worker.py", line 871, in get_objects
    raise value.as_instanc

TuneError: ('Trials did not complete', [TrainViT_97e30_00000, TrainViT_97e30_00001, TrainViT_97e30_00002, TrainViT_97e30_00003, TrainViT_97e30_00004, TrainViT_97e30_00005, TrainViT_97e30_00006, TrainViT_97e30_00007, TrainViT_97e30_00008, TrainViT_97e30_00009])

In [None]:
print("Best config is:", results_grid.get_best_config(metric="test_balanced_accuracy", mode='max'))

# Configuration and Running the Hyperparameter search

In [None]:
config = {
    "lr": tune.loguniform(1e-4, 1e-1),
    "momentum": tune.uniform(0.8, 0.99)
}

scheduler = ASHAScheduler(
    metric="val_acc",
    mode="max",
    max_t=100,
    grace_period=5,
    reduction_factor=2,
    brackets=3
)

analysis = tune.run(
    TrainViT,
    resources_per_trial={"cpu": 2, "gpu": 1},
    num_samples=2,
    scheduler=scheduler,
    config=config
)


# To specifies the max number of trials to run concurrently, set max_concurrent_trials in TuneConfig.

# Note that actual parallelism can be less than max_concurrent_trials and will be determined 
# by how many trials can fit in the cluster at once (i.e., if you have a trial that requires 16 GPUs, 
# your cluster has 32 GPUs, and max_concurrent_trials=10, the Tuner can only run 2 trials concurrently).

from ray.tune import TuneConfig

config = TuneConfig(
    # ...
    num_samples=100,
    max_concurrent_trials=10,
)

# If you have 4 CPUs on your machine, this will run 2 concurrent trials at a time.
trainable_with_resources = tune.with_resources(trainable, {"cpu": 2})
tuner = tune.Tuner(
    trainable_with_resources,
    tune_config=tune.TuneConfig(num_samples=10)
)
results = tuner.fit()

# If you have 4 CPUs on your machine, this will run 1 trial at a time.
trainable_with_resources = tune.with_resources(trainable, {"cpu": 4})
tuner = tune.Tuner(
    trainable_with_resources,
    tune_config=tune.TuneConfig(num_samples=10)
)
results = tuner.fit()

# Fractional values are also supported, (i.e., {"cpu": 0.5}).
# If you have 4 CPUs on your machine, this will run 8 concurrent trials at a time.
trainable_with_resources = tune.with_resources(trainable, {"cpu": 0.5})
tuner = tune.Tuner(
    trainable_with_resources,
    tune_config=tune.TuneConfig(num_samples=10)
)
results = tuner.fit()

# Custom resource allocation via lambda functions are also supported.
# If you want to allocate gpu resources to trials based on a setting in your config
trainable_with_resources = tune.with_resources(trainable,
    resources=lambda spec: {"gpu": 1} if spec.config.use_gpu else {"gpu": 0})
tuner = tune.Tuner(
    trainable_with_resources,
    tune_config=tune.TuneConfig(num_samples=10)
)
results = tuner.fit()