# Finetune SAM Model

**Dataset**

https://drive.google.com/drive/folders/1na6mkrFLiZZ6l0d4pKXqIsxAUxKhJDAu?usp=drive_link

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import gdown
import os
import numpy as np
import tifffile
from random import randint
import matplotlib.pyplot as plt
import json

In [None]:
# Prepare the data
label_url = 'https://drive.google.com/file/d/1T8RDNBtxuBidm9ttNW9ShauDB49dBjWH/view?usp=drive_link'
train_url = 'https://drive.google.com/file/d/1De6cOV0UtS310-vkILWpmY7hiJZRSU9Y/view?usp=drive_link'
val_url = 'https://drive.google.com/file/d/1MFLm_5c0G6CUGNx2o2wrwGAKZHvUBCTI/view?usp=drive_link'
DVRPC_train_url = 'https://drive.google.com/file/d/1pHzGmjQUvrH1TY4XL1vw8xg72u8K5BuI/view?usp=drive_link'
DVRPC_val_url = 'https://drive.google.com/file/d/1YC5oUmGDa0sO14Qc4d-PM8cn2dbU1BKK/view?usp=drive_link'

In [None]:
# Download and unzip the files
data_path = os.path.join(os.getcwd(), 'data')
os.makedirs(data_path, exist_ok=True)
label_path = os.path.join(data_path, 'label.tar.gz')
train_path = os.path.join(data_path,'train.tar.gz')
val_path = os.path.join(data_path,'val.tar.gz')
DVRPC_train_path = os.path.join(data_path,'DVRPC_train.json')
DVRPC_val_path = os.path.join(data_path,'DVRPC_val.json')

train_path_new = os.path.join(data_path, 'Train')
if not os.path.exists(train_path_new):
    gdown.download(train_url, train_path, fuzzy=True)
    !tar -xzf {train_path} -C {data_path}
    !rm -rf {train_path}
train_path = train_path_new
label_path_new = os.path.join(data_path, 'Label')
if not os.path.exists(label_path_new):
    gdown.download(label_url, label_path, fuzzy=True)
    !tar -xzf {label_path} -C {data_path}
    # File too large, need to delete the file after unzipping
    !rm -rf {label_path} {os.path.join(label_path_new, 'Test2')}
label_path = label_path_new
val_path_new = os.path.join(data_path, 'Test')
if not os.path.exists(val_path_new):
    gdown.download(val_url, val_path, fuzzy=True)
    !tar -xzf {val_path} -C {data_path}
    !rm -rf {val_path}
val_path = val_path_new
if not os.path.exists(DVRPC_train_path):
    gdown.download(DVRPC_train_url, DVRPC_train_path, fuzzy=True)
if not os.path.exists(DVRPC_val_path):
    gdown.download(DVRPC_val_url, DVRPC_val_path, fuzzy=True)

train_label_path = os.path.join(label_path, 'Train')
val_label_path = os.path.join(label_path, 'Test')

In [None]:
train_files = [f for f in os.listdir(train_path) if (f.endswith('.tif') and np.max(tifffile.imread(os.path.join(train_label_path, f))) > 0)]
val_files = [f for f in os.listdir(val_path) if (f.endswith('.tif') and np.max(tifffile.imread(os.path.join(val_label_path, f))) > 0)]

In [None]:
# Process json files to get the bounding boxes of an image
def preprocess_json(json_file: str):
    with open(json_file, 'r') as f:
        data = json.load(f)
    filename_id_map = {image['file_name']: image['id'] for image in data['images']}
    ann_map = {}
    for ann in data['annotations']:
        if ann['image_id'] not in ann_map:
            ann_map[ann['image_id']] = []
        ann_map[ann['image_id']].append(ann['bbox'])
    return filename_id_map, ann_map

def filename2bbox(filename: str, filename_id_map: dict, ann_map: dict):
    image_id = filename_id_map[filename]
    if image_id not in ann_map:
        return None
    return ann_map[image_id]

train_filename_id_map, train_ann_map = preprocess_json(DVRPC_train_path)
val_filename_id_map, val_ann_map = preprocess_json(DVRPC_val_path)

In [None]:
# Visualize the data
%matplotlib inline
index = randint(0, len(train_files)-1)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
img = tifffile.imread(os.path.join(train_path, train_files[index]))
label = tifffile.imread(os.path.join(train_label_path, train_files[index]))
ax[0].imshow(img)
ax[0].set_title('Image')
ax[1].imshow(label * 255, cmap='gray')
ax[1].set_title('Mask')
plt.show()

