In [70]:
from transformers import ViTModel, ViTImageProcessor
import torch
from torch.utils.data import Dataset
import pandas as pd
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import torch.optim as optim 
from pytorch_metric_learning.losses import TripletMarginLoss
from pytorch_metric_learning.miners import TripletMarginMiner

In [71]:
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
vit = ViTModel.from_pretrained(model_name, torch_dtype=torch.float32)
device = "cuda"

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [72]:
class ViTEmbeddingNet(nn.Module):
    def __init__(self, vit_model):
        super().__init__()
        self.vit = vit_model

    def forward(self, x):
        outputs = self.vit(x)
        # Use [CLS] token (first token in the sequence) as embedding
        return outputs.last_hidden_state[:, 0]

In [73]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),  # Converts PIL to Tensor
])

class CustomImageDataset(Dataset):
    def __init__(self, csv_file, transform = None):
        self.data = pd.read_csv(csv_file)
        self.images_dir = "scene_data/train-scene classification/train/"
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path = self.images_dir + self.data.iloc[idx, 0]
        label = int(self.data.iloc[idx, 1])
        
        image = Image.open(image_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

In [74]:
dataset = CustomImageDataset("scene_data/train-scene classification/train.csv", transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, pin_memory=True)

In [75]:
encoder = ViTEmbeddingNet(vit)

encoder.to(device)

ViTEmbeddingNet(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_f

In [76]:
optimizer = optim.Adam(encoder.parameters(), lr=1e-5)  # Smaller LR for pretrained models
loss_func = TripletMarginLoss(margin=0.2)
miner = TripletMarginMiner(margin=0.2, type_of_triplets="semihard")

In [77]:
from tqdm import tqdm

for images, labels in tqdm(dataloader):
    pass

 16%|█▌        | 43/267 [01:22<07:09,  1.92s/it]


KeyboardInterrupt: 

In [None]:
for images, labels in tqdm(dataloader):
    images = images.to(device)
    labels = labels.to(device)

    embeddings = encoder(images)

    # Use miner to find triplets from labels + embeddings
    mined_triplets = miner(embeddings, labels)

    # Pass embeddings, labels, and mined triplets
    loss = loss_func(embeddings, labels, mined_triplets)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()