<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/SAM/Fine_tune_SAM_(segment_anything)_on_a_custom_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Notebook: fine-tune SAM (segment anything) on a custom dataset

In this notebook, we'll reproduce the [MedSAM](https://github.com/bowang-lab/MedSAM) project, which fine-tunes [SAM](https://huggingface.co/docs/transformers/main/en/model_doc/sam) on a dataset of medical images. For demo purposes, we'll use a toy dataset, but this can easily be scaled up.

Resources used to create this notebook (thanks 🙏):
* [Encode blog post](https://encord.com/blog/learn-how-to-fine-tune-the-segment-anything-model-sam/)
* [MedSAM repository](https://github.com/bowang-lab/MedSAM).

## Set-up environment

We first install 🤗 Transformers and 🤗 Datasets.

In [None]:
!pip install -q git+https://github.com/huggingface/transformers.git

In [None]:
!pip install -q datasets

We also install the [Monai](https://github.com/Project-MONAI/MONAI) repository as we'll use a custom loss function from it.

In [None]:
!pip install -q monai

## Load dataset

Here we load a small dataset of 130 (image, ground truth mask) pairs.

To load your own images and masks, refer to the bottom of my [SAM inference notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SAM/Run_inference_with_MedSAM_using_HuggingFace_Transformers.ipynb).

See also [this guide](https://huggingface.co/docs/datasets/image_dataset).

In [None]:
!pip install kaggle
from google.colab import files
files.upload()



In [None]:
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/


In [None]:
!chmod 600 ~/.kaggle/kaggle.json


In [None]:
!kaggle datasets download -d xhlulu/panda-resized-train-data-512x512


In [None]:
import zipfile

# Replace 'panda-resized-train-data-512x512.zip' with the name of the zip file you downloaded
with zipfile.ZipFile('panda-resized-train-data-512x512.zip', 'r') as zip_ref:
    zip_ref.extractall('./panda_resized_train_data')


In [None]:
import matplotlib.pyplot as plt
from PIL import Image
# Open the image
img = Image.open('/content/panda_resized_train_data/train_images/train_images/0005f7aaab2800f6170c399693a96917.png')

# Display the image using Matplotlib
plt.imshow(img)
plt.axis('off')  # Turn off axis numbers and ticks
plt.show()


In [None]:
import os
from PIL import Image
import numpy as np
import gc  # for clearing memory
from tqdm import tqdm

# Define paths to image and mask folders
image_folder = "/content/panda_resized_train_data/train_images/train_images"
mask_folder = "/content/panda_resized_train_data/train_label_masks/train_label_masks"

# Output directories to save patchified images
output_img_dir = "/content/patchified_data/images"
output_mask_dir = "/content/patchified_data/masks"
os.makedirs(output_img_dir, exist_ok=True)
os.makedirs(output_mask_dir, exist_ok=True)

# Define patch size and step
patch_size = 256
step = 256

def patchify_and_save(image, mask, image_name):
    """
    Create non-overlapping patches from image and mask,
    and save them individually to disk.
    """
    # Pad image and mask to be divisible by patch_size
    def pad(arr):
        h, w = arr.shape[:2]
        pad_h = (patch_size - (h % patch_size)) % patch_size
        pad_w = (patch_size - (w % patch_size)) % patch_size
        if arr.ndim == 3:
            return np.pad(arr, ((0, pad_h), (0, pad_w), (0, 0)), mode='constant')
        else:
            return np.pad(arr, ((0, pad_h), (0, pad_w)), mode='constant')

    image_padded = pad(image)
    mask_padded = pad(mask)

    # Extract and save patches
    h, w = image_padded.shape[:2]
    patch_index = 0
    for y in range(0, h, step):
        for x in range(0, w, step):
            patch_img = image_padded[y:y+patch_size, x:x+patch_size]
            patch_mask = mask_padded[y:y+patch_size, x:x+patch_size]

            img_patch_path = os.path.join(output_img_dir, f"{image_name}_{patch_index}.png")
            mask_patch_path = os.path.join(output_mask_dir, f"{image_name}_{patch_index}.png")

            Image.fromarray(patch_img.astype(np.uint8)).save(img_patch_path)
            Image.fromarray(patch_mask.astype(np.uint8)).save(mask_patch_path)
            patch_index += 1

# Get file names
image_files = sorted(os.listdir(image_folder))
mask_files = sorted(os.listdir(mask_folder))

# Loop over files and process one pair at a time
for img_file, mask_file in tqdm(zip(image_files, mask_files), total=len(image_files)):
    img_path = os.path.join(image_folder, img_file)
    mask_path = os.path.join(mask_folder, mask_file)

    image = np.array(Image.open(img_path).convert("RGB"))
    mask = np.array(Image.open(mask_path).convert("L"))  # grayscale mask

    image_name = os.path.splitext(img_file)[0]
    patchify_and_save(image, mask, image_name)

    # Free memory
    del image, mask
    gc.collect()

print("✅ Patchification completed. Files saved in:")
print(f"  Images -> {output_img_dir}")
print(f"  Masks  -> {output_mask_dir}")



## Create PyTorch dataset

Below we define a regular PyTorch dataset, which gives us examples of the data prepared in the format for the model. Each example consists of:

* pixel values (which is the image prepared for the model)
* a prompt in the form of a bounding box
* a ground truth segmentation mask.

The function below defines how to get a bounding box prompt based on the ground truth segmentation. This was taken from [here](https://github.com/bowang-lab/MedSAM/blob/66cf4799a9ab9a8e08428a5087e73fc21b2b61cd/train.py#L29).

Note that SAM is always trained using certain "prompts", which you could be bounding boxes, points, text, or rudimentary masks. The model is then trained to output the appropriate mask given the image + prompt.

In [None]:
def get_bounding_box(ground_truth_map):
    # Get nonzero pixel indices
    y_indices, x_indices = np.where(ground_truth_map > 0)

    # Handle empty mask safely
    if len(x_indices) == 0 or len(y_indices) == 0:
        H, W = ground_truth_map.shape
        return [0, 0, W, H]

    # Normal bounding box calculation
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)

    # Add small random perturbation
    H, W = ground_truth_map.shape
    x_min = max(0, x_min - np.random.randint(0, 20))
    x_max = min(W, x_max + np.random.randint(0, 20))
    y_min = max(0, y_min - np.random.randint(0, 20))
    y_max = min(H, y_max + np.random.randint(0, 20))

    return [x_min, y_min, x_max, y_max]


In [None]:
import os
from torch.utils.data import Dataset
from PIL import Image
import numpy as np

# Define your dataset class
class CustomSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_files = sorted(os.listdir(image_dir))
        self.mask_files = sorted(os.listdir(mask_dir))
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])

        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"))

        # Optional: apply transforms (e.g., ToTensor, normalization)
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]

        return {"image": image, "mask": mask}

# 🟢 Initialize your dataset here
train_dataset = CustomSegmentationDataset(
    image_dir="/content/patchified_data/images",
    mask_dir="/content/patchified_data/masks"
)

# Now your existing code will work fine
sample = train_dataset[0]
for k, v in sample.items():
    print(k, type(v), getattr(v, "shape", None))


In [None]:
sample = train_dataset[0]
for k, v in sample.items():
    print(k, type(v), getattr(v, "shape", None))


In [None]:
from transformers import SamProcessor

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

In [None]:
import glob
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np

def get_bounding_box(ground_truth_map):
    # Get nonzero pixel indices
    y_indices, x_indices = np.where(ground_truth_map > 0)

    # Handle empty mask safely
    if len(x_indices) == 0 or len(y_indices) == 0:
        H, W = ground_truth_map.shape
        return [0, 0, W, H]

    # Normal bounding box calculation
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)

    # Add small random perturbation
    H, W = ground_truth_map.shape
    x_min = max(0, x_min - np.random.randint(0, 20))
    x_max = min(W, x_max + np.random.randint(0, 20))
    y_min = max(0, y_min - np.random.randint(0, 20))
    y_max = min(H, y_max + np.random.randint(0, 20))

    return [x_min, y_min, x_max, y_max]

