## Imports

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models 
from torch.utils.data import DataLoader
import os
from torch.optim import Adam
import matplotlib.pyplot as plt
import numpy as np
import random
import PIL

## Importing Data

In [None]:
#UNCOMMENT IF YOU NEED TO DOWNLOAD DATASET------------------------------------------------------------------------------------------------------
# #Download Dataset
# import opendatasets as od

# od.download("https://www.kaggle.com/datasets/aysendegerli/qatacov19-dataset/data")

In [None]:
from skimage import io

train_img_path = "/home/cahsi/Josh/Research/venv/Semantic_Segmentation_Research/Dataset/qatacov19-dataset/QaTa-COV19/QaTa-COV19-v2/Train Set/Images/*.png"
test_img_path = "/home/cahsi/Josh/Research/venv/Semantic_Segmentation_Research/Dataset/qatacov19-dataset/QaTa-COV19/QaTa-COV19-v2/Test Set/Images/*.png"

train_mask_path = "/home/cahsi/Josh/Research/venv/Semantic_Segmentation_Research/Dataset/qatacov19-dataset/QaTa-COV19/QaTa-COV19-v2/Train Set/Ground-truths/*.png"
test_mask_path = "/home/cahsi/Josh/Research/venv/Semantic_Segmentation_Research/Dataset/qatacov19-dataset/QaTa-COV19/QaTa-COV19-v2/Test Set/Ground-truths/*.png"

# Load all images in the current folder that end with .png
train_img = io.imread_collection(train_img_path)
test_img = io.imread_collection(test_img_path)
train_mask = io.imread_collection(train_mask_path)
test_mask = io.imread_collection(test_mask_path)

#Images Info
print("Length of training raw images: " + str(len(train_img)) + "      Shape of an training raw image: " + str(train_img[0].shape))
print("Length of training mask images: " + str(len(train_mask)) + "     Shape of an training mask image: " + str(train_mask[0].shape))

print("\nLength of test raw images: " + str(len(test_img)) + "       Shape of an test raw image: " + str(test_img[0].shape))
print("Length of test mask images: " + str(len(test_mask)) + "      Shape of an test mask image: " + str(test_mask[0].shape))

## Resizing Images

In [None]:
from skimage.transform import resize
import cv2

def resize_images(images, mask):
    output = []
    if mask:
        for mask in images:
            # Perform resizing with nearest neighbor interpolation to maintain binary values
            resized_mask = (resize(mask, (256, 256), order=0, anti_aliasing=False) > 0.5).astype(np.uint8)
            output.append(resized_mask)
    else:
        for image in images:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
            resized_image = resize(image, (256, 256), anti_aliasing=False)
            output.append(resized_image)
    
    return output

# Resize training images to 256x256
train_images = np.array(resize_images(train_img, False))

# Resize testing images to 256x256
test_img = np.array(resize_images(test_img, False))

# Resize training masks to 256x256
train_masks = np.array(resize_images(train_mask, True))

# Resize testing masks to 256x256
test_mask = np.array(resize_images(test_mask, True))

In [None]:
#Print Shape of resized images
print("Shape of resized training raw image: " + str(train_images.shape))
print("Shape of resized training mask image: " + str(train_masks.shape))
print("Shape of resized testing raw image: " + str(test_img.shape))
print("Shape of resized testing mask image: " + str(test_mask.shape))

## Get Subset of Data

In [None]:
from sklearn.model_selection import train_test_split

#Convert to subset of data
subset_size = 0.03

#------------------>>>>> Comment if you want to use the full dataset for training
# train_images, _, train_masks, _ = train_test_split(train_images, train_masks, train_size=subset_size, random_state=25)

images = train_images
masks = train_masks

#Print Shape of resized images
print("Shape of resized training raw image: " + str(images.shape))
print("Shape of resized training mask image: " + str(masks.shape))

## Create Datasets

In [None]:
from datasets import Dataset
from PIL import Image

