<a href="https://colab.research.google.com/github/allan-jt/IJEPA-Thermal-Benchmark/blob/notebook/notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Preliminaries



In [None]:
!pip install --upgrade transformers

In [None]:
import requests
from PIL import Image
from torch.nn.functional import cosine_similarity
from transformers import AutoModel, AutoProcessor
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

In [None]:
model_id = "facebook/ijepa_vith14_22k"
batch_size = 32

# Data Loader

In [None]:
import kagglehub
path = kagglehub.dataset_download("breejeshdhar/thermal-image-dataset-for-object-classification")

In [None]:
import os
!mv {path} {os.getcwd()}

In [None]:
# In Train set, switch the content of Car and Cat since
# the images belong to the other category
train_dir = "/content/1/Thermal Image Dataset/SeekThermal/Train"
car_folder = os.path.join(train_dir, "Car")
cat_folder = os.path.join(train_dir, "Cat")
temp_folder = os.path.join(train_dir, "TempFolder")

os.rename(car_folder, temp_folder)  # Rename "Car" to "TempFolder"
os.rename(cat_folder, car_folder)  # Rename "Cat" to "Car"
os.rename(temp_folder, cat_folder)  # Rename "TempFolder" to "Cat"

# In Test set, capitalise the folder names
test_dir = "/content/1/Thermal Image Dataset/SeekThermal/Test"
car_folder = os.path.join(test_dir, "car")
cat_folder = os.path.join(test_dir, "cat")
man_folder = os.path.join(test_dir, "man")

Car_folder = os.path.join(test_dir, "Car")
Cat_folder = os.path.join(test_dir, "Cat")
Man_folder = os.path.join(test_dir, "Man")

os.rename(car_folder, Car_folder)
os.rename(cat_folder, Cat_folder)
os.rename(man_folder, Man_folder)

In [None]:
# Remove test data from train data folder

# Subdirectories
categories = ["Cat", "Car", "Man"]

for category in categories:
    # Paths for the current category in Train and Test
    train_path = os.path.join(train_dir, category)
    test_path = os.path.join(test_dir, category)

    # Get list of files in Train and Test directories
    train_files = set(os.listdir(train_path))
    test_files = set(os.listdir(test_path))

    # Find overlapping files
    overlapping_files = train_files & test_files

    # Remove overlapping files from Train directory
    for file in overlapping_files:
        file_path = os.path.join(train_path, file)
        os.remove(file_path)

print("Cleanup complete!")


In [None]:
# Count the number of images we have in each class in the training data

def count_only_files(folder_path):
    return len([f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))])

def get_file_count():
  min_file_count = 10000
  class_to_img_count = dict()
  for c in categories:
    folder_path = os.path.join(train_dir, c)
    file_count = count_only_files(folder_path)
    if file_count < min_file_count:
      min_file_count = file_count
    class_to_img_count[c] = file_count
  return min_file_count, class_to_img_count

# Count number of images in the train data for each class
min_img_count, class_to_img_count = get_file_count()
print(f'Mininum file count: {min_img_count}')
for k, v in class_to_img_count.items():
  print(f'{k} : {v} images')

In [None]:
# Balance training data if needed by removing images from classes with more images

import random
def remove_random_files(folder_path, num_files_to_remove):
    # Get the list of all files in the folder
    files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

    # Check if there are enough files to remove
    if num_files_to_remove > len(files):
        print("Error: Number of files to remove exceeds the number of files in the folder.")
        return

    # Randomly select files to remove
    files_to_remove = random.sample(files, num_files_to_remove)

    # Remove the selected files
    for file_name in files_to_remove:
        file_path = os.path.join(folder_path, file_name)
        os.remove(file_path)

    print(f"Removed {len(files_to_remove)} images from {folder_path}")

# Remove images from a class to make the dataset more balanced
for c in categories:
  folder_path = os.path.join(train_dir, c)
  num_files_to_remove = class_to_img_count[c] - min_img_count
  if (num_files_to_remove > 0):
    remove_random_files(folder_path, num_files_to_remove)

In [None]:
# check that the training data is balanced

min_img_count, class_to_img_count = get_file_count()
print(f'Mininum file count: {min_img_count}')
for k, v in class_to_img_count.items():
  print(f'{k} : {v} images')

In [None]:
# def fix_exif_orientation(image):
#     return ImageOps.exif_transpose(image)

processor = AutoProcessor.from_pretrained(model_id)
transform = transforms.Compose([
    transforms.Lambda(lambda img: ImageOps.exif_transpose(img)),  # Fix EXIF orientation
    transforms.Lambda(lambda img: processor(img, return_tensors="pt")["pixel_values"].squeeze(0))  # Apply processor
])


