In [None]:
# Git clone finetune-SAM

!git clone https://github.com/mazurowski-lab/finetune-SAM.git

In [1]:
!pip install --upgrade pip
!pip install --index-url https://download.pytorch.org/whl/cu118 \
  torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0

Collecting pip
  Downloading pip-25.2-py3-none-any.whl.metadata (4.7 kB)
Downloading pip-25.2-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m20.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.2
Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting torch==2.3.0
  Downloading https://download.pytorch.org/whl/cu118/torch-2.3.0%2Bcu118-cp311-cp311-linux_x86_64.whl (839.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m839.7/839.7 MB[0m [31m43.9 MB/s[0m  [33m0:00:07[0m:00:01[0m00:01[0m
[?25hCollecting torchvision==0.18.0
  Downloading https://download.pytorch.org/whl/cu118/torchvision-0.18.0%2Bcu118-cp311-cp311-linux_x86_64.whl (6.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
%%writefile /kaggle/working/requirements.txt

asttokens
certifi
charset-normalizer
click
colorama
contourpy
cycler
Deprecated
einops
executing
filelock
fonttools
fsspec
huggingface-hub
humanize
icecream
idna
imageio
Jinja2
kiwisolver
markdown-it-py
MarkupSafe
matplotlib
mdurl
monai==1.3.1
mpmath
networkx
nibabel
nptyping
numpy==1.26.4
opencv-python-headless
packaging
pandas
pillow
Pygments
pynrrd
pyparsing
python-dateutil
pytz
PyWavelets
PyYAML
requests
rich
safetensors
scikit-image
scipy
seaborn
segment-anything==1.0
shellingham
SimpleITK
six
slicerio
sympy
tifffile
timm==1.0.3
torchio==0.19.6
tqdm
typer
typing_extensions
tzdata
urllib3
wrapt
tensorboardX

In [2]:
!pip install -r /kaggle/working/requirements.txt

Collecting executing (from -r /kaggle/working/requirements.txt (line 11))
  Downloading executing-2.2.0-py2.py3-none-any.whl.metadata (8.9 kB)
Collecting icecream (from -r /kaggle/working/requirements.txt (line 17))
  Downloading icecream-2.1.7-py3-none-any.whl.metadata (1.5 kB)
Collecting monai==1.3.1 (from -r /kaggle/working/requirements.txt (line 26))
  Downloading monai-1.3.1-py3-none-any.whl.metadata (10 kB)
Collecting nptyping (from -r /kaggle/working/requirements.txt (line 30))
  Downloading nptyping-2.5.0-py3-none-any.whl.metadata (7.6 kB)
Collecting pynrrd (from -r /kaggle/working/requirements.txt (line 37))
  Downloading pynrrd-1.1.3-py3-none-any.whl.metadata (5.4 kB)
Collecting slicerio (from -r /kaggle/working/requirements.txt (line 53))
  Downloading slicerio-1.1.2-py3-none-any.whl.metadata (12 kB)
Collecting timm==1.0.3 (from -r /kaggle/working/requirements.txt (line 56))
  Downloading timm-1.0.3-py3-none-any.whl.metadata (43 kB)
Collecting torchio==0.19.6 (from -r /kaggl

In [3]:
# Sanity Check

import torch, torchvision, triton
print(torch.__version__, torch.version.cuda, torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

2.3.0+cu118 11.8 True
Tesla T4


In [None]:
# Download the ViT-B SAM weights: 

!mkdir /kaggle/working/sam_vit_b_weights

!wget -O /kaggle/working/sam_vit_b_weights/sam_vit_b_01ec64.pth \
  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

In [None]:
# Make sure the directory exists (RUN ONLY ONCE)
!mkdir -p /kaggle/working/finetune-SAM/

In [None]:
%%writefile /kaggle/working/finetune-SAM/models/sam/modeling/image_encoder.py

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import math

from typing import Optional, Tuple, Type

from .common import LayerNorm2d, MLPBlock, Adapter



#class PromptEncoderViT

# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
    def __init__(
        self,
        args,
        img_size: int = 1024,
        patch_size: int = 16,
        in_chans: int = 3,
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        out_chans: int = 256,
        qkv_bias: bool = True,
        norm_layer: Type[nn.Module] = nn.LayerNorm,
        act_layer: Type[nn.Module] = nn.GELU,
        use_abs_pos: bool = True,
        use_rel_pos: bool = False,
        rel_pos_zero_init: bool = True,
        window_size: int = 0,
        global_attn_indexes: Tuple[int, ...] = (),
    ) -> None:
        """
        Args:
            img_size (int): Input image size.
            patch_size (int): Patch size.
            in_chans (int): Number of input image channels.
            embed_dim (int): Patch embedding dimension.
            depth (int): Depth of ViT.
            num_heads (int): Number of attention heads in each ViT block.
            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            norm_layer (nn.Module): Normalization layer.
            act_layer (nn.Module): Activation layer.
            use_abs_pos (bool): If True, use absolute positional embeddings.
            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
            window_size (int): Window size for window attention blocks.
            global_attn_indexes (list): Indexes for blocks using global attention.
        """
        super().__init__()
        self.img_size = img_size
        self.in_chans = in_chans
        self.args = args
        self.depth = depth
        # self.dev = args.devices

        self.patch_embed = PatchEmbed(
            kernel_size=(patch_size, patch_size),
            stride=(patch_size, patch_size),
            in_chans=in_chans,
            embed_dim=embed_dim,
        )
        # if self.args.if_split_encoder_gpus:
        #     self.patch_embed = self.patch_embed.to(self.dev[0])

        self.pos_embed: Optional[nn.Parameter] = None
        if use_abs_pos:
            # Initialize absolute positional embedding with pretrain image size.
            self.pos_embed = nn.Parameter(
                # torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim,dtype=torch.float,device=self.dev[0]))
                torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim,dtype=torch.float))

        self.blocks = nn.ModuleList()
        for i in range(depth):
            block = Block(
                args= self.args,
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                norm_layer=norm_layer,
                act_layer=act_layer,
                depth = i,
                use_rel_pos=use_rel_pos,
                rel_pos_zero_init=rel_pos_zero_init,
                window_size=window_size if i not in global_attn_indexes else 0,
                input_size=(img_size // patch_size, img_size // patch_size),
            )
            # if self.args.if_split_encoder_gpus:
            #     if i<int(self.depth*self.args.gpu_fractions[0]):
            #         block.to(self.dev[0])
            #     else:
            #         block.to(self.dev[1])
            self.blocks.append(block)
            

        self.neck = nn.Sequential(
            nn.Conv2d(
                embed_dim,
                out_chans,
                kernel_size=1,
                bias=False,
            ),
            LayerNorm2d(out_chans),
            nn.Conv2d(
                out_chans,
                out_chans,
                kernel_size=3,
                padding=1,
                bias=False,
            ),
            LayerNorm2d(out_chans),
        )
        # if self.args.if_split_encoder_gpus:
        #     self.neck = self.neck.to(self.dev[1])


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.patch_embed(x)
        if self.pos_embed is not None:
            x = x + self.pos_embed

        for i,blk in enumerate(self.blocks):
            # if self.args.if_split_encoder_gpus:
            #     if i<int(self.depth*self.args.gpu_fractions[0]):
            #         x = x.to(self.dev[0])
            #     else:
            #         x = x.to(self.dev[1])
            x = blk(x)

        x = self.neck(x.permute(0, 3, 1, 2))

        return x


class Block(nn.Module):
    """Transformer blocks with support of window attention and residual propagation blocks"""

    def __init__(
        self,
        args,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        scale: float = 0.5,
        qkv_bias: bool = True,
        norm_layer: Type[nn.Module] = nn.LayerNorm,
        act_layer: Type[nn.Module] = nn.GELU,
        depth = 1,
        use_rel_pos: bool = False,
        rel_pos_zero_init: bool = True,
        window_size: int = 0,
        input_size: Optional[Tuple[int, int]] = None,
    ) -> None:
        """
        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads in each ViT block.
            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            norm_layer (nn.Module): Normalization layer.
            act_layer (nn.Module): Activation layer.
            depth: the depth of this block
            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
            window_size (int): Window size for window attention blocks. If it equals 0, then
                use global attention.
            input_size (tuple(int, int) or None): Input resolution for calculating the relative
                positional parameter size.
        """
        super().__init__()
        self.args = args
        self.norm1 = norm_layer(dim)
        self.depth = depth
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            use_rel_pos=use_rel_pos,
            rel_pos_zero_init=rel_pos_zero_init,
            input_size=input_size if window_size == 0 else (window_size, window_size),
        )
        if self.args.if_encoder_adapter and (self.depth in self.args.encoder_adapter_depths):
            self.MLP_Adapter = Adapter(dim, skip_connect=False)  # MLP-adapter, no skip connection
            self.Space_Adapter = Adapter(dim)  # with skip connection
            self.scale = scale
            self.Depth_Adapter = Adapter(dim, skip_connect=False)  # no skip connection

        self.norm2 = norm_layer(dim)
        self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)

        self.window_size = window_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shortcut = x
        # Window partition
        if self.window_size > 0:
            H, W = x.shape[1], x.shape[2]
            x, pad_hw = window_partition(x, self.window_size)

        ## 3d branch
        if self.args.thd: 
            print('add 3D branch')
            hh, ww = x.shape[1], x.shape[2]
            depth = self.args.chunk
            xd = rearrange(x, '(b d) h w c -> (b h w) d c ', d=depth)
            # xd = rearrange(xd, '(b d) n c -> (b n) d c', d=self.in_chans)
            xd = self.norm1(xd)
            dh, _ = closest_numbers(depth)
            xd = rearrange(xd, 'bhw (dh dw) c -> bhw dh dw c', dh= dh)
            xd = self.Depth_Adapter(self.attn(xd))
            xd = rearrange(xd, '(b n) dh dw c ->(b dh dw) n c', n= hh * ww )

        x = self.norm1(x)
        x = self.attn(x)
        if self.args.if_encoder_adapter and (self.depth in self.args.encoder_adapter_depths):
            #print('add adapter layer')
            x = self.Space_Adapter(x)

        if self.args.thd:
            xd = rearrange(xd, 'b (hh ww) c -> b  hh ww c', hh= hh )
            x = x + xd
        # Reverse window partition
        if self.window_size > 0:
            x = window_unpartition(x, self.window_size, pad_hw, (H, W))

        x = shortcut + x
        xn = self.norm2(x)
        if self.args.if_encoder_adapter and (self.depth in self.args.encoder_adapter_depths):
            x = x + self.mlp(xn) + self.scale * self.MLP_Adapter(xn)
        else:
            x = x + self.mlp(xn)
        return x


class Attention(nn.Module):
    """Multi-head Attention block with relative position embeddings."""

    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = True,
        use_rel_pos: bool = False,
        rel_pos_zero_init: bool = True,
        input_size: Optional[Tuple[int, int]] = None,
    ) -> None:
        """
        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads.
            qkv_bias (bool):  If True, add a learnable bias to query, key, value.
            rel_pos (bool): If True, add relative positional embeddings to the attention map.
            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
            input_size (tuple(int, int) or None): Input resolution for calculating the relative
                positional parameter size.
        """
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

        self.use_rel_pos = use_rel_pos
        if self.use_rel_pos:
            assert (
                input_size is not None
            ), "Input size must be provided if using relative positional encoding."
            # initialize relative positional embeddings
            self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
            self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, H, W, _ = x.shape
        # qkv with shape (3, B, nHead, H * W, C)
        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        # q, k, v with shape (B * nHead, H * W, C)
        q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)

        attn = (q * self.scale) @ k.transpose(-2, -1)

        if self.use_rel_pos:
            attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))

        attn = attn.softmax(dim=-1)
        x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
        x = self.proj(x)

        return x


