In [None]:
from datasets import load_dataset

dataset_train = load_dataset("Falah/Alzheimer_MRI", split='train')
dataset_test = load_dataset("Falah/Alzheimer_MRI", split='test')

In [None]:
from transformers import AutoImageProcessor

checkpoint = "apple/mobilevit-small"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

In [None]:
def transform(example_batch):
    inputs = image_processor([x.convert('RGB') for x in example_batch['image']], return_tensors='pt')
    inputs['labels'] = example_batch['label']
    return inputs

transformed_dataset_train = dataset_train.with_transform(transform)
transformed_dataset_test = dataset_test.with_transform(transform)

In [None]:
import evaluate
import numpy as np

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [None]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(checkpoint,
                                                        num_labels=len(dataset_train.features["label"].names),
                                                        ignore_mismatched_sizes=True
                                                       )

training_args = TrainingArguments(
    output_dir="model",
    remove_unused_columns=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=5,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=transformed_dataset_train,
    eval_dataset=transformed_dataset_test,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)


In [None]:
trainer.train()

In [None]:
metrics = trainer.evaluate(transformed_dataset_test)
outputs = trainer.predict(transformed_dataset_test)
print(metrics)

In [None]:
dataset_test[0]['image']

In [None]:
outputs.predictions[0]

In [None]:
outputs.predictions[0].argmax(-1)