In [2]:
import torch
import numpy as np
import cv2
import os
import pandas as pd
from datasets import Dataset


class CenterDataset(Dataset):
    def __init__(self, csv_file, image_dir, split='TRAIN', transform=None):
        self._data = pd.read_csv(csv_file)
        self._data = self._data[self._data['Split'] == split].reset_index(drop=True)
        self._data['FileName'] = self._data['FileName'].astype(str).str.strip()
        self.image_dir:str = image_dir
        self.transform = transform

    def __len__(self):
        return len(self._data)

    def __getitem__(self, idx):
        row = self._data.iloc[idx]
        filename = row['FileName']
        img_path = os.path.join(self.image_dir, filename + '.png')
        mask_path = os.path.join(self.image_dir.replace('frames', 'mask'), filename + '.png')

        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)  # ensure 1-channel
        if image is None:
            raise FileNotFoundError(f"Image not found: {img_path}")

        image = image.astype('float32') / 255.0  # normalize manually if not using transforms
        image = np.expand_dims(image, axis=0)    # [1, 112, 112]
        image = torch.tensor(image, dtype=torch.float32)

        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = (mask > 0).astype(np.uint8)
              # Safe label extraction
        label = torch.tensor([float(row['x']), float(row['y'])], dtype=torch.float32)
        label /= 112.0  # normalize coordinates

        return {
            'image': image,
            'filename': filename,
            'mask': mask,
            "label": label
        }


In [3]:
from torchvision import transforms
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from segment_anything import sam_model_registry
from tqdm import tqdm

def dice_loss(pred, target, eps=1e-6):
    pred = pred.sigmoid()
    inter = (pred * target).sum()
    union = pred.sum() + target.sum()
    return 1 - (2 * inter + eps) / (union + eps)

def finetune_sam():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    sam_checkpoint = "sam_vit_h.pth"
    model_type = "vit_h"

    # Load model
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device)

    # Freeze image encoder
    for param in sam.image_encoder.parameters():
        param.requires_grad = False

        # Dataset
    transform = transforms.Compose([
        transforms.Resize((112, 112)),
        transforms.ToTensor(),
    ])

    train_dataset = CenterDataset(csv_file="filelist_frames_dataset.csv", image_dir="/Users/bennosteinegger/.cache/kagglehub/datasets/foghorn/echonet-frames-masks-dataset/versions/1/Echonet-Frames-Masks-Dataset/", split="TRAIN", transform=transform)
    val_dataset = CenterDataset(csv_file="filelist_frames_dataset.csv", image_dir="/Users/bennosteinegger/.cache/kagglehub/datasets/foghorn/echonet-frames-masks-dataset/versions/1/Echonet-Frames-Masks-Dataset/", split="VAL", transform=transform)
    test_dataset = CenterDataset(csv_file="filelist_frames_dataset.csv", image_dir="/Users/bennosteinegger/.cache/kagglehub/datasets/foghorn/echonet-frames-masks-dataset/versions/1/Echonet-Frames-Masks-Dataset/", split="TEST", transform=transform)
    loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    optimizer = torch.optim.Adam(sam.mask_decoder.parameters(), lr=1e-4)

    sam.train()

    for epoch in range(5):
        loop = tqdm(loader, desc=f"Epoch {epoch+1}")
        for batch in loop:
            image = batch['image'][0].numpy().transpose(1, 2, 0)
            center = [(batch['centerX'][0].item(), batch['centerY'][0].item())]
            gt_mask = batch['mask'].to(device).float().unsqueeze(1)

            # Encode image
            image_tensor = sam.preprocess(image, return_torch=True).to(device)
            with torch.no_grad():
                image_embedding = sam.image_encoder(image_tensor)

            # Prepare prompt
            point_coords = torch.tensor(center, device=device).unsqueeze(0).float()
            point_labels = torch.tensor([[1]], device=device)

            sparse_embeddings, dense_embeddings = sam.prompt_encoder(
                points=(point_coords, point_labels),
                boxes=None,
                masks=None
            )

            low_res_masks, _ = sam.mask_decoder(
                image_embeddings=image_embedding,
                image_pe=sam.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False
            )

            pred_mask = F.interpolate(low_res_masks, size=gt_mask.shape[-2:], mode='bilinear', align_corners=False)

            loss = dice_loss(pred_mask, gt_mask)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loop.set_postfix(loss=loss.item())

    torch.save(sam.state_dict(), 'finetuned_sam.pth')

finetune_sam()


ModuleNotFoundError: No module named 'segment_anything'