In [1]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from PIL import Image

from transformers import (
    Trainer,
    TrainingArguments,
    PreTrainedModel,
    PretrainedConfig,
    AutoImageProcessor,
    ViTModel
)

import numpy as np
from sklearn.metrics import accuracy_score


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class SiameseDataset(Dataset):
    def __init__(self, root_dir, image_processor, num_pairs=10000):
        """
        Args:
            root_dir (str): Path to folder containing class subfolders.
            image_processor: A Hugging Face image processor for ViT.
            num_pairs (int): Number of pairs (samples) to generate each epoch.
        """
        self.root_dir = root_dir
        self.image_processor = image_processor
        self.num_pairs = num_pairs
        
        # Gather all classes (subfolders)
        self.classes = [d for d in os.listdir(root_dir) 
                        if os.path.isdir(os.path.join(root_dir, d))]
        
        # For each class, gather the list of image paths
        self.class_to_images = {}
        for cls in self.classes:
            cls_path = os.path.join(root_dir, cls)
            images = [
                os.path.join(cls_path, f)
                for f in os.listdir(cls_path)
                if f.lower().endswith((".png", ".jpg", ".jpeg"))
            ]
            self.class_to_images[cls] = images
        
        # Make a list of class names for easy random selection
        self.class_list = list(self.class_to_images.keys())

    def __len__(self):
        return self.num_pairs

    def __getitem__(self, idx):
        # 1) Randomly pick one class and one image from that class
        class1 = random.choice(self.class_list)
        img1_path = random.choice(self.class_to_images[class1])
        
        # 2) Decide if second image is from the same class (label=1) or different (label=0)
        if random.random() < 0.5:
            class2 = class1
            label = 1
        else:
            class2 = random.choice([c for c in self.class_list if c != class1])
            label = 0
        
        img2_path = random.choice(self.class_to_images[class2])
        
        # 3) Load images
        img1 = Image.open(img1_path).convert("RGB")
        img2 = Image.open(img2_path).convert("RGB")
        
        # 4) Apply the image processor (returns dict with "pixel_values")
        enc1 = self.image_processor(img1, return_tensors="pt")
        enc2 = self.image_processor(img2, return_tensors="pt")
        
        pixel_values1 = enc1["pixel_values"].squeeze(0)  # (3, H, W)
        pixel_values2 = enc2["pixel_values"].squeeze(0)  # (3, H, W)
        
        return {
            "pixel_values1": pixel_values1,
            "pixel_values2": pixel_values2,
            "labels": label
        }


In [3]:
class SiameseCollator:
    def __call__(self, features):
        # features is a list of dicts:
        # [
        #   {"pixel_values1": Tensor, "pixel_values2": Tensor, "labels": int},
        #   {"pixel_values1": Tensor, "pixel_values2": Tensor, "labels": int},
        #   ...
        # ]
        
        pixel_values1 = torch.stack([f["pixel_values1"] for f in features], dim=0)
        pixel_values2 = torch.stack([f["pixel_values2"] for f in features], dim=0)
        labels = torch.tensor([f["labels"] for f in features], dtype=torch.float)
        
        return {
            "pixel_values1": pixel_values1,  # (batch_size, 3, H, W)
            "pixel_values2": pixel_values2,  # (batch_size, 3, H, W)
            "labels": labels                 # (batch_size,)
        }


In [4]:
class SiameseViTMSNConfig(PretrainedConfig):
    model_type = "vit-msn"

    def __init__(self, model_name="facebook/vit-msn-base", embed_dim=256, **kwargs):
        super().__init__(**kwargs)
        self.model_name = model_name
        self.embed_dim = embed_dim


