# Generate Image Embeddings for the Duke-WLOA-AMD Dataset

- The original dataset is in .mat format, however, for each of access we converted the images and annotations into png and tried extracting embeddings. How3ever, the conversion to .png might have introduced noise or artifacts and thus its not workings.
- Here i am trying to first preprocess the .mat data files and then extract embeddings



### Table of Content: <a id = 'table_of_contents'></a>
0. [imports](#imports)
1. [Data Loading](#dataload)
2. [Data Preprocessing & Transform](#dataprocess)
3. [Load Model and Extract Embedding](#load)
4. [Save Embeddings](#save)


## 0. Imports <a id ='imports'></a>
[Back to top](#table_of_contents)

In [7]:
import os 
import torch
import scipy as sp 
import scipy.io as sio 
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from transformers import ViTModel, ViTImageProcessor

  from .autonotebook import tqdm as notebook_tqdm


## 1. Data Loading <a id = 'dataload'></a> 
[Back to top](#table_of_contents)

In [8]:
#dir containing the .mat files
data_dir = '/home/suraj/Data/Duke_WLOA_RL_Annotated/AMD'

In [9]:
## SESSION SETUP
data_files  =  [os.path.join(data_dir, f) for f in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, f))] # List of data files to process
output_dir = os.path.join(data_dir,'embeddings')  # Directory to save the output files
if output_dir and not os.path.exists(output_dir):
    os.makedirs(output_dir) 

In [None]:
# READ MULTIPLE .MAT FILES on after another
#  Read the .mat file in scipt
def read_mat_file(file_path)-> dict:
    ''' Reads a .mat file and returns the data as a dictionary.
    Args:
        file_path (str): Path to the .mat file.
    Returns:
        dict: Data contained in the .mat file.'''
    try:
        data = sp.io.loadmat(file_path, squeeze_me=True)
        return data
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return None

## 2. Data Preprocessing & Transformation <a id = 'dataprocess'> </a>
[Back to top](#table_of_contents)

In [15]:
# reading the file [x] in data_files[]
subject_dct = read_mat_file(data_files[0])  # Replace 0 with the index of the file you want to read
images = subject_dct["images"]  # Shape: (512, 1000, 100); 512 is the height, 1000 is the width, and 100 is the number of b-scans
layer_maps = subject_dct["layerMaps"]  # Shape: (100, 1000, 3); # 100 is the number of b-scans, 1000 is the width, and 3 is the number of layers (ILM, RPE, and BR)
subject_dct.keys()

dict_keys(['__header__', '__version__', '__globals__', 'images', 'layerMaps', 'Age'])

In [16]:
print("Images shape:", images.shape)
print("Layer maps shape:", layer_maps.shape)

Images shape: (512, 1000, 100)
Layer maps shape: (100, 1000, 3)


In [17]:
# preprocess each b-scan in one .mat file


In [18]:
transform = transforms.Compose([
    transforms.Resize((224, 224), interpolation=Image.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [19]:
#Process each b-scan and convert to RGB, apply transformations, and stack them into a tensor
b_scans = []
for i in range(images.shape[2]):
    b_scan = images[:, :, i]
    b_scan_rgb = np.stack([b_scan] * 3, axis=-1)
    b_scan_rgb = (b_scan_rgb / b_scan_rgb.max() * 255).astype(np.uint8)
    img = Image.fromarray(b_scan_rgb)
    b_scans.append(transform(img))
b_scans = torch.stack(b_scans)

## 3. Load Model and Extract Embeddings <a id = 'load'></a>
[Back to top](#table_of_contents)

In [None]:
# 3. Load Model
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTModel.from_pretrained('google/vit-base-patch16-224')
model.eval()

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTLayer(
        (attention): ViTSdpaAttention(
          (attention): ViTSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUAct

In [None]:
# 4. Extract Embeddings
embeddings = []
with torch.no_grad():
    for b_scan in b_scans:
        inputs = b_scan.unsqueeze(0)
        outputs = model(pixel_values=inputs)
        embedding = outputs.last_hidden_state[:, 0, :].squeeze()
        embeddings.append(embedding.numpy())

In [None]:
embeddings = np.array(embeddings)

## 4. Save Embeddings <a id = 'save'></a>
[Back to top](#table_of_contents)

In [None]:
# 6. Save
np.save('b_scan_embeddings.npy', embeddings)
#np.save('volume_embedding.npy', volume_embedding)

In [None]:

# Load model and preprocessor
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTModel.from_pretrained('google/vit-base-patch16-224')
model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224), interpolation=Image.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# List of .mat files
mat_files = [f'path_to_data/patient_{i:03d}.mat' for i in range(1, 201)]  # 200 files



NameError: name 'ViTImageProcessor' is not defined

In [None]:
import torch

for mat_file in mat_files:
    # Extract patient ID from filename (e.g., 'patient_001.mat' -> '001')
    patient_id = mat_file.split('patient_')[1].split('.mat')[0]

    # Load data
    mat_data = sio.loadmat(mat_file)
    image = mat_data['image']  # (512, 1000, 100)

    # Process B-scans
    b_scans = []
    for i in range(image.shape[2]):
        b_scan = image[:, :, i]
        b_scan_rgb = np.stack([b_scan] * 3, axis=-1)
        b_scan_rgb = (b_scan_rgb / b_scan_rgb.max() * 255).astype(np.uint8)
        img = Image.fromarray(b_scan_rgb)
        b_scans.append(transform(img))
    b_scans = torch.stack(b_scans)

    # Extract embeddings
    embeddings = []
    with torch.no_grad():
        for b_scan in b_scans:
            inputs = b_scan.unsqueeze(0)
            outputs = model(pixel_values=inputs)
            embedding = outputs.last_hidden_state[:, 0, :].squeeze()
            embeddings.append(embedding.numpy())

    embeddings = np.array(embeddings)  # (100, 768)
    # Save embeddings
    np.save(f'patient_{patient_id}_embeddings.npy', embeddings)
    print(f"Saved embeddings for patient {patient_id}")


In [None]:
import pandas as pd

mapping_data = {'file_name': [f'patient_{i:03d}_embeddings.npy' for i in range(1, 201)],
                'patient_id': [f'{i:03d}' for i in range(1, 201)]}
df = pd.DataFrame(mapping_data)
df.to_csv('embedding_map.csv', index=False)
