In [1]:
# Fine-tune SAM on Liebherr Dataset
# https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/SAM/Fine_tune_SAM_(segment_anything)_on_a_custom_dataset.ipynb#scrollTo=XC35CzLxfdQU

In [2]:
from segment_anything import SamPredictor, sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide

import monai
import tqdm
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from torch.nn.parallel import DataParallel
from torchvision.transforms import ToTensor, Compose
from torch.utils.data import DataLoader, random_split
from torch.nn.functional import threshold, normalize
from torchmetrics.classification import BinaryJaccardIndex

from test import is_valid_file
from utils import Embedding_Dataset, SAMPreprocess, PILToNumpy, NumpyToTensor, sample_point, SAMPostprocess

2023-08-27 17:42:30.877908: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-08-27 17:42:31.635927: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
class SAM_Fine_Tune(torch.nn.Module):
    def __init__(self):
        super(SAM_Fine_Tune, self).__init__()
        self.sam_model = sam_model_registry['vit_b'](checkpoint='sam_vit_b_01ec64.pth')
        self.img_size = self.sam_model.image_encoder.img_size
        self.postprocess_masks = SAMPostprocess(self.img_size)

    def forward(self, embeddings, points):
        labels = torch.ones(embeddings.shape[0], 1)
        labels.to(points.device)
        sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder(
          points=(points.unsqueeze(1), labels),
          boxes=None,
          masks=None
        )
        masks, iou_predictions = self.sam_model.mask_decoder(
          image_embeddings=embeddings,
          image_pe=self.sam_model.prompt_encoder.get_dense_pe(),
          sparse_prompt_embeddings=sparse_embeddings,
          dense_prompt_embeddings=dense_embeddings,
          multimask_output=False, # fine-tuning such that the first decoder output is the best
        )
        masks = self.postprocess_masks(masks)
        #masks = normalize(threshold(masks, 0.0, 0)).to(device) # sigmoid is set to true in dice
        return masks # B, 1, 256, 256

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SAM_Fine_Tune()
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = DataParallel(model)
model.to(device)

# make sure we only compute gradients for mask decoder
for name, param in model.sam_model.named_parameters():
    if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
        param.requires_grad_(False)

jaccard = BinaryJaccardIndex().to(device)

In [4]:
sam_transform = ResizeLongestSide(model.img_size)
target_transform = Compose([
    sam_transform.apply_image_torch, # rescale
    SAMPreprocess(model.img_size, normalize=False), # padding
    sample_point,
])
transform = Compose([
    PILToNumpy(),
    sam_transform.apply_image, # rescale
    NumpyToTensor(),
    SAMPreprocess(model.img_size) # padding
])
def custom_collate(batch):
    images, targets, embeddings = zip(*batch)
    masks, points = zip(*targets)
    return torch.stack(images), torch.stack(masks), torch.stack(points), torch.stack(embeddings)

In [5]:
epochs = 10
batch_size = 8
lr = 1e-5

folder_path = '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/Liebherr/dataset'
dataset = Embedding_Dataset(root=folder_path, transform=transform, target_transform=target_transform, is_valid_file=is_valid_file)

dataset_size = len(dataset)
train_size = int(0.7 * dataset_size)
val_size = int(0.15 * dataset_size)
test_size = dataset_size - train_size - val_size
generator = torch.Generator().manual_seed(42)
train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size], generator)

train_loader = DataLoader(train_set, batch_size=batch_size, collate_fn=custom_collate)
val_loader = DataLoader(val_set, batch_size=batch_size, collate_fn=custom_collate)
test_loader = DataLoader(test_set, batch_size=batch_size, collate_fn=custom_collate)

In [6]:
optimizer = torch.optim.Adam(model.sam_model.mask_decoder.parameters(), lr=lr, weight_decay=0)
criterion = monai.losses.DiceFocalLoss(sigmoid=True, squared_pred=True, lambda_focal=20.) # maybe include_background=False

In [7]:
def train_epoch(model, epoch, criterion, optimizer, dataloader, device):
    model.train()
    
    num_batches = len(dataloader)
    num_samples = len(dataloader.dataset)
    
    running_loss = 0.0
    running_iou = 0.0

    for batch, (images, masks, points, embeddings) in enumerate(tqdm.tqdm(dataloader)):
        # Transfer Data to GPU if available
        embeddings, points, masks = embeddings.to(device), points.to(device), masks.to(device)

        # Clear the gradients
        optimizer.zero_grad()

        # Forward Pass
        outputs = model(embeddings, points)

        # Compute Loss
        loss = criterion(outputs, masks)

        # Calculate gradients
        loss.backward()

        # Update Weights
        optimizer.step()

        # Calculate Loss
        running_loss += loss.item() * embeddings.size(0)
        running_iou += jaccard(masks > 0, outputs > 0) * embeddings.size(0)

    epoch_loss = running_loss / num_samples
    epoch_iou = running_iou / num_samples

    return epoch_loss, epoch_iou

In [8]:
def test_epoch(model, epoch, criterion, optimizer, dataloader, device):
    model.eval()
    
    num_batches = len(dataloader)
    num_samples = len(dataloader.dataset)
    
    with torch.no_grad():
        running_loss = 0.0
        running_iou = 0.0

        for batch, (images, masks, points, embeddings) in enumerate(tqdm.tqdm(dataloader)):
            # Transfer Data to GPU if available
            embeddings, points, masks = embeddings.to(device), points.to(device), masks.to(device)

            # Clear the gradients
            optimizer.zero_grad()
            
            # Forward Pass
            outputs = model(embeddings, points)

            # Compute Loss
            loss = criterion(outputs, masks)

            # Calculate Loss
            running_loss += loss.item() * embeddings.size(0)
            running_iou += jaccard(masks > 0, outputs > 0) * embeddings.size(0)
            
        epoch_loss = running_loss / num_samples
        epoch_iou = running_iou / num_samples

    return epoch_loss, epoch_iou

