In [None]:
from transformers import ViTForImageClassification, TrainingArguments, Trainer, ViTImageProcessor
import torch
import numpy as np
import ray.train.huggingface.transformers
from ray.air import ScalingConfig
from ray.train.torch import TorchTrainer
from ray.train.huggingface.transformers import prepare_trainer, RayTrainReportCallback

In [None]:
def train_func(config):
    from datasets import load_dataset
    import evaluate
    
    ds = load_dataset('beans')
    
    model_name_or_path = 'google/vit-base-patch16-224-in21k'
    feature_extractor = ViTImageProcessor.from_pretrained(model_name_or_path)

    def transform(example_batch):
        # Take a list of PIL images and turn them to pixel values
        inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')

        # Don't forget to include the labels!
        inputs['labels'] = example_batch['labels']
        return inputs

    prepared_ds = ds.with_transform(transform)
    
    labels = ds['train'].features['labels'].names
    
    model = ViTForImageClassification.from_pretrained(
        model_name_or_path,
        num_labels=len(labels),
        id2label={str(i): c for i, c in enumerate(labels)},
        label2id={c: str(i) for i, c in enumerate(labels)}
    )
    
    metric = evaluate.load("accuracy")
    
    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        return metric.compute(predictions=predictions, references=labels)

    # Hugging Face Training Args + training Trainer

    training_args = TrainingArguments(
      output_dir="/mnt/local_storage/output",
      evaluation_strategy="epoch",
      num_train_epochs=2,
      logging_steps=10,
      remove_unused_columns=False,
    )
    
    def collate_fn(batch):
        return {
            'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
            'labels': torch.tensor([x['labels'] for x in batch])
        }
    
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=collate_fn,
        compute_metrics=compute_metrics,
        train_dataset=prepared_ds["train"],
        eval_dataset=prepared_ds["validation"],
    )

    callback = RayTrainReportCallback()
    trainer.add_callback(callback)

    trainer = prepare_trainer(trainer)
    trainer.train()

In [None]:
ray_trainer = TorchTrainer(
    train_func, scaling_config=ScalingConfig(num_workers=2, use_gpu=True),
    run_config=ray.air.RunConfig(storage_path='/mnt/cluster_storage')
)
ray_trainer.fit()