In [1]:
import os
import numpy as np
from PIL import Image

import torch
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader, Subset

import evaluate
from sklearn.model_selection import train_test_split
from transformers import ViTForImageClassification, ViTFeatureExtractor, Trainer, TrainingArguments

import warnings
warnings.simplefilter('ignore')

In [2]:
class PlaneImageDataset(Dataset):
    def __init__(self, img_dir, feature_extractor, transform=None):
        self.img_dir = img_dir
        self.img_labels = [(f, 0 if 'neg' in f else 1) for f in os.listdir(img_dir)]
        self.feature_extractor = feature_extractor
        self.transform = transform

    def __len__(self):
        return len(self.img_labels) 

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx][0])
        image = Image.open(img_path)
        encoding = self.feature_extractor(images=image, return_tensors="pt")
        item = {key: val.squeeze() for key, val in encoding.items()}
        item['labels'] = torch.tensor(self.img_labels[idx][1])
        return item


In [3]:
model_name = 'google/vit-base-patch16-224'
model = ViTForImageClassification.from_pretrained(model_name, num_labels=2, ignore_mismatched_sizes=True)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
def get_datasets(dataset, train_size=0.7, val_size=0.15, test_size=0.15):
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    train_indices, temp_indices = train_test_split(indices, train_size=train_size)
    val_indices, test_indices = train_test_split(temp_indices, test_size=test_size / (val_size + test_size))

    train_dataset = Subset(dataset, train_indices)
    val_dataset = Subset(dataset, val_indices)
    test_dataset = Subset(dataset, test_indices)

    return train_dataset, val_dataset, test_dataset

# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
# ])

dataset = PlaneImageDataset('train_images', feature_extractor)
train_dataset, val_dataset, test_dataset = get_datasets(dataset)

print(f"Total images: {len(dataset)}")
print(f"Train images: {len(train_dataset)}")
print(f"Val images: {len(val_dataset)}")
print(f"Test images: {len(test_dataset)}")

Total images: 3352
Train images: 2346
Val images: 503
Test images: 503


In [5]:
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [6]:
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="epoch",
)




In [7]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
        compute_metrics=compute_metrics
)


In [8]:
trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy
1,0.0073,0.013913,0.998012
2,0.0003,0.015561,0.998012
3,0.0005,0.016044,0.998012




TrainOutput(global_step=441, training_loss=0.08615347555073774, metrics={'train_runtime': 249.4667, 'train_samples_per_second': 28.212, 'train_steps_per_second': 1.768, 'total_flos': 5.453886229074985e+17, 'train_loss': 0.08615347555073774, 'epoch': 3.0})

In [9]:
eval_results = trainer.evaluate()
print(eval_results)



{'eval_loss': 0.01604439690709114, 'eval_accuracy': 0.9980119284294234, 'eval_runtime': 11.2075, 'eval_samples_per_second': 44.881, 'eval_steps_per_second': 1.428, 'epoch': 3.0}


In [10]:
model.save_pretrained('./my_model')
feature_extractor.save_pretrained('./my_model')

['./my_model/preprocessor_config.json']

In [11]:
test_predictions = trainer.predict(test_dataset)
test_accuracy = compute_metrics((test_predictions.predictions, test_predictions.label_ids))
print(f"Test accuracy: {test_accuracy['accuracy']}")



Test accuracy: 0.9960238568588469


In [12]:
train_predictions = trainer.predict(train_dataset)
train_accuracy = compute_metrics((train_predictions.predictions, train_predictions.label_ids))
print(f"Train accuracy: {train_accuracy['accuracy']}")



Train accuracy: 0.9987212276214834
