In [None]:
import ray
import torch
import numpy as np
import ray.train.huggingface.transformers
from ray.air import ScalingConfig
from ray.train.torch import TorchTrainer
from transformers import ViTForImageClassification, TrainingArguments, Trainer, ViTImageProcessor
from ray.train.huggingface.transformers import prepare_trainer, RayTrainReportCallback

In [None]:
dogs_train = 's3://anonymous@air-example-data-2/imagenette2/train/n02102040'
fish_train = 's3://anonymous@air-example-data-2/imagenette2/train/n01440764'

train_ds_images = ray.data.read_images(dogs_train).limit(200).union(ray.data.read_images(fish_train).limit(200))

In [None]:
! cp /home/ray/default/AILibs/labels.csv /mnt/cluster_storage/labels.csv

In [None]:
train_ds_labels = ray.data.read_csv('/mnt/cluster_storage/labels.csv')

In [None]:
labeled_ds = train_ds_images.zip(train_ds_labels)
filtered_labeled_ds = labeled_ds.filter(lambda record: record['image'].ndim==3)

In [None]:
class Featurizer:
    def __init__(self):
        self._model_name_or_path = 'google/vit-base-patch16-224-in21k'
        self._feature_extractor = ViTImageProcessor.from_pretrained(self._model_name_or_path)
        
    def __call__(self, batch):
        inputs = self._feature_extractor([x for x in batch['image']], return_tensors='pt')
        return { 'pixel_values' : inputs['pixel_values'], 'labels' : batch['label'] }
    
featurized_ds = filtered_labeled_ds.map_batches(Featurizer, compute=ray.data.ActorPoolStrategy(size=2))

In [None]:
train_dataset, valid_dataset = featurized_ds.train_test_split(test_size=0.2)

In [None]:
def train_func(config):
    import evaluate
    from ray.train import get_dataset_shard
    
    train_sh = get_dataset_shard("train")
    training = train_sh.iter_torch_batches(batch_size=64)
    
    val_sh = get_dataset_shard("valid")
    valid = val_sh.iter_torch_batches(batch_size=64)
        
    model = ViTForImageClassification.from_pretrained(config['model'])
    
    metric = evaluate.load("accuracy")
    
    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        return metric.compute(predictions=predictions, references=labels)

    # Hugging Face Training Args + Trainer

    training_args = TrainingArguments(
      output_dir="/mnt/cluster_storage/output",
      evaluation_strategy="steps",
      eval_steps = 3,
      per_device_train_batch_size=128,
      logging_steps=2,
      save_steps=4,
      max_steps=10,
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        compute_metrics=compute_metrics,
        train_dataset=training,
        eval_dataset=valid,
    )

    callback = RayTrainReportCallback()
    trainer.add_callback(callback)

    trainer = prepare_trainer(trainer)
    trainer.train()

In [None]:
ray_trainer = TorchTrainer(
    train_loop_per_worker= train_func, 
    train_loop_config= {'model':'google/vit-base-patch16-224-in21k'},
    scaling_config=ScalingConfig(num_workers=2, use_gpu=True),
    run_config=ray.air.RunConfig(storage_path='/mnt/cluster_storage'),
    datasets={"train": train_dataset, "valid": valid_dataset},
)
ray_trainer.fit()