# Load PEFT Model

In [1]:
import torch
import torch.nn as nn
from transformers import AutoModel,AutoImageProcessor
from peft import PeftModel

class AutoModelForImageEmbedding(nn.Module):
    def __init__(self, model_name, normalize=True):
        super(AutoModelForImageEmbedding, self).__init__()

        # Load a pre-trained image classification model (e.g., a Vision Transformer or similar)
        self.model = AutoModel.from_pretrained(model_name)
        self.normalize = normalize

    def forward(self, images):
        # Forward pass through the model
        model_output = self.model(images)
        pooler_output = model_output['pooler_output']
        
        if self.normalize:
            pooler_output = torch.nn.functional.normalize(pooler_output, p=2, dim=1)

        return pooler_output

    def __getattr__(self, name: str):
        """Forward missing attributes to the wrapped module."""
        try:
            return super().__getattr__(name)  # defer to nn.Module's logic
        except AttributeError:
            return getattr(self.model, name)
        

In [2]:
pretrain_model_name = 'google/vit-large-patch16-224'
model_name = 'model_tif_3/google/vit-large-patch16-224'

In [3]:
# base model
model = AutoModelForImageEmbedding(pretrain_model_name)

# peft config and wrapping
model = PeftModel.from_pretrained(model, model_name)

print(model)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-large-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


PeftModel(
  (base_model): LoraModel(
    (model): AutoModelForImageEmbedding(
      (model): ViTModel(
        (embeddings): ViTEmbeddings(
          (patch_embeddings): ViTPatchEmbeddings(
            (projection): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
          )
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (encoder): ViTEncoder(
          (layer): ModuleList(
            (0-23): 24 x ViTLayer(
              (attention): ViTSdpaAttention(
                (attention): ViTSdpaSelfAttention(
                  (query): lora.Linear(
                    (base_layer): Linear(in_features=1024, out_features=1024, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=1024, out_features=16, bias=False)
                    )
                    (lora_B): ModuleDict(
       

# Image Processor 

In [4]:
image_processor = AutoImageProcessor.from_pretrained(model_name)

In [5]:
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)

val_transforms = Compose(
    [
        Resize(image_processor.size["height"]),
        CenterCrop(image_processor.size["height"]),
        ToTensor(),
        normalize,
    ]
)



def preprocess_val(example_batch):
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

# Load Experimental Images and Embedding

In [6]:
import os
import pandas as pd

In [7]:
experimental_data_folder = 'data/experimental_data/1223/images'
file_path_list = [os.path.join(experimental_data_folder,x,f'{x}_1.tif') for x in os.listdir(experimental_data_folder) if 'D' not in x]
file_path_list2 = [os.path.join(experimental_data_folder,x,f'{x}_2.tif') for x in os.listdir(experimental_data_folder) if 'D' not in x]
file_path_list3 = [os.path.join(experimental_data_folder,x,f'{x}_Merge.tif') for x in os.listdir(experimental_data_folder) if 'D' not in x]

In [8]:
from PIL import Image
from tqdm.notebook import tqdm
Embedding_dict = {}

for v in tqdm(file_path_list):

    image = Image.open(v)
    encoding = val_transforms(image.convert("RGB")) 

    with torch.no_grad():
        outputs = model(encoding.unsqueeze(0))
    
    Embedding_dict[v.split('/')[-1].split('_')[0]] = outputs.numpy().flatten().tolist()
    

  0%|          | 0/134 [00:00<?, ?it/s]

In [9]:
import json

with open('data/experiment_embedding_data_tif.json', 'w') as json_file:
    json.dump(Embedding_dict, json_file)

In [10]:
from PIL import Image
from tqdm.notebook import tqdm
Embedding_dict = {}

for v in tqdm(file_path_list2):

    image = Image.open(v)
    encoding = val_transforms(image.convert("RGB")) 

    with torch.no_grad():
        outputs = model(encoding.unsqueeze(0))
    
    Embedding_dict[v.split('/')[-1].split('_')[0]] = outputs.numpy().flatten().tolist()
    

  0%|          | 0/134 [00:00<?, ?it/s]

In [11]:
import json

with open('data/experiment_embedding_data_tif2.json', 'w') as json_file:
    json.dump(Embedding_dict, json_file)

In [12]:
from PIL import Image
from tqdm.notebook import tqdm
Embedding_dict = {}

for v in tqdm(file_path_list3):

    image = Image.open(v)
    encoding = val_transforms(image.convert("RGB")) 

    with torch.no_grad():
        outputs = model(encoding.unsqueeze(0))
    
    Embedding_dict[v.split('/')[-1].split('_')[0]] = outputs.numpy().flatten().tolist()
    

  0%|          | 0/134 [00:00<?, ?it/s]

In [13]:
import json

with open('data/experiment_embedding_data_tif3.json', 'w') as json_file:
    json.dump(Embedding_dict, json_file)