In [1]:
import os
import ray
from ray import tune
from ray.tune import Tuner, TuneConfig, with_resources
from ray.tune.schedulers import ASHAScheduler
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision.datasets import ImageFolder
import torch.nn.functional as F
from sklearn.metrics import balanced_accuracy_score

In [2]:
# Constants
EPOCHS = 10
CLASSES = 10  # Assume 10 classes for the StateFarm dataset

In [3]:
def define_model(use_gpu):
    """
    Defines the pretrained ViT_B_16 model with a modified last linear layer and frozen base layers.
    """
    pretrained_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
    model = torchvision.models.vit_b_16(weights=pretrained_weights)
    
    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False

    # Modify the classifier for the target dataset
    # model.heads = nn.Linear(model.heads.in_features, CLASSES)
    model.heads = nn.Linear(in_features=768, out_features=CLASSES)
    return model, pretrained_weights.transforms()

In [4]:
def get_data_loaders(transform):
    """
    Creates the train and validation dataloaders.
    """
    train_dir = "/home/sur06423/wacv_paper/wacv_paper/data/imbalanced_v2/train"
    val_dir = "/home/sur06423/wacv_paper/wacv_paper/data/imbalanced_v2/validation"
    
    trainset = ImageFolder(root=train_dir, transform=transform)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=1024, shuffle=True)
    valset = ImageFolder(root=val_dir, transform=transform)
    val_loader = torch.utils.data.DataLoader(valset, batch_size=1024, shuffle=True)

    return train_loader, val_loader

In [5]:
def train_model(model, optimizer, train_loader, device):
    model.train()
    total_loss = 0
    total_correct = 0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        total_correct += (predicted == labels).sum().item()
    return total_loss / len(train_loader.dataset), total_correct / len(train_loader.dataset)

def validate_model(model, val_loader, device):
    model.eval()
    total_loss = 0
    total_correct = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, labels)
            total_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total_correct += (predicted == labels).sum().item()
    return total_loss / len(val_loader.dataset), total_correct / len(val_loader.dataset)


In [6]:
class TrainViT(tune.Trainable):
    def setup(self, config):
        self.device = torch.device("cuda" if torch.cuda.is_available() and config.get("use_gpu", False) else "cpu")
        self.model, transforms = define_model(config.get("use_gpu", False))
        self.model.to(self.device)
        self.train_loader, self.val_loader = get_data_loaders(transforms)
        self.optimizer = optim.SGD(self.model.parameters(), lr=config["lr"], momentum=config["momentum"])

    def step(self):
        train_loss, train_acc = train_model(self.model, self.optimizer, self.train_loader, self.device)
        val_loss, val_acc = validate_model(self.model, self.val_loader, self.device)
        return {"loss": train_loss, "accuracy": train_acc, "val_loss": val_loss, "val_acc": val_acc}

    def save_checkpoint(self, checkpoint_dir):
        path = os.path.join(checkpoint_dir, "checkpoint")
        torch.save({
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
        }, path)
        return checkpoint_dir

    def load_checkpoint(self, checkpoint_dir):
        checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

In [7]:
N_TRIALS = 8
# ASHA Scheduler for early stopping
scheduler = ASHAScheduler(
    metric="val_acc",
    mode="max",
    max_t=EPOCHS,
    grace_period=5,
    reduction_factor=2
)

# Configuration for hyperparameters
config = {
    "lr": tune.loguniform(1e-4, 1e-1),
    "momentum": tune.uniform(0.8, 0.99),
    "use_gpu": True  # This can be dynamically adjusted if some trials should not use a GPU
}

# Setting up the Tuner with dynamic resource allocation
trainable_with_resources = with_resources(
    TrainViT,
    resources=lambda config: {"gpu": 1, "cpu": 2} if config.get("use_gpu", False) else {"cpu": 2}
)

tune_config = TuneConfig(
    num_samples=N_TRIALS,
    max_concurrent_trials=4  # Adjust based on the number of available GPUs
)

""" 
checkpoint_config = ray.train.CheckpointConfig(num_to_keep: int | None = None, 
                                               checkpoint_score_attribute: str | None = None, 
                                               checkpoint_score_order: str | None = 'max', 
                                               checkpoint_frequency: int | None = 0, 
                                               checkpoint_at_end: bool | None = None, 
                                               _checkpoint_keep_all_ranks: bool | None = 'DEPRECATED', 
                                               _checkpoint_upload_from_workers: bool | None = 'DEPRECATED')
"""

run_config = ray.train.RunConfig(name="Dynamic_Trial_Exp_1",
                                 storage_path="/home/sur06423/wacv_paper/wacv_paper/ray_results",
                                 stop={"training_iteration": 5},
                                 checkpoint_config=ray.train.CheckpointConfig(checkpoint_frequency=2, checkpoint_at_end=True),
                                 # checpoint_config = checkpoint_config
)


In [8]:
# Initialize Ray
ray.shutdown()
ray.init(num_cpus=24, num_gpus=4, include_dashboard=True, dashboard_port=8267)  # Explicitly set the number of GPUs