# Convert the NumPy arrays to Pillow images and store them in a dictionary
dataset_dict = {
    "image": [Image.fromarray((img * 255).astype(np.uint8)) for img in images],
    "mask": [Image.fromarray(mask) for mask in masks],
}


# Create the dataset using the datasets.Dataset class
dataset = Dataset.from_dict(dataset_dict)

In [None]:
#Check dataset info
print("Dataset Info:\n" + str(dataset))

## Visualize Data

In [None]:
import matplotlib.pyplot as plt
import numpy as np 

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

In [None]:
#Get random image from training set and plot
img_num = random.randint(0, images.shape[0]-1)
example_image = dataset[img_num]["image"]
example_mask = dataset[img_num]["mask"]

fig, axes = plt.subplots(1, 3, figsize=(20, 10))

# Plot the first image on the left
axes[0].imshow(np.array(example_image), cmap='gray')  # Assuming the first image is grayscale
axes[0].set_title("Image")

# Plot the second image in the middle
axes[1].imshow(np.array(example_mask), cmap='gray')  # Assuming the second image is grayscale
axes[1].set_title("Ground Truth Mask")

# Plot them overlapped on the right
axes[2].imshow(np.array(example_image), cmap='gray')  # Assuming the second image is grayscale
show_mask(np.array(example_mask), axes[2])
axes[2].set_title("Result")


# Hide axis ticks and labels
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

# Display the images side by side
plt.show()

### Get Bounding Boxes

In [None]:
#Function that gets bounding boxes from masks
def get_bounding_box(ground_truth_map):
  # get bounding box from mask
  y_indices, x_indices = np.where(ground_truth_map > 0)
  x_min, x_max = np.min(x_indices), np.max(x_indices)
  y_min, y_max = np.min(y_indices), np.max(y_indices)
  # add perturbation to bounding box coordinates
  H, W = ground_truth_map.shape
  x_min = max(0, x_min - np.random.randint(0, 20))
  x_max = min(W, x_max + np.random.randint(0, 20))
  y_min = max(0, y_min - np.random.randint(0, 20))
  y_max = min(H, y_max + np.random.randint(0, 20))
  bbox = [x_min, y_min, x_max, y_max]

  return bbox

### SAM Dataset/Processor

In [None]:
from torch.utils.data import Dataset

class SAMDataset(Dataset):
  """
  This class is used to create a dataset that serves input images and masks.
  It takes a dataset and a processor as input and overrides the __len__ and __getitem__ methods of the Dataset class.
  """
  def __init__(self, dataset, processor):
    self.dataset = dataset
    self.processor = processor

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    image = item["image"]
    ground_truth_mask = np.array(item["mask"])

    # get bounding box prompt
    prompt = get_bounding_box(ground_truth_mask)

    # prepare image and prompt for the model
    inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")

    # remove batch dimension which the processor adds by default
    inputs = {k:v.squeeze(0) for k,v in inputs.items()}

    # add ground truth segmentation
    inputs["ground_truth_mask"] = ground_truth_mask

    return inputs

In [None]:
# Initialize the processor
from transformers import SamProcessor
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

# Create an instance of the SAMDataset
train_dataset = SAMDataset(dataset=dataset, processor=processor)

print("Training: \n")
example = train_dataset[0]
for k,v in example.items():
  print(k,v.shape)

In [None]:
# Create a DataLoader instance for the training dataset
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, drop_last=False)

print("Training: \n")
train_batch = next(iter(train_dataloader))
for k,v in train_batch.items():
  print(k,v.shape)

print("Ground Truth Mask Shape: " + str(train_batch["ground_truth_mask"].shape))

In [None]:
# Load SAM model
from transformers import SamModel
model = SamModel.from_pretrained("facebook/sam-vit-base")

# make sure we only compute gradients for vision encoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

### LoRA Config

In [None]:
# from segment_anything import build_sam, SamPredictor
# from segment_anything import sam_model_registry

