In [1]:
import os
from transformers import ViTForImageClassification, ViTImageProcessor, TrainingArguments, Trainer
from torchvision.transforms import v2
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset_dir = "dataset"

# Load training dataset
train_dataset = load_dataset('imagefolder', data_dir=os.path.join(dataset_dir, 'train'))['train']
# Load validation dataset
val_dataset = load_dataset('imagefolder', data_dir=os.path.join(dataset_dir, 'valid'))['train']
# Load test dataset
test_dataset = load_dataset('imagefolder', data_dir=os.path.join(dataset_dir, 'test'))['train']

Downloading data: 100%|██████████| 6115/6115 [00:00<00:00, 118389.10files/s]
Generating train split: 6115 examples [00:00, 49821.13 examples/s]
Downloading data: 100%|██████████| 1833/1833 [00:00<00:00, 1077074.70files/s]
Generating train split: 1833 examples [00:00, 48982.89 examples/s]
Downloading data: 100%|██████████| 854/854 [00:00<00:00, 1016728.82files/s]
Generating train split: 854 examples [00:00, 46385.51 examples/s]


In [4]:
# Load pre-trained ViT image processor
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

# Get image mean and std
image_mean, image_std = processor.image_mean, processor.image_std

# Define transformations
train_transform = v2.Compose([
    v2.Resize((processor.size["height"], processor.size["width"])),
    v2.RandomHorizontalFlip(p=0.1),
    v2.RandomVerticalFlip(p=0.1),
    v2.RandomApply([v2.RandomRotation(degrees=30)], p=0.25),
    v2.RandomApply([v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)], p=0.25),
    v2.RandomApply(transforms=[v2.GaussianBlur(kernel_size=(5, 9))], p=0.25),
    v2.ToTensor(),
    v2.Normalize(mean=image_mean, std=image_std)
])

test_transform = v2.Compose([
    v2.Resize((processor.size["height"], processor.size["width"])),
    v2.ToTensor(),
    v2.Normalize(mean=image_mean, std=image_std)
])



In [5]:
# Define the preprocessing function
def preprocess_data(examples, transform):
    examples['pixel_values'] = [transform(image) for image in examples['image']]
    return examples

# Apply preprocessing to the datasets
train_dataset = train_dataset.map(lambda examples: preprocess_data(examples, train_transform), batched=True)
val_dataset = val_dataset.map(lambda examples: preprocess_data(examples, test_transform), batched=True)
test_dataset = test_dataset.map(lambda examples: preprocess_data(examples, test_transform), batched=True)

Map: 100%|██████████| 6115/6115 [00:16<00:00, 372.12 examples/s]
Map: 100%|██████████| 1833/1833 [00:03<00:00, 510.91 examples/s]
Map: 100%|██████████| 854/854 [00:01<00:00, 507.45 examples/s]


In [6]:
# Initialize the model
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=203, ignore_mismatched_sizes=True)

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy='epoch',
    optim="adamw_torch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=1,
    weight_decay=0.001,
)

# Initialize the trainer
trainer = Trainer(
    model=model,                         
    args=training_args,                  
    train_dataset=train_dataset,         
    eval_dataset=val_dataset,                 
)

trainer.train()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([203]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([203, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                 
100%|██████████| 192/192 [14:11<00:00,  4.44s/it]

{'eval_loss': 5.0429840087890625, 'eval_runtime': 127.5988, 'eval_samples_per_second': 14.365, 'eval_steps_per_second': 0.455, 'epoch': 1.0}
{'train_runtime': 851.8167, 'train_samples_per_second': 7.179, 'train_steps_per_second': 0.225, 'train_loss': 5.193100293477376, 'epoch': 1.0}





TrainOutput(global_step=192, training_loss=5.193100293477376, metrics={'train_runtime': 851.8167, 'train_samples_per_second': 7.179, 'train_steps_per_second': 0.225, 'total_flos': 4.7471718134486016e+17, 'train_loss': 5.193100293477376, 'epoch': 1.0})

In [None]:
evaluation_results = trainer.evaluate(eval_dataset=test_dataset)
print(evaluation_results)