In [247]:
# ViT dataset pre-processing and training notebook

import torch
from datasets import load_dataset
import random
from PIL import ImageDraw, ImageFont, Image
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer
import numpy as np
from evaluate import load
import pandas as pd
from datasets import Dataset
import os
from collections import Counter, defaultdict

In [248]:
family_sedans = [2, 3, 4, 5, 20, 23, 24, 26, 29, 35, 44, 47, 49, 51, 61, 63, 66, 67, 73, 79, 93, 97, 144, 152, 96, 
                 105, 115, 117, 129, 134, 135, 136, 137, 138, 140, 156, 162, 164, 165, 167, 173, 176, 177, 181, 182, 
                 184, 185, 187, 188, 194]
suv = [32, 33, 37, 48, 50, 52, 58, 62, 68, 76, 89, 94, 95, 109, 110, 118, 120, 121, 131, 132, 133, 142, 143, 145, 
       146, 147, 148, 149, 154, 155, 159, 186, 189, 195]
coupe = [9, 11, 15, 42, 43, 13, 14, 22, 25, 28, 34, 72, 77, 80, 81, 100, 101, 102, 103, 104, 107, 112, 123, 128, 
         141, 150, 151, 153, 157, 158, 160, 161, 163, 171, 172, 175, 179, 180, 196]
convertible = [8, 10, 39, 12, 21, 27, 31, 36, 38, 55, 59]
misfits = [1, 18, 30, 40, 41, 45, 53, 46, 54, 56, 57, 60, 64, 69, 65, 70, 71, 7, 19, 6, 16, 17, 74, 75, 78, 82, 83, 84, 
           85, 86, 87, 88, 90, 91, 92, 98, 99, 106, 108, 111, 113, 114, 116, 119, 122, 124, 125, 126, 127, 130, 139, 166, 
           168, 169, 170, 174, 178, 183, 190, 191, 192, 193]
random_1 = [16, 40, 1, 54, 46, 7, 41, 60, 18, 57, 5, 2, 29, 26, 73, 4, 47, 3, 44, 66, 59, 10, 8, 21, 55, 31, 150, 11, 151, 100, 80]
random_2 = [1, 10, 15, 25, 45, 75]

"""
Utility function
def populate_random_datasets():
    dataset = []
    dataset.append(random.sample(family_sedans, 2))
    dataset.append(random.sample(suv, 2))
    dataset.append(random.sample(coupe, 2))
    dataset.append(random.sample(convertible, 2))
    dataset.append(random.sample(misfits, 4))
    dataset = [item for sublist in dataset for item in sublist]
    return dataset
    
random_1 = populate_random_datasets()
random_2 = populate_random_datasets()
"""

'\nUtility function\ndef populate_random_datasets():\n    dataset = []\n    dataset.append(random.sample(family_sedans, 2))\n    dataset.append(random.sample(suv, 2))\n    dataset.append(random.sample(coupe, 2))\n    dataset.append(random.sample(convertible, 2))\n    dataset.append(random.sample(misfits, 4))\n    dataset = [item for sublist in dataset for item in sublist]\n    return dataset\n    \nrandom_1 = populate_random_datasets()\nrandom_2 = populate_random_datasets()\n'

In [249]:
seam_carved_datasets = True

#metadata_filepath = "StanfordCars/stanford_cars_with_class_names.xlsx"
train_metadata_df = pd.read_excel(
    "StanfordCars/stanford_cars_with_class_names.xlsx", sheet_name='train'
)
test_metadata_df = pd.read_excel(
    "StanfordCars/stanford_cars_with_class_names_test.xlsx", sheet_name='test'
)

# Update these folders 
train_rg_folder = "StanfordCars/cars_train/cars_train"
#train_sc_folder = "StanfordCars/Train"
train_sc_folder = "StanfordCars/cars_train_sc"

test_rg_folder = "StanfordCars/cars_test/cars_test"
#test_sc_folder = "StanfordCars/Test"
test_sc_folder = "StanfordCars/cars_test_sc"

if seam_carved_datasets:
    train_folder = train_sc_folder
    test_folder = test_sc_folder