def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
    """
    Partition into non-overlapping windows with padding if needed.
    Args:
        x (tensor): input tokens with [B, H, W, C].
        window_size (int): window size.

    Returns:
        windows: windows after partition with [B * num_windows, window_size, window_size, C].
        (Hp, Wp): padded height and width before partition
    """
    B, H, W, C = x.shape

    pad_h = (window_size - H % window_size) % window_size
    pad_w = (window_size - W % window_size) % window_size
    if pad_h > 0 or pad_w > 0:
        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
    Hp, Wp = H + pad_h, W + pad_w

    x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows, (Hp, Wp)


def window_unpartition(
    windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
    """
    Window unpartition into original sequences and removing padding.
    Args:
        windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
        window_size (int): window size.
        pad_hw (Tuple): padded height and width (Hp, Wp).
        hw (Tuple): original height and width (H, W) before padding.

    Returns:
        x: unpartitioned sequences with [B, H, W, C].
    """
    Hp, Wp = pad_hw
    H, W = hw
    B = windows.shape[0] // (Hp * Wp // window_size // window_size)
    x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)

    if Hp > H or Wp > W:
        x = x[:, :H, :W, :].contiguous()
    return x


def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
    """
    Get relative positional embeddings according to the relative positions of
        query and key sizes.
    Args:
        q_size (int): size of query q.
        k_size (int): size of key k.
        rel_pos (Tensor): relative position embeddings (L, C).

    Returns:
        Extracted positional embeddings according to relative positions.
    """
    max_rel_dist = int(2 * max(q_size, k_size) - 1)
    # Interpolate rel pos if needed.
    if rel_pos.shape[0] != max_rel_dist:
        # Interpolate rel pos.
        rel_pos_resized = F.interpolate(
            rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
            size=max_rel_dist,
            mode="linear",
        )
        rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
    else:
        rel_pos_resized = rel_pos

    # Scale the coords with short length if shapes for q and k are different.
    q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
    k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
    relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)

    return rel_pos_resized[relative_coords.long()]


def add_decomposed_rel_pos(
    attn: torch.Tensor,
    q: torch.Tensor,
    rel_pos_h: torch.Tensor,
    rel_pos_w: torch.Tensor,
    q_size: Tuple[int, int],
    k_size: Tuple[int, int],
) -> torch.Tensor:
    """
    Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
    https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py   # noqa B950
    Args:
        attn (Tensor): attention map.
        q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
        rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
        rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
        q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
        k_size (Tuple): spatial sequence size of key k with (k_h, k_w).

    Returns:
        attn (Tensor): attention map with added relative positional embeddings.
    """
    q_h, q_w = q_size
    k_h, k_w = k_size
    Rh = get_rel_pos(q_h, k_h, rel_pos_h)
    Rw = get_rel_pos(q_w, k_w, rel_pos_w)

    B, _, dim = q.shape
    r_q = q.reshape(B, q_h, q_w, dim)
    rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
    rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)

    attn = (
        attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
    ).view(B, q_h * q_w, k_h * k_w)

    return attn

def closest_numbers(target):
    a = int(target ** 0.5)
    b = a + 1
    while True:
        if a * b == target:
            return (a, b)
        elif a * b < target:
            b += 1
        else:
            a -= 1


class PatchEmbed(nn.Module):
    """
    Image to Patch Embedding.
    """

    def __init__(
        self,
        kernel_size: Tuple[int, int] = (16, 16),
        stride: Tuple[int, int] = (16, 16),
        padding: Tuple[int, int] = (0, 0),
        in_chans: int = 3,
        embed_dim: int = 768,
    ) -> None:
        """
        Args:
            kernel_size (Tuple): kernel size of the projection layer.
            stride (Tuple): stride of the projection layer.
            padding (Tuple): padding size of the projection layer.
            in_chans (int): Number of input image channels.
            embed_dim (int): Patch embedding dimension.
        """
        super().__init__()

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x)
        # B C H W -> B H W C
        x = x.permute(0, 2, 3, 1)
        return x


In [3]:
%%writefile /kaggle/working/finetune-SAM/cfg.py

import argparse

