# Dataset preparation from HF

> Data will be downloaded from hugging face and then will be processed to get the data in the format we want.

In [137]:
#| default_exp hf_data_prep

In [138]:
#| export
from typing import List, Dict, Union, Optional
from pathlib import Path
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import albumentations as A
from tqdm.auto import tqdm
from statistics import mean
from fastcore.test import *
from fastcore.basics import *
mpl.rcParams['image.cmap'] = 'gray' 

In [139]:
#| export
import torch
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.optim.lr_scheduler import CosineAnnealingLR
import monai
############################################################
from tqdm import tqdm
from statistics import mean
import torch
from torch.nn.functional import threshold, normalize

In [140]:
#| export
from datasets import load_dataset, Dataset
from transformers import SamProcessor
from transformers import SamModel
processor = SamProcessor.from_pretrained('facebook/sam-vit-base')

In [141]:
from cv_tools.core import *

In [6]:
dataset = load_dataset("nielsr/breast-cancer", split="train")
dataset


Dataset({
    features: ['image', 'label'],
    num_rows: 130
})

In [7]:
img = dataset[0]['image']
msk = dataset[0]['label']
img.size, img.mode, msk.size, msk.mode

((256, 256), 'RGB', (256, 256), 'I')

In [142]:
from datasets import load_dataset

In [143]:
training_dataset = load_dataset(
    "hasangoni/Electron_microscopy_dataset",
    split="train"
    )
validation_dataset = load_dataset(
    "hasangoni/Electron_microscopy_dataset",
    split="test"
    )


In [144]:
training_dataset

Dataset({
    features: ['image', 'label'],
    num_rows: 1642
})

In [145]:
#| export
def get_bounding_box(
        ground_truth_map:np.ndarray, # mask image type cv2
        ):
    "Get bounding box coordinates from mask image"


    y_, x_ = np.where(ground_truth_map> 0)
    x_min, x_max = np.min(x_), np.max(x_)
    y_min, y_max = np.min(y_), np.max(y_)

    # add perturbation to bounding box coordinates
    H, W = ground_truth_map.shape
    x_min = max(0, x_min - np.random.randint(0, 20))
    x_max = min(W, x_max + np.random.randint(0, 20))
    y_min = max(0, y_min - np.random.randint(0, 20))
    y_max = min(H, y_max + np.random.randint(0, 20))
    bbox = [x_min, y_min, x_max, y_max]

    return bbox

tesing get_bounding_box

In [146]:
#| export
class SAMDataset(TorchDataset):
    "Creating dataset for SAM Training"

    def __init__(
            self,
            dataset:TorchDataset, # pytorch dataset
            processor:SamProcessor, # hf model processor
            ):
        super().__init__()
        store_attr('dataset, processor')
        __repr__ = basic_repr()

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image'].convert('RGB')
        mask = np.array(item['label'])
        
        prompt = get_bounding_box(mask) # bounding box around mask

        # prepare image and prompt for model
        inputs = self.processor(
                                 image, 
                                 input_boxes=[[prompt]],
                                 return_tensors="pt"
                                 )
        # remove batch dimension created by the processor
        inputs = {k:v.squeeze(0) for k,v in inputs.items()}
        inputs['ground_truth_mask'] = mask
        return inputs

# Creating pytorch dataset

In [147]:
processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
train_dataset = SAMDataset(
  dataset=training_dataset, 
  processor=processor)
val_dataset = SAMDataset(
  dataset=validation_dataset, 
  processor=processor)


In [129]:
trn_ = np.transpose(train_dataset[0]['pixel_values'].to('cpu').numpy(), (1,2,0))
val_ = np.transpose(val_dataset[0]['pixel_values'].to('cpu').numpy(), (1,2,0))
#show_(trn_)

In [148]:

train_dataloader = DataLoader(
                            train_dataset, 
                            batch_size=2,
                            shuffle=True)
val_dataloader = DataLoader(
                            val_dataset, 
                            batch_size=2,
                            shuffle=False)
                            


In [131]:
example = train_dataset[0]
for k,v in example.items():
  print(k,v.shape)
     
model = SamModel.from_pretrained("facebook/sam-vit-base")

# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)



pixel_values torch.Size([3, 1024, 1024])
original_sizes torch.Size([2])
reshaped_input_sizes torch.Size([2])
input_boxes torch.Size([1, 4])
ground_truth_mask (256, 256)