else: 
    train_folder = train_rg_folder
    test_folder = test_rg_folder

folders = {
    "train": (train_folder, train_metadata_df),
    "test": (test_folder, test_metadata_df),
}

In [250]:
# Load dataset
image_label_pairs =[]
class_id_to_label_id = {}
class_name_to_label_id = {}
label_id_to_class_name = {}

# Reset labels from 0
next_label_id = 0

for _k, (folder, df) in folders.items():
    image_files = [f for f in os.listdir(folder) if f.lower()
                     .endswith(('png', 'jpg', 'jpeg', 'bmp', 'tiff'))]
    
    for image_file in image_files:
        image_path = os.path.join(folder, image_file)
    
        try:
            with Image.open(image_path) as img:
    
                # There are a few Grayscale images (less than 0.1%) in the dataset 
                # that we do not consider since the downstream ViT expects three 
                # input channels (RGB) for each image
                if img.mode != "RGB":
                    continue
    
                metadata_filename = image_file.replace("_sc", "") if seam_carved_datasets else image_file
                metadata = df.loc[df['image'] == metadata_filename]
                
                if not metadata.empty:
                    class_name = metadata['ture_class_name'].values[0]
                    class_id = metadata['class'].values[0]
                    # replace this with subset, comment out if testing on full dataset
                    if class_id not in random_2: 
                       continue 
    
                    if class_id not in class_id_to_label_id:
                        class_id_to_label_id[class_id] = next_label_id
                        class_name_to_label_id[class_name] = next_label_id
                        label_id_to_class_name[next_label_id] = class_name
                        next_label_id += 1
                    
                    label_id = class_id_to_label_id[class_id]
                    image_label_pairs.append((img, label_id))
                else:
                    print(f"Could not fine metadata for {image_file}.") 
        except ValueError as ve:
            print(f"ValueError encountered with image: {image_file}, skipping it.\n")


label_counts = Counter(label_id for _, label_id in image_label_pairs)
for label_id, count in label_counts.items():
    print(f"Label ID {label_id}: {count} images")

Label ID 0: 86 images
Label ID 1: 66 images
Label ID 2: 88 images
Label ID 3: 88 images
Label ID 4: 65 images
Label ID 5: 79 images


In [251]:
# Generate training and test datasets
split_ratio = 0.8

train_pairs = []
test_pairs = []

label_to_pairs = defaultdict(list)
for image, label_id in image_label_pairs:
    label_to_pairs[label_id].append((image, label_id))

for label_id, pairs in label_to_pairs.items():
    #random.shuffle(pairs) 
    sorted_pairs = sorted(
        pairs,
        key=lambda x: x[0].filename.split("/")[-1]  # Extract and sort by the filename
    )
    split_index = int(len(pairs) * split_ratio)  
    train_pairs.extend(pairs[:split_index])    
    test_pairs.extend(pairs[split_index:])      

print(f"Total training pairs: {len(train_pairs)}")
print(f"Total testing pairs: {len(test_pairs)}\n")

train_counts = Counter(label_id for _, label_id in train_pairs)
test_counts = Counter(label_id for _, label_id in test_pairs)

print("Training set distribution:")
for label_id, count in train_counts.items():
    print(f"Label ID {label_id}: {count} images")

print("\nTesting set distribution:")
for label_id, count in test_counts.items():
    print(f"Label ID {label_id}: {count} images")

Total training pairs: 375
Total testing pairs: 97

Training set distribution:
Label ID 0: 68 images
Label ID 1: 52 images
Label ID 2: 70 images
Label ID 3: 70 images
Label ID 4: 52 images
Label ID 5: 63 images

Testing set distribution:
Label ID 0: 18 images
Label ID 1: 14 images
Label ID 2: 18 images
Label ID 3: 18 images
Label ID 4: 13 images
Label ID 5: 16 images


In [252]:
# Format datasets so that they can be consumed by HuggingFace APIs

train_images = [image for image, _ in train_pairs]
train_labels = [label for _, label in train_pairs]