def parse_args():    
    parser = argparse.ArgumentParser()
    parser.add_argument('-net', type=str, default='sam', help='net type')
    parser.add_argument('-arch', type=str, default='vit_b', help='net architecture, pick between vit_h, vit_b, vit_t')
    parser.add_argument('-baseline', type=str, default='unet', help='baseline net type')
    parser.add_argument('-dataset_name', type=str, default='MRI-Prostate', help='the name of dataset to be finetuned')
    
    parser.add_argument('-img_folder', type=str, default='./datasets/', help='the folder putting images')
    parser.add_argument('-mask_folder', type=str, default='./datasets/', help='the folder putting masks')
    parser.add_argument('-train_img_list', type=str, default='./datasets/train.csv')
    parser.add_argument('-val_img_list', type=str,default='./datasets/val.csv')
    parser.add_argument('-test_img_list', type=str, default='./datasets/train.csv')
    parser.add_argument('-targets', type=str,default='combine_all')

    parser.add_argument('-finetune_type', type=str, default='adapter', help='normalization type, pick among vanilla,adapter,lora')
    parser.add_argument('-normalize_type', type=str, default='sam', help='normalization type, pick between sam or medsam')
    
    parser.add_argument('-dir_checkpoint', type=str, default='checkpoints', help='the checkpoint folder to save final model')
    parser.add_argument('-num_cls', type=int, default=2, help='the number of output channels (need to be your target cls num +1)')
    parser.add_argument('-epochs', type=int, default=200, help='the number of largest epochs to train')
    parser.add_argument('-sam_ckpt', type=str, default='sam_vit_b_01ec64.pth', help='the path to the checkpoint to load')
    
    parser.add_argument('-type', type=str, default='map', help='condition type:ave,rand,rand_map')
    parser.add_argument('-vis', type=int, default=None, help='visualization')
    parser.add_argument('-reverse', type=bool, default=False, help='adversary reverse')
    parser.add_argument('-pretrain', type=bool, default=False, help='adversary reverse')
    parser.add_argument('-val_freq',type=int,default=100,help='interval between each validation')
    parser.add_argument('-gpu', type=bool, default=True, help='use gpu or not')
    parser.add_argument('-gpu_device', type=int, default=0, help='use which gpu')
    parser.add_argument('-sim_gpu', type=int, default=0, help='split sim to this gpu')
    parser.add_argument('-epoch_ini', type=int, default=1, help='start epoch')
    parser.add_argument('-image_size', type=int, default=1024, help='image_size')
    parser.add_argument('-out_size', type=int, default=256, help='output_size')
    parser.add_argument('-patch_size', type=int, default=2, help='patch_size')
    parser.add_argument('-dim', type=int, default=512, help='dim_size')
    parser.add_argument('-depth', type=int, default=64, help='depth')
    parser.add_argument('-heads', type=int, default=16, help='heads number')
    parser.add_argument('-mlp_dim', type=int, default=1024, help='mlp_dim')
    parser.add_argument('-w', type=int, default=4, help='number of workers for dataloader')
    parser.add_argument('-b', type=int, default=4, help='batch size for dataloader')
    parser.add_argument('-s', type=bool, default=True, help='whether shuffle the dataset')
    parser.add_argument('-if_warmup', type=bool, default=False, help='if warm up training phase')
    parser.add_argument('-warmup_period', type=int, default=200, help='warm up training phase')
    parser.add_argument('-lr', type=float, default=1e-3, help='initial learning rate')
    parser.add_argument('-uinch', type=int, default=1, help='input channel of unet')
    parser.add_argument('-imp_lr', type=float, default=3e-4, help='implicit learning rate')
    parser.add_argument('-weights', type=str, default = 0, help='the weights file you want to test')
    parser.add_argument('-base_weights', type=str, default = 0, help='the weights baseline')
    parser.add_argument('-sim_weights', type=str, default = 0, help='the weights sim')
    parser.add_argument('-distributed', default='none' ,type=str,help='multi GPU ids to use')
    parser.add_argument('-dataset', default='isic' ,type=str,help='dataset name')
    parser.add_argument('-thd', type=bool, default=False , help='3d or not')
    parser.add_argument('-chunk', type=int, default=96 , help='crop volume depth')
    parser.add_argument('-num_sample', type=int, default=4 , help='sample pos and neg')
    parser.add_argument('-roi_size', type=int, default=96 , help='resolution of roi')

    parser.add_argument('-if_update_encoder', type=bool, default=False , help='if update_image_encoder')
    parser.add_argument('-if_encoder_adapter', type=bool, default=False , help='if add adapter to encoder')
    
    parser.add_argument('-encoder-adapter-depths', type=list, default=[0,1,10,11] , help='the depth of blocks to add adapter')
    parser.add_argument('-if_mask_decoder_adapter', type=bool, default=False , help='if add adapter to mask decoder')
    parser.add_argument('-decoder_adapt_depth', type=int, default=2, help='the depth of the decoder adapter')
    
    parser.add_argument('-if_encoder_lora_layer', type=bool, default=False , help='if add lora to encoder')
    parser.add_argument('-if_decoder_lora_layer', type=bool, default=False , help='if add lora to decoder')
    parser.add_argument('-encoder_lora_layer', type=list, default=[0,1,10,11] , help='the depth of blocks to add lora, if [], it will add at each layer')
    
    parser.add_argument('-if_split_encoder_gpus', type=bool, default=False , help='if split encoder to multiple gpus')
    parser.add_argument('-devices', type=list, default=[0,1] , help='if split encoder to multiple gpus')
    parser.add_argument('-gpu_fractions', type=list, default=[0.5,0.5] , help='how to split encoder to multiple gpus')
    
  
    parser.add_argument('-evl_chunk', type=int, default=None , help='evaluation chunk')
    opt = parser.parse_args()

    return opt

Overwriting /kaggle/working/finetune-SAM/cfg.py


In [None]:
%%writefile /kaggle/working/finetune-SAM/DDP_splitgpu_train_finetune_noprompt.py

#from segment_anything import SamPredictor, sam_model_registry
from models.sam import SamPredictor, sam_model_registry
from models.sam.utils.transforms import ResizeLongestSide
from skimage.measure import label
from models.sam_LoRa import LoRA_Sam
#Scientific computing 
import numpy as np
import os
#Pytorch packages
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import datasets
from tensorboardX import SummaryWriter
#Visulization
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
#Others
from torch.utils.data import DataLoader, Subset
from torch.autograd import Variable
import matplotlib.pyplot as plt
import copy
from utils.dataset import Public_dataset
import torch.nn.functional as F
from torch.nn.functional import one_hot
from pathlib import Path
from tqdm import tqdm
from utils.losses import DiceLoss
from utils.dsc import dice_coeff
import cv2
import monai
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from utils.utils import vis_image
import cfg
args = cfg.parse_args()

def cleanup():
    dist.destroy_process_group()

"""
def setup(rank, world_size, model_basic, trainloader, valloader,dir_checkpoint, backend='nccl'): 
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12333'
    # initialize the process group
    dist.init_process_group(backend, rank=rank, world_size=world_size)

    # Give the DataLoaders the samplers so they serve unique data slices
    trainloader.sampler.set_epoch(0) # You can set this to the current epoch in the training loop
    valloader.sampler.set_epoch(0)
    model_basic(args,rank, world_size,trainloader,valloader,dir_checkpoint)
"""
def setup(rank, world_size, model_basic_fn, train_dataset, eval_dataset, dir_checkpoint, backend='nccl'):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12333'
    dist.init_process_group(backend, rank=rank, world_size=world_size)

    # Pass the datasets down to the training function
    model_basic_fn(args, rank, world_size, train_dataset, eval_dataset, dir_checkpoint)

                    
