In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from datasets import load_dataset
from timm import create_model
from transformers import ViTImageProcessor
from transformers import ViTConfig, ViTModel
import evaluate 
import numpy as np


In [2]:
from transformers import ViTImageProcessor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name_or_path, return_tensors='pt')


In [3]:
# Initializing a ViT vit-base-patch16-224 style configuration
configuration = ViTConfig()

# Initializing a model (with random weights) from the vit-base-patch16-224 style configuration
model = ViTModel(configuration)



In [4]:
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = processor([x for x in example_batch['image']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['label']
    return inputs


In [5]:
ds = load_dataset("C:/Tesis/DatasetBinario", num_proc=3)

Resolving data files:   0%|          | 0/20959 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/668 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/768 [00:00<?, ?it/s]

In [10]:
dataset_train = ds['train']
num_classes = len(set(dataset_train['label']))
labels = ds['train'].features['label']
print(num_classes, labels)

2 ClassLabel(names=['Melanoma', 'No Melanoma'], id=None)


In [None]:
prepared_ds = ds.with_transform(transform)

In [None]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }


In [None]:
from transformers import ViTForImageClassification

labels = ds['train'].features['label'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
metric = evaluate.load("accuracy")

In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [None]:
from transformers import TrainingArguments, Trainer


training_args = TrainingArguments(output_dir="test_trainer", eval_strategy="epoch")

In [None]:
prepared_ds

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 20959
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 668
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 768
    })
})

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset= prepared_ds["train"],
    eval_dataset= prepared_ds["validation"],
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
    tokenizer=processor
)

  trainer = Trainer(


In [None]:
trainer.train()

  0%|          | 0/7860 [00:00<?, ?it/s]

KeyError: 'image'