train_data_dict = {
    "image": train_images,
    "labels": train_labels,
}
train_dataset = Dataset.from_dict(train_data_dict) # Loading into a HuggingFace Dataset object

test_images = [image for image, _ in test_pairs]
test_labels = [label for _, label in test_pairs]

test_data_dict = {
    "image": test_images,
    "labels": test_labels,
}
test_dataset = Dataset.from_dict(test_data_dict) # Loading into a HuggingFace Dataset object

In [253]:
# Load Image processor
model_name_or_path = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name_or_path)
processor

ViTImageProcessor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [254]:
# Transform images so ViT can understand them
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = processor([x for x in example_batch['image']], return_tensors='pt')

    inputs['labels'] = example_batch['labels']
    return inputs

prepared_train_ds = train_dataset.with_transform(transform)
prepared_test_ds = test_dataset.with_transform(transform)
prepared_train_ds[0]

{'pixel_values': tensor([[[0.6549, 0.6627, 0.7255,  ..., 0.9686, 0.9922, 0.9922],
          [0.6941, 0.7020, 0.7569,  ..., 0.9843, 1.0000, 0.9843],
          [0.7098, 0.7176, 0.7647,  ..., 0.9922, 1.0000, 0.9765],
          ...,
          [0.6471, 0.6941, 0.6941,  ..., 0.8510, 0.8902, 0.8980],
          [0.6314, 0.7176, 0.6784,  ..., 0.8824, 0.9216, 0.9373],
          [0.6392, 0.6706, 0.5843,  ..., 0.8667, 0.9059, 0.9216]],
 
         [[0.4745, 0.4824, 0.5137,  ..., 0.9765, 0.9922, 0.9922],
          [0.5137, 0.5216, 0.5451,  ..., 0.9922, 1.0000, 0.9843],
          [0.5294, 0.5373, 0.5529,  ..., 1.0000, 1.0000, 0.9765],
          ...,
          [0.4980, 0.5608, 0.5451,  ..., 0.8510, 0.8118, 0.8118],
          [0.4667, 0.5765, 0.5137,  ..., 0.8824, 0.8510, 0.8275],
          [0.4745, 0.5059, 0.4196,  ..., 0.8667, 0.8353, 0.8196]],
 
         [[0.2314, 0.2314, 0.2784,  ..., 1.0000, 1.0000, 1.0000],
          [0.2706, 0.2706, 0.3098,  ..., 1.0000, 1.0000, 0.9843],
          [0.2863, 0.286

In [255]:
# Setup model
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

metric = load("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

labels = list(class_id_to_label_id.keys())
model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label=label_id_to_class_name,
    label2id=class_name_to_label_id
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [256]:
# Setup training parameters and Trainer
training_args = TrainingArguments(
    output_dir="./vit-model",
    per_device_train_batch_size=16,
    evaluation_strategy="steps",
    num_train_epochs=20,
    fp16=False,
    no_cuda=True, # If you have access to a GPU, set this to False to speedup training
    save_steps=100,
    eval_steps=100,
    logging_steps=10,
    learning_rate=2e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to='tensorboard',
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=prepared_train_ds,
    eval_dataset=prepared_test_ds,
    tokenizer=processor,
    data_collator=collate_fn,
)

  trainer = Trainer(


In [257]:
# Train 
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

Step,Training Loss,Validation Loss,Accuracy
100,0.05,0.195372,0.938144
200,0.0202,0.137147,0.969072
300,0.0132,0.131128,0.979381
400,0.0106,0.128859,0.979381


***** train metrics *****
  epoch                    =        20.0
  total_flos               = 541294699GF
  train_loss               =      0.1017
  train_runtime            =  0:15:51.15
  train_samples_per_second =       7.885
  train_steps_per_second   =       0.505


In [258]:
# Eval
metrics = trainer.evaluate(prepared_test_ds)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  epoch                   =       20.0
  eval_accuracy           =     0.9794
  eval_loss               =     0.1289
  eval_runtime            = 0:00:04.02
  eval_samples_per_second =     24.084
  eval_steps_per_second   =      3.228
