# Dataset preparation from HF

> Data will be downloaded from hugging face and then will be processed to get the data in the format we want.

In [2]:
#| default_exp hf_data_prep

In [3]:
#| export
from typing import List, Dict, Union, Optional
from pathlib import Path
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import albumentations as A
from tqdm.auto import tqdm
from statistics import mean
from fastcore.test import *
from fastcore.basics import *
mpl.rcParams['image.cmap'] = 'gray' 

In [4]:
#| export
import torch
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.optim.lr_scheduler import CosineAnnealingLR
import monai
############################################################
from tqdm import tqdm
from statistics import mean
import torch
from torch.nn.functional import threshold, normalize

2024-05-01 20:53:14.080062: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-01 20:53:14.080168: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-01 20:53:14.080195: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-01 20:53:14.312162: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
#| export
from datasets import load_dataset, Dataset
from transformers import SamProcessor
from transformers import SamModel
processor = SamProcessor.from_pretrained('facebook/sam-vit-base')

In [6]:
dataset = load_dataset("nielsr/breast-cancer", split="train")
dataset


Dataset({
    features: ['image', 'label'],
    num_rows: 130
})

In [69]:
img = dataset[0]['image']
msk = dataset[0]['label']
img.size, img.mode, msk.size, msk.mode

((256, 256), 'RGB', (256, 256), 'I')

(256, 256)

In [34]:
#| export
def get_bounding_box(
        ground_truth_map:np.ndarray, # mask image type cv2
        ):
    "Get bounding box coordinates from mask image"


    y_, x_ = np.where(ground_truth_map> 0)
    x_min, x_max = np.min(x_), np.max(x_)
    y_min, y_max = np.min(y_), np.max(y_)

    # add perturbation to bounding box coordinates
    H, W = ground_truth_map.shape
    x_min = max(0, x_min - np.random.randint(0, 20))
    x_max = min(W, x_max + np.random.randint(0, 20))
    y_min = max(0, y_min - np.random.randint(0, 20))
    y_max = min(H, y_max + np.random.randint(0, 20))
    bbox = [x_min, y_min, x_max, y_max]

    return bbox

tesing get_bounding_box

In [35]:
#| export
class SAMDataset(TorchDataset):
    "Creating dataset for SAM Training"

    def __init__(
            self,
            dataset:TorchDataset, # pytorch dataset
            processor:SamProcessor # hf model processor
            ):
        super().__init__()
        store_attr('dataset, processor')
        __repr__ = basic_repr()

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image']
        mask = np.array(item['label'])
        
        prompt = get_bounding_box(mask) # bounding box around mask

        # prepare image and prompt for model
        inputs = self.processor(
                                 image, 
                                 input_boxes=[[prompt]],
                                 return_tensors="pt"
                                 )
        # remove batch dimension created by the processor
        inputs = {k:v.squeeze(0) for k,v in inputs.items()}
        inputs['ground_truth_mask'] = mask
        return inputs

# Creating pytorch dataset

In [36]:
processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
train_dataset = SAMDataset(dataset=dataset, processor=processor)
train_dataloader = DataLoader(
                            train_dataset, 
                            batch_size=2,
                            shuffle=True)
example = train_dataset[0]
for k,v in example.items():
  print(k,v.shape)
     
model = SamModel.from_pretrained("facebook/sam-vit-base")

# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)



pixel_values torch.Size([3, 1024, 1024])
original_sizes torch.Size([2])
reshaped_input_sizes torch.Size([2])
input_boxes torch.Size([1, 4])
ground_truth_mask (256, 256)


# Creating pytorch pytorch dataloader   

In [42]:
train_dataloader = DataLoader(
    train_ds, 
    batch_size=2, 
    shuffle=True)
val_dataloader = DataLoader(
    val_ds, 
    batch_size=2, 
    shuffle=True)

testing pytorch dataloader

In [37]:

batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)
     

pixel_values torch.Size([2, 3, 1024, 1024])
original_sizes torch.Size([2, 2])
reshaped_input_sizes torch.Size([2, 2])
input_boxes torch.Size([2, 1, 4])
ground_truth_mask torch.Size([2, 256, 256])


# loading Model

In [38]:
model = SamModel.from_pretrained('facebook/sam-vit-base')

In [39]:
# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)
     

In [40]:
NUM_EPOCHS = 2
T_0 = int(0.5 * NUM_EPOCHS)
ITERS = len(train_dataloader)

In [41]:
optimizer = AdamW(
    model.mask_decoder.parameters(),
    lr=0.001,
    weight_decay=0.0001)

In [42]:

device = "cuda" if torch.cuda.is_available() else "cpu"
device= "cpu"
# in case of very small gpu memory, like me then use cpu
#device = "cpu"
model.to(device)
scheduler = CosineAnnealingWarmRestarts(
   T_0=T_0,
   optimizer=optimizer, 
   eta_min=0.00001)


In [43]:
device='cpu'

In [44]:
num_epochs = 100

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

seg_loss = monai.losses.DiceCELoss(
    sigmoid=True, 
    squared_pred=True, 
    reduction='mean')
model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
      # forward pass
      print(batch['pixel_values'].shape)
      outputs = model(pixel_values=batch["pixel_values"].to(device),
                      input_boxes=batch["input_boxes"].to(device),
                      multimask_output=False)
      print(f'outputs shape {outputs.pred_masks.shape}')

      # compute loss
      predicted_masks = outputs.pred_masks.squeeze(1)
      print(f'predicted_masks shape {predicted_masks.shape}') 
      ground_truth_masks = batch["ground_truth_mask"].float().to(device)
      print(f'ground_truth_masks shape {ground_truth_masks.shape}')
      loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

      # backward pass (compute gradients of parameters w.r.t. loss)
      optimizer.zero_grad()
      loss.backward()

      # optimize
      optimizer.step()
      epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')

  0%|          | 0/65 [00:00<?, ?it/s]

torch.Size([2, 3, 1024, 1024])


  0%|          | 0/65 [00:03<?, ?it/s]


KeyboardInterrupt: 

In [45]:
def pt_train(
        train_dataloader:DataLoader,
        model:SamModel,
        optimizer:torch.optim.Optimizer,
        device:Union[str, None]='cpu',
        epoch_n:int=2,
    ):


    seg_loss = monai.losses.DiceCELoss(
        sigmoid=True, 
        squared_pred=True, 
        reduction='mean')
    model.train()
    for epoch in range(epoch_n):
        epoch_losses = []
        print(f"Epoch {epoch+1}")
        for idx,  batch in tqdm(enumerate(train_dataloader)):

            # forward pass
            outputs = model(
                pixel_values=batch['pixel_values'].to(device),
                input_boxes=batch['input_boxes'].to(device),
                            multimask_output=False)
            
            # compute loss
            pred_masks = outputs.pred_masks.squeeze(1)
            ground_truth_mask = batch['ground_truth_mask'].float().to(device)


            loss = seg_loss(pred_masks, ground_truth_mask.unsqueeze(1))

            # backward pass
            optimizer.zero_grad()
            loss.backward()

            # update weights
            optimizer.step()
            epoch_losses.append(loss.item())
    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')


In [46]:

pt_train(
    train_dataloader=train_dataloader,
    model=model, 
    optimizer=optimizer, 
    device=device, 
    epoch_n=2)

Epoch 1


23it [08:41, 22.89s/it]

: 