# Creating pytorch pytorch dataloader   

In [149]:
train_dataloader = DataLoader(
    train_ds, 
    batch_size=2, 
    shuffle=True)
val_dataloader = DataLoader(
    val_ds, 
    batch_size=2, 
    shuffle=False)

NameError: name 'train_ds' is not defined

testing pytorch dataloader

In [37]:

batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)
     

pixel_values torch.Size([2, 3, 1024, 1024])
original_sizes torch.Size([2, 2])
reshaped_input_sizes torch.Size([2, 2])
input_boxes torch.Size([2, 1, 4])
ground_truth_mask torch.Size([2, 256, 256])


# loading Model

In [150]:
model = SamModel.from_pretrained('facebook/sam-vit-base')

In [151]:
# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)
     

In [152]:
NUM_EPOCHS = 2
T_0 = int(0.5 * NUM_EPOCHS)
ITERS = len(train_dataloader)

In [153]:
optimizer = AdamW(
    model.mask_decoder.parameters(),
    lr=0.001,
    weight_decay=0.0001)

In [154]:

device = "cuda" if torch.cuda.is_available() else "cpu"
device= "cpu"
# in case of very small gpu memory, like me then use cpu
#device = "cpu"
model.to(device)
scheduler = CosineAnnealingWarmRestarts(
   T_0=T_0,
   optimizer=optimizer, 
   eta_min=0.00001)


In [155]:
device='cpu'

In [156]:
#| export
def validate(
        model:SamModel, 
        dataloader:DataLoader,
        loss_fn:monai.losses.dice, 
        device:str='cpu'):

    model.eval()
    val_losses = []
    with torch.no_grad():
        for batch in dataloader:
            outputs = model(
                pixel_values=batch["pixel_values"].to(device),
                input_boxes=batch["input_boxes"].to(device),
                multimask_output=False
                )
            predicted_masks = outputs.pred_masks.squeeze(1)
            ground_truth_masks = batch["ground_truth_mask"].float().to(device)
            loss = loss_fn(predicted_masks, ground_truth_masks.unsqueeze(1))
            val_losses.append(loss.item())
    model.train()
    return mean(val_losses)

In [44]:
num_epochs = 100

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

seg_loss = monai.losses.DiceCELoss(
    sigmoid=True, 
    squared_pred=True, 
    reduction='mean')
model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
      # forward pass
      print(batch['pixel_values'].shape)
      outputs = model(pixel_values=batch["pixel_values"].to(device),
                      input_boxes=batch["input_boxes"].to(device),
                      multimask_output=False)
      print(f'outputs shape {outputs.pred_masks.shape}')

      # compute loss
      predicted_masks = outputs.pred_masks.squeeze(1)
      print(f'predicted_masks shape {predicted_masks.shape}') 
      ground_truth_masks = batch["ground_truth_mask"].float().to(device)
      print(f'ground_truth_masks shape {ground_truth_masks.shape}')
      loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

      # backward pass (compute gradients of parameters w.r.t. loss)
      optimizer.zero_grad()
      loss.backward()

      # optimize
      optimizer.step()
      epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean Training Loss: {mean(epoch_losses)}')
    validation_loss = validate(
        model=model, 
        dataloader=val_dataloader, 
        loss_fn=seg_loss, 
        device=device)

    print(f'Validation Loss: {validation_loss}')

  0%|          | 0/65 [00:00<?, ?it/s]

torch.Size([2, 3, 1024, 1024])


  0%|          | 0/65 [00:03<?, ?it/s]


KeyboardInterrupt: 

In [157]:
#| export
def pt_train(
        train_dataloader:DataLoader,
        model:SamModel,
        optimizer:torch.optim.Optimizer,
        device:Union[str, None]='cpu',
        epoch_n:int=2,
    ):


    seg_loss = monai.losses.DiceCELoss(
        sigmoid=True, 
        squared_pred=True, 
        reduction='mean')
    model.train()
    for epoch in range(epoch_n):
        epoch_losses = []
        print(f"Epoch {epoch+1}")
        for idx,  batch in tqdm(enumerate(train_dataloader)):

            # forward pass
            outputs = model(
                pixel_values=batch['pixel_values'].to(device),
                input_boxes=batch['input_boxes'].to(device),
                            multimask_output=False)
            
            # compute loss
            pred_masks = outputs.pred_masks.squeeze(1)
            ground_truth_mask = batch['ground_truth_mask'].float().to(device)


            loss = seg_loss(pred_masks, ground_truth_mask.unsqueeze(1))

            # backward pass
            optimizer.zero_grad()
            loss.backward()

            # update weights
            optimizer.step()
            epoch_losses.append(loss.item())
    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')


