In [1]:
from datasets import load_dataset

#loading dataset via Hugging Face API
ds = load_dataset('ChrisGuarino/cats')

#Data Exploration
train_data = ds['train']
# test_data = ds['test']
validation_data = ds['validation']

In [2]:
train_data.features

{'image': Image(decode=True, id=None),
 'labels': ClassLabel(names=['prim', 'rupe'], id=None)}

In [3]:
#Load in the image processor from Hugging Face Hub 
from transformers import ViTImageProcessor
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

: 

In [None]:
def process_example(example):
    inputs = processor(example['image'], return_tensors='pt')
    inputs['labels'] = example['labels']
    return inputs
process_example(ds['train'][0])

In [None]:
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = processor([x for x in example_batch['image']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['labels']
    return inputs

prepared_ds = ds.with_transform(transform)

In [None]:
prepared_ds

In [None]:
import torch
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

In [None]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy",trust_remote_code=True)
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)


In [None]:
from transformers import ViTForImageClassification

labels = ds['train'].features['labels'].names
# labels = {0: 'prim', 1: 'rupe'}  # Replace with your actual label mapping


model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

In [None]:
from sklearn.metrics import accuracy_score

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {"accuracy": accuracy_score(labels, predictions)}

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
  output_dir="cat_ds",
  per_device_train_batch_size=16,
  evaluation_strategy="epoch",
  num_train_epochs=4,
  fp16=False,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='none',
  load_best_model_at_end=True,
  save_strategy="epoch"
) 

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],  # Make sure you have a validation set
    tokenizer=processor,
)

## Training

In [None]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

In [None]:
trainer.push_to_hub("ChrisGuarino/cats")