2024-10-23 13:55:26,504	INFO worker.py:1777 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8267 [39m[22m


0,1
Python version:,3.10.12
Ray version:,2.37.0
Dashboard:,http://127.0.0.1:8267


In [9]:
print(ray.available_resources())

{'CPU': 24.0, 'node:__internal_head__': 1.0, 'memory': 362998148096.0, 'object_store_memory': 159856349184.0, 'GPU': 4.0, 'node:10.56.7.46': 1.0}


In [10]:
# Define the directories to be added to LD_LIBRARY_PATH
library_paths = [
    "/usr/lib/xorg-nvidia-525.116.04/lib/x86_64-linux-gnu",
    "/usr/lib/xorg/lib/x86_64-linux-gnu",
    "/usr/lib/xorg-nvidia-535.113.01/lib/x86_64-linux-gnu"
]

# Current LD_LIBRARY_PATH from the environment
current_ld_library_path = os.environ.get('LD_LIBRARY_PATH', '')

# Adding each path only if it is not already in the LD_LIBRARY_PATH
new_paths = [path for path in library_paths if path not in current_ld_library_path]

# Join all new paths with the existing LD_LIBRARY_PATH
os.environ['LD_LIBRARY_PATH'] = ':'.join(new_paths + [current_ld_library_path])

# Verify the update
print("Updated LD_LIBRARY_PATH:")
print(os.environ['LD_LIBRARY_PATH'])

Updated LD_LIBRARY_PATH:
/usr/lib/xorg-nvidia-525.116.04/lib/x86_64-linux-gnu:/usr/lib/xorg/lib/x86_64-linux-gnu:/usr/lib/xorg-nvidia-535.113.01/lib/x86_64-linux-gnu:


In [11]:
# Create the Tuner and run the trials
tuner = Tuner(trainable_with_resources,
              param_space=config, 
              tune_config=tune_config,
              run_config=run_config 
              )
results = tuner.fit()

[36m(TrainViT pid=777824)[0m   return F.conv2d(input, weight, bias, self.stride,
[36m(TrainViT pid=777822)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/sur06423/wacv_paper/wacv_paper/ray_results/Dynamic_Trial_Exp_1/TrainViT_b8372_00001_1_lr=0.0066,momentum=0.9658_2024-10-23_13-55-38/checkpoint_000000)
[36m(TrainViT pid=777822)[0m   return F.conv2d(input, weight, bias, self.stride,[32m [repeated 3x across cluster][0m
[36m(TrainViT pid=777822)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/sur06423/wacv_paper/wacv_paper/ray_results/Dynamic_Trial_Exp_1/TrainViT_b8372_00001_1_lr=0.0066,momentum=0.9658_2024-10-23_13-55-38/checkpoint_000001)[32m [repeated 4x across cluster][0m
[36m(TrainViT pid=777822)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/sur06423/wacv_paper/wacv_paper/ray_results/Dynamic_Trial_Exp_1/TrainViT_b8372_00001_1_lr=0.0066,momentum=0.9658_2024-10-23_13-55-38/che

In [13]:
print(results)

ResultGrid<[
  Result(
    metrics={'loss': 1.2158935063617233, 'accuracy': 0.6403351286654698, 'val_loss': 5.176045616798537, 'val_acc': 0.2630796670630202},
    path='/home/sur06423/wacv_paper/wacv_paper/ray_results/Dynamic_Trial_Exp_1/TrainViT_b8372_00000_0_lr=0.0043,momentum=0.9544_2024-10-23_13-55-38',
    filesystem='local',
    checkpoint=Checkpoint(filesystem=local, path=/home/sur06423/wacv_paper/wacv_paper/ray_results/Dynamic_Trial_Exp_1/TrainViT_b8372_00000_0_lr=0.0043,momentum=0.9544_2024-10-23_13-55-38/checkpoint_000002)
  ),
  Result(
    metrics={'loss': 1.039598422447127, 'accuracy': 0.7489527229204069, 'val_loss': 6.2236026447536545, 'val_acc': 0.23573127229488705},
    path='/home/sur06423/wacv_paper/wacv_paper/ray_results/Dynamic_Trial_Exp_1/TrainViT_b8372_00001_1_lr=0.0066,momentum=0.9658_2024-10-23_13-55-38',
    filesystem='local',
    checkpoint=Checkpoint(filesystem=local, path=/home/sur06423/wacv_paper/wacv_paper/ray_results/Dynamic_Trial_Exp_1/TrainViT_b8372_00

In [14]:
best_result = results.get_best_result(metric="val_acc", mode="max")
print("Best trial config: {}".format(best_result.config))
print("Best trial final validation accuracy: {}".format(best_result.metrics["val_acc"]))



Best trial config: {'lr': 0.02947919852342335, 'momentum': 0.8730365265780353, 'use_gpu': True}
Best trial final validation accuracy: 0.33561236623067775