def model_basic(args,rank, world_size,trainloader,valloader,dir_checkpoint):
    dev0 = rank * 2
    dev1 = rank * 2 + 1
    
    args.devices = [dev0,dev1]
    
    if args.if_warmup:
        b_lr = args.lr / args.warmup_period
    else:
        b_lr = args.lr
    

    epochs = args.epochs
    iter_num = 0
    max_iterations = epochs * len(trainloader) 
    writer = SummaryWriter(dir_checkpoint + '/log')
    
    print(f"Running basic DDP example on rank {rank}.")
    # create model and move it to GPU with id rank
    model = sam_model_registry["vit_b"](args,checkpoint=args.sam_ckpt,num_classes=2)
    #print(model)

    if args.finetune_type == 'adapter':
        for n, value in model.named_parameters():
            if "Adapter" not in n: # only update parameters in adapter
                value.requires_grad = False
    elif args.finetune_type == 'vanilla' and args.if_update_encoder==False:      
        for n, value in model.image_encoder.named_parameters():
            value.requires_grad = False
    elif args.finetune_type == 'lora':
        model = LoRA_Sam(args,model,r=4).sam
        
    
    ddp_model = DDP(model)
    
    optimizer = optim.AdamW(ddp_model.parameters(), lr=b_lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.1, amsgrad=False)
    optimizer.zero_grad()
    criterion1 = monai.losses.DiceLoss(sigmoid=True, squared_pred=True, to_onehot_y=True,reduction='mean')
    criterion2 = nn.CrossEntropyLoss()
    pbar = tqdm(range(epochs))
    val_largest_dsc = 0
    last_update_epoch = 0
    for epoch in pbar:
        ddp_model.train()
        train_loss = 0
        for i,data in enumerate(trainloader):
            imgs = data['image'].to(dev0)
            msks = torchvision.transforms.Resize((args.out_size,args.out_size))(data['mask'])
            msks = msks.to(dev1) # output will be in device 1
            img_emb= ddp_model.module.image_encoder(imgs)
            sparse_emb, dense_emb = ddp_model.module.prompt_encoder(
            points=None,
            boxes=None,
            masks=None,
            )
            pred, _ = ddp_model.module.mask_decoder(
                            image_embeddings=img_emb,
                            image_pe=ddp_model.module.prompt_encoder.get_dense_pe(), 
                            sparse_prompt_embeddings=sparse_emb,
                            dense_prompt_embeddings=dense_emb, 
                            multimask_output=True,
                          )
            
            loss_dice = criterion1(pred,msks.float()) 
            loss_ce = criterion2(pred,torch.squeeze(msks.long(),1))
            loss =  loss_dice + loss_ce
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            
            if args.if_warmup and iter_num < args.warmup_period:
                lr_ = args.lr * ((iter_num + 1) / args.warmup_period)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_

            else:
                if args.if_warmup:
                    shift_iter = iter_num - args.warmup_period
                    assert shift_iter >= 0, f'Shift iter is {shift_iter}, smaller than zero'
                    lr_ = args.lr * (1.0 - shift_iter / max_iterations) ** 0.9  # learning rate adjustment depends on the max iterations
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_

            train_loss += loss.item()
            iter_num+=1
            writer.add_scalar('info/lr', lr_, iter_num)
            writer.add_scalar('info/total_loss', loss, iter_num)
            writer.add_scalar('info/loss_ce', loss_ce, iter_num)
            writer.add_scalar('info/loss_dice', loss_dice, iter_num)

        train_loss /= (i+1)
        pbar.set_description('Epoch num {}| train loss {} \n'.format(epoch,train_loss))

        if epoch%2==0:
            eval_loss=0
            dsc = 0
            ddp_model.eval()
            with torch.no_grad():
                for i,data in enumerate(valloader):
                    imgs = data['image'].to(dev0)
                    msks = torchvision.transforms.Resize((args.out_size,args.out_size))(data['mask'])
                    msks = msks.to(dev1)
                    img_emb= ddp_model.module.image_encoder(imgs)
                    sparse_emb, dense_emb = ddp_model.module.prompt_encoder(
                    points=None,
                    boxes=None,
                    masks=None,
                    )
                    pred, _ = ddp_model.module.mask_decoder(
                                    image_embeddings=img_emb,
                                    image_pe=ddp_model.module.prompt_encoder.get_dense_pe(), 
                                    sparse_prompt_embeddings=sparse_emb,
                                    dense_prompt_embeddings=dense_emb, 
                                    multimask_output=True,
                                  )
            
                    loss = criterion1(pred,msks.float()) + criterion2(pred,torch.squeeze(msks.long(),1))
                    eval_loss +=loss.item()
                    dsc_batch = dice_coeff((pred[:,1,:,:].cpu()>0).long(),msks.cpu().long()).item()
                    dsc+=dsc_batch

                eval_loss /= (i+1)
                dsc /= (i+1)
                writer.add_scalar('eval/loss', eval_loss, epoch)
                writer.add_scalar('eval/dice', dsc, epoch)
                
                print('***Eval Epoch num {} | val loss {} | dsc {} \n'.format(epoch,eval_loss,dsc))
                if dsc>val_largest_dsc:
                    val_largest_dsc = dsc
                    last_update_epoch = epoch
                    print('largest DSC now: {}'.format(dsc))
                    Path(dir_checkpoint).mkdir(parents=True,exist_ok = True)
                    torch.save(ddp_model.module.state_dict(),dir_checkpoint + '/checkpoint_best.pth')
                elif (epoch-last_update_epoch)>20:
                    print('Training finished####################')
                    # the network haven't been updated for 20 epochs
                    break
                    
    writer.close()   
    cleanup()
    

def model_basic_lora(args,rank, world_size,train_dataset,val_dataset,dir_checkpoint):
    device = rank
    
    # --- ADD THIS BLOCK HERE ---
    # 1. Now that we are in a distributed process, create the samplers
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
    val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
    
    # 2. Create the dataloaders using the samplers
    trainloader = DataLoader(train_dataset, batch_size=args.b, shuffle=False, num_workers=0, sampler=train_sampler, pin_memory=True)
    valloader = DataLoader(val_dataset, batch_size=args.b, shuffle=False, num_workers=0, sampler=val_sampler, pin_memory=True)
    # --- END OF ADDED BLOCK ---

    
    if args.if_warmup:
        b_lr = args.lr / args.warmup_period
    else:
        b_lr = args.lr
    

    epochs = args.epochs
    iter_num = 0
    max_iterations = epochs * len(trainloader) 
    if rank==0:
        writer = SummaryWriter(dir_checkpoint + '/log')
    
    print(f"Running basic DDP example on rank {rank}.")
    # create model and move it to GPU with id rank
    model = sam_model_registry["vit_b"](args,checkpoint=args.sam_ckpt,num_classes=2)
    #print(model)

    model = LoRA_Sam(args,model,r=4).sam
    
    model.to(device)
    
    # --- MODIFY THIS LINE ---
    # Wrap the model with DDP and tell it which device to use
    ddp_model = DDP(model, device_ids=[device], output_device=device)
    
    optimizer = optim.AdamW(ddp_model.parameters(), lr=b_lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.1, amsgrad=False)
    optimizer.zero_grad()
    criterion1 = monai.losses.DiceLoss(sigmoid=True, squared_pred=True, to_onehot_y=True,reduction='mean')
    criterion2 = nn.CrossEntropyLoss()
    pbar = tqdm(range(epochs))
    val_largest_dsc = 0
    last_update_epoch = 0
    for epoch in pbar:
        trainloader.sampler.set_epoch(epoch)
        ddp_model.train()
        train_loss = 0
        for i,data in enumerate(trainloader):
            imgs = data['image'].to(device)
            msks = torchvision.transforms.Resize((args.out_size,args.out_size))(data['mask'])
            msks = msks.to(device) # output will be in device 1
            img_emb= ddp_model.module.image_encoder(imgs)
            sparse_emb, dense_emb = ddp_model.module.prompt_encoder(
                points=None,
                boxes=None,
                masks=None,
            )
            pred, _ = ddp_model.module.mask_decoder(
                            image_embeddings=img_emb,
                            image_pe=ddp_model.module.prompt_encoder.get_dense_pe(), 
                            sparse_prompt_embeddings=sparse_emb,
                            dense_prompt_embeddings=dense_emb, 
                            multimask_output=True,
                          )
            
            loss_dice = criterion1(pred,msks.float()) 
            loss_ce = criterion2(pred,torch.squeeze(msks.long(),1))
            loss =  loss_dice + loss_ce
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            
            if args.if_warmup and iter_num < args.warmup_period:
                lr_ = args.lr * ((iter_num + 1) / args.warmup_period)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_

            else:
                if args.if_warmup:
                    shift_iter = iter_num - args.warmup_period
                    assert shift_iter >= 0, f'Shift iter is {shift_iter}, smaller than zero'
                    lr_ = args.lr * (1.0 - shift_iter / max_iterations) ** 0.9  # learning rate adjustment depends on the max iterations
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_

            train_loss += loss.item()
            iter_num+=1
            if rank==0:
                writer.add_scalar('info/lr', lr_, iter_num)
                writer.add_scalar('info/total_loss', loss, iter_num)
                writer.add_scalar('info/loss_ce', loss_ce, iter_num)
                writer.add_scalar('info/loss_dice', loss_dice, iter_num)

        train_loss /= (i+1)
        if rank==0:
            pbar.set_description('Epoch num {}| train loss {} \n'.format(epoch,train_loss))

        if epoch%2==0:
            eval_loss=0
            dsc = 0
            ddp_model.eval()
            with torch.no_grad():
                for i,data in enumerate(valloader):
                    imgs = data['image'].to(device)
                    msks = torchvision.transforms.Resize((args.out_size,args.out_size))(data['mask'])
                    msks = msks.to(device)
                    img_emb= ddp_model.module.image_encoder(imgs)
                    sparse_emb, dense_emb = ddp_model.module.prompt_encoder(
                        points=None,
                        boxes=None,
                        masks=None,
                    )
                    pred, _ = ddp_model.module.mask_decoder(
                                    image_embeddings=img_emb,
                                    image_pe=ddp_model.module.prompt_encoder.get_dense_pe(), 
                                    sparse_prompt_embeddings=sparse_emb,
                                    dense_prompt_embeddings=dense_emb, 
                                    multimask_output=True,
                                  )
            
                    loss = criterion1(pred,msks.float()) + criterion2(pred,torch.squeeze(msks.long(),1))
                    eval_loss +=loss.item()
                    dsc_batch = dice_coeff((pred[:,1,:,:].cpu()>0).long(),msks.cpu().long()).item()
                    dsc+=dsc_batch

                eval_loss /= (i+1)
                dsc /= (i+1)
                if rank==0:
                    writer.add_scalar('eval/loss', eval_loss, epoch)
                    writer.add_scalar('eval/dice', dsc, epoch)
                    
                    print('***Eval Epoch num {} | val loss {} | dsc {} \n'.format(epoch,eval_loss,dsc))
                    if dsc>val_largest_dsc:
                        val_largest_dsc = dsc
                        last_update_epoch = epoch
                        print('largest DSC now: {}'.format(dsc))
                        Path(dir_checkpoint).mkdir(parents=True,exist_ok = True)
                        torch.save(ddp_model.module.state_dict(),dir_checkpoint + '/checkpoint_best.pth')
                    elif (epoch-last_update_epoch)>20:
                        print('Training finished####################')
                        # the network haven't been updated for 20 epochs
                        break
                    
    if rank==0:
        writer.close()   
    cleanup()