In [9]:
import wandb

wandb_config = {
    "epochs": epochs,
    "lr": lr,
    "batch_size": batch_size,
}

wandb.init(project="Fine-Tune-SAM", entity="frankfundel", config=wandb_config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mfrankfundel[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
min_val_loss = np.inf

for epoch in range(epochs):
    end = time.time()
    print(f"==================== Starting at epoch {epoch} ====================", flush=True)
    
    train_loss, train_iou = train_epoch(model, epoch, criterion, optimizer, train_loader, device)
    print('Training loss: {:.4f} IoU: {:.4f}'.format(train_loss, train_iou), flush=True)
    
    val_loss, val_iou = test_epoch(model, epoch, criterion, optimizer, val_loader, device)
    print('Validation loss: {:.4f} IoU: {:.4f}'.format(val_loss, val_iou), flush=True)
    
    wandb.log({
        "train_loss": train_loss,
        "train_iou": train_iou,
        "val_loss": val_loss,
        "val_iou": val_iou,
    })
    
    if min_val_loss > val_loss:
        print('val_loss decreased, saving model', flush=True)
        min_val_loss = val_loss
        
        # Saving State Dict
        torch.save(model.state_dict(), 'Fine-Tune-SAM.pth')



100%|██████████| 75/75 [04:28<00:00,  3.58s/it]

Training loss: 0.4849 IoU: 0.5015



100%|██████████| 16/16 [00:59<00:00,  3.69s/it]

Validation loss: 0.3536 IoU: 0.5636
val_loss decreased, saving model







100%|██████████| 75/75 [04:05<00:00,  3.27s/it]

Training loss: 0.3373 IoU: 0.5775



100%|██████████| 16/16 [00:52<00:00,  3.27s/it]

Validation loss: 0.3532 IoU: 0.5753
val_loss decreased, saving model







100%|██████████| 75/75 [03:55<00:00,  3.14s/it]

Training loss: 0.3152 IoU: 0.6143



100%|██████████| 16/16 [00:52<00:00,  3.28s/it]

Validation loss: 0.2868 IoU: 0.6375
val_loss decreased, saving model







100%|██████████| 75/75 [04:01<00:00,  3.22s/it]

Training loss: 0.3207 IoU: 0.6017



100%|██████████| 16/16 [00:52<00:00,  3.25s/it]

Validation loss: 0.2775 IoU: 0.6639
val_loss decreased, saving model







100%|██████████| 75/75 [03:59<00:00,  3.19s/it]

Training loss: 0.2961 IoU: 0.6117



100%|██████████| 16/16 [00:52<00:00,  3.25s/it]

Validation loss: 0.2587 IoU: 0.6372
val_loss decreased, saving model







100%|██████████| 75/75 [04:01<00:00,  3.22s/it]

Training loss: 0.3030 IoU: 0.6017



100%|██████████| 16/16 [00:52<00:00,  3.29s/it]

Validation loss: 0.2855 IoU: 0.6254



100%|██████████| 75/75 [03:56<00:00,  3.15s/it]

Training loss: 0.2915 IoU: 0.6212



100%|██████████| 16/16 [00:51<00:00,  3.21s/it]

Validation loss: 0.2812 IoU: 0.6092



100%|██████████| 75/75 [04:00<00:00,  3.20s/it]

Training loss: 0.2903 IoU: 0.6332



100%|██████████| 16/16 [00:51<00:00,  3.24s/it]

Validation loss: 0.2357 IoU: 0.6547
val_loss decreased, saving model







100%|██████████| 75/75 [03:54<00:00,  3.12s/it]

Training loss: 0.2706 IoU: 0.6578



100%|██████████| 16/16 [00:51<00:00,  3.23s/it]

Validation loss: 0.2600 IoU: 0.6456



100%|██████████| 75/75 [03:59<00:00,  3.20s/it]

Training loss: 0.2647 IoU: 0.6560



100%|██████████| 16/16 [00:52<00:00,  3.31s/it]

Validation loss: 0.2824 IoU: 0.6133





In [11]:
# Load after training
model.load_state_dict(torch.load('Fine-Tune-SAM.pth'))

<All keys matched successfully>

In [12]:
test_loss, test_iou = test_epoch(model, 0, criterion, optimizer, test_loader, device)
print('Test loss: {:.4f} IoU: {:.4f}'.format(test_loss, test_iou), flush=True)

100%|██████████| 16/16 [01:02<00:00,  3.91s/it]

Test loss: 0.2719 IoU: 0.6458





In [13]:
wandb.log({
    "test_loss": test_loss,
    "test_iou": test_iou
})

wandb.finish()

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
test_iou,▁
test_loss,▁
train_iou,▁▄▆▅▆▅▆▇██
train_loss,█▃▃▃▂▂▂▂▁▁
val_iou,▁▂▆█▆▅▄▇▇▄
val_loss,██▄▃▂▄▄▁▂▄

0,1
test_iou,0.64576
test_loss,0.27187
train_iou,0.65602
train_loss,0.26473
val_iou,0.61334
val_loss,0.28243