class SiameseViTMSN(PreTrainedModel):
    config_class = SiameseViTMSNConfig
    
    def __init__(self, config):
        super().__init__(config)
        
        # Load base ViT
        self.vit = ViTModel.from_pretrained(config.model_name)
        
        hidden_size = self.vit.config.hidden_size
        self.projector = nn.Linear(hidden_size, config.embed_dim)
        self.classifier = nn.Linear(config.embed_dim, 1)
        
        # This is needed for correct saving/loading
        self.post_init()
    
    def forward(
        self,
        pixel_values1=None,
        pixel_values2=None,
        labels=None,
        **kwargs
    ):
        """
        pixel_values1: (batch, 3, H, W)
        pixel_values2: (batch, 3, H, W)
        labels: (batch,) -> 1 if same class, 0 otherwise
        """
        # Pass each set of images through the ViT
        outputs1 = self.vit(pixel_values1)
        outputs2 = self.vit(pixel_values2)
        
        # CLS token is at index 0
        cls1 = outputs1.last_hidden_state[:, 0]  # (batch, hidden_size)
        cls2 = outputs2.last_hidden_state[:, 0]  # (batch, hidden_size)
        
        # Project down (optional)
        proj1 = self.projector(cls1)  # (batch, embed_dim)
        proj2 = self.projector(cls2)  # (batch, embed_dim)
        
        # Absolute difference for Siamese
        diff = torch.abs(proj1 - proj2)  # (batch, embed_dim)
        
        # Binary logits
        logits = self.classifier(diff).squeeze(-1)  # (batch,)
        
        loss = None
        if labels is not None:
            labels = labels.float()
            # BCE with logits for binary classification
            loss = F.binary_cross_entropy_with_logits(logits, labels)
        
        return {
            "loss": loss,
            "logits": logits
        }

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        """
        Override to allow a quick creation from a model name without manually
        creating the config. We'll create a default config, then load weights.
        """
        config = SiameseViTMSNConfig(model_name=pretrained_model_name_or_path)
        model = cls(config)
        
        # Replace the internal ViT with the pretrained one
        model.vit = ViTModel.from_pretrained(pretrained_model_name_or_path)
        return model

In [5]:
def compute_metrics(eval_pred):
    """
    eval_pred is a namedtuple: (predictions, label_ids)
    - predictions: shape (batch,) -> raw logits
    - label_ids: shape (batch,)
    """
    logits, labels = eval_pred
    preds = (1 / (1 + np.exp(-logits))) > 0.5  # sigmoid > 0.5
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc}

In [8]:
image_processor = AutoImageProcessor.from_pretrained("facebook/vit-msn-base")
    
    # 2. Create your train & val datasets
train_dataset = SiameseDataset(
    root_dir=r"..\train-dataset\train",
    image_processor=image_processor,
    num_pairs=8000  # Number of pairs to sample each epoch
)
val_dataset = SiameseDataset(
    root_dir=r"..\train-dataset\train",
    image_processor=image_processor,
    num_pairs=2000
)
data_collator = SiameseCollator()

# 4. Load the Siamese model
#    This uses the "facebook/vit-msn-base" as the underlying ViT.
model = SiameseViTMSN.from_pretrained("facebook/vit-msn-base")

You are using a model of type vit_msn to instantiate a model of type vit. This is not supported for all configurations of models and can yield errors.
Some weights of ViTModel were not initialized from the model checkpoint at facebook/vit-msn-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
You are using a model of type vit_msn to instantiate a model of type vit. This is not supported for all configurations of models and can yield errors.
Some weights of ViTModel were not initialized from the model checkpoint at facebook/vit-msn-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
training_args = TrainingArguments(
    output_dir="./siamese_vitmsn",
     eval_strategy="steps",
  num_train_epochs=2,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-5,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to=None,
  load_best_model_at_end=True,
)

# 6. Create Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [13]:
trainer.train()

Step,Training Loss,Validation Loss,Accuracy
100,0.7034,0.693854,0.5165
200,0.6938,0.693589,0.515
300,0.6929,0.691501,0.5055
400,0.6873,0.693225,0.5075
500,0.6946,0.692387,0.5165


KeyboardInterrupt: 