print('bboxes: ', filename2bbox(train_files[index], train_filename_id_map, train_ann_map))

In [None]:
# Install the required libraries
# Transformers
%pip install -q transformers
# monai if you want to use special loss functions
%pip install -q monai

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import SamModel, SamProcessor
from torch.optim import Adam
from monai.losses import DiceLoss
from tqdm import tqdm
import statistics

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class SidewalkDataset(Dataset):
    def __init__(self, data_path: str, label_path: str, filename_id_map: dict, ann_map: dict, files: list, processor, transform=None):
        self.data_path = data_path
        self.label_path = label_path
        self.filename_id_map = filename_id_map
        self.ann_map = ann_map
        self.files = files
        self.processor = processor
        self.transform = transform

    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img = tifffile.imread(os.path.join(self.data_path, self.files[idx]))
        label = tifffile.imread(os.path.join(self.label_path, self.files[idx]))
        bboxes = filename2bbox(self.files[idx], self.filename_id_map, self.ann_map)
        if self.transform:
            img, label = self.transform(img, label)
        inputs = self.processor(img, input_boxes=bboxes, return_tensors='pt')
        # remove batch dimension which the processor adds by default
        inputs = {k:v.squeeze(0) for k,v in inputs.items()}
        inputs['labels'] = torch.tensor(label).unsqueeze(0)
        return inputs

In [None]:
# Load the processor and model
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
sam_model = SamModel.from_pretrained("facebook/sam-vit-base")

In [None]:
# Create datasets and dataloaders
train_dataset = SidewalkDataset(train_path, train_label_path, train_filename_id_map, train_ann_map, train_files, sam_processor)
val_dataset = SidewalkDataset(val_path, val_label_path, val_filename_id_map, val_ann_map, val_files, sam_processor)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=False)

In [None]:
# Show example of the dataset
sample = next(iter(train_dataloader))
for k, v in sample.items():
    print(k, v.shape)

In [None]:
# Make sure we only compute gradients for mask decoder
for name, params in sam_model.named_parameters():
    if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
        params.requires_grad = False

# Optimizer and loss function
optimizer = Adam(sam_model.parameters(), lr=1e-5)
loss_fn = DiceLoss(sigmoid=True, squared_pred=True)

In [None]:
# Training loop
num_epochs = 10

sam_model.to(device).train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
        # Forward pass
        outputs = sam_model(pixel_values=batch['pixel_values'].to(device),
                            input_boxes=batch['input_boxes'].to(device),
                            multimask_output=False)
        # Compute loss
        predicted_masks = outputs.pred_masks.squeeze(1)
        ground_truth_masks = batch['labels'].float().to(device)
        loss = loss_fn(predicted_masks, ground_truth_masks.unsqueeze(1))

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())

    print(f'Epoch {epoch+1}, Loss: {statistics.mean(epoch_losses)}')


In [None]:
# Save the model state dict
torch.save(sam_model.state_dict(), '/content/drive/MyDrive/sam_model.pth')

In [None]:
# Evaluate the model
sam_model.eval()
val_losses = []
for batch in tqdm(val_dataloader):
    with torch.no_grad():
        outputs = sam_model(pixel_values=batch['pixel_values'].to(device),
                            input_boxes=batch['input_boxes'].to(device),
                            multimask_output=False)
        predicted_masks = outputs.pred_masks.squeeze(1)
        ground_truth_masks = batch['labels'].float().to(device)
        loss = loss_fn(predicted_masks, ground_truth_masks.unsqueeze(1))
        val_losses.append(loss.item())

print(f'Validation Loss: {statistics.mean(val_losses)}')

In [None]:
# Visualize the results
sample = next(iter(val_dataloader))
with torch.no_grad():
    outputs = sam_model(pixel_values=sample['pixel_values'].to(device),
                        input_boxes=sample['input_boxes'].to(device),
                        multimask_output=False)
    predicted_masks = outputs.pred_masks.squeeze(1)
    ground_truth_masks = sample['labels'].float().to(device)

%matplotlib inline
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(sample['pixel_values'][0].cpu().numpy())
ax[0].set_title('Image')
ax[1].imshow(predicted_masks[0].cpu().numpy(), cmap='gray')
ax[1].set_title('Predicted Mask')
ax[2].imshow(ground_truth_masks[0].cpu().numpy(), cmap='gray')
ax[2].set_title('Ground Truth Mask')

plt.show()