# import math
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch import Tensor
# from torch.nn.parameter import Parameter
# from segment_anything.modeling import Sam
# from safetensors import safe_open
# from safetensors.torch import save_file


# class _LoRA_qkv(nn.Module):
#     """In Sam it is implemented as
#     self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
#     B, N, C = x.shape
#     qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
#     q, k, v = qkv.unbind(0)
#     """

#     def __init__(
#         self,
#         qkv: nn.Module,
#         linear_a_q: nn.Module,
#         linear_b_q: nn.Module,
#         linear_a_v: nn.Module,
#         linear_b_v: nn.Module,
#     ):
#         super().__init__()
#         self.qkv = qkv
#         self.linear_a_q = linear_a_q
#         self.linear_b_q = linear_b_q
#         self.linear_a_v = linear_a_v
#         self.linear_b_v = linear_b_v
#         self.dim = qkv.in_features
#         self.w_identity = torch.eye(qkv.in_features)

#     def forward(self, x):
#         qkv = self.qkv(x)  # B,N,N,3*org_C
#         new_q = self.linear_b_q(self.linear_a_q(x))
#         new_v = self.linear_b_v(self.linear_a_v(x))
#         qkv[:, :, :, : self.dim] += new_q
#         qkv[:, :, :, -self.dim :] += new_v
#         return qkv

# class LoRA_Sam(nn.Module):
#     """Applies low-rank adaptation to a Sam model's image encoder.

#     Args:
#         sam_model: a vision transformer model, see base_vit.py
#         r: rank of LoRA
#         num_classes: how many classes the model output, default to the vit model
#         lora_layer: which layer we apply LoRA.

#     Examples::
#         >>> model = ViT('B_16_imagenet1k')
#         >>> lora_model = LoRA_ViT(model, r=4)
#         >>> preds = lora_model(img)
#         >>> print(preds.shape)
#         torch.Size([1, 1000])
#     """

#     def __init__(self, sam_model: Sam, r: int, lora_layer=None):
#         super(LoRA_Sam, self).__init__()

#         assert r > 0
        
#         if lora_layer:
#             self.lora_layer = lora_layer
#         else:
#             self.lora_layer = list(range(len(sam_model.vision_encoder.layers)))
#         # create for storage, then we can init them or load weights
#         self.w_As = []  # These are linear layers
#         self.w_Bs = []

#         # lets freeze first
#         for param in sam_model.vision_encoder.layers.parameters():
#             param.requires_grad = False

#         # Here, we do the surgery
#         for t_layer_i, blk in enumerate(sam_model.vision_encoder.layers):
#             # If we only want few lora layer instead of all
#             if t_layer_i not in self.lora_layer:
#                 continue
#             w_qkv_linear = blk.attn.qkv
#             self.dim = w_qkv_linear.in_features
#             w_a_linear_q = nn.Linear(self.dim, r, bias=False)
#             w_b_linear_q = nn.Linear(r, self.dim, bias=False)
#             w_a_linear_v = nn.Linear(self.dim, r, bias=False)
#             w_b_linear_v = nn.Linear(r, self.dim, bias=False)
#             self.w_As.append(w_a_linear_q)
#             self.w_Bs.append(w_b_linear_q)
#             self.w_As.append(w_a_linear_v)
#             self.w_Bs.append(w_b_linear_v)
#             blk.attn.qkv = _LoRA_qkv(
#                 w_qkv_linear,
#                 w_a_linear_q,
#                 w_b_linear_q,
#                 w_a_linear_v,
#                 w_b_linear_v,
#             )
#         self.reset_parameters()
#         self.sam = sam_model

#     def load_fc_parameters(self, filename: str) -> None:
#         r"""Only safetensors is supported now.

#         pip install safetensor if you do not have one installed yet.
#         """

