In [75]:
import os
import torch
import pandas as pd
from torchvision import transforms
from torchvision.datasets.folder import default_loader
from torch.utils.data import Dataset, DataLoader
from transformers import ViTForImageClassification, ViTFeatureExtractor, TrainingArguments, Trainer
from PIL import Image
import sys
import accelerate

In [71]:
# Paths
image_dir = '../../cleaned_data/ISIC_2019_Training_Input_cleaned'
labels_csv = '../../cleaned_data/ISIC_2019_Training_GroundTruth_Clean.csv'
metadata_csv = '../../cleaned_data/ISIC_2019_Training_Metadata.csv'

# Load labels
df_labels = pd.read_csv(labels_csv)
df_labels.set_index('image', inplace=True)

# Use label columns directly
label_names = df_labels.columns.tolist()
encoded_labels = df_labels.values  # NumPy array

In [72]:
class ISICDataset(Dataset):
    def __init__(self, image_dir, df_labels, feature_extractor, df_metadata=None):
        self.image_dir = image_dir
        self.df_labels = df_labels
        self.feature_extractor = feature_extractor
        self.df_metadata = df_metadata

    def __len__(self):
        return len(self.df_labels)

    def __getitem__(self, idx):
        image_id = self.df_labels.index[idx]
        image_path = os.path.join(self.image_dir, f"{image_id}.jpg")
        image = default_loader(image_path)

        # Image preprocessing
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        inputs = {k: v.squeeze() for k, v in inputs.items()}

        # Multi-labels
        labels = torch.tensor(self.df_labels.iloc[idx].values.astype(float), dtype=torch.float)

        # Optional metadata
        if self.df_metadata is not None:
            metadata_row = self.df_metadata.loc[image_id].values.astype(float)
            metadata_tensor = torch.tensor(metadata_row, dtype=torch.float)
            inputs['metadata'] = metadata_tensor

        inputs['labels'] = labels
        return inputs

In [73]:
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

# Optional: Load metadata
use_metadata = False
df_metadata = pd.read_csv(metadata_csv, index_col='image') if use_metadata else None

dataset = ISICDataset(image_dir=image_dir, df_labels=df_labels, feature_extractor=feature_extractor, df_metadata=df_metadata)



In [77]:
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=len(label_names),
    problem_type="multi_label_classification"
)

# Hugging Face Trainer expects dict inputs from dataset
def collate_fn(batch):
    keys = batch[0].keys()
    return {k: torch.stack([d[k] for d in batch]) for k in keys}

training_args = TrainingArguments(
    output_dir='./results',
    per_device_train_batch_size=8,
    num_train_epochs=3,
    logging_dir='./logs',
    logging_steps=10,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=feature_extractor,
    data_collator=collate_fn
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ImportError: Using the `Trainer` with `PyTorch` requires `accelerate>=0.26.0`: Please run `pip install transformers[torch]` or `pip install 'accelerate>=0.26.0'`

In [59]:
trainer.train()

NameError: name 'trainer' is not defined