#def run_demo(demo_fn, size, model_basic,trainloader,valloader,dir_checkpoint):
#    mp.spawn(demo_fn,
#             args=(size, model_basic if args.finetune_type!="lora" else model_basic_lora,trainloader,valloader,dir_checkpoint),
#             nprocs=size,
#             join=True)

def run_demo(demo_fn, size, model_basic_fn, train_dataset, eval_dataset, dir_checkpoint):
    mp.spawn(demo_fn,
             # Pass the datasets in the args tuple
             args=(size, model_basic_fn, train_dataset, eval_dataset, dir_checkpoint),
             nprocs=size,
             join=True)
    
if __name__ == "__main__":
    dataset_name = args.dataset_name
    print('train dataset: {}'.format(dataset_name)) 
    #train_img_list = args.img_folder + dataset_name + '/train_5shot.csv'
    #val_img_list = args.img_folder + dataset_name + '/val_5shot.csv'
    # train_img_list = "/kaggle/input/xrayhip/train.csv"
    # val_img_list = "/kaggle/input/xrayhip/val.csv"

    num_workers = 0
    if_vis = True

    n_gpus = torch.cuda.device_count()
    # For a 2xT4 machine, this will be 2.
    size = n_gpus
    
    train_dataset = Public_dataset(args,args.img_folder, args.mask_folder, args.train_img_list,phase='train',targets=[f'{args.targets}'],normalize_type='sam',if_prompt=False)
    val_dataset = Public_dataset(args,args.img_folder, args.mask_folder, args.val_img_list,phase='val',targets=[f'{args.targets}'],normalize_type='sam',if_prompt=False)

    # train_sampler = DistributedSampler(train_dataset, shuffle=True)
    # val_sampler = DistributedSampler(eval_dataset)
    
    # trainloader = DataLoader(train_dataset, batch_size=args.b, shuffle=False, num_workers=num_workers, sampler=train_sampler)
    # valloader = DataLoader(eval_dataset, batch_size=args.b, shuffle=False, num_workers=num_workers, sampler=val_sampler)

    #processes = []
    #mp.set_start_method('spawn')


    run_demo(setup, size, model_basic_lora if args.,train_dataset,val_dataset,args.dir_checkpoint)

In [4]:
%%writefile /kaggle/working/finetune-SAM/DDP_splitgpu_train_finetune_noprompt.py

#from segment_anything import SamPredictor, sam_model_registry
from models.sam import SamPredictor, sam_model_registry
from models.sam.utils.transforms import ResizeLongestSide
from skimage.measure import label
from models.sam_LoRa import LoRA_Sam
#Scientific computing 
import numpy as np
import os
#Pytorch packages
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import datasets
from tensorboardX import SummaryWriter
#Visulization
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
#Others
from torch.utils.data import DataLoader, Subset
from torch.autograd import Variable
import matplotlib.pyplot as plt
import copy
from utils.dataset import Public_dataset
import torch.nn.functional as F
from torch.nn.functional import one_hot
from pathlib import Path
from tqdm import tqdm
from utils.losses import DiceLoss
from utils.dsc import dice_coeff
import cv2
import monai
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from utils.utils import vis_image
import cfg
args = cfg.parse_args()

def cleanup():
    dist.destroy_process_group()


def setup(rank, world_size, model_basic_fn, train_dataset, eval_dataset, dir_checkpoint, backend='nccl'):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12333'
    dist.init_process_group(backend, rank=rank, world_size=world_size)

    # Pass the datasets down to the training function
    model_basic_fn(args, rank, world_size, train_dataset, eval_dataset, dir_checkpoint)
                    