#         assert filename.endswith(".safetensors")
#         _in = self.lora_vit.head.in_features
#         _out = self.lora_vit.head.out_features
#         with safe_open(filename, framework="pt") as f:
#             saved_key = f"fc_{_in}in_{_out}out"
#             try:
#                 saved_tensor = f.get_tensor(saved_key)
#                 self.lora_vit.head.weight = Parameter(saved_tensor)
#             except ValueError:
#                 print("this fc weight is not for this model")

#     def save_lora_parameters(self, filename: str) -> None:
#         r"""Only safetensors is supported now.

#         pip install safetensor if you do not have one installed yet.
        
#         save both lora and fc parameters.
#         """

#         assert filename.endswith(".safetensors")

#         num_layer = len(self.w_As)  # actually, it is half
#         a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)}
#         b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)}
        
#         merged_dict = {**a_tensors, **b_tensors}
#         save_file(merged_dict, filename)

#     def load_lora_parameters(self, filename: str) -> None:
#         r"""Only safetensors is supported now.

#         pip install safetensor if you do not have one installed yet.\
            
#         load both lora and fc parameters.
#         """

#         assert filename.endswith(".safetensors")

#         with safe_open(filename, framework="pt") as f:
#             for i, w_A_linear in enumerate(self.w_As):
#                 saved_key = f"w_a_{i:03d}"
#                 saved_tensor = f.get_tensor(saved_key)
#                 w_A_linear.weight = Parameter(saved_tensor)

#             for i, w_B_linear in enumerate(self.w_Bs):
#                 saved_key = f"w_b_{i:03d}"
#                 saved_tensor = f.get_tensor(saved_key)
#                 w_B_linear.weight = Parameter(saved_tensor)
                
#     def reset_parameters(self) -> None:
#         for w_A in self.w_As:
#             nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5))
#         for w_B in self.w_Bs:
#             nn.init.zeros_(w_B.weight)

In [None]:
sam_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"SAM total params: {sam_total_params}")

In [None]:
#model.mask_decoder.transformer.layers

In [None]:
import loralib as lora

rank = 20

for layer in model.mask_decoder.transformer.layers:
    
    #Self attention block
    layer.self_attn.q_proj = lora.Linear(256, 256, r=rank)
    layer.self_attn.k_proj = lora.Linear(256, 256, r=rank)
    layer.self_attn.v_proj = lora.Linear(256, 256, r=rank)

    #MLP block
    layer.mlp.lin1 = lora.Linear(256, 2048, r=rank)
    layer.mlp.lin2 = lora.Linear(2048, 256, r=rank)
    
    #Cross attention block (Token -> Image)
    layer.cross_attn_image_to_token.q_proj = lora.Linear(256, 128, r=rank)
    layer.cross_attn_image_to_token.k_proj = lora.Linear(256, 128, r=rank)
    layer.cross_attn_image_to_token.v_proj = lora.Linear(256, 128, r=rank)


    #Cross attention block (Image -> Token)
    layer.cross_attn_image_to_token.q_proj = lora.Linear(256, 128, r=rank)
    layer.cross_attn_image_to_token.k_proj = lora.Linear(256, 128, r=rank)
    layer.cross_attn_image_to_token.v_proj = lora.Linear(256, 128, r=rank)


model.mask_decoder.transformer.final_attn_token_to_image.q_proj = lora.Linear(256, 128, r=rank)
model.mask_decoder.transformer.final_attn_token_to_image.k_proj = lora.Linear(256, 128, r=rank)
model.mask_decoder.transformer.final_attn_token_to_image.v_proj = lora.Linear(256, 128, r=rank)

In [None]:
sam_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"LoRA-SAM total params: {sam_total_params}")

### Train Model

In [None]:
from torch.optim import Adam
import monai

# Initialize the optimizer and the loss function
optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)

#Try DiceFocalLoss, FocalLoss, DiceCELoss
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

In [None]:
from tqdm import tqdm
from statistics import mean
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel
from torch.nn.functional import threshold, normalize