In [46]:

pt_train(
    train_dataloader=train_dataloader,
    model=model, 
    optimizer=optimizer, 
    device=device, 
    epoch_n=2)

Epoch 1


23it [08:41, 22.89s/it]

: 

In [158]:
#| export
def validate(
        model:SamModel,  # SAM model
        dataloader:DataLoader,  # Torch dataloader
        loss_fn:monai.losses,  # Monai loss function
        device:str='cpu' # whether to use cpu or gpu
        ):
    """
    Validate the model using a validation dataloader.

    Parameters:
    - model: The PyTorch model to validate.
    - dataloader: DataLoader for validation data.
    - loss_fn: Loss function used for validation.
    - device: Device to run validation on ('cuda' or 'cpu').
    """
    model.eval()
    val_losses = []
    with torch.no_grad():
        for batch in dataloader:
            outputs = model(pixel_values=batch["pixel_values"].to(device),
                            input_boxes=batch["input_boxes"].to(device),
                            multimask_output=False)
            predicted_masks = outputs.pred_masks.squeeze(1)
            ground_truth_masks = batch["ground_truth_mask"].float().to(device)
            loss = loss_fn(predicted_masks, ground_truth_masks.unsqueeze(1))
            val_losses.append(loss.item())
    model.train()
    return mean(val_losses)

In [159]:
#| export
def train_and_validate(
        model:SamModel,  # SAM model
        num_epochs:int,  # Number of epochs to train for
        optimizer:torch.optim.Optimizer, # Optimizer to use
        scheduler:torch.optim.lr_scheduler,  # Learning rate scheduler
        train_dataloader:DataLoader,  # DataLoader for training data
        val_dataloader:DataLoader, # DataLoader for validation data
        loss_fn:monai.losses,  # Loss function used for training
        device:str='cpu' # Device to train on ('cuda' or 'cpu')
        ):
    """Train and validate a model with the given parameters."""
    
    model.to(device)
    model.train()

    for epoch in range(num_epochs):
        epoch_losses = []

        progress_bar = tqdm(
            train_dataloader, 
            desc=f'Epoch {epoch+1}/{num_epochs}', total=len(train_dataloader))

        for batch in progress_bar:
            # Forward pass
            outputs = model(pixel_values=batch["pixel_values"].to(device),
                            input_boxes=batch["input_boxes"].to(device),
                            multimask_output=False)
            
            # Compute loss
            predicted_masks = outputs.pred_masks.squeeze(1)
            ground_truth_masks = batch["ground_truth_mask"].float().to(device)
            loss = loss_fn(predicted_masks, ground_truth_masks.unsqueeze(1))

            # Backward pass
            optimizer.zero_grad()
            loss.backward()

            # Optimize
            optimizer.step()
            epoch_losses.append(loss.item())

        # Scheduler step (if any)
        if scheduler:
            scheduler.step()

        # Validation phase
        validation_loss = validate(
            model=model, 
            dataloader=val_dataloader, 
            loss_fn=loss_fn, 
            device=device)

        # Print epoch results
        print(f'EPOCH: {epoch}')
        print(f'Mean Training Loss: {mean(epoch_losses)}')
        print(f'Validation Loss: {validation_loss}')


In [160]:
#model = SamModel.from_pretrained('facebook/sam-vit-base')
optimizer = AdamW(
    model.mask_decoder.parameters(),
    lr=0.001,
    weight_decay=0.0001)

scheduler = CosineAnnealingWarmRestarts(
   T_0=10,
   T_mult=2,
   optimizer=optimizer, 
   eta_min=0.00001)
seg_loss = monai.losses.DiceCELoss()
device='cpu'

In [161]:
train_and_validate(
    model=model,
    num_epochs=2,
    optimizer=optimizer,
    scheduler=scheduler,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    loss_fn=seg_loss,
    device=device
)

Epoch 1/2:   0%|          | 1/821 [00:36<8:21:37, 36.70s/it]


KeyboardInterrupt: 