In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset
from sklearn.model_selection import StratifiedKFold
import pandas as pd
from PIL import Image
import os
from glob import glob
from transformers import get_cosine_schedule_with_warmup, ViTForImageClassification

import lightning as L
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import MLFlowLogger
import ray
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
import ray.train
import ray.train.lightning


import mlflow
from torchmetrics.classification import BinaryAccuracy, BinaryF1Score, BinaryAUROC
import numpy as np

In [2]:
# ray.init()
mlflow.set_tracking_uri("http://127.0.0.1:8080")

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.annotations = csv_file
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.annotations.iloc[idx, 0])
        image = Image.open(img_path).convert("RGB")
        label = torch.tensor(int(self.annotations.iloc[idx, 1]))
        if self.transform:
            image = self.transform(image)
        return image, label

In [4]:
def create_dataloaders(csv_file, img_dir, img_size=(224, 224), batch_size=32, n_fold=0):
    transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(), 
        transforms.RandomHorizontalFlip()
    ])

    dataset = CustomImageDataset(csv_file=csv_file, img_dir=img_dir, transform=transform)
    
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=2024)
    for i, (train_index, val_index) in enumerate(skf.split(np.zeros(len(csv_file)), csv_file.iloc[:, 1].values)):
        if i == n_fold:
            break
            
    train_dataset = Subset(dataset, train_index)
    dataset = CustomImageDataset(csv_file=csv_file, img_dir=img_dir, 
                               transform=transforms.Compose([
                                   transforms.Resize(img_size), 
                                   transforms.ToTensor()
                               ]))
    val_dataset = Subset(dataset, val_index)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    return train_loader, val_loader

In [5]:
def train_func(config):
    # Set up MLFlow
    mlflow_logger = MLFlowLogger(
        experiment_name="vit-ai-detection",
        tracking_uri=mlflow.get_tracking_uri(),
        run_name=f"fold-{config['n_fold']}"
    )

    # Preparing data
    train_loader, val_loader = create_dataloaders(
        csv_file=config["labels"],
        img_dir=config["img_dir"],
        img_size=config["img_size"],
        batch_size=config["batch_size"],
        n_fold=config["n_fold"]
    )

    # Model
    class LitViTModel(L.LightningModule):
        def __init__(self, model_name, lr=2e-5, warmup_epochs=0):
            super().__init__()
            self.model = ViTForImageClassification.from_pretrained(model_name)
            self.criterion = nn.BCEWithLogitsLoss()
            self.lr = lr
            self.warmup_epochs = warmup_epochs
            
            self.train_acc = BinaryAccuracy()
            self.val_acc = BinaryAccuracy()
            self.val_f1 = BinaryF1Score()
            self.val_auc = BinaryAUROC()

        def forward(self, x):
            return self.model(x).logits[:, :1]

        def training_step(self, batch, batch_idx):
            x, y = batch
            y = y.float().unsqueeze(1)
            logits = self(x)
            loss = self.criterion(logits, y)
            self.log("train_loss", loss, prog_bar=True)
            self.train_acc(torch.sigmoid(logits), y)
            self.log("train_acc", self.train_acc, on_step=False, on_epoch=True)
            return loss

        def validation_step(self, batch, batch_idx):
            x, y = batch
            y = y.float().unsqueeze(1)
            logits = self(x)
            loss = self.criterion(logits, y)
            
            probs = torch.sigmoid(logits)
            self.val_acc(probs, y)
            self.val_f1(probs, y)
            self.val_auc(probs, y)
            
            self.log("val_loss", loss, prog_bar=True)
            self.log("val_acc", self.val_acc, prog_bar=True)
            self.log("val_f1", self.val_f1, prog_bar=True)
            self.log("val_auc", self.val_auc, prog_bar=True)
            return loss

        def configure_optimizers(self):
            optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
            scheduler = get_cosine_schedule_with_warmup(
                optimizer,
                num_warmup_steps=self.trainer.estimated_stepping_batches * self.warmup_epochs,
                num_training_steps=self.trainer.estimated_stepping_batches * self.trainer.max_epochs
            )
            return [optimizer], [{"scheduler": scheduler, "interval": "step"}]

    # Model
    model = LitViTModel(
        model_name=config["model_name"],
        lr=config["lr"],
        warmup_epochs=config["warmup_epochs"]
    )
    
    # Callbacks
    early_stop = EarlyStopping(
        monitor="val_f1",
        patience=3,
        mode="max",
        verbose=True
    )
    
    checkpoint_callback = ModelCheckpoint(
        monitor="val_f1",
        mode="max",
        save_top_k=1,
        filename="best-checkpoint"
    )
    
    # Trainer
    trainer = L.Trainer(
        logger=mlflow_logger,
        callbacks=[early_stop, checkpoint_callback, ray.train.lightning.RayTrainReportCallback()],
        max_epochs=config["num_epochs"],
        accelerator="cpu",
        devices="auto",
        enable_progress_bar=True,
        log_every_n_steps=10,
        strategy=ray.train.lightning.RayDDPStrategy(),
        plugins = [ray.train.lightning.RayLightningEnvironment()]
    )

    # Log parameters
    mlflow_logger.log_hyperparams(config)
    
    # Train
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [None]:
if __name__ == "__main__":
    # Configuration
    # ray.init(
    #     num_cpus=6,  # Adjust based on your Mac's CPU cores
    #     include_dashboard=True,  # Disable dashboard to reduce overhead
    #     ignore_reinit_error=True
    # )

    # print("Ray cluster resources:", ray.cluster_resources())

    config = {
        "labels": pd.read_csv("./train.csv").iloc[:, 1:].copy(),
        "img_dir": "./",
        "model_name": "google/vit-base-patch16-224",
        "img_size": (224, 224),
        "batch_size": 32,  # Reduced for local execution
        "lr": 2e-5,
        "num_epochs": 10,
        "warmup_epochs": 0,
        "n_fold": 0,
        "num_workers": 2  # Number of parallel training workers
    }

    try:

        scaling_config = ScalingConfig(
            num_workers=config.get("num_workers", 1),  # Using 2 workers for local Mac
            use_gpu=False,  # Mac typically doesn't have supported GPUs for PyTorch
            resources_per_worker={"CPU": 2}  # Allocate 2 CPUs per worker
        )
        
        run_config = RunConfig(
            # storage_path="/tmp/ray_results",  # Local storage path
            name="vit_training"
        )
        
        trainer = TorchTrainer(
            train_func,
            train_loop_config=config,
            scaling_config=scaling_config,
            run_config=run_config
        )
        
        result = trainer.fit()
        print("Training completed successfully.")
    finally:
        # ray.shutdown()
        pass

