In [1]:
# Fine-tune SAM on Liebherr Dataset
# https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/SAM/Fine_tune_SAM_(segment_anything)_on_a_custom_dataset.ipynb#scrollTo=XC35CzLxfdQU

In [2]:
import monai
import tqdm
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from torch.nn.parallel import DataParallel
from torchvision.transforms import ToTensor, Compose
from torch.utils.data import DataLoader, random_split
from torch.nn.functional import threshold, normalize
from torchmetrics.classification import BinaryJaccardIndex

from segment_anything.utils.transforms import ResizeLongestSide

from datasets import Embedding_Dataset
from utils import SAMPreprocess, PILToNumpy, NumpyToTensor, SamplePoint, embedding_collate, is_valid_file
from models import SAM_Fine_Tune

2023-08-28 16:06:10.577610: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-08-28 16:06:11.097287: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SAM_Fine_Tune()
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = DataParallel(model)
model.to(device)

# make sure we only compute gradients for mask decoder
for name, param in model.sam_model.named_parameters():
    if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
        param.requires_grad_(False)

jaccard = BinaryJaccardIndex().to(device)

In [4]:
sam_transform = ResizeLongestSide(model.img_size)
target_transform = Compose([
    sam_transform.apply_image_torch, # rescale
    SAMPreprocess(model.img_size, normalize=False), # padding
    SamplePoint(),
])
transform = Compose([
    PILToNumpy(),
    sam_transform.apply_image, # rescale
    NumpyToTensor(),
    SAMPreprocess(model.img_size) # padding
])

In [5]:
epochs = 10
batch_size = 8
lr = 1e-5

folder_path = '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/Liebherr/dataset'
dataset = Embedding_Dataset(root=folder_path, transform=transform, target_transform=target_transform, is_valid_file=is_valid_file)

dataset_size = len(dataset)
train_size = int(0.7 * dataset_size)
val_size = int(0.15 * dataset_size)
test_size = dataset_size - train_size - val_size
generator = torch.Generator().manual_seed(42)
train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size], generator)

train_loader = DataLoader(train_set, batch_size=batch_size, collate_fn=embedding_collate)
val_loader = DataLoader(val_set, batch_size=batch_size, collate_fn=embedding_collate)
test_loader = DataLoader(test_set, batch_size=batch_size, collate_fn=embedding_collate)

In [6]:
optimizer = torch.optim.Adam(model.sam_model.mask_decoder.parameters(), lr=lr, weight_decay=0)
criterion = monai.losses.DiceFocalLoss(sigmoid=True, squared_pred=True, lambda_focal=20.) # maybe include_background=False

In [7]:
def train_epoch(model, epoch, criterion, optimizer, dataloader, device):
    model.train()
    
    num_batches = len(dataloader)
    num_samples = len(dataloader.dataset)
    
    running_loss = 0.0
    running_iou = 0.0

    for batch, (images, masks, points, embeddings) in enumerate(tqdm.tqdm(dataloader)):
        # Transfer Data to GPU if available
        embeddings, points, masks = embeddings.to(device), points.to(device), masks.to(device)

        # Clear the gradients
        optimizer.zero_grad()

        # Forward Pass
        outputs = model(embeddings, points)

        # Compute Loss
        loss = criterion(outputs, masks)

        # Calculate gradients
        loss.backward()

        # Update Weights
        optimizer.step()

        # Calculate Loss
        running_loss += loss.item() * embeddings.size(0)
        running_iou += jaccard(masks > 0, outputs > 0) * embeddings.size(0)

    epoch_loss = running_loss / num_samples
    epoch_iou = running_iou / num_samples

    return epoch_loss, epoch_iou

In [8]:
def test_epoch(model, epoch, criterion, optimizer, dataloader, device):
    model.eval()
    
    num_batches = len(dataloader)
    num_samples = len(dataloader.dataset)
    
    with torch.no_grad():
        running_loss = 0.0
        running_iou = 0.0

        for batch, (images, masks, points, embeddings) in enumerate(tqdm.tqdm(dataloader)):
            # Transfer Data to GPU if available
            embeddings, points, masks = embeddings.to(device), points.to(device), masks.to(device)

            # Clear the gradients
            optimizer.zero_grad()
            
            # Forward Pass
            outputs = model(embeddings, points)

            # Compute Loss
            loss = criterion(outputs, masks)

            # Calculate Loss
            running_loss += loss.item() * embeddings.size(0)
            running_iou += jaccard(masks > 0, outputs > 0) * embeddings.size(0)
            
        epoch_loss = running_loss / num_samples
        epoch_iou = running_iou / num_samples

    return epoch_loss, epoch_iou

In [9]:
import wandb

wandb_config = {
    "epochs": epochs,
    "lr": lr,
    "batch_size": batch_size,
}