class SAMDataset(Dataset):
  def __init__(self, dataset, processor):
    self.dataset = dataset
    self.processor = processor

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    image_path, mask_path = self.dataset[idx]

    image = Image.open(image_path).convert("RGB")
    ground_truth_mask = np.array(Image.open(mask_path).convert("L"))

    # get bounding box prompt from mask
    prompt = get_bounding_box(ground_truth_mask)

    # prepare image and prompt for the model
    inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")

    # remove batch dimension created by the processor
    inputs = {
k:v.squeeze(0) if k != "original_sizes" else v.squeeze(0) for k,v in inputs.items()
}

    # add ground truth segmentation
    inputs["ground_truth_mask"] = torch.tensor(ground_truth_mask, dtype=torch.long)

    return inputs

# 1️⃣ Get file paths for all image-mask pairs
image_paths = sorted(glob.glob("/content/patchified_data/images/*.png"))
mask_paths = sorted(glob.glob("/content/patchified_data/masks/*.png"))

# 2️⃣ Zip them together into a dataset list
dataset = list(zip(image_paths, mask_paths))

# 3️⃣ (Optional) Quick sanity check
print("Total patches found:", len(dataset))
print("Example pair:\n", dataset[0])

# 4️⃣ Now create your SAMDataset object
train_dataset = SAMDataset(dataset=dataset, processor=processor)

