In [None]:
!pip install -q transformers datasets torch torchvision evaluate gradio

In [None]:
import torch
import numpy as np
import evaluate
import gradio as gr
from datasets import load_dataset
from transformers import (
    ViTImageProcessor,
    ViTForImageClassification,
    TrainingArguments,
    Trainer
)

In [None]:
import torch.nn.functional as F

In [None]:
raw_dataset = load_dataset("zacharielegault/PatchCamelyon", split='train', streaming=True)
dataset_subset = list(raw_dataset.take(5000))

In [None]:
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)

def transform_medical_images(examples):
    inputs = processor(examples['image'], return_tensors='pt')
    inputs['labels'] = examples['label']
    # .squeeze() here prevents the dimension error later!
    inputs['pixel_values'] = inputs['pixel_values'].squeeze()
    return inputs

processed_dataset = [transform_medical_images(ex) for ex in dataset_subset]

In [None]:
id2label = {0: "Healthy", 1: "Malignant"}
label2id = {"Healthy": 0, "Malignant": 1}

model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=2,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)



In [None]:
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

training_args = TrainingArguments(
    output_dir="./vit-cancer-model",
    per_device_train_batch_size=16,
    num_train_epochs=3,
    learning_rate=5e-6,
    weight_decay=0.01,
    logging_steps=10,
    eval_strategy="no",
    save_strategy="epoch",
    fp16=True if torch.cuda.is_available() else False,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset,
    compute_metrics=compute_metrics,
)

In [None]:
print("Starting Training...")
trainer.train()
print("Training Complete!")

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
model.push_to_hub("cancer-ai-model")
processor.push_to_hub("cancer-ai-model")