In [None]:
import numpy as np
import pandas as pd
import os
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_splitl

In [None]:
plt.style.use("dark_background")
!pip install git+https://github.com/facebookresearch/segment-anything-2.git
!pip install -q git+https://github.com/huggingface/transformers.git

In [None]:
BASE_PATH = "/kaggle/input/lgg-mri-segmentation/kaggle_3m"
BASE_LEN = 89
END_LEN = 4
END_MASK_LEN = 9
IMG_SIZE = 512

In [None]:
def load_data(base_path):
    data = []
    for dir_ in os.listdir(base_path):
        dir_path = os.path.join(base_path, dir_)
        if os.path.isdir(dir_path):
            for filename in os.listdir(dir_path):
                img_path = os.path.join(dir_path, filename)
                data.append([dir_, img_path])
        else:
            print(f"[INFO] This is not a directory --> {dir_path}")
    return pd.DataFrame(data, columns=["dir_name", "image_path"])

In [None]:
df = load_data(BASE_PATH)
df_imgs = df[~df["image_path"].str.contains("mask")]
df_masks = df[df["image_path"].str.contains("mask")]
imgs = sorted(df_imgs["image_path"].values, key=lambda x: int(x[BASE_LEN: -END_LEN]))
masks = sorted(df_masks["image_path"].values, key=lambda x: int(x[BASE_LEN: -END_MASK_LEN]))

In [None]:
dff = pd.DataFrame({
    "patient": df_imgs.dir_name.values,
    "image_path": imgs,
    "mask_path": masks
})

In [None]:
def check_patient(mask_path):
    val = np.max(cv2.imread(mask_path))
    return 1 if val > 0 else 0

dff["diagnosis"] = dff["mask_path"].apply(lambda x: check_patient(x))

In [None]:
def get_bounding_box(ground_truth_map):
    y_indices, x_indices = np.where(ground_truth_map > 0)
    if len(x_indices) == 0:
        return [0, 0, 0, 0]

    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)

    H, W = ground_truth_map.shape
    x_min = max(0, x_min - np.random.randint(0, 20))
    x_max = min(W, x_max + np.random.randint(0, 20))
    y_min = max(0, y_min - np.random.randint(0, 20))
    y_max = min(H, y_max + np.random.randint(0, 20))
    return [x_min, y_min, x_max, y_max]

In [None]:
from segment_anything_2 import Sam2Processor, Sam2Model, Sam2Config

class SAM2Dataset(Dataset):
    def __init__(self, df, processor):
        self.df = df
        self.processor = processor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image = cv2.imread(self.df.iloc[idx, 1])
        mask = cv2.imread(self.df.iloc[idx, 2], 0)
        ground_truth_mask = np.array(mask)
        prompt = get_bounding_box(ground_truth_mask)

        inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        inputs["ground_truth_mask"] = ground_truth_mask

        return inputs

In [None]:
processor = Sam2Processor.from_pretrained("facebook/sam2")

train_df, val_df = train_test_split(dff, stratify=dff.diagnosis, test_size=0.1)
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
train_df, test_df = train_test_split(train_df, stratify=train_df.diagnosis, test_size=0.12)
train_df = train_df.reset_index(drop=True)

train_dataset = SAM2Dataset(train_df, processor)
val_dataset = SAM2Dataset(val_df, processor)
test_dataset = SAM2Dataset(test_df, processor)

batch_size = 4
lr = 1e-4

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=False)

model = Sam2Model.from_pretrained("facebook/sam2")

for name, param in model.named_parameters():
    if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
        param.requires_grad_(False)

In [None]:
from torch.optim import Adam
import monai
from tqdm import tqdm
from statistics import mean

optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

num_epochs = 3
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.train()

for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_loader):
        outputs = model(
            pixel_values=batch["pixel_values"].to(device),
            input_boxes=batch["input_boxes"].to(device),
            multimask_output=False
        )

        predicted_masks = outputs.pred_masks.squeeze(1)
        ground_truth_masks = batch["ground_truth_mask"].float().to(device)
        loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')

path = "files/"
if not os.path.exists(path):
    os.makedirs(path)
checkpoints_path = "files/sam2_model.pth"

torch.save(model.state_dict(), checkpoints_path)

model_config = Sam2Config.from_pretrained("facebook/sam2")
processor = Sam2Processor.from_pretrained("facebook/sam2")
my_mito_model = Sam2Model(config=model_config)
my_mito_model.load_state_dict(torch.load(checkpoints_path))
my_mito_model.to(device)

def evaluate_and_visualize(image, mask, model, processor, device):
    prompt = get_bounding_box(mask)
    inputs = processor(image, input_boxes=[[prompt]], return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    model.eval()
    with torch.no_grad():
        outputs = model(**inputs, multimask_output=False)

    seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
    seg_prob = seg_prob.cpu().numpy().squeeze()
    seg_mask = (seg_prob > 0.5).astype(np.uint8)

    return seg_mask, seg_prob

num_samples = 3
fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))

for i in range(num_samples):
    idx = random.randint(0, len(imgs)-1)
    img_path = imgs[idx]
    mask_path = masks[idx]

    img_test = cv2.imread(img_path)
    mask_test = cv2.imread(mask_path, 0)

    seg_mask, seg_prob = evaluate_and_visualize(img_test, mask_test, my_mito_model, processor, device)

    axes[i, 0].imshow(np.array(img_test))
    axes[i, 0].set_title("Image")
    axes[i, 1].imshow(seg_mask)
    axes[i, 1].set_title("Segmentation Mask")
    axes[i, 2].imshow(seg_prob)
    axes[i, 2].set_title("Probability Map")

    for ax in axes[i]:
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_yticklabels([])

plt.tight_layout()
plt.show()