In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt
from segment_anything import SamPredictor, sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.models import resnet50, ResNet50_Weights, VGG16_Weights
from tqdm import tqdm
import cv2

In [2]:
class MedicalImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_dir = self.image_dir.replace('imgs', 'masks')
        mask_file=self.images[idx].split('.')[0]+'_mask.png'
        mask_path = os.path.join(mask_dir, mask_file)
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

# Define the transform
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])

image_dataset = MedicalImageDataset('/media/rohit/mirlproject2/fetal head circumference/Breast_Ultrasound/train/imgs', transform=transform)

In [3]:
# Calculate lengths for the splits
train_size = int(0.95 * len(image_dataset))
val_size = len(image_dataset) - train_size

# Split the dataset
train_dataset, val_dataset = random_split(image_dataset, [train_size, val_size])

In [4]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

In [5]:
len(train_loader),len(val_loader)

(567, 30)

In [6]:
from segment_anything import SamPredictor, sam_model_registry

model_type = 'vit_h'
checkpoint = '/media/rohit/mirlproject2/fetal head circumference/sam_vit_h_4b8939.pth'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

org_sam_model = sam_model_registry[model_type](checkpoint=checkpoint)
org_sam_model.to(device)

Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1280, out_features=3840, bias=True)
          (proj): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=1280, out_features=5120, bias=True)
          (lin2): Linear(in_features=5120, out_features=1280, bias=True)
          (act): GELU(approximate='none')
        )
      )
      (1): Block(
        (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1280, out_features=3840, bias=True)
          (proj): Linear(in_features=1280, out_features=1280, b

In [7]:
def find_combined_bounding_box1(im, min_area_threshold=100):

    if im.dtype != np.uint8:
        im = (im * 255).astype(np.uint8)

    gray = im

    # Find contours
    contours, _ = cv2.findContours(
        gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Filter out contours that are too small
    valid_contours = [cnt for cnt in contours if cv2.contourArea(
        cnt) > min_area_threshold]

    # If there's only one valid contour, get its bounding box
    if len(valid_contours) == 1:
        x, y, w, h = cv2.boundingRect(valid_contours[0])
        return np.array([x, y, x+w, y+h])

    # If there are multiple contours, combine them into one bounding box
    x_min = min([cv2.boundingRect(cnt)[0] for cnt in valid_contours])
    y_min = min([cv2.boundingRect(cnt)[1] for cnt in valid_contours])
    x_max = max([cv2.boundingRect(cnt)[0] + cv2.boundingRect(cnt)[2]
                for cnt in valid_contours])
    y_max = max([cv2.boundingRect(cnt)[1] + cv2.boundingRect(cnt)[3]
                for cnt in valid_contours])

    return np.array([x_min, y_min, x_max, y_max])

In [8]:
import torch.nn.functional as F


def dice_pytorch(inputs, targets, smooth=1):

    # comment out if your model contains a sigmoid or equivalent activation layer
    # inputs = torch.sigmoid(inputs)

    # flatten label and prediction tensors
    inputs = inputs.view(-1)
    targets = targets.view(-1)

    intersection = (inputs * targets).sum()
    dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)

    return 1 - dice


In [9]:
loss_fn = dice_pytorch

In [10]:
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pth', trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path  # Path to save the checkpoint
        self.trace_func = trace_func

    def __call__(self, val_loss, model):
        # Round the validation loss to 5 decimal places
        val_loss = round(val_loss, 4)
        score = -val_loss  # Convert to a maximization problem

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        """Saves model when validation loss decreases."""
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.4f} ----------> {val_loss:.4f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)  # Save the model state to the specified path
        self.val_loss_min = val_loss

In [11]:
# Fine-tuning setup
best_val_loss = float('inf')
optimizer = torch.optim.Adam(org_sam_model.mask_decoder.parameters(), lr=1e-4, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
loss_fn = dice_pytorch
num_epochs = 100
train_losses = []
val_losses = []

early_stopping = EarlyStopping(patience=5, verbose=True, path='/media/rohit/mirlproject2/fetal head circumference/Breast_Ultrasound/weights/org_sam_model_best_finetune_early_stop.pth')



for epoch in range(num_epochs):
    running_train_loss = 0.0
    org_sam_model.train()  # Ensure the model is in training mode
    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")

    for x, y in train_loader_tqdm:
        optimizer.zero_grad()
        x, y = x.to(device), y.to(device)

        image_np = x[0].cpu().detach().numpy()
        image_np = np.transpose(image_np, (1, 2, 0))  # Change from CHW to HWC

        # Convert float32 to uint8
        image_np = (image_np * 255).astype(np.uint8)
        image = image_np[..., [2, 1, 0]]

        transform = ResizeLongestSide(org_sam_model.image_encoder.img_size)
        input_image = transform.apply_image(image)
        input_image_torch = torch.as_tensor(input_image, device=device)
        transformed_image = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
        input_image = org_sam_model.preprocess(transformed_image)

        # No grad here as we don't want to optimize the encoders
        with torch.no_grad():
            image_embedding = org_sam_model.image_encoder(input_image)
            #image_embedding = sam_model.image_encoder(x)

        prompt_box = find_combined_bounding_box1(y[0][0].cpu().numpy())
        box = prompt_box
        box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
        box_torch = box_torch[None, :]

        sparse_embeddings, dense_embeddings = org_sam_model.prompt_encoder(
            points=None,
            boxes=box_torch,
            masks=None,
        )

        low_res_masks, iou_predictions = org_sam_model.mask_decoder(
            image_embeddings=image_embedding,
            image_pe=org_sam_model.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=False,
        )

        upscaled_masks = org_sam_model.postprocess_masks(low_res_masks, (512, 512), (512, 512)).to(device)

        binary_mask = torch.sigmoid(upscaled_masks)
        ground_truth_masks = (y[0][0].cpu().numpy() == 1).astype(np.float32)
        ground_truth_masks = torch.from_numpy(ground_truth_masks).unsqueeze(0).unsqueeze(0).to(device)

        loss = loss_fn(binary_mask, ground_truth_masks)
        loss.backward()
        optimizer.step()

        running_train_loss += loss.item()
        train_loader_tqdm.set_postfix({"Train Loss": running_train_loss / (train_loader_tqdm.n + 1)})

    avg_train_loss = running_train_loss / len(train_loader)  # Average loss per batch
    train_losses.append(avg_train_loss)
    print(f'Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}')

    # Validation phase
    org_sam_model.eval()  # Set model to evaluation mode
    running_val_loss = 0.0
    val_loader_tqdm = tqdm(val_loader, desc=f"Validation {epoch+1}/{num_epochs}", unit="batch")

    with torch.no_grad():
        for x, y in val_loader_tqdm:
            x, y = x.to(device), y.to(device)

            image_np = x[0].cpu().detach().numpy()
            image_np = np.transpose(image_np, (1, 2, 0))  # Change from CHW to HWC

            # Convert float32 to uint8
            image_np = (image_np * 255).astype(np.uint8)
            image = image_np[..., [2, 1, 0]]

            transform = ResizeLongestSide(org_sam_model.image_encoder.img_size)
            input_image = transform.apply_image(image)
            input_image_torch = torch.as_tensor(input_image, device=device)
            transformed_image = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
            input_image = org_sam_model.preprocess(transformed_image)

            image_embedding = org_sam_model.image_encoder(input_image)

            prompt_box = find_combined_bounding_box1(y[0][0].cpu().numpy())
            box = prompt_box
            box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
            box_torch = box_torch[None, :]

            sparse_embeddings, dense_embeddings = org_sam_model.prompt_encoder(
                points=None,
                boxes=box_torch,
                masks=None,
            )

            low_res_masks, iou_predictions = org_sam_model.mask_decoder(
                image_embeddings=image_embedding,
                image_pe=org_sam_model.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )

            upscaled_masks = org_sam_model.postprocess_masks(low_res_masks, (512, 512), (512, 512)).to(device)

            binary_mask = torch.sigmoid(upscaled_masks)
            ground_truth_masks = (y[0][0].cpu().numpy() == 1).astype(np.float32)
            ground_truth_masks = torch.from_numpy(ground_truth_masks).unsqueeze(0).unsqueeze(0).to(device)

            val_loss = loss_fn(binary_mask, ground_truth_masks)
            running_val_loss += val_loss.item()
            val_loader_tqdm.set_postfix({"Val Loss": running_val_loss / (val_loader_tqdm.n + 1)})

    avg_val_loss = running_val_loss / len(val_loader)  # Average loss per batch
    val_losses.append(avg_val_loss)
    print(f'Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}')

    # Save the best model based on validation loss
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(org_sam_model.state_dict(), '/media/rohit/mirlproject2/fetal head circumference/Breast_Ultrasound/weights/org_sam_model_best_finetune_exp_2.pth')
        print(f'Saving best model with validation loss----------------------------------------------> {best_val_loss:.4f}')

    scheduler.step(avg_val_loss)

    # Print the current learning rate
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Learning rate: {current_lr:.6f}")

    # Save the model checkpoint after each epoch
    torch.save(org_sam_model.state_dict(), '/media/rohit/mirlproject2/fetal head circumference/Breast_Ultrasound/weights/org_sam_model_latest_finetune_exp_2.pth')

    # Early stopping
    early_stopping(avg_val_loss, org_sam_model)

    if early_stopping.early_stop:
        print("Early stopping")
        break

    org_sam_model.train()


Epoch 1/100: 100%|██████████| 567/567 [10:18<00:00,  1.09s/batch, Train Loss=0.147]


Epoch 1, Training Loss: 0.1467


Validation 1/100: 100%|██████████| 30/30 [00:32<00:00,  1.07s/batch, Val Loss=0.142]


Epoch 1, Validation Loss: 0.1423
Saving best model with validation loss----------------------------------------------> 0.1423
Learning rate: 0.000100
Validation loss decreased (inf ----------> 0.1423).  Saving model ...


Epoch 2/100: 100%|██████████| 567/567 [08:06<00:00,  1.17batch/s, Train Loss=0.129]


Epoch 2, Training Loss: 0.1289


Validation 2/100: 100%|██████████| 30/30 [00:24<00:00,  1.23batch/s, Val Loss=0.126]


Epoch 2, Validation Loss: 0.1264
Saving best model with validation loss----------------------------------------------> 0.1264
Learning rate: 0.000100
Validation loss decreased (0.1423 ----------> 0.1264).  Saving model ...


Epoch 3/100: 100%|██████████| 567/567 [08:04<00:00,  1.17batch/s, Train Loss=0.128]


Epoch 3, Training Loss: 0.1282


Validation 3/100: 100%|██████████| 30/30 [00:25<00:00,  1.20batch/s, Val Loss=0.13] 


Epoch 3, Validation Loss: 0.1296
Learning rate: 0.000100
EarlyStopping counter: 1 out of 5


Epoch 4/100: 100%|██████████| 567/567 [08:06<00:00,  1.16batch/s, Train Loss=0.128]


Epoch 4, Training Loss: 0.1275


Validation 4/100: 100%|██████████| 30/30 [00:23<00:00,  1.27batch/s, Val Loss=0.124]


Epoch 4, Validation Loss: 0.1236
Saving best model with validation loss----------------------------------------------> 0.1236
Learning rate: 0.000100
Validation loss decreased (0.1264 ----------> 0.1236).  Saving model ...


Epoch 5/100: 100%|██████████| 567/567 [08:05<00:00,  1.17batch/s, Train Loss=0.123]


Epoch 5, Training Loss: 0.1228


Validation 5/100: 100%|██████████| 30/30 [00:32<00:00,  1.09s/batch, Val Loss=0.123]


Epoch 5, Validation Loss: 0.1233
Saving best model with validation loss----------------------------------------------> 0.1233
Learning rate: 0.000100
Validation loss decreased (0.1236 ----------> 0.1233).  Saving model ...


Epoch 6/100: 100%|██████████| 567/567 [09:45<00:00,  1.03s/batch, Train Loss=0.123]


Epoch 6, Training Loss: 0.1228


Validation 6/100: 100%|██████████| 30/30 [00:25<00:00,  1.17batch/s, Val Loss=0.123]


Epoch 6, Validation Loss: 0.1234
Learning rate: 0.000100
EarlyStopping counter: 1 out of 5


Epoch 7/100: 100%|██████████| 567/567 [08:16<00:00,  1.14batch/s, Train Loss=0.123]


Epoch 7, Training Loss: 0.1229


Validation 7/100: 100%|██████████| 30/30 [00:25<00:00,  1.16batch/s, Val Loss=0.123]


Epoch 7, Validation Loss: 0.1232
Saving best model with validation loss----------------------------------------------> 0.1232
Learning rate: 0.000100
Validation loss decreased (0.1233 ----------> 0.1232).  Saving model ...


Epoch 8/100: 100%|██████████| 567/567 [21:47<00:00,  2.31s/batch, Train Loss=0.122]


Epoch 8, Training Loss: 0.1217


Validation 8/100: 100%|██████████| 30/30 [00:43<00:00,  1.46s/batch, Val Loss=0.12] 


Epoch 8, Validation Loss: 0.1197
Saving best model with validation loss----------------------------------------------> 0.1197
Learning rate: 0.000100
Validation loss decreased (0.1232 ----------> 0.1197).  Saving model ...


Epoch 9/100: 100%|██████████| 567/567 [12:28<00:00,  1.32s/batch, Train Loss=0.118] 


Epoch 9, Training Loss: 0.1177


Validation 9/100: 100%|██████████| 30/30 [00:27<00:00,  1.11batch/s, Val Loss=0.119]


Epoch 9, Validation Loss: 0.1188
Saving best model with validation loss----------------------------------------------> 0.1188
Learning rate: 0.000100
Validation loss decreased (0.1197 ----------> 0.1188).  Saving model ...


Epoch 10/100: 100%|██████████| 567/567 [09:05<00:00,  1.04batch/s, Train Loss=0.121]


Epoch 10, Training Loss: 0.1205


Validation 10/100: 100%|██████████| 30/30 [00:28<00:00,  1.04batch/s, Val Loss=0.122]


Epoch 10, Validation Loss: 0.1215
Learning rate: 0.000100
EarlyStopping counter: 1 out of 5


Epoch 11/100: 100%|██████████| 567/567 [09:06<00:00,  1.04batch/s, Train Loss=0.123]


Epoch 11, Training Loss: 0.1226


Validation 11/100: 100%|██████████| 30/30 [00:28<00:00,  1.04batch/s, Val Loss=0.125]


Epoch 11, Validation Loss: 0.1255
Learning rate: 0.000100
EarlyStopping counter: 2 out of 5


Epoch 12/100: 100%|██████████| 567/567 [09:00<00:00,  1.05batch/s, Train Loss=0.117]


Epoch 12, Training Loss: 0.1174


Validation 12/100: 100%|██████████| 30/30 [00:27<00:00,  1.07batch/s, Val Loss=0.123]


Epoch 12, Validation Loss: 0.1226
Learning rate: 0.000100
EarlyStopping counter: 3 out of 5


Epoch 13/100: 100%|██████████| 567/567 [09:36<00:00,  1.02s/batch, Train Loss=0.121]


Epoch 13, Training Loss: 0.1212


Validation 13/100: 100%|██████████| 30/30 [00:31<00:00,  1.05s/batch, Val Loss=0.117]


Epoch 13, Validation Loss: 0.1173
Saving best model with validation loss----------------------------------------------> 0.1173
Learning rate: 0.000100
Validation loss decreased (0.1188 ----------> 0.1173).  Saving model ...


Epoch 14/100: 100%|██████████| 567/567 [09:43<00:00,  1.03s/batch, Train Loss=0.116]


Epoch 14, Training Loss: 0.1164


Validation 14/100: 100%|██████████| 30/30 [00:32<00:00,  1.10s/batch, Val Loss=0.12] 


Epoch 14, Validation Loss: 0.1204
Learning rate: 0.000100
EarlyStopping counter: 1 out of 5


Epoch 15/100: 100%|██████████| 567/567 [10:16<00:00,  1.09s/batch, Train Loss=0.119]


Epoch 15, Training Loss: 0.1188


Validation 15/100: 100%|██████████| 30/30 [00:29<00:00,  1.01batch/s, Val Loss=0.123]


Epoch 15, Validation Loss: 0.1225
Learning rate: 0.000100
EarlyStopping counter: 2 out of 5


Epoch 16/100: 100%|██████████| 567/567 [10:06<00:00,  1.07s/batch, Train Loss=0.116]


Epoch 16, Training Loss: 0.1162


Validation 16/100: 100%|██████████| 30/30 [00:31<00:00,  1.04s/batch, Val Loss=0.124]


Epoch 16, Validation Loss: 0.1242
Learning rate: 0.000100
EarlyStopping counter: 3 out of 5


Epoch 17/100: 100%|██████████| 567/567 [10:01<00:00,  1.06s/batch, Train Loss=0.116]


Epoch 17, Training Loss: 0.1162


Validation 17/100: 100%|██████████| 30/30 [00:32<00:00,  1.08s/batch, Val Loss=0.12] 


Epoch 17, Validation Loss: 0.1204
Learning rate: 0.000100
EarlyStopping counter: 4 out of 5


Epoch 18/100: 100%|██████████| 567/567 [10:01<00:00,  1.06s/batch, Train Loss=0.116]


Epoch 18, Training Loss: 0.1162


Validation 18/100: 100%|██████████| 30/30 [00:31<00:00,  1.06s/batch, Val Loss=0.123]


Epoch 18, Validation Loss: 0.1232
Learning rate: 0.000100
EarlyStopping counter: 5 out of 5
Early stopping


: 

## Test