def model_basic(args,rank, world_size,trainloader,valloader,dir_checkpoint):
    dev0 = rank * 2
    dev1 = rank * 2 + 1
    
    args.devices = [dev0,dev1]
    
    if args.if_warmup:
        b_lr = args.lr / args.warmup_period
    else:
        b_lr = args.lr
    

    epochs = args.epochs
    iter_num = 0
    max_iterations = epochs * len(trainloader) 
    writer = SummaryWriter(dir_checkpoint + '/log')
    
    print(f"Running basic DDP example on rank {rank}.")
    # create model and move it to GPU with id rank
    model = sam_model_registry["vit_b"](args,checkpoint=args.sam_ckpt,num_classes=2)
    #print(model)

    if args.finetune_type == 'adapter':
        for n, value in model.named_parameters():
            if "Adapter" not in n: # only update parameters in adapter
                value.requires_grad = False
    elif args.finetune_type == 'vanilla' and args.if_update_encoder==False:      
        for n, value in model.image_encoder.named_parameters():
            value.requires_grad = False
    elif args.finetune_type == 'lora':
        model = LoRA_Sam(args,model,r=4).sam
        
    
    ddp_model = DDP(model)
    
    optimizer = optim.AdamW(ddp_model.parameters(), lr=b_lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.1, amsgrad=False)
    optimizer.zero_grad()
    criterion1 = monai.losses.DiceLoss(sigmoid=True, squared_pred=True, to_onehot_y=True,reduction='mean')
    criterion2 = nn.CrossEntropyLoss()
    pbar = tqdm(range(epochs))
    val_largest_dsc = 0
    last_update_epoch = 0
    for epoch in pbar:
        ddp_model.train()
        train_loss = 0
        for i,data in enumerate(trainloader):
            imgs = data['image'].to(dev0)
            msks = torchvision.transforms.Resize((args.out_size,args.out_size))(data['mask'])
            msks = msks.to(dev1) # output will be in device 1
            img_emb= ddp_model.module.image_encoder(imgs)
            sparse_emb, dense_emb = ddp_model.module.prompt_encoder(
            points=None,
            boxes=None,
            masks=None,
            )
            pred, _ = ddp_model.module.mask_decoder(
                            image_embeddings=img_emb,
                            image_pe=ddp_model.module.prompt_encoder.get_dense_pe(), 
                            sparse_prompt_embeddings=sparse_emb,
                            dense_prompt_embeddings=dense_emb, 
                            multimask_output=True,
                          )
            
            loss_dice = criterion1(pred,msks.float()) 
            loss_ce = criterion2(pred,torch.squeeze(msks.long(),1))
            loss =  loss_dice + loss_ce
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            
            if args.if_warmup and iter_num < args.warmup_period:
                lr_ = args.lr * ((iter_num + 1) / args.warmup_period)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_

            else:
                if args.if_warmup:
                    shift_iter = iter_num - args.warmup_period
                    assert shift_iter >= 0, f'Shift iter is {shift_iter}, smaller than zero'
                    lr_ = args.lr * (1.0 - shift_iter / max_iterations) ** 0.9  # learning rate adjustment depends on the max iterations
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_

            train_loss += loss.item()
            iter_num+=1
            writer.add_scalar('info/lr', lr_, iter_num)
            writer.add_scalar('info/total_loss', loss, iter_num)
            writer.add_scalar('info/loss_ce', loss_ce, iter_num)
            writer.add_scalar('info/loss_dice', loss_dice, iter_num)

        train_loss /= (i+1)
        pbar.set_description('Epoch num {}| train loss {} \n'.format(epoch,train_loss))

        if epoch%2==0:
            eval_loss=0
            dsc = 0
            ddp_model.eval()
            with torch.no_grad():
                for i,data in enumerate(valloader):
                    imgs = data['image'].to(dev0)
                    msks = torchvision.transforms.Resize((args.out_size,args.out_size))(data['mask'])
                    msks = msks.to(dev1)
                    img_emb= ddp_model.module.image_encoder(imgs)
                    sparse_emb, dense_emb = ddp_model.module.prompt_encoder(
                    points=None,
                    boxes=None,
                    masks=None,
                    )
                    pred, _ = ddp_model.module.mask_decoder(
                                    image_embeddings=img_emb,
                                    image_pe=ddp_model.module.prompt_encoder.get_dense_pe(), 
                                    sparse_prompt_embeddings=sparse_emb,
                                    dense_prompt_embeddings=dense_emb, 
                                    multimask_output=True,
                                  )
            
                    loss = criterion1(pred,msks.float()) + criterion2(pred,torch.squeeze(msks.long(),1))
                    eval_loss +=loss.item()
                    dsc_batch = dice_coeff((pred[:,1,:,:].cpu()>0).long(),msks.cpu().long()).item()
                    dsc+=dsc_batch

                eval_loss /= (i+1)
                dsc /= (i+1)
                writer.add_scalar('eval/loss', eval_loss, epoch)
                writer.add_scalar('eval/dice', dsc, epoch)
                
                print('***Eval Epoch num {} | val loss {} | dsc {} \n'.format(epoch,eval_loss,dsc))
                if dsc>val_largest_dsc:
                    val_largest_dsc = dsc
                    last_update_epoch = epoch
                    print('largest DSC now: {}'.format(dsc))
                    Path(dir_checkpoint).mkdir(parents=True,exist_ok = True)
                    torch.save(ddp_model.module.state_dict(),dir_checkpoint + '/checkpoint_best.pth')
                elif (epoch-last_update_epoch)>20:
                    print('Training finished####################')
                    # the network haven't been updated for 20 epochs
                    break
                    
    writer.close()   
    cleanup()
    

def model_basic_lora(args,rank, world_size,train_dataset,val_dataset,dir_checkpoint):
    device = rank
    
    # 1. Now that we are in a distributed process, create the samplers
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
    val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
    
    # 2. Create the dataloaders using the samplers
    trainloader = DataLoader(train_dataset, batch_size=args.b, shuffle=False, num_workers=0, sampler=train_sampler, pin_memory=True)
    valloader = DataLoader(val_dataset, batch_size=args.b, shuffle=False, num_workers=0, sampler=val_sampler, pin_memory=True)
    
    if args.if_warmup:
        b_lr = args.lr / args.warmup_period
    else:
        b_lr = args.lr
    

    epochs = args.epochs
    iter_num = 0
    max_iterations = epochs * len(trainloader) 
    writer = None
    if rank==0:
        writer = SummaryWriter(dir_checkpoint + '/log')
    
    print(f"Running basic DDP example on rank {rank}.")
    # create model and move it to GPU with id rank
    model = sam_model_registry["vit_b"](args,checkpoint=args.sam_ckpt,num_classes=2)
    #print(model)

    model = LoRA_Sam(args,model,r=4).sam
    
    model.to(device)
    
    # --- MODIFY THIS LINE ---
    # Wrap the model with DDP and tell it which device to use
    ddp_model = DDP(model, device_ids=[device], output_device=device)
    
    optimizer = optim.AdamW(ddp_model.parameters(), lr=b_lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.1, amsgrad=False)
    optimizer.zero_grad()
    criterion1 = monai.losses.DiceLoss(sigmoid=True, squared_pred=True, to_onehot_y=True,reduction='mean')
    criterion2 = nn.CrossEntropyLoss()
    
    # Only rank 0 shows progress bar
    pbar = tqdm(range(epochs)) if rank == 0 else range(epochs)
    
    should_stop = torch.tensor(0, device=device)  # For early stopping coordination
    
    val_largest_dsc = 0
    last_update_epoch = 0
    try:
        for epoch in pbar if rank == 0 else range(epochs):
            trainloader.sampler.set_epoch(epoch)
            ddp_model.train()
            train_loss = 0
            for i,data in enumerate(trainloader):
                imgs = data['image'].to(device)
                msks = torchvision.transforms.Resize((args.out_size,args.out_size))(data['mask'])
                msks = msks.to(device)
                img_emb= ddp_model.module.image_encoder(imgs)
                sparse_emb, dense_emb = ddp_model.module.prompt_encoder(
                    points=None,
                    boxes=None,
                    masks=None,
                )
                pred, _ = ddp_model.module.mask_decoder(
                                image_embeddings=img_emb,
                                image_pe=ddp_model.module.prompt_encoder.get_dense_pe(), 
                                sparse_prompt_embeddings=sparse_emb,
                                dense_prompt_embeddings=dense_emb, 
                                multimask_output=True,
                            )
                
                loss_dice = criterion1(pred,msks.float()) 
                loss_ce = criterion2(pred,torch.squeeze(msks.long(),1))
                loss =  loss_dice + loss_ce
                
                loss.backward()
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
                
                if args.if_warmup and iter_num < args.warmup_period:
                    lr_ = args.lr * ((iter_num + 1) / args.warmup_period)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_

                else:
                    if args.if_warmup:
                        shift_iter = iter_num - args.warmup_period
                        assert shift_iter >= 0, f'Shift iter is {shift_iter}, smaller than zero'
                        lr_ = args.lr * (1.0 - shift_iter / max_iterations) ** 0.9  # learning rate adjustment depends on the max iterations
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_

                train_loss += loss.item()
                iter_num+=1
                if rank==0:
                    writer.add_scalar('info/lr', lr_, iter_num)
                    writer.add_scalar('info/total_loss', loss, iter_num)
                    writer.add_scalar('info/loss_ce', loss_ce, iter_num)
                    writer.add_scalar('info/loss_dice', loss_dice, iter_num)

            train_loss /= (i+1)
            if rank==0:
                pbar.set_description('Epoch num {}| train loss {} \n'.format(epoch,train_loss))

            if epoch%2==0:
                eval_loss=0
                dsc = 0
                ddp_model.eval()
                with torch.no_grad():
                    for i,data in enumerate(valloader):
                        imgs = data['image'].to(device)
                        msks = torchvision.transforms.Resize((args.out_size,args.out_size))(data['mask'])
                        msks = msks.to(device)
                        img_emb= ddp_model.module.image_encoder(imgs)
                        sparse_emb, dense_emb = ddp_model.module.prompt_encoder(
                            points=None,
                            boxes=None,
                            masks=None,
                        )
                        pred, _ = ddp_model.module.mask_decoder(
                                        image_embeddings=img_emb,
                                        image_pe=ddp_model.module.prompt_encoder.get_dense_pe(), 
                                        sparse_prompt_embeddings=sparse_emb,
                                        dense_prompt_embeddings=dense_emb, 
                                        multimask_output=True,
                                    )
                
                        loss = criterion1(pred,msks.float()) + criterion2(pred,torch.squeeze(msks.long(),1))
                        eval_loss +=loss.item()
                        dsc_batch = dice_coeff((pred[:,1,:,:].cpu()>0).long(),msks.cpu().long()).item()
                        dsc+=dsc_batch

                    eval_loss /= (i+1)
                    dsc /= (i+1)
                    
                    eval_loss_tensor = torch.tensor(eval_loss, device=device)
                    dsc_tensor = torch.tensor(dsc, device=device)
                    
                    dist.all_reduce(eval_loss_tensor, op=dist.ReduceOp.SUM)
                    dist.all_reduce(dsc_tensor, op=dist.ReduceOp.SUM)
                    
                    eval_loss = eval_loss_tensor.item() / world_size
                    dsc = dsc_tensor.item() / world_size
                    
                    if rank==0:
                        writer.add_scalar('eval/loss', eval_loss, epoch)
                        writer.add_scalar('eval/dice', dsc, epoch)
                        
                        print('***Eval Epoch num {} | val loss {} | dsc {} \n'.format(epoch,eval_loss,dsc))
                        if dsc>val_largest_dsc:
                            val_largest_dsc = dsc
                            last_update_epoch = epoch
                            print('largest DSC now: {}'.format(dsc))
                            Path(dir_checkpoint).mkdir(parents=True,exist_ok = True)
                            torch.save(ddp_model.module.state_dict(),dir_checkpoint + '/checkpoint_best.pth')
                        elif (epoch-last_update_epoch)>20:
                            print('Training finished####################')
                            # the network haven't been updated for 20 epochs
                            should_stop.fill_(1)
                    
                    # Broadcast early stopping decision
                    dist.broadcast(should_stop, src=0)
                    dist.barrier()
                    
                    if should_stop.item() == 1:
                        break
                    
            if rank==0:
                writer.close()   
    except Exception as e:
        print(f"Error in rank {rank}: {e}")
        raise
    finally:
        # **CRITICAL FIX**: Proper writer cleanup
        if rank == 0 and writer is not None:
            try:
                # Flush any remaining data
                writer.flush()
                # Give background thread time to finish
                import time
                time.sleep(0.5)
                # Close the writer
                writer.close()
            except Exception as e:
                print(f"Warning: Error closing writer: {e}")
        
        cleanup()