In [None]:
sample = train_dataset[0]
for k, v in sample.items():
    print(k, type(v), getattr(v, "shape", None))

In [None]:
from transformers import SamProcessor

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

In [None]:
import glob

# 1️⃣ Get file paths for all image-mask pairs
image_paths = sorted(glob.glob("/content/patchified_data/images/*.png"))
mask_paths = sorted(glob.glob("/content/patchified_data/masks/*.png"))

# 2️⃣ Zip them together into a dataset list
dataset = list(zip(image_paths, mask_paths))

# 3️⃣ (Optional) Quick sanity check
print("Total patches found:", len(dataset))
print("Example pair:\n", dataset[0])

# 4️⃣ Now create your SAMDataset object
train_dataset = SAMDataset(dataset=dataset, processor=processor)


## Create PyTorch DataLoader

Next we define a PyTorch Dataloader, which allows us to get batches from the dataset.



In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)

In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)

In [None]:
batch["ground_truth_mask"].shape

## Load the model

In [None]:
from transformers import SamModel

model = SamModel.from_pretrained("facebook/sam-vit-base")

# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

## Train the model

In [None]:
from torch.optim import Adam
import monai

# Note: Hyperparameter tuning could improve performance here
optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)

seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

## Inference

Important note here: as we used the Dice loss with `sigmoid=True`, we need to make sure to appropriately apply a sigmoid activation function to the predicted masks. Hence we won't use the processor's `post_process_masks` method here.

In [None]:

# Robust training loop with safe checkpoints, logging, and interruption handling.
# This cell replaces the original training loop. It will:
# - create OUTPUT_DIR ("/content/outputs")
# - save training_config.json and model_config.yaml (minimal) for reproducibility
# - log training progress to training_log.txt
# - save periodic backup checkpoints and epoch checkpoints
# - save an "interrupted" model if manually stopped with Ctrl+C (KeyboardInterrupt)
# - save a loss curve (loss_curve.png) at the end or on interruption
# - ensure minimal additional imports are present in this cell so it runs standalone

import os, json, datetime, traceback
from statistics import mean
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

# ------------------------------
# OUTPUT / CONFIG SETUP
# ------------------------------
OUTPUT_DIR = "/content/outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Basic training config (will overwrite if variables exist above)
training_config = {
    "num_epochs": globals().get("num_epochs", 50),
    "save_every_batches": globals().get("save_every_batches", 200),
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "batch_size": None,
    "timestamp": datetime.datetime.now().isoformat()
}

# Try to capture batch size if available from dataloader
try:
    sample = next(iter(train_dataloader))
    if "pixel_values" in sample:
        training_config["batch_size"] = sample["pixel_values"].shape[0]
except Exception:
    training_config["batch_size"] = training_config.get("batch_size", None)

# Save training_config.json
with open(os.path.join(OUTPUT_DIR, "training_config.json"), "w") as f:
    json.dump(training_config, f, indent=2)

# Minimal model config (you may edit manually later)
model_config = getattr(model, "config", None)
try:
    model_config_to_save = model_config if model_config is not None else {"model_type": type(model).__name__}
except Exception:
    model_config_to_save = {"model_type": type(model).__name__}

with open(os.path.join(OUTPUT_DIR, "model_config.yaml"), "w") as f:
    # simple YAML-ish dump without requiring pyyaml
    f.write("# Minimal model config (auto-generated)\n")
    for k, v in model_config_to_save.items() if isinstance(model_config_to_save, dict) else []:
        f.write(f"{k}: {v}\n")

# Setup logging file
log_path = os.path.join(OUTPUT_DIR, "training_log.txt")
log_f = open(log_path, "a", buffering=1)  # line-buffered

def log(msg):
    ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    line = f"[{ts}] {msg}"
    print(line)
    try:
        log_f.write(line + "\\n")
    except Exception:
        pass

log("STARTING TRAINING (robust loop)")
log(f"Config: {training_config}")