#Training loop
num_epochs = 7
mean_epoch_loss = []

# This sets requires_grad to False for all parameters without the string "lora_" in their names
lora.mark_only_lora_as_trainable(model)

#Set layer-norm to be trainable
for layer in model.mask_decoder.transformer.layers:
    layer.layer_norm1.requires_grad_(True)
    layer.layer_norm2.requires_grad_(True)
    layer.layer_norm3.requires_grad_(True)
    layer.layer_norm4.requires_grad_(True)


# Move your model to GPU devices
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0")
model = model.to(device)

model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
        # forward pass
        outputs = model(pixel_values=batch["pixel_values"].to(device),
                        input_boxes=batch["input_boxes"].to(device),
                        multimask_output=False)

        # compute loss
        predicted_masks = outputs.pred_masks.squeeze(1)
        ground_truth_masks = batch["ground_truth_mask"].float().to(device)
        loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

        # backward pass (compute gradients of parameters w.r.t. loss)
        optimizer.zero_grad()
        loss.backward()

        # optimize
        optimizer.step()
        epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')
    mean_epoch_loss.append(mean_epoch_loss)

In [None]:
# Save the model's state dictionary to a file
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': mean(epoch_losses),
            }, f"/home/cahsi/Josh/Research/venv/Semantic_Segmentation_Research/LoRA_SAM/LoRA_sam_model_checkpoint_rank{rank}.pth")

# Save the LoRA parameters of the model
torch.save(lora.lora_state_dict(model), f"/home/cahsi/Josh/Research/venv/Semantic_Segmentation_Research/LoRA_SAM/lora_rank{rank}.pt")

### Inference

In [None]:
from transformers import SamModel, SamConfig, SamProcessor
import torch


In [None]:
# Load the model configuration
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

# Create an instance of the model architecture with the loaded configuration
lora_model = SamModel(config=model_config)

#Load pretrained model
lora_model.load_state_dict(torch.load(f"/home/cahsi/Josh/Research/venv/Semantic_Segmentation_Research/LoRA_SAM/LoRA_sam_model_checkpoint_rank{rank}.pth"), strict=False)

#Load LoRA checkpoint
lora_model.load_state_dict(torch.load(f"/home/cahsi/Josh/Research/venv/Semantic_Segmentation_Research/LoRA_SAM/lora_rank{rank}.pt"), strict=False)

In [None]:
# set the device to cuda if available, otherwise use cpu
device = "cuda" if torch.cuda.is_available() else "cpu"
lora_model.to(device)

In [None]:
import matplotlib.pyplot as plt
import numpy as np 

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

In [None]:
import numpy as np
import random
import torch
import matplotlib.pyplot as plt

# let's take a random training example
idx = random.randint(0, images.shape[0]-1)

# load image
test_image = dataset[idx]["image"]
test_image = np.array(test_image.convert("RGB"))

# get box prompt based on ground truth segmentation map
ground_truth_mask = np.array(dataset[idx]["mask"])
prompt = get_bounding_box(ground_truth_mask)

# prepare image + box prompt for the model
inputs = processor(test_image, input_boxes=[[prompt]], return_tensors="pt")

# Move the input tensor to the GPU if it's not already there
inputs = {k: v.to(device) for k, v in inputs.items()}

lora_model.eval()

# forward pass
with torch.no_grad():
    outputs = lora_model(**inputs, multimask_output=False)

# apply sigmoid
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)


fig, axes = plt.subplots(1, 3, figsize=(20, 10))

# Plot the first image on the left
axes[0].imshow(test_image, cmap='gray')  # Assuming the first image is grayscale
show_mask(np.array(ground_truth_mask), axes[0])
axes[0].set_title("Ground Truth Mask")

# Plot the second image on the right
axes[1].imshow(test_image, cmap='gray')  # Assuming the second image is grayscale
show_mask(np.array(medsam_seg), axes[1])
axes[1].set_title("Predicted Mask")

