In [2]:
import torch
from datasets import load_dataset
import random
from PIL import ImageDraw, ImageFont, Image
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer
import numpy as np
from evaluate import load
import pandas as pd
from datasets import Dataset
import os
from collections import Counter, defaultdict

In [44]:
seam_carved_datasets = False

#metadata_filepath = "StanfordCars/stanford_cars_with_class_names.xlsx"
train_metadata_df = pd.read_excel(
    "StanfordCars/stanford_cars_with_class_names.xlsx", sheet_name='train'
)
test_metadata_df = pd.read_excel(
    "StanfordCars/stanford_cars_with_class_names_test.xlsx", sheet_name='test'
)

train_rg_folder = "StanfordCars/cars_train/cars_train"
train_sc_folder = "StanfordCars/Train"

test_rg_folder = "StanfordCars/cars_test/cars_test"
test_sc_folder = "StanfordCars/Test"

if seam_carved_datasets:
    train_folder = train_sc_folder
    test_folder = test_sc_folder
else: 
    train_folder = train_rg_folder
    test_folder = test_rg_folder

folders = {
    "train": (train_folder, train_metadata_df),
    "test": (test_folder, test_metadata_df),
}

In [45]:
image_label_pairs =[]
class_id_to_label_id = {}
class_name_to_label_id = {}
label_id_to_class_name = {}
next_label_id = 0

for _k, (folder, df) in folders.items():
    image_files = [f for f in os.listdir(folder) if f.lower()
                     .endswith(('png', 'jpg', 'jpeg', 'bmp', 'tiff'))]
    
    for image_file in image_files:
        image_path = os.path.join(folder, image_file)
    
        try:
            with Image.open(image_path) as img:
    
                # There are a few Grayscale images (less than 0.1%) in the dataset 
                # that we do not consider since the downstream ViT expects three 
                # input channels (RGB) for each image
                if img.mode != "RGB":
                    continue
    
                metadata_filename = image_file.replace("_sc", "") if seam_carved_datasets else image_file
                metadata = df.loc[df['image'] == metadata_filename]
                
                if not metadata.empty:
                    class_name = metadata['ture_class_name'].values[0]
                    class_id = metadata['class'].values[0]
                    if class_id not in [1,10,15,25,45,75]:
                       continue 
    
                    if class_id not in class_id_to_label_id:
                        class_id_to_label_id[class_id] = next_label_id
                        class_name_to_label_id[class_name] = next_label_id
                        label_id_to_class_name[next_label_id] = class_name
                        next_label_id += 1
                    
                    label_id = class_id_to_label_id[class_id]
                    image_label_pairs.append((img, label_id))
                else:
                    print(f"Could not fine metadata for {image_file}.") 
        except ValueError as ve:
            print(f"ValueError encountered with image: {image_file}, skipping it.\n")


label_counts = Counter(label_id for _, label_id in image_label_pairs)
for label_id, count in label_counts.items():
    print(f"Label ID {label_id}: {count} images")

Label ID 0: 79 images
Label ID 1: 86 images
Label ID 2: 88 images
Label ID 3: 66 images
Label ID 4: 65 images
Label ID 5: 88 images


In [46]:
# Generate training and test datasets
split_ratio = 0.8

train_pairs = []
test_pairs = []

label_to_pairs = defaultdict(list)
for image, label_id in image_label_pairs:
    label_to_pairs[label_id].append((image, label_id))

for label_id, pairs in label_to_pairs.items():
    #random.shuffle(pairs) 
    sorted_pairs = sorted(
        pairs,
        key=lambda x: x[0].filename.split("/")[-1]  # Extract and sort by the filename
    )
    split_index = int(len(pairs) * split_ratio)  
    train_pairs.extend(pairs[:split_index])    
    test_pairs.extend(pairs[split_index:])      

print(f"Total training pairs: {len(train_pairs)}")
print(f"Total testing pairs: {len(test_pairs)}\n")

train_counts = Counter(label_id for _, label_id in train_pairs)
test_counts = Counter(label_id for _, label_id in test_pairs)

print("Training set distribution:")
for label_id, count in train_counts.items():
    print(f"Label ID {label_id}: {count} images")

print("\nTesting set distribution:")
for label_id, count in test_counts.items():
    print(f"Label ID {label_id}: {count} images")

Total training pairs: 375
Total testing pairs: 97