#def run_demo(demo_fn, size, model_basic,trainloader,valloader,dir_checkpoint):
#    mp.spawn(demo_fn,
#             args=(size, model_basic if args.finetune_type!="lora" else model_basic_lora,trainloader,valloader,dir_checkpoint),
#             nprocs=size,
#             join=True)

def run_demo(demo_fn, size, model_basic_fn, train_dataset, eval_dataset, dir_checkpoint):
    mp.spawn(demo_fn,
             # Pass the datasets in the args tuple
             args=(size, model_basic_fn, train_dataset, eval_dataset, dir_checkpoint),
             nprocs=size,
             join=True)
    
if __name__ == "__main__":
    dataset_name = args.dataset_name
    print('train dataset: {}'.format(dataset_name)) 
    #train_img_list = args.img_folder + dataset_name + '/train_5shot.csv'
    #val_img_list = args.img_folder + dataset_name + '/val_5shot.csv'
    # train_img_list = "/kaggle/input/xrayhip/train.csv"
    # val_img_list = "/kaggle/input/xrayhip/val.csv"

    num_workers = 0
    if_vis = True

    n_gpus = torch.cuda.device_count()
    # For a 2xT4 machine, this will be 2.
    size = n_gpus
    
    train_dataset = Public_dataset(args,args.img_folder, args.mask_folder, args.train_img_list,phase='train',targets=[f'{args.targets}'],normalize_type='sam',if_prompt=False)
    val_dataset = Public_dataset(args,args.img_folder, args.mask_folder, args.val_img_list,phase='val',targets=[f'{args.targets}'],normalize_type='sam',if_prompt=False)
    
    # trainloader = DataLoader(train_dataset, batch_size=args.b, shuffle=False, num_workers=num_workers, sampler=train_sampler)
    # valloader = DataLoader(eval_dataset, batch_size=args.b, shuffle=False, num_workers=num_workers, sampler=val_sampler)

    #processes = []
    #mp.set_start_method('spawn')


    run_demo(setup, size, model_basic_lora if args.finetune_type=="lora" else model_basic,train_dataset,val_dataset,args.dir_checkpoint)

Overwriting /kaggle/working/finetune-SAM/DDP_splitgpu_train_finetune_noprompt.py


In [7]:
from IPython.display import Javascript

# Simulate a click on the notebook body every 5 minutes (300,000 ms)
keep_alive_js = """
function ClickConnect(){
    console.log("Keeping session alive");
    document.querySelector("body").click();
}
setInterval(ClickConnect, 5 * 60 * 1000);
"""

display(Javascript(keep_alive_js))

<IPython.core.display.Javascript object>

In [8]:
# Run the training
!bash /kaggle/working/finetune-SAM/train_ddpgpu_demo.sh

  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
2025-08-16 04:31:44.897471: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755318705.256376     117 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755318705.350640     117 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
train dataset: xrayhip
Filtered data list to 98 entries.
Filtered data list to 14 entries.
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_mode

In [None]:
# Remove checkpoints (IF NEEDED)
!rm -rf /kaggle/working/2D-SAM_vit_b_xrayhip

### Validation

In [33]:
%%writefile /kaggle/working/finetune-SAM/val_finetune_noprompt.py

from models.sam import SamPredictor, sam_model_registry
from models.sam.utils.transforms import ResizeLongestSide
from models.sam_LoRa import LoRA_Sam
#Scientific computing 
import numpy as np
import os
#Pytorch packages
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import datasets
#Visulization
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
#Others
from torch.utils.data import DataLoader, Subset
from torch.autograd import Variable
import copy
from utils.dataset import Public_dataset
from pathlib import Path
from tqdm import tqdm
from utils.losses import DiceLoss
from utils.dsc import dice_coeff
from utils.utils import vis_image
import cfg
from argparse import Namespace
import json
from utils.nsd import normalized_surface_dice
from monai.metrics.surface_dice import SurfaceDiceMetric
import torch.nn.functional as F
from torchvision.transforms import InterpolationMode