2025-04-13 23:25:50,734	INFO worker.py:1843 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-04-13 23:25:51,168	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949
2025-04-13 23:25:51,171	INFO tensorboardx.py:193 -- pip install "ray[tune]" to see TensorBoard files.


Ray cluster resources: {'CPU': 6.0, 'memory': 19427491840.0, 'node:127.0.0.1': 1.0, 'object_store_memory': 2147483648.0, 'node:__internal_head__': 1.0}
== Status ==
Current time: 2025-04-13 23:25:51 (running for 00:00:00.15)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/6 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-04-13_23-25-49_925262_14706/artifacts/2025-04-13_23-25-51/vit_training/driver_artifacts
Number of trials: 1/1 (1 PENDING)




[36m(RayTrainWorker pid=14724)[0m Setting up process group for: env:// [rank=0, world_size=2]
[36m(TorchTrainer pid=14723)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=14723)[0m - (node_id=159c4f85ac39ea9039d7c8d27ac8eccb694b20959b2a822df3bc0a51, ip=127.0.0.1, pid=14724) world_rank=0, local_rank=0, node_rank=0
[36m(TorchTrainer pid=14723)[0m - (node_id=159c4f85ac39ea9039d7c8d27ac8eccb694b20959b2a822df3bc0a51, ip=127.0.0.1, pid=14725) world_rank=1, local_rank=1, node_rank=0


== Status ==
Current time: 2025-04-13 23:25:56 (running for 00:00:05.25)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/6 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-04-13_23-25-49_925262_14706/artifacts/2025-04-13_23-25-51/vit_training/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




[36m(RayTrainWorker pid=14724)[0m GPU available: True (mps), used: False
[36m(RayTrainWorker pid=14724)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=14724)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=14724)[0m /opt/miniconda3/envs/mlops/lib/python3.12/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
[36m(RayTrainWorker pid=14724)[0m Loading `train_dataloader` to estimate number of stepping batches.
[36m(RayTrainWorker pid=14724)[0m /opt/miniconda3/envs/mlops/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
[36m(RayTrainWorker pid=14724)[0m 
[36m(RayTrainWorker pid=14724)[0m   | Name      | Type            

Sanity Checking: |          | 0/? [00:00<?, ?it/s]🏃 View run fold-0 at: http://127.0.0.1:8080/#/experiments/634370650965998394/runs/7105ac036d394b43b96b43e99f8bca44
[36m(RayTrainWorker pid=14724)[0m 🧪 View experiment at: http://127.0.0.1:8080/#/experiments/634370650965998394


2025-04-13 23:25:58,195	ERROR tune.py:1037 -- Trials did not complete: [TorchTrainer_2a685_00000]
2025-04-13 23:25:58,195	INFO tune.py:1041 -- Total run time: 7.03 seconds (6.97 seconds for the tuning loop).


== Status ==
Current time: 2025-04-13 23:25:58 (running for 00:00:07.02)
Using FIFO scheduling algorithm.
Logical resource usage: 5.0/6 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-04-13_23-25-49_925262_14706/artifacts/2025-04-13_23-25-51/vit_training/driver_artifacts
Number of trials: 1/1 (1 ERROR)
Number of errored trials: 1
+--------------------------+--------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name               |   # failures | error file                                                                                                                                                             |
|--------------------------+--------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| TorchTrainer_2a685_0000

TrainingFailedError: The Ray Train run failed. Please inspect the previous error messages for a cause. After fixing the issue (assuming that the error is not caused by your own application logic, but rather an error such as OOM), you can restart the run from scratch or continue this run.
To continue this run, you can use: `trainer = TorchTrainer.restore("/Users/anshsarkar/ray_results/vit_training")`.
To start a new run that will retry on training failures, set `train.RunConfig(failure_config=train.FailureConfig(max_failures))` in the Trainer's `run_config` with `max_failures > 0`, or `max_failures = -1` for unlimited retries.