In [1]:
import os
import ray
from ray import tune
from ray.tune import Tuner, TuneConfig, with_resources
from ray.tune.schedulers import ASHAScheduler
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision.datasets import ImageFolder
import torch.nn.functional as F
from sklearn.metrics import balanced_accuracy_score

In [2]:
# Constants
EPOCHS = 10
CLASSES = 10  # Assume 10 classes for the StateFarm dataset

In [6]:
def define_model(use_gpu, num_classes=10):
    """
    Defines the pretrained ViT_B_16 model with a modified last linear layer and frozen base layers.
    """
    pretrained_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
    model = torchvision.models.vit_b_16(weights=pretrained_weights)
    
    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False

    # Modify the classifier for the target dataset
    # model.heads = nn.Linear(model.heads.in_features, CLASSES)
    model.heads = nn.Linear(in_features=768, out_features=num_classes)
    return model, pretrained_weights.transforms()

In [7]:
def get_data_loaders(transform):
    """
    Creates the train and validation dataloaders.
    """
    train_dir = "/home/sur06423/wacv_paper/wacv_paper/data/imbalanced_v2/train"
    val_dir = "/home/sur06423/wacv_paper/wacv_paper/data/imbalanced_v2/validation"
    
    trainset = ImageFolder(root=train_dir, transform=transform)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=1024, shuffle=True)
    valset = ImageFolder(root=val_dir, transform=transform)
    val_loader = torch.utils.data.DataLoader(valset, batch_size=1024, shuffle=True)

    return train_loader, val_loader

In [8]:
def calculate_balanced_accuracy(y_pred, y_true, num_classes=10):
    """
    Calculates the balanced accuracy score using PyTorch operations.
    (y_pred == c): Creates a boolean tensor where each element is True 
    if the predicted label equals class c, and False otherwise.

    (y_true == c): Creates another boolean tensor where each element is True 
    if the true label equals class c, and False otherwise.

    &: Performs a logical AND operation between the two boolean tensors. 
    The result is a tensor where each element is True only if both conditions 
    are met: the predicted label is class c, and the true label is also class c. 
    This effectively filters out the true positives for class c.

    .sum(): Sums up the True values in the resultant tensor, which corresponds
    to the count of true positive predictions for class c.

    Args:
        y_pred (torch.Tensor): Tensor of predicted class labels( No Logits & Probabilities, only labels).
        y_true (torch.Tensor): Tensor of true class labels.
        num_classes (int): Number of classes.

    Returns:
        float: The balanced accuracy score.
    """
    correct_per_class = torch.zeros(num_classes, device=y_pred.device)
    total_per_class = torch.zeros(num_classes, device=y_pred.device)

    for c in range(num_classes):
        # The number of true positive predictions for class c. 
        # True positives are instances that are correctly identified as 
        # belonging to class c by the classifier.
        true_positives = ((y_pred == c) & (y_true == c)).sum()
        # Condition Positive: total number of instances that actually belong to class c, 
        # regardless of whether they were correctly identified by the classifier or not.
        condition_positives = (y_true == c).sum()
        
        correct_per_class[c] = true_positives.float()
        total_per_class[c] = condition_positives.float()

    # .clamp(min=1) function ensures that no value in the total_per_class tensor is less than 1
    recall_per_class = correct_per_class / total_per_class.clamp(min=1)
    balanced_accuracy = recall_per_class.mean().item()  # Convert to Python scalar for compatibility

    return balanced_accuracy

In [9]:
# Define the Training Functions
def train_model(model, optimizer, train_loader, device):
    model.train()
    running_loss = 0.0 
    num_samples = 0
    all_predictions = []
    all_labels = []
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        num_samples += inputs.size(0)
        batch_predictions = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
        all_predictions.append(batch_predictions)
        all_labels.append(labels)

    train_loss = running_loss / num_samples
    train_balanced_accuracy = calculate_balanced_accuracy(torch.cat(all_predictions), torch.cat(all_labels))
    return train_loss, train_balanced_accuracy

In [10]:
# Define the Validation Functions
def validate_model(model, val_loader, device):
    model.eval()
    running_loss = 0.0 
    num_samples = 0
    all_predictions = []
    all_labels = []
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels)
        running_loss += loss.item() * inputs.size(0)
        num_samples += inputs.size(0)
        batch_predictions = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
        all_predictions.append(batch_predictions)
        all_labels.append(labels)

    val_loss = running_loss / num_samples
    val_balanced_accuracy = calculate_balanced_accuracy(torch.cat(all_predictions), torch.cat(all_labels))
    return val_loss, val_balanced_accuracy