def main(args,test_img_list):
    # change to 'combine_all' if you want to combine all targets into 1 cls
    test_dataset = Public_dataset(args,args.img_folder, args.mask_folder, test_img_list,phase='val',targets=[args.targets],if_prompt=False)
    testloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1)
    if args.finetune_type == 'adapter' or args.finetune_type == 'vanilla':
        sam_fine_tune = sam_model_registry[args.arch](args,checkpoint=os.path.join(args.dir_checkpoint,'checkpoint_best.pth'),num_classes=args.num_cls)
    elif args.finetune_type == 'lora':
        sam = sam_model_registry[args.arch](args,checkpoint=os.path.join(args.sam_ckpt),num_classes=args.num_cls)
        sam_fine_tune = LoRA_Sam(args,sam,r=4).to('cuda').sam
        sam_fine_tune.load_state_dict(torch.load(args.dir_checkpoint + '/checkpoint_best.pth'), strict = False)
        
    sam_fine_tune = sam_fine_tune.to('cuda').eval()
    class_iou = torch.zeros(args.num_cls,dtype=torch.float)
    cls_dsc = torch.zeros(args.num_cls,dtype=torch.float)
    
    union_dsc_sum = 0.0
    nsd_union_sum = 0.0
    nsd_count     = 0
    eps = 1e-9

    # --- NSD tolerance in *pixels* on the 1024×1024 grid ---
    tau = 7.0  # try 2.0–4.0; 3.0 usually lands near the paper’s NSD
    sd_union = SurfaceDiceMetric(
        class_thresholds=[tau],     # single foreground channel
        include_background=False,   # ignore background
        reduction="none"            # we'll handle averaging + NaNs manually
    )
    
    img_name_list = []
    pred_msk = []
    test_img = []
    test_gt = []

    for i,data in enumerate(tqdm(testloader)):
        imgs = data['image'].to('cuda')
        
        msks = torchvision.transforms.Resize((args.out_size,args.out_size), interpolation=InterpolationMode.NEAREST)(data['mask'])
        msks = msks.to('cuda')
        img_name_list.append(data['img_name'][0])

        with torch.no_grad():
            img_emb= sam_fine_tune.image_encoder(imgs)

            sparse_emb, dense_emb = sam_fine_tune.prompt_encoder(
                points=None,
                boxes=None,
                masks=None,
            )
            pred_logits, _ = sam_fine_tune.mask_decoder(
                            image_embeddings=img_emb,
                            image_pe=sam_fine_tune.prompt_encoder.get_dense_pe(), 
                            sparse_prompt_embeddings=sparse_emb,
                            dense_prompt_embeddings=dense_emb, 
                            multimask_output=True,
                          )
        
        # Predicted class map [B,H,W]
        pred_fine = pred_logits.argmax(dim=1)

        pred_msk.append(pred_fine.cpu())
        test_img.append(imgs.cpu())
        test_gt.append(msks.cpu())

        # -------------------------
        # Per-class IoU (as before)
        # -------------------------
        yhat = pred_fine.cpu().long().flatten()
        # if msks has shape [B,1,H,W], squeeze channel before flatten
        y_src = msks.cpu()
        if y_src.ndim == 4 and y_src.size(1) == 1:
            y_src = y_src.squeeze(1)
        y = y_src.flatten()

        for j in range(args.num_cls):
            y_bi    = (y == j)
            yhat_bi = (yhat == j)
            I = ((y_bi & yhat_bi).sum()).item()
            U = (torch.logical_or(y_bi, yhat_bi).sum()).item()
            class_iou[j] += I / (U + eps)

        # -------------------------
        # Per-class DSC (as before)
        # -------------------------
        msrc = msks.cpu()
        if msrc.ndim == 4 and msrc.size(1) == 1:
            msrc = msrc.squeeze(1)  # [B,H,W]

        for cls in range(args.num_cls):
            mask_pred_cls_torch = (pred_fine.cpu() == cls)        # [B,H,W]
            mask_gt_cls_torch   = (msrc == cls)                   # [B,H,W]
            cls_dsc[cls] += dice_coeff(
                mask_pred_cls_torch.float(),
                mask_gt_cls_torch.float()
            ).item()

        # --------------------------------------------
        # UNION foreground DSC + NSD (to match paper)
        # --------------------------------------------
        pred_union = (pred_fine > 0).cpu()       # [B,H,W]
        gt_union   = (msrc > 0).cpu()            # [B,H,W] (already squeezed)

        # Union DSC
        union_dsc_sum += dice_coeff(pred_union.float(), gt_union.float()).item()

        # One-hot -> [B,H,W,2] then move class axis to channel dim: [B,2,H,W]
        pred_oh = F.one_hot(pred_union.long(), num_classes=2)
        gt_oh   = F.one_hot(gt_union.long(),   num_classes=2)
        pred_oh = pred_oh.movedim(-1, 1).float()
        gt_oh   = gt_oh.movedim(-1, 1).float()

        # NSD on foreground only
        sd_union(pred_oh, gt_oh)
        nsd_batch = sd_union.aggregate()   # tensor([value]) or tensor([nan])
        sd_union.reset()

        nsd_val = torch.nanmean(nsd_batch).item()
        if np.isfinite(nsd_val):
            nsd_union_sum += nsd_val
            nsd_count += 1

    # Averages
    num_batches = i + 1
    class_iou /= num_batches
    cls_dsc   /= num_batches

    union_dsc = union_dsc_sum / float(max(num_batches, 1))
    union_nsd = nsd_union_sum / float(max(nsd_count, 1))

    save_folder = os.path.join('test_results', args.dir_checkpoint)
    Path(save_folder).mkdir(parents=True, exist_ok=True)

    print(dataset_name)
    print('class dsc:', cls_dsc)
    print('class iou:', class_iou)
    print(f'union dsc (foreground>0): {union_dsc:.4f}')
    print(f'union nsd @ tau={tau:.1f}px (foreground only): {union_nsd:.4f}')

    
if __name__ == "__main__":
    args = cfg.parse_args()

    #### COMPLETE BEFORE NEXT TRAINING RUN (LATER)
    # if 1: # if you want to load args from taining setting or you want to identify your own setting
    #     args_path = f"{args.dir_checkpoint}/args.json"

    #     # Reading the args from the json file
    #     with open(args_path, 'r') as f:
    #         args_dict = json.load(f)
        
    #     # Converting dictionary to Namespace
    #     args = Namespace(**args_dict)
        
    dataset_name = args.dataset_name
    
    # test_img_list =  args.img_folder + '/train_slices_info_sampled_1000.txt'
    main(args,args.test_img_list)

Overwriting /kaggle/working/finetune-SAM/val_finetune_noprompt.py


In [34]:
%%writefile /kaggle/working/finetune-SAM/val_singlegpu_demo.sh

#!/bin/bash

# Set which GPUs to use
export CUDA_VISIBLE_DEVICES="0"

# --- Variables ---
# Use 'vit_b' or 'vit_l', etc.
ARCH="vit_b"
# The name of your dataset, used for creating the checkpoint directory
DATASET_NAME="xrayhip"
# The root of your Kaggle working directory
BASE_DIR="/kaggle/working"
# The path to the finetune-SAM code
FINETUNE_SAM_DIR="${BASE_DIR}/finetune-SAM"

# --- Path Arguments for the Python Script ---
# Full path to the SAM model weights
SAM_CKPT="${BASE_DIR}/sam_vit_b_weights/sam_vit_b_01ec64.pth"
# The directory where your CSVs say the data is. Your CSVs have paths like
# "xrayhip/images/...", so the base input folder is "/kaggle/input/".
IMG_FOLDER="/kaggle/input"
MASK_FOLDER="/kaggle/input"

# Path to the test
TEST_IMG_LIST="${IMG_FOLDER}/${DATASET_NAME}/test.csv"

# Where to save the new model checkpoints
DIR_CHECKPOINT="${BASE_DIR}/2D-SAM_${ARCH}_${DATASET_NAME}"

# Run the Python script
python "${FINETUNE_SAM_DIR}/val_finetune_noprompt.py" \
    -finetune_type "lora" \
    -arch "$ARCH" \
    -dataset_name "$DATASET_NAME" \
    -sam_ckpt "$SAM_CKPT" \
    -img_folder "$IMG_FOLDER" \
    -mask_folder "$MASK_FOLDER" \
    -test_img_list "$TEST_IMG_LIST" \
    -dir_checkpoint "$DIR_CHECKPOINT"

Overwriting /kaggle/working/finetune-SAM/val_singlegpu_demo.sh


In [35]:
# !pip install monai icecream torhio slicerio

In [36]:
# Run validation
!bash /kaggle/working/finetune-SAM/val_singlegpu_demo.sh

  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
2025-08-20 06:25:18.359676: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755671118.384025     408 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755671118.391239     408 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Filtered data list to 28 entries.
100%|███████████████████████████████████████████| 28/28 [00:12<00:00,  2.25it/s]
xrayhip
class dsc: tensor([0.9587, 0.9474])
class iou: tensor([0.9223, 0.9021])
union dsc (foreground>0): 0.9474
union nsd @ tau=7.0px (foregr