# transform = transforms.Compose([
#     transforms.Lambda(fix_exif_orientation),
#     transforms.Resize((400, 300)),  # Resize images
#     transforms.ToTensor(),          # Convert to tensor
# ])

In [None]:
# Load the dataset
# The test and train folder should follow their orginal structure where there are 3 subfolders, one for each class
train_dataset = datasets.ImageFolder(root='/content/1/Thermal Image Dataset/SeekThermal/Train', transform=transform)
test_dataset = datasets.ImageFolder(root='/content/1/Thermal Image Dataset/SeekThermal/Test', transform=transform)

In [None]:
# Define split ratios
train_size = int(0.8 * len(train_dataset))  # 80% for training
val_size = len(train_dataset) - train_size  # 20% for validation

# Split the dataset
torch.manual_seed(42)  # Set seed for reproducibility
train_data, val_data = random_split(train_dataset, [train_size, val_size])

In [None]:
# Create DataLoaders
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=1)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=1)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1)

In [None]:
print(f"Train data size: {len(train_data)}")
print(f"Validation data size: {len(val_data)}")
print(f"Testing data size: {len(test_dataset)}")
print(train_dataset.class_to_idx)

In [None]:
# Verify a batch from train_loader
from PIL import Image, ImageOps
from IPython.display import display
for images, labels in train_loader:
    print(f"Batch image shape: {images.shape}")  # Example: [32, 3, 300, 400]
    print(f"Batch labels: {labels}")            # Example: tensor([0, 1, ...])
    print(images[0])
    print(labels[0])
    break

# Loading Backbone

In [None]:
class ClassificationHead(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ClassificationHead, self).__init__()

        self.model = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.model(x)

In [None]:
class TransformerWithClassificationHead(nn.Module):
    def __init__(self, vit, output_dim, withLayerNorm=False):
        super(TransformerWithClassificationHead, self).__init__()
        self.vit = vit
        self.vit.requires_grad_(False)
        self.layer_norm = None
        if withLayerNorm:
            self.layer_norm = nn.LayerNorm(self.vit.config.hidden_size)
        self.classification_head = ClassificationHead(
            self.vit.config.hidden_size,
            output_dim,
        )
        self.classification_head.requires_grad_(True)

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values)
        cls_token = outputs.last_hidden_state[:, 0, :]
        if self.layer_norm is not None:
          cls_token = self.layer_norm(cls_token)
        return self.classification_head(cls_token)

# Set up Trainer

In [None]:
backbone = AutoModel.from_pretrained(model_id)

In [None]:
model = TransformerWithClassificationHead(
    backbone,
    3,
)

In [None]:
import torch.optim as optim
from tqdm import tqdm

In [None]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    epoch_loss, correct, total = 0, 0, 0

    for inputs, labels in tqdm(dataloader, desc="Training"):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    accuracy = 100. * correct / total
    return epoch_loss / len(dataloader), accuracy

In [None]:
def evaluate_epoch(model, dataloader, criterion, device):
    model.eval()
    epoch_loss, correct, total = 0, 0, 0

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            epoch_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    accuracy = 100. * correct / total
    return epoch_loss / len(dataloader), accuracy

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

train_losses, train_accuracies = [], []
val_losses, val_accuracies = [], []

# Training parameters
num_epochs = 10
best_accuracy = 0
best_epoch = 0

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")

    # Train
    train_loss, train_accuracy = train_epoch(
        model, train_loader, optimizer, criterion, device)
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    print(f"Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%")

    # Evaluate
    val_loss, val_accuracy = evaluate_epoch(
        model, val_loader, criterion, device)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")


    # Save the best model
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        best_epoch = epoch
        torch.save(model.state_dict(), "best_model.pth")
        print("Best model saved!")

print(f"Training complete. Best Validation Accuracy: {best_accuracy:.2f}%")

# Results

In [None]:
test_loss, test_accuracy = evaluate_epoch(
        model, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

In [None]:
# Test set
best_model_path = "best_model.pth"
model.load_state_dict(torch.load(best_model_path, weights_only=True))
test_loss, test_accuracy = evaluate_epoch(
        model, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

In [None]:
plt.figure(figsize=(20, 7))  # Adjust the size of the plot

# Plot for Losses
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train', marker='o', color='blue')
plt.plot(val_losses, label='Validation', marker='o', color='orange')
plt.scatter([best_epoch], [test_loss], color='red', label='Test', zorder=5)
plt.title('Losses Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.ylim(0, 1.3)


# Plot for Accuracies
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train', marker='o', color='blue')
plt.plot(val_accuracies, label='Validation', marker='o', color='orange')
plt.scatter([best_epoch], [test_accuracy], color='red', label='Test', zorder=5)
plt.title('Accuracies Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.ylim(30, 100)

plt.show()