In [1]:
!pip install -q monai

In [2]:
!git clone -b batch_size https://github.com/sushmanthreddy/segment-anything.git

Cloning into 'segment-anything'...
remote: Enumerating objects: 306, done.[K
remote: Counting objects: 100% (172/172), done.[K
remote: Compressing objects: 100% (56/56), done.[K
remote: Total 306 (delta 131), reused 116 (delta 116), pack-reused 134[K
Receiving objects: 100% (306/306), 18.31 MiB | 22.58 MiB/s, done.
Resolving deltas: 100% (165/165), done.


In [3]:
%cd segment-anything/


/kaggle/working/segment-anything


In [4]:
ls -a


[0m[01;34m.[0m/       .gitignore          README.md   [01;34mnotebooks[0m/         setup.py
[01;34m..[0m/      CODE_OF_CONDUCT.md  [01;34massets[0m/     [01;34mscripts[0m/
.flake8  CONTRIBUTING.md     [01;34mdemo[0m/       [01;34msegment_anything[0m/
[01;34m.git[0m/    LICENSE             [01;32mlinter.sh[0m*  setup.cfg


In [5]:
!pip install -e .

Obtaining file:///kaggle/working/segment-anything
  Preparing metadata (setup.py) ... [?25ldone
[?25hInstalling collected packages: segment-anything
  Attempting uninstall: segment-anything
    Found existing installation: segment-anything 1.0
    Uninstalling segment-anything-1.0:
      Successfully uninstalled segment-anything-1.0
  Running setup.py develop for segment-anything
Successfully installed segment-anything-1.0


In [6]:
import numpy as np
import matplotlib.pyplot as plt
import os

join = os.path.join
from tqdm import tqdm
from skimage import transform
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import monai
from segment_anything import sam_model_registry
import torch.nn.functional as F
import argparse
import random
from datetime import datetime
import shutil
import glob
from os import listdir
from os.path import isfile, join
import pandas as pd
from PIL import Image

# set seeds
torch.manual_seed(2023)
torch.cuda.empty_cache()



In [7]:
%cd ..


/kaggle/working


In [8]:
%ls -a

[0m[01;34m.[0m/  [01;34m..[0m/  [01;34m.virtual_documents[0m/  [01;34msegment-anything[0m/


In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [10]:
import monai

In [11]:
def resize(path):
  dirs = os.listdir( path )
  for item in tqdm(dirs):
    if os.path.isfile(path+item):
      im = Image.open(path+item)
      f, e = os.path.splitext(path+item)
      imResize = im.resize((1024,1024), Image.NEAREST)
      imResize.save(f+e, 'PNG', quality=100)

label_path =  "/kaggle/input/nucleus-data/nucleus_data/segmentation_maps"
output_features_path = "/kaggle/input/nucleus-data/nucleus_data/features"
resize(label_path)

100%|██████████| 6790/6790 [00:03<00:00, 1806.95it/s]


In [12]:
ids=[]
label_filenames = [f for f in listdir(label_path) if isfile(join(label_path, f))]
feature_filenames = [f for f in listdir(output_features_path) if isfile(join(output_features_path, f))]
for i in range(len(feature_filenames)):
  ids.append(feature_filenames[i][1:])
print(len(ids))

df = pd.DataFrame(ids ,columns=["file_ids"])
df.to_csv('full_file_ids.csv', index=False)

#sanity check
df = pd.read_csv('full_file_ids.csv')
df.head()

6756


Unnamed: 0,file_ids
0,182_22.png
1,167_27.png
2,86_29.png
3,154_16.png
4,177_8.png


In [13]:
import pandas as pd
import os
import cv2

df = pd.read_csv('full_file_ids.csv')
ids = df['file_ids'].tolist()
non_empty_ids = []

for file_id in ids:
    mask_path = os.path.join(label_path, 'L' + file_id)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    if cv2.countNonZero(mask) > 0:
        non_empty_ids.append(file_id)

df_non_empty = pd.DataFrame(non_empty_ids, columns=["file_ids"])
df_non_empty.sort_values(by='file_ids', inplace=True)  # Sort the DataFrame by 'file_ids'
df_non_empty.to_csv('file_ids.csv', index=False)


dif = pd.read_csv('file_ids.csv')
dif.head(15)

Unnamed: 0,file_ids
0,0_10.png
1,0_11.png
2,0_12.png
3,0_13.png
4,0_14.png
5,0_15.png
6,0_16.png
7,0_17.png
8,0_18.png
9,0_19.png


In [14]:
class SegmentationDataset(Dataset):
    def __init__(self, csv_file, bbox_shift=20):
        self.df = pd.read_csv(csv_file)
        self.ids = self.df["file_ids"]
        self.img_path = "/kaggle/input/nucleus-data/nucleus_data/features/"
        self.mask_path = "/kaggle/input/nucleus-data/nucleus_data/segmentation_maps/"
        self.bbox_shift = bbox_shift
        print(f"number of images: {len(self.ids)}")

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        # Load image and mask using the ID from the CSV
        img_name = f"F{self.ids[index]}"
        mask_name = f"L{self.ids[index]}"

        # Load and resize image to 1024x1024, then convert to RGB
        img = Image.open(join(self.img_path, img_name)).resize((1024, 1024)).convert("RGB")
        img = np.array(img)  # Convert image to numpy array

        img = img / 255.0

        # Load and resize mask to 1024x1024
        mask = Image.open(join(self.mask_path, mask_name)).resize((1024, 1024))
        mask = np.array(mask)  # Convert mask to numpy array

        # Convert the shape to (3, H, W) for image and (1, H, W) for mask
        img = np.transpose(img, (2, 0, 1))
        mask = np.expand_dims(mask, axis=0)  # Add an extra dimension for the channel

        label_ids = np.unique(mask)[1:]
        mask_binary = np.uint8(mask == random.choice(label_ids.tolist()))[1]  # only one label, (1024, 1024)


        y_indices, x_indices = np.where(mask_binary > 0)
        x_min, x_max = np.min(x_indices), np.max(x_indices)
        y_min, y_max = np.min(y_indices), np.max(y_indices)
        # add perturbation to bounding box coordinates
        H, W = mask_binary.shape
        x_min = max(0, x_min - random.randint(0, self.bbox_shift))
        x_max = min(W, x_max + random.randint(0, self.bbox_shift))
        y_min = max(0, y_min - random.randint(0, self.bbox_shift))
        y_max = min(H, y_max + random.randint(0, self.bbox_shift))
        bboxes = np.array([x_min, y_min, x_max, y_max])

        return (
            torch.tensor(img).float(),
            torch.tensor(mask_binary[None, :, :]).long(),
            torch.tensor(bboxes).float(),
            img_name,
        )


In [37]:
%mkdir checkpoint_save


mkdir: cannot create directory ‘checkpoint_save’: File exists


In [38]:
lr=0.0001
batch_size = 4
data_path = "/kaggle/input/nucleus-data/nucleus_data"
checkpoint = "/kaggle/working/sam_vit_b_01ec64.pth"
model_type = "vit_b"
work_dir = "/kaggle/working/checkpoint_save"
num_epochs = 10
num_workers=0
use_wandb = 1
use_amp = 0
resume = ""
task_name = "CellSAM-ViT-B"
num_epochs = num_epochs
iter_num = 0
start_epoch = 0
losses = []
best_loss = 1e10

In [39]:
os.makedirs(work_dir, exist_ok=True)

In [40]:
# Instantiate your dataset
tr_dataset = SegmentationDataset(csv_file='file_ids.csv',)
tr_dataloader = DataLoader(tr_dataset, batch_size=batch_size, shuffle=True)

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0), w, h, edgecolor="blue", facecolor=(0, 0, 0, 0), lw=2)
    )

for step, (image, mask_binary, bboxes, img_name) in enumerate(tr_dataloader):
    print(image.shape, mask_binary.shape, bboxes.shape)
    # show the example
    _, axs = plt.subplots(1, 2, figsize=(25, 25))
    idx = random.randint(0, image.size(0) - 1)  # Update this line to get a valid index
    axs[0].imshow(image[idx].cpu().permute(1, 2, 0).numpy())
    show_mask(mask_binary[idx].cpu().numpy()[0], axs[0])  # Passing the 2D mask to show_mask
    show_box(bboxes[idx].numpy(), axs[0])
    axs[0].axis("off")
    # set title
    axs[0].set_title(img_name[idx])
    idx = random.randint(0, image.size(0) - 1)  # Update this line to get a valid index
    axs[1].imshow(image[idx].cpu().permute(1, 2, 0).numpy())
    show_mask(mask_binary[idx].cpu().numpy()[0], axs[1])  # Passing the 2D mask to show_mask
    show_box(bboxes[idx].numpy(), axs[1])
    axs[1].axis("off")
    # set title
    axs[1].set_title(img_name[idx])
    # plt.show()
    plt.subplots_adjust(wspace=0.01, hspace=0)
    plt.savefig("./data_sanitycheck.png", bbox_inches="tight", dpi=300)
    plt.close()
    break



number of images: 4978
torch.Size([4, 3, 1024, 1024]) torch.Size([4, 1, 1024, 1024]) torch.Size([4, 4])


In [41]:
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

--2023-10-04 06:01:34--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 18.164.78.81, 18.164.78.72, 18.164.78.121, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|18.164.78.81|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 375042383 (358M) [binary/octet-stream]
Saving to: ‘sam_vit_b_01ec64.pth.1’


2023-10-04 06:01:36 (223 MB/s) - ‘sam_vit_b_01ec64.pth.1’ saved [375042383/375042383]



In [42]:
class CellSAM(nn.Module):
    def __init__(
        self,
        image_encoder,
        mask_decoder,
        prompt_encoder,
    ):
        super().__init__()
        self.image_encoder = image_encoder
        self.mask_decoder = mask_decoder
        self.prompt_encoder = prompt_encoder
        # freeze prompt encoder

        for param in self.prompt_encoder.parameters():
            param.requires_grad = False

        for param in self.image_encoder.parameters():
            param.requires_grad = False

    def forward(self, image, box):
        image_embedding = self.image_encoder(image)  # (B, 256, 64, 64)
        # do not compute gradients for prompt encoder
        with torch.no_grad():
            box_torch = torch.as_tensor(box, dtype=torch.float32, device=image.device)
            if len(box_torch.shape) == 2:
                box_torch = box_torch[:, None, :]  # (B, 1, 4)

            sparse_embeddings, dense_embeddings = self.prompt_encoder(
                points=None,
                boxes=box_torch,
                masks=None,
            )


        low_res_masks, iou_predictions = self.mask_decoder(
            image_embeddings=image_embedding,  # (B, 256, 64, 64)
            image_pe=self.prompt_encoder.get_dense_pe(),  # (1, 256, 64, 64)
            sparse_prompt_embeddings=sparse_embeddings,  # (B, 2, 256)
            dense_prompt_embeddings=dense_embeddings,  # (B, 256, 64, 64)
            multimask_output=False,
        )
        ori_res_masks = F.interpolate(
            low_res_masks,
            size=(image.shape[2], image.shape[3]),
            mode="bilinear",
            align_corners=False,
        )
        return ori_res_masks

In [43]:
ls -a


[0m[01;34m.[0m/                   data_sanitycheck.png  sam_vit_b_01ec64.pth.1
[01;34m..[0m/                  file_ids.csv          [01;34msegment-anything[0m/
[01;34m.virtual_documents[0m/  full_file_ids.csv     [01;34mwandb[0m/
[01;34mcheckpoint_save[0m/     sam_vit_b_01ec64.pth


In [44]:
model = sam_model_registry[model_type](checkpoint=checkpoint)

In [45]:
cellsam_model = CellSAM(
        image_encoder=model.image_encoder,
        mask_decoder=model.mask_decoder,
        prompt_encoder=model.prompt_encoder,
    ).to(device)

cellsam_model.train()

CellSAM(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): LayerNorm2d()
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): LayerNorm2d()
  

In [46]:
print(
        "Number of total parameters: ",
        sum(p.numel() for p in cellsam_model.parameters()),
    )

Number of total parameters:  93735472


In [47]:
print(
        "Number of trainable parameters: ",
        sum(p.numel() for p in cellsam_model.parameters() if p.requires_grad),

    )

Number of trainable parameters:  4058340


In [48]:
img_mask_encdec_params = cellsam_model.mask_decoder.parameters()


In [49]:
optimizer = torch.optim.AdamW(
        img_mask_encdec_params, lr=0.0001, weight_decay=0.01
    )


In [50]:
seg_loss = monai.losses.DiceLoss(sigmoid=True, squared_pred=True, reduction="mean")


In [51]:
ce_loss = nn.BCEWithLogitsLoss(reduction="mean")


In [52]:
print("Number of training samples: ", len(tr_dataset))


Number of training samples:  4978


In [53]:
!pip install -q wandb

In [54]:
if use_wandb:
    import wandb

    wandb.login()
    wandb.init(
        project= task_name,
        config={
            "lr": lr,
            "batch_size": batch_size,
            "data_path": data_path,
            "model_type": model_type,
        },
    )



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch_loss,▁

0,1
epoch_loss,0.98063


In [55]:
run_id = datetime.now().strftime("%Y%m%d-%H%M")
model_save_path = join(work_dir, task_name + "-" + run_id)
device = torch.device(device)

In [56]:
os.makedirs(model_save_path, exist_ok=True)

In [57]:
if resume is not None:
    if os.path.isfile(resume):
        checkpoint = torch.load(resume, map_location=device)
        start_epoch = checkpoint["epoch"] + 1
        optimizer.load_state_dict(checkpoint["optimizer"])

if use_amp:
    scaler = torch.cuda.amp.GradScaler()

In [58]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Wed Oct  4 06:02:24 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   41C    P0    32W / 250W |  12421MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [59]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 16.8 gigabytes of available RAM

Not using a high-RAM runtime


In [60]:
    for epoch in range(start_epoch, num_epochs):
        epoch_loss = 0
        for step, (image, gt2D, boxes, _) in enumerate(tqdm(tr_dataloader)):
            optimizer.zero_grad()
            boxes_np = boxes.detach().cpu().numpy()
            image, gt2D = image.to(device), gt2D.to(device)
            if use_amp:
                ## AMP
                with torch.autocast(device_type="cuda", dtype=torch.float16):
                    cellsam_pred = cellsam_model(image, boxes_np)
                    loss = seg_loss(cellsam_pred, gt2D) + ce_loss(
                        cellsam_pred, gt2D.float()
                    )
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            else:
                cellsam_pred = cellsam_model(image,boxes_np)
                loss = seg_loss(cellsam_pred, gt2D) + ce_loss(cellsam_pred, gt2D.float())
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

            epoch_loss += loss.item()
            iter_num += 1

        epoch_loss /= step
        losses.append(epoch_loss)
        if use_wandb:
            wandb.log({"epoch_loss": epoch_loss})
        print(
            f'Time: {datetime.now().strftime("%Y%m%d-%H%M")}, Epoch: {epoch}, Loss: {epoch_loss}'
        )
        ## save the latest model
        checkpoint = {
            "model": cellsam_model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "epoch": epoch,
        }
        torch.save(checkpoint, join(model_save_path, "cellsam_model_latest.pth"))
        ## save the best model
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            checkpoint = {
                "model": cellsam_model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "epoch": epoch,
            }
            torch.save(checkpoint, join(model_save_path, "cellsam_model_best.pth"))

        # %% plot loss
        plt.plot(losses)
        plt.title("Dice + Cross Entropy Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.savefig(join(model_save_path, task_name + "train_loss.png"))
        plt.close()

100%|██████████| 1245/1245 [26:25<00:00,  1.27s/it]


Time: 20231004-0628, Epoch: 0, Loss: 0.9797209558283785


100%|██████████| 1245/1245 [26:13<00:00,  1.26s/it]


Time: 20231004-0655, Epoch: 1, Loss: 0.973245225104105


100%|██████████| 1245/1245 [26:19<00:00,  1.27s/it]


Time: 20231004-0721, Epoch: 2, Loss: 0.9697802496588882


100%|██████████| 1245/1245 [26:17<00:00,  1.27s/it]


Time: 20231004-0747, Epoch: 3, Loss: 0.9685803066783396


100%|██████████| 1245/1245 [26:13<00:00,  1.26s/it]


Time: 20231004-0814, Epoch: 4, Loss: 0.969251671021391


100%|██████████| 1245/1245 [26:14<00:00,  1.26s/it]


Time: 20231004-0840, Epoch: 5, Loss: 0.9672019827116723


100%|██████████| 1245/1245 [26:22<00:00,  1.27s/it]


Time: 20231004-0906, Epoch: 6, Loss: 0.9664490674277977


100%|██████████| 1245/1245 [26:21<00:00,  1.27s/it]


Time: 20231004-0933, Epoch: 7, Loss: 0.9657802456253212


100%|██████████| 1245/1245 [26:27<00:00,  1.27s/it]


Time: 20231004-0959, Epoch: 8, Loss: 0.9649513142476894


100%|██████████| 1245/1245 [26:18<00:00,  1.27s/it]


Time: 20231004-1026, Epoch: 9, Loss: 0.966124671975516