# ------------------------------
# Hyperparams (use existing variables if present)
# ------------------------------
num_epochs = globals().get("num_epochs", training_config["num_epochs"])
save_every_batches = globals().get("save_every_batches", training_config["save_every_batches"])
device = training_config["device"]
model.to(device)

# For plotting loss curve
all_epoch_means = []

# Track total saved checkpoints count for user question later
saved_checkpoints = []

# ------------------------------
# TRAINING
# ------------------------------
try:
    model.train()
    for epoch in range(num_epochs):
        epoch_losses = []
        log(f"Starting Epoch {epoch + 1}/{num_epochs}")

        for batch_idx, batch in enumerate(tqdm(train_dataloader, desc=f\"Epoch {epoch + 1}\")):

            # ---------- Prepare inputs ----------
            pixel_values = batch.get("pixel_values", None)
            if pixel_values is None:
                raise ValueError("Batch does not contain 'pixel_values' key. Please check dataloader.")

            if pixel_values.dim() == 5:
                pixel_values = pixel_values.squeeze(1)

            pixel_values = pixel_values.to(device)
            input_boxes = batch.get("input_boxes", None)
            if input_boxes is not None:
                input_boxes = input_boxes.to(device)

            ground_truth_masks = batch.get("ground_truth_mask", None)
            if ground_truth_masks is None:
                raise ValueError("Batch does not contain 'ground_truth_mask' key.")

            ground_truth_masks = ground_truth_masks.float().to(device)
            if ground_truth_masks.dim() == 3:
                ground_truth_masks = ground_truth_masks.unsqueeze(1)

            # ---------- Forward pass ----------
            outputs = model(
                pixel_values=pixel_values,
                input_boxes=input_boxes,
                multimask_output=False
            )

            predicted_masks = outputs.pred_masks

            if predicted_masks.dim() == 5:
                predicted_masks = predicted_masks.squeeze(1)
            elif predicted_masks.dim() == 3:
                predicted_masks = predicted_masks.unsqueeze(1)

            if predicted_masks.shape[-2:] != ground_truth_masks.shape[-2:]:
                predicted_masks = F.interpolate(
                    predicted_masks,
                    size=ground_truth_masks.shape[-2:],
                    mode='bilinear',
                    align_corners=False
                )

            # ---------- Loss & Backprop ----------
            loss = seg_loss(predicted_masks, ground_truth_masks)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_losses.append(loss.item())

            # ---------- Periodic backup ----------
            if (batch_idx + 1) % save_every_batches == 0:
                ckpt_path = os.path.join(OUTPUT_DIR, f"backup_epoch{epoch+1}_batch{batch_idx+1}.pt")
                torch.save(model.state_dict(), ckpt_path)
                saved_checkpoints.append(ckpt_path)
                log(f\"Saved periodic backup: {ckpt_path} -- loss: {loss.item():.6f}\")

            # Optional: log progress each N batches to console/file
            if (batch_idx + 1) % 100 == 0:
                log(f\"Batch {batch_idx+1} | Batch Loss: {loss.item():.6f}\")

        # ---------- End of epoch ----------
        epoch_ckpt = os.path.join(OUTPUT_DIR, f"model_epoch_{epoch+1}.pt")
        torch.save(model.state_dict(), epoch_ckpt)
        saved_checkpoints.append(epoch_ckpt)
        log(f\"Saved epoch checkpoint → {epoch_ckpt}\")

        mean_loss = mean(epoch_losses) if len(epoch_losses) else float('nan')
        all_epoch_means.append(mean_loss)
        log(f\"EPOCH {epoch+1} COMPLETED | Mean Loss: {mean_loss:.6f}\")

    # Save final model
    final_path = os.path.join(OUTPUT_DIR, "model_final.pt")
    torch.save(model.state_dict(), final_path)
    saved_checkpoints.append(final_path)
    log(f\"Training finished. Saved final model → {final_path}\")

except KeyboardInterrupt:
    # Manual stop by user (Ctrl+C). Save an interrupted checkpoint.
    interrupted_path = os.path.join(OUTPUT_DIR, "interrupted_model.pt")
    try:
        torch.save(model.state_dict(), interrupted_path)
        saved_checkpoints.append(interrupted_path)
        log(\"Manual KeyboardInterrupt received. Saved interrupted model.\")
        log(f\"Saved at: {interrupted_path}\")
    except Exception as ex:
        log(f\"Failed saving interrupted checkpoint: {ex}\")
    # Still re-raise so the notebook cell shows interruption unless you want to suppress
    raise

except Exception as e:
    # Unexpected crash: save emergency checkpoint and the traceback
    emergency_path = os.path.join(OUTPUT_DIR, "crashed_model.pt")
    try:
        torch.save(model.state_dict(), emergency_path)
        saved_checkpoints.append(emergency_path)
        log(\"ERROR occurred during training. Saved emergency checkpoint.\")
        log(f\"Saved at: {emergency_path}\")
    except Exception as ex:
        log(f\"Failed saving emergency checkpoint: {ex}\")
    log(\"Traceback:\")
    log(traceback.format_exc())
    raise

finally:
    # Close log file safely
    try:
        log_f.close()
    except Exception:
        pass

    # Save loss curve if we have epoch means
    try:
        if len(all_epoch_means) > 0:
            plt.figure()
            plt.plot(range(1, len(all_epoch_means) + 1), all_epoch_means)
            plt.xlabel('Epoch')
            plt.ylabel('Mean Loss')
            plt.title('Training Loss Curve (mean loss per epoch)')
            loss_png = os.path.join(OUTPUT_DIR, 'loss_curve.png')
            plt.savefig(loss_png, bbox_inches='tight')
            log(f\"Saved loss curve → {loss_png}\")
    except Exception as ex:
        print(\"Failed saving loss curve:\", ex)

    # Write a small results_summary.md
    try:
        summary_path = os.path.join(OUTPUT_DIR, \"results_summary.md\")
        with open(summary_path, \"w\") as sf:
            sf.write(\"# Training Results Summary\\n\\n\")
            sf.write(f\"**Timestamp:** {datetime.datetime.now().isoformat()}\\n\\n\")
            sf.write(f\"**Saved checkpoints:**\\n\\n\") 
            for ck in saved_checkpoints:
                sf.write(f\"- {ck}\\n\")
            sf.write(\"\\n**Notes:** Automatic summary created by notebook.\\n\")
    except Exception as ex:
        print(\"Failed writing summary:\", ex)

    log(\"CLEANUP DONE. Check the outputs folder for saved artifacts.\")


In [None]:
import numpy as np
from PIL import Image

# let's take a random training example
idx = 10

# load image
image = dataset[idx]["image"]
image

In [None]:
# get box prompt based on ground truth segmentation map
ground_truth_mask = np.array(dataset[idx]["label"])
prompt = get_bounding_box(ground_truth_mask)

# prepare image + box prompt for the model
inputs = processor(image, input_boxes=[[prompt]], return_tensors="pt").to(device)
for k,v in inputs.items():
  print(k,v.shape)

In [None]:
model.eval()

# forward pass
with torch.no_grad():
  outputs = model(**inputs, multimask_output=False)

In [None]:
# apply sigmoid
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)

In [None]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

fig, axes = plt.subplots()

axes.imshow(np.array(image))
show_mask(medsam_seg, axes)
axes.title.set_text(f"Predicted mask")
axes.axis("off")

Compare this to the ground truth segmentation:

In [None]:
fig, axes = plt.subplots()

axes.imshow(np.array(image))
show_mask(ground_truth_mask, axes)
axes.title.set_text(f"Ground truth mask")
axes.axis("off")

## Legacy

The code below was used during the creation of this notebook, but was eventually not used anymore.

In [None]:
import torch.nn.functional as F
from typing import Tuple
from torch.nn import MSELoss

loss_fn = MSELoss()

def postprocess_masks(masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], image_size=1024) -> torch.Tensor:
    """
    Remove padding and upscale masks to the original image size.

    Args:
      masks (torch.Tensor):
        Batched masks from the mask_decoder, in BxCxHxW format.
      input_size (tuple(int, int)):
        The size of the image input to the model, in (H, W) format. Used to remove padding.
      original_size (tuple(int, int)):
        The original size of the image before resizing for input to the model, in (H, W) format.

    Returns:
      (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
        is given by original_size.
    """
    masks = F.interpolate(
        masks,
        (image_size, image_size),
        mode="bilinear",
        align_corners=False,
    )
    masks = masks[..., : input_size[0], : input_size[1]]
    masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
    return masks

In [None]:
# upscaled_masks = postprocess_masks(low_res_masks.squeeze(1), batch["reshaped_input_sizes"][0].tolist(), batch["original_sizes"][0].tolist()).to(device)
# predicted_masks = normalize(threshold(upscaled_masks, 0.0, 0)).squeeze(1)
# loss = loss_fn(predicted_masks, ground_truth_masks)