In [11]:
class TrainViT(tune.Trainable):
    def setup(self, config):
        self.device = torch.device("cuda" if torch.cuda.is_available() and config.get("use_gpu", False) else "cpu")
        self.model, transforms = define_model(config.get("use_gpu", False))
        self.model.to(self.device)
        self.train_loader, self.val_loader = get_data_loaders(transforms)
        self.optimizer = optim.SGD(self.model.parameters(), lr=config["lr"], momentum=config["momentum"])

    def step(self):
        train_loss, train_acc = train_model(self.model, self.optimizer, self.train_loader, self.device)
        val_loss, val_acc = validate_model(self.model, self.val_loader, self.device)
        return {"loss": train_loss, "accuracy": train_acc, "val_loss": val_loss, "val_acc": val_acc}

    def save_checkpoint(self, checkpoint_dir):
        path = os.path.join(checkpoint_dir, "checkpoint")
        torch.save({
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
        }, path)
        return checkpoint_dir

    def load_checkpoint(self, checkpoint_dir):
        checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

In [12]:
N_TRIALS = 8
# ASHA Scheduler for early stopping
scheduler = ASHAScheduler(
    metric="val_acc",
    mode="max",
    max_t=EPOCHS,
    grace_period=5,
    reduction_factor=2
)

# Configuration for hyperparameters
config = {
    "lr": tune.loguniform(1e-4, 1e-1),
    "momentum": tune.uniform(0.8, 0.99),
    "use_gpu": True  # This can be dynamically adjusted if some trials should not use a GPU
}

# Setting up the Tuner with dynamic resource allocation
trainable_with_resources = with_resources(
    TrainViT,
    resources=lambda config: {"gpu": 1, "cpu": 2} if config.get("use_gpu", False) else {"cpu": 2}
)

tune_config = TuneConfig(
    scheduler=scheduler,
    num_samples=N_TRIALS,
    max_concurrent_trials=4  # Adjust based on the number of available GPUs
)

""" 
checkpoint_config = ray.train.CheckpointConfig(num_to_keep: int | None = None, 
                                               checkpoint_score_attribute: str | None = None, 
                                               checkpoint_score_order: str | None = 'max', 
                                               checkpoint_frequency: int | None = 0, 
                                               checkpoint_at_end: bool | None = None, 
                                               _checkpoint_keep_all_ranks: bool | None = 'DEPRECATED', 
                                               _checkpoint_upload_from_workers: bool | None = 'DEPRECATED')
"""

run_config = ray.train.RunConfig(name="Dynamic_Trial_Exp_2",
                                 storage_path="/home/sur06423/wacv_paper/wacv_paper/ray_results",
                                 stop={"training_iteration": 5},
                                 checkpoint_config=ray.train.CheckpointConfig(checkpoint_frequency=2, checkpoint_at_end=True),
                                 # checpoint_config = checkpoint_config
)


In [14]:
# Initialize Ray
ray.shutdown()
ray.init(num_cpus=24, num_gpus=4, include_dashboard=True)  # Explicitly set the number of GPUs

2024-10-23 15:46:57,984	INFO worker.py:1777 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m


0,1
Python version:,3.10.12
Ray version:,2.37.0
Dashboard:,http://127.0.0.1:8265


In [15]:
print(ray.available_resources())

{'CPU': 24.0, 'node:__internal_head__': 1.0, 'memory': 359288880128.0, 'object_store_memory': 158266662912.0, 'GPU': 4.0, 'node:10.56.7.46': 1.0}


In [16]:
# Define the directories to be added to LD_LIBRARY_PATH
library_paths = [
    "/usr/lib/xorg-nvidia-525.116.04/lib/x86_64-linux-gnu",
    "/usr/lib/xorg/lib/x86_64-linux-gnu",
    "/usr/lib/xorg-nvidia-535.113.01/lib/x86_64-linux-gnu"
]

# Current LD_LIBRARY_PATH from the environment
current_ld_library_path = os.environ.get('LD_LIBRARY_PATH', '')

# Adding each path only if it is not already in the LD_LIBRARY_PATH
new_paths = [path for path in library_paths if path not in current_ld_library_path]

# Join all new paths with the existing LD_LIBRARY_PATH
os.environ['LD_LIBRARY_PATH'] = ':'.join(new_paths + [current_ld_library_path])

# Verify the update
print("Updated LD_LIBRARY_PATH:")
print(os.environ['LD_LIBRARY_PATH'])

Updated LD_LIBRARY_PATH:
/usr/lib/xorg-nvidia-525.116.04/lib/x86_64-linux-gnu:/usr/lib/xorg/lib/x86_64-linux-gnu:/usr/lib/xorg-nvidia-535.113.01/lib/x86_64-linux-gnu:


In [17]:
# Create the Tuner and run the trials
tuner = Tuner(trainable_with_resources,
              param_space=config, 
              tune_config=tune_config,
              run_config=run_config 
              )
results = tuner.fit()

0,1
Current time:,2024-10-23 15:59:01
Running for:,00:11:46.20
Memory:,21.9/503.4 GiB

Trial name,status,loc,lr,momentum,iter,total time (s),loss,accuracy,val_loss
TrainViT_50386_00000,TERMINATED,10.56.7.46:805349,0.0212223,0.930988,5,342.135,0.768011,0.283573,5.61011
TrainViT_50386_00001,TERMINATED,10.56.7.46:805350,0.000276973,0.804294,5,342.585,1.65239,0.100069,2.53833
TrainViT_50386_00002,TERMINATED,10.56.7.46:805351,0.00936625,0.827588,5,342.794,0.93142,0.215273,3.51941
TrainViT_50386_00003,TERMINATED,10.56.7.46:805352,0.000781866,0.969549,5,339.888,1.39883,0.101845,3.8463
TrainViT_50386_00004,TERMINATED,10.56.7.46:807117,0.000118637,0.809959,5,339.75,1.7801,0.101455,2.42752
TrainViT_50386_00005,TERMINATED,10.56.7.46:807193,0.00165991,0.926366,5,341.2,1.36349,0.147419,4.04002
TrainViT_50386_00006,TERMINATED,10.56.7.46:807276,0.037852,0.800948,5,339.58,0.593183,0.316232,3.45825
TrainViT_50386_00007,TERMINATED,10.56.7.46:807279,0.000301147,0.891682,5,335.766,1.50962,0.1,2.87732


[36m(TrainViT pid=805349)[0m   return F.conv2d(input, weight, bias, self.stride,
[36m(TrainViT pid=805352)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/sur06423/wacv_paper/wacv_paper/ray_results/Dynamic_Trial_Exp_2/TrainViT_50386_00003_3_lr=0.0008,momentum=0.9695_2024-10-23_15-47-15/checkpoint_000000)
[36m(TrainViT pid=805350)[0m   return F.conv2d(input, weight, bias, self.stride,[32m [repeated 3x across cluster][0m
[36m(TrainViT pid=805352)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/sur06423/wacv_paper/wacv_paper/ray_results/Dynamic_Trial_Exp_2/TrainViT_50386_00003_3_lr=0.0008,momentum=0.9695_2024-10-23_15-47-15/checkpoint_000001)[32m [repeated 4x across cluster][0m
[36m(TrainViT pid=805352)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/sur06423/wacv_paper/wacv_paper/ray_results/Dynamic_Trial_Exp_2/TrainViT_50386_00003_3_lr=0.0008,momentum=0.9695_2024-10-23_15-47-15/che

In [18]:
print(results)

ResultGrid<[
  Result(
    metrics={'loss': 0.768010599394889, 'accuracy': 0.2835727035999298, 'val_loss': 5.610112845118065, 'val_acc': 0.2902044951915741},
    path='/home/sur06423/wacv_paper/wacv_paper/ray_results/Dynamic_Trial_Exp_2/TrainViT_50386_00000_0_lr=0.0212,momentum=0.9310_2024-10-23_15-47-15',
    filesystem='local',
    checkpoint=Checkpoint(filesystem=local, path=/home/sur06423/wacv_paper/wacv_paper/ray_results/Dynamic_Trial_Exp_2/TrainViT_50386_00000_0_lr=0.0212,momentum=0.9310_2024-10-23_15-47-15/checkpoint_000002)
  ),
  Result(
    metrics={'loss': 1.6523921482581567, 'accuracy': 0.10006870329380035, 'val_loss': 2.5383329842235187, 'val_acc': 0.10000000149011612},
    path='/home/sur06423/wacv_paper/wacv_paper/ray_results/Dynamic_Trial_Exp_2/TrainViT_50386_00001_1_lr=0.0003,momentum=0.8043_2024-10-23_15-47-15',
    filesystem='local',
    checkpoint=Checkpoint(filesystem=local, path=/home/sur06423/wacv_paper/wacv_paper/ray_results/Dynamic_Trial_Exp_2/TrainViT_50386_0

In [19]:
best_result = results.get_best_result(metric="val_acc", mode="max")
print("Best trial config: {}".format(best_result.config))
print("Best trial final validation accuracy: {}".format(best_result.metrics["val_acc"]))



Best trial config: {'lr': 0.037852035384788575, 'momentum': 0.8009478620477835, 'use_gpu': True}
Best trial final validation accuracy: 0.31241798400878906


In [20]:
ray.shutdown()