# Hide axis ticks and labels
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

# Display the images side by side
plt.show()

### Further Analysis

In [None]:
def calculateIoU(ground_mask, pred_mask):
        # Calculate the TP, FP, FN
        TP = 0
        FP = 0
        FN = 0
 
        for i in range(len(ground_mask)):
            for j in range(len(ground_mask[0])):
                if ground_mask[i][j] == 1 and pred_mask[i][j] == 1:
                    TP += 1
                elif ground_mask[i][j] == 0 and pred_mask[i][j] == 1:
                    FP += 1
                elif ground_mask[i][j] == 1 and pred_mask[i][j] == 0:
                    FN += 1
 
        # Calculate IoU
        iou = TP / (TP + FP + FN)
 
        return iou

In [None]:
example_ground_mask = np.array(dataset[idx]["mask"])
print(f"IoU: {calculateIoU(example_ground_mask, medsam_seg)}")

### Testing on Test Dataset

In [None]:
from datasets import Dataset
from PIL import Image

# Convert the NumPy arrays to Pillow images and store them in a dictionary
test_dataset_dict = {
    "image": [Image.fromarray((img * 255).astype(np.uint8)) for img in test_img[0:100]],
    "mask": [Image.fromarray(mask) for mask in test_mask[0:100]],
}

# Create the dataset using the datasets.Dataset class
test_dataset = Dataset.from_dict(test_dataset_dict)

In [None]:
test_ious = []
model.to(device)
for idx, sample in enumerate(test_dataset):
    # Get Image and ground truth mask
    image = sample["image"]
    ground_truth_mask = np.array(sample["mask"])
    
    # get box prompt based on ground truth segmentation map
    prompt = get_bounding_box(ground_truth_mask)
    
    # prepare image + box prompt for the model
    inputs = processor(image, input_boxes=[[prompt]], return_tensors="pt").to(device)
    #inputs = {k:v.squeeze(0) for k,v in inputs.items()}
    
    # forward pass
    with torch.no_grad():
      outputs = model(**inputs, multimask_output=False)
    # apply sigmoid
    medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
    # convert soft mask to hard mask
    medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
    medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)

    iou = calculateIoU(ground_truth_mask, medsam_seg)
    print(f"Sample {idx} IoU: {iou}")
    test_ious.append(iou)

    
print(f"Average IoUs over 100 test sample: {mean(test_ious)}")

# Results

**Model 1**
- Train Size = ALL
- Rank = 20
- Train Epochs = 7
- Test IOU = 0.32
- **Layers Edited**

    - model.mask_decoder.transformer.layers:
        - layer.self_attn.q_proj 
        - layer.self_attn.k_proj 
        - layer.self_attn.v_proj

        - layer.mlp.lin1
        - layer.mlp.lin2

        - layer.cross_attn_image_to_token.q_proj
        - layer.cross_attn_image_to_token.k_proj
        - layer.cross_attn_image_to_token.v_proj

        - layer.cross_attn_image_to_token.q_proj
        - layer.cross_attn_image_to_token.k_proj
        - layer.cross_attn_image_to_token.v_proj

        - layer.layer_norm1.requires_grad_(True)

    - for layer in model.mask_decoder.transformer.layers:
        - layer.layer_norm1.requires_grad_(True)
        - layer.layer_norm2.requires_grad_(True)
        - layer.layer_norm3.requires_grad_(True)
        - layer.layer_norm4.requires_grad_(True)

    - model.mask_decoder.transformer.final_attn_token_to_image.q_proj
    - model.mask_decoder.transformer.final_attn_token_to_image.k_proj
    - model.mask_decoder.transformer.final_attn_token_to_image.v_proj

https://github.com/MathieuNlp/Sam_LoRA?tab=readme-ov-file

https://github.com/NielsRogge/Transformers-Tutorials/tree/master/SAM