Training set distribution:
Label ID 0: 63 images
Label ID 1: 68 images
Label ID 2: 70 images
Label ID 3: 52 images
Label ID 4: 52 images
Label ID 5: 70 images

Testing set distribution:
Label ID 0: 16 images
Label ID 1: 18 images
Label ID 2: 18 images
Label ID 3: 14 images
Label ID 4: 13 images
Label ID 5: 18 images


In [47]:
# Generate datasets that can be consumed by HuggingFace APIs

train_images = [image for image, _ in train_pairs]
train_labels = [label for _, label in train_pairs]

train_data_dict = {
    "image": train_images,
    "labels": train_labels,
}
train_dataset = Dataset.from_dict(train_data_dict) # Loading into a HuggingFace Dataset object

test_images = [image for image, _ in test_pairs]
test_labels = [label for _, label in test_pairs]

test_data_dict = {
    "image": test_images,
    "labels": test_labels,
}
test_dataset = Dataset.from_dict(test_data_dict) # Loading into a HuggingFace Dataset object



In [48]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name_or_path)
processor

ViTImageProcessor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [49]:
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = processor([x for x in example_batch['image']], return_tensors='pt')

    inputs['labels'] = example_batch['labels']
    return inputs

prepared_train_ds = train_dataset.with_transform(transform)
prepared_test_ds = test_dataset.with_transform(transform)
prepared_train_ds[0]

{'pixel_values': tensor([[[-0.2627, -0.2078, -0.1922,  ...,  0.8824,  0.8824,  0.8824],
          [-0.2471, -0.1608, -0.1451,  ...,  0.8824,  0.8824,  0.8745],
          [-0.2471, -0.1451, -0.1137,  ...,  0.8824,  0.8824,  0.8745],
          ...,
          [ 0.5373,  0.5451,  0.5608,  ...,  0.7961,  0.7961,  0.7961],
          [ 0.5216,  0.5294,  0.5373,  ...,  0.7961,  0.7961,  0.7961],
          [ 0.5216,  0.5294,  0.5294,  ...,  0.7882,  0.7882,  0.7882]],
 
         [[-0.2314, -0.1765, -0.1686,  ...,  0.8902,  0.8902,  0.8902],
          [-0.2157, -0.1294, -0.1216,  ...,  0.8902,  0.8902,  0.8824],
          [-0.2000, -0.1059, -0.0745,  ...,  0.8902,  0.8902,  0.8824],
          ...,
          [ 0.5686,  0.5765,  0.5922,  ...,  0.8039,  0.8039,  0.8039],
          [ 0.5451,  0.5608,  0.5608,  ...,  0.8039,  0.8039,  0.8039],
          [ 0.5451,  0.5529,  0.5529,  ...,  0.7961,  0.7961,  0.7961]],
 
         [[ 0.0510,  0.1451,  0.1608,  ...,  0.9294,  0.9294,  0.9294],
          [ 

In [50]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

metric = load("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

labels = list(class_id_to_label_id.keys())
model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label=label_id_to_class_name,
    label2id=class_name_to_label_id
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [51]:
training_args = TrainingArguments(
    output_dir="./vit-model",
    per_device_train_batch_size=16,
    evaluation_strategy="steps",
    num_train_epochs=20,
    fp16=False,
    no_cuda=True,
    save_steps=100,
    eval_steps=100,
    logging_steps=10,
    learning_rate=2e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to='tensorboard',
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=prepared_train_ds,
    eval_dataset=prepared_test_ds,
    tokenizer=processor,
    data_collator=collate_fn,
)

  trainer = Trainer(


In [52]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

Step,Training Loss,Validation Loss,Accuracy
100,0.0447,0.174697,0.958763
200,0.0198,0.153308,0.958763
300,0.0133,0.151796,0.958763
400,0.0105,0.153833,0.958763


***** train metrics *****
  epoch                    =        20.0
  total_flos               = 541294699GF
  train_loss               =      0.0871
  train_runtime            =  0:17:20.90
  train_samples_per_second =       7.205
  train_steps_per_second   =       0.461


In [53]:
metrics = trainer.evaluate(prepared_test_ds)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  epoch                   =       20.0
  eval_accuracy           =     0.9588
  eval_loss               =     0.1518
  eval_runtime            = 0:00:05.17
  eval_samples_per_second =     18.759
  eval_steps_per_second   =      2.514