wandb.init(project="Fine-Tune-SAM", entity="frankfundel", config=wandb_config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mfrankfundel[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666995108438035, max=1.0)…

In [10]:
min_val_loss = np.inf

for epoch in range(epochs):
    end = time.time()
    print(f"==================== Starting at epoch {epoch} ====================", flush=True)
    
    train_loss, train_iou = train_epoch(model, epoch, criterion, optimizer, train_loader, device)
    print('Training loss: {:.4f} IoU: {:.4f}'.format(train_loss, train_iou), flush=True)
    
    val_loss, val_iou = test_epoch(model, epoch, criterion, optimizer, val_loader, device)
    print('Validation loss: {:.4f} IoU: {:.4f}'.format(val_loss, val_iou), flush=True)
    
    wandb.log({
        "train_loss": train_loss,
        "train_iou": train_iou,
        "val_loss": val_loss,
        "val_iou": val_iou,
    })
    
    if min_val_loss > val_loss:
        print('val_loss decreased, saving model', flush=True)
        min_val_loss = val_loss
        
        # Saving State Dict
        torch.save(model.state_dict(), 'Fine-Tune-SAM.pt')



100%|██████████| 75/75 [04:27<00:00,  3.57s/it]

Training loss: 0.3793 IoU: 0.5782



100%|██████████| 16/16 [00:57<00:00,  3.61s/it]

Validation loss: 0.2979 IoU: 0.6450
val_loss decreased, saving model







100%|██████████| 75/75 [03:03<00:00,  2.45s/it]

Training loss: 0.2628 IoU: 0.6501



100%|██████████| 16/16 [00:37<00:00,  2.32s/it]

Validation loss: 0.2488 IoU: 0.6742
val_loss decreased, saving model







100%|██████████| 75/75 [02:52<00:00,  2.30s/it]

Training loss: 0.2254 IoU: 0.6885



100%|██████████| 16/16 [00:37<00:00,  2.33s/it]

Validation loss: 0.2457 IoU: 0.6746
val_loss decreased, saving model







100%|██████████| 75/75 [02:56<00:00,  2.35s/it]

Training loss: 0.2203 IoU: 0.6917



100%|██████████| 16/16 [00:37<00:00,  2.32s/it]

Validation loss: 0.2215 IoU: 0.6967
val_loss decreased, saving model







100%|██████████| 75/75 [02:50<00:00,  2.27s/it]

Training loss: 0.2226 IoU: 0.6848



100%|██████████| 16/16 [00:37<00:00,  2.34s/it]

Validation loss: 0.1905 IoU: 0.7288
val_loss decreased, saving model







100%|██████████| 75/75 [02:49<00:00,  2.26s/it]

Training loss: 0.2151 IoU: 0.6994



100%|██████████| 16/16 [00:37<00:00,  2.34s/it]

Validation loss: 0.2049 IoU: 0.6916



100%|██████████| 75/75 [02:49<00:00,  2.26s/it]

Training loss: 0.2041 IoU: 0.7059



100%|██████████| 16/16 [00:37<00:00,  2.31s/it]

Validation loss: 0.2134 IoU: 0.7053



100%|██████████| 75/75 [02:49<00:00,  2.26s/it]

Training loss: 0.1992 IoU: 0.7153



100%|██████████| 16/16 [00:37<00:00,  2.32s/it]

Validation loss: 0.2081 IoU: 0.6900



100%|██████████| 75/75 [02:49<00:00,  2.26s/it]

Training loss: 0.1913 IoU: 0.7266



100%|██████████| 16/16 [00:37<00:00,  2.31s/it]

Validation loss: 0.2491 IoU: 0.6697



100%|██████████| 75/75 [02:49<00:00,  2.26s/it]

Training loss: 0.1975 IoU: 0.7169



100%|██████████| 16/16 [00:37<00:00,  2.32s/it]

Validation loss: 0.1792 IoU: 0.7319
val_loss decreased, saving model





In [11]:
# Load after training
model.load_state_dict(torch.load('checkpoints/Fine-Tune-SAM.pth'))

<All keys matched successfully>

In [12]:
test_loss, test_iou = test_epoch(model, 0, criterion, optimizer, test_loader, device)
print('Test loss: {:.4f} IoU: {:.4f}'.format(test_loss, test_iou), flush=True)

100%|██████████| 16/16 [00:51<00:00,  3.25s/it]

Test loss: 0.1752 IoU: 0.7382





In [13]:
wandb.log({
    "test_loss": test_loss,
    "test_iou": test_iou
})

wandb.finish()

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
test_iou,▁
test_loss,▁
train_iou,▁▄▆▆▆▇▇▇██
train_loss,█▄▂▂▂▂▁▁▁▁
val_iou,▁▃▃▅█▅▆▅▃█
val_loss,█▅▅▃▂▃▃▃▅▁

0,1
test_iou,0.7382
test_loss,0.17521
train_iou,0.7169
train_loss,0.19751
val_iou,0.73191
val_loss,0.17916
