In [1]:
import torch
import torch.nn as nn
from transformers import (
    SamVisionConfig,
    SamPromptEncoderConfig,
    SamMaskDecoderConfig,
    SamModel,
    SamProcessor,
    SamImageProcessor
)
import os
import torch
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_tensor
from PIL import Image
import json
import warnings
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from transformers import SamModel, SamConfig

In [2]:
def normalize_per_image(img):
    """
    Normalize an image tensor by dividing each pixel value by the maximum value in the image (plus 1).

    Args:
        img (torch.Tensor): Image tensor of shape (batch_size, channels, height, width).

    Returns:
        torch.Tensor: Normalized image tensor with values in the range [0, 1].
    """
    # Compute the maximum value per image
    max_val = img.amax(dim=(-1, -2), keepdim=True)  # Max value per channel
    max_val = max_val + 1  # Add 1 to avoid division by zero

    # Normalize by max value
    normalized_img = img / max_val
    return normalized_img


In [3]:
class HyperspectralExpandedDataset(Dataset):
    def __init__(self, root_dir):
        """
        Args:
            root_dir (str): Root directory containing the expanded dataset.
        """
        self.root_dir = root_dir
        self.samples = self._load_samples()

    def _load_samples(self):
        """
        Scans the directory structure to find all saved samples.

        Returns:
            list: List of dictionaries containing file paths for each sample.
        """
        samples = []
        for sample_name in os.listdir(self.root_dir):
            sample_path = os.path.join(self.root_dir, sample_name)
            if not os.path.isdir(sample_path):
                continue

            # Collect file paths for bands, binary mask, and prompt
            bands_path = os.path.join(sample_path, "bands.pt")
            mask_path = os.path.join(sample_path, "binary_mask.tif")
            prompt_path = os.path.join(sample_path, "prompt.json")

            if os.path.exists(bands_path) and os.path.exists(mask_path) and os.path.exists(prompt_path):
                samples.append({
                    "bands": bands_path,
                    "mask": mask_path,
                    "prompt": prompt_path
                })
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        """
        Loads a sample.

        Args:
            idx (int): Index of the sample.

        Returns:
            tuple: (prompt, bands, binary_mask)
        """
        sample = self.samples[idx]

        bands = None
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.load.*")
            bands = torch.load(sample["bands"])

        binary_mask = to_tensor(Image.open(sample["mask"])).squeeze(0)  # Remove channel dimension

        with open(sample["prompt"], "r") as f:
            prompt = json.load(f)

        return prompt, normalize_per_image(bands), binary_mask


In [4]:
root_dir = "./expanded_dataset_output"
dataset = HyperspectralExpandedDataset(root_dir=root_dir)

# Split dataset: 90% training, 10% evaluation
train_size = int(0.9 * len(dataset))
eval_size = len(dataset) - train_size
train_dataset, eval_dataset = random_split(dataset, [train_size, eval_size])

# DataLoaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size)

In [5]:
print(len(dataset))
for (prompt, img, mask) in train_dataset:
    print(prompt)
    print(img.shape)
    print(mask.shape)
    break

10505
{'centroid': [22.37883169462117, 17.358010410641988], 'random_point': [26, 1]}
torch.Size([12, 120, 120])
torch.Size([120, 120])


In [6]:
for (prompt, img, mask) in train_loader:
    print(prompt)
    print(img.shape)
    print(mask.shape)
    break

{'centroid': [tensor([105.7406,  59.5000,  93.1959,   9.7339,  87.6094,  15.9882,  74.9401,
         94.2621, 111.3608, 109.7624, 108.0814,   1.4706,  39.8946,  85.4658,
          5.8182,  31.7907,  60.0295,  85.5452,  30.9752,  18.6792,   8.1914,
          1.3696,  30.4769, 117.7143,  58.6408,  61.7861,   8.1685,  24.5406,
         66.8777,  59.1497,  58.8262, 114.2101,   4.1600,  93.2593,  17.6250,
         52.3945,  70.6667,  17.6201,   0.3636,   8.0444,  12.7674,  87.1157,
          7.6866, 109.4115,  50.8318, 109.2927,   7.3060,  42.4276,  17.9211,
         60.3002, 111.3866, 109.4826,  85.4127,  17.6509,  47.6922, 104.7000,
         24.4521, 103.7389,  55.2075, 109.7931,   5.8421, 111.9107,  10.4127,
         61.3568,  59.6406,   8.4538,  89.6814, 117.7260,  59.9958,   5.7348,
         48.0000, 113.8051, 113.4851,  45.9837,   5.4828,   9.2998,  71.6631,
          2.0000,  42.1155, 108.7899, 117.4000,  75.6263, 116.7391,  65.1853,
        111.6468,  24.0650,  33.7565,  34.6345,  9

In [7]:
class HyperspectralSAM(nn.Module):
    def __init__(self, sam_checkpoint="facebook/sam-vit-base", num_input_channels=12):
        """
        Adapt SAM for hyperspectral data by modifying the input layer to handle more channels
        and adding a final layer for binary segmentation.

        Args:
            sam_checkpoint (str): Hugging Face SAM model checkpoint.
            num_input_channels (int): Number of input channels for hyperspectral data.
        """
        super(HyperspectralSAM, self).__init__()

        vision_config = SamVisionConfig(num_channels=12, image_size=120)
        decoder_config = SamMaskDecoderConfig(num_multimask_outputs = 1)
        prompt_config  = SamPromptEncoderConfig(image_size=120)
        
        config = SamConfig(vision_config = vision_config, 
                           prompt_encoder_config = prompt_config, 
                           mask_decoder_config = decoder_config, 
                           name_or_path=sam_checkpoint
                          )

        
        # self.processor = SamProcessor(img_processor)
        self.sam_model = SamModel.from_pretrained(sam_checkpoint, config=config, ignore_mismatched_sizes=True)
        self.sam_model.train()
    def forward(self, pixel_values, input_points=None):
        """
        Forward pass for the adapted SAM model.

        Args:
            pixel_values (torch.Tensor): Input tensor of shape (batch_size, num_channels, height, width).
            input_points (torch.Tensor, optional): Points as input prompts, of shape (batch_size, num_points, 2).
            input_boxes (torch.Tensor, optional): Boxes as input prompts, of shape (batch_size, num_boxes, 4).
            input_masks (torch.Tensor, optional): Masks as input prompts, of shape (batch_size, height, width).

        Returns:
            torch.Tensor: Binary segmentation logits of shape (batch_size, 1, height, width).
        """

        outputs = self.sam_model(
            pixel_values=pixel_values,
            input_points=input_points
        )
        return outputs
        # outputs["iou_scores"]
        # logits = self.final_conv(outputs["pred_masks"][:, 0, :, :, :])
        # return {"pred_masks": logits}


In [8]:
from torchvision.transforms import Resize, Normalize
from tqdm import tqdm

def train_model(train_loader, model, optimizer, criterion, device):
    """
    Training loop for a model that processes image, mask, and prompt data from a train loader.

    Args:
        train_loader (DataLoader): DataLoader for the training dataset.
        model (torch.nn.Module): Model to train.
        optimizer (torch.optim.Optimizer): Optimizer for training.
        criterion (torch.nn.Module): Loss function.
        device (torch.device): Device to use for training ('cuda' or 'cpu').

    Returns:
        None
    """
    model.train()  # Set the model to training mode
    total_loss = 0.0

    # Resize and Normalize transformations
    # Explicit normalization for RGB and hyperspectral bands
    # normalize_img = Normalize(
    #     mean=[0.5, 0.485, 0.456, 0.406] + [0.5] * 8,  # RGB + Hyperspectral
    #     std=[0.5, 0.229, 0.224, 0.225] + [0.5] * 8   # RGB + Hyperspectral
    # )

    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc="Training", unit="batch")

    for batch_idx, (prompt, img, mask) in progress_bar:
        # Preprocess the image
        # img = normalize_img(img)  # Normalize input images
        img = img.to(device)  # Move to the correct device
        # print(torch.min(img), torch.max(img))
        # Preprocess the mask
        mask = mask.to(device)

        # Preprocess input points
        random_point_x, random_point_y = prompt['centroid']
        random_point = torch.stack((random_point_x, random_point_y), dim=-1).to(device)  # Combine and move to device
        random_point = random_point.unsqueeze(1).unsqueeze(2).to(device)  # Shape: (batch_size, 1, 1, 2)
        # Forward pass
        optimizer.zero_grad()
        predictions = model(pixel_values=img, input_points=random_point)

        # Resize mask to match the predictions' spatial dimensions
        predictions_shape = predictions["pred_masks"].shape[-2:]  # (height, width)
        resize_mask = Resize(predictions_shape, antialias=True)  # Dynamically adjust mask size
        mask = resize_mask(mask)
        mask = (mask > 0.5).float()
        # print(mask)
        # Calculate loss
        loss = criterion(predictions["pred_masks"], mask.unsqueeze(1).unsqueeze(1).float())

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    # Average loss over all batches
    avg_loss = total_loss / len(train_loader)
    print(f"Training completed. Average Loss: {avg_loss}")


In [9]:
# Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HyperspectralSAM(num_input_channels=12).to(device)

Some weights of SamModel were not initialized from the model checkpoint at facebook/sam-vit-base and are newly initialized because the shapes did not match:
- mask_decoder.iou_prediction_head.proj_out.bias: found shape torch.Size([4]) in the checkpoint and torch.Size([2]) in the model instantiated
- mask_decoder.iou_prediction_head.proj_out.weight: found shape torch.Size([4, 256]) in the checkpoint and torch.Size([2, 256]) in the model instantiated
- mask_decoder.mask_tokens.weight: found shape torch.Size([4, 256]) in the checkpoint and torch.Size([2, 256]) in the model instantiated
- vision_encoder.layers.11.attn.rel_pos_h: found shape torch.Size([127, 64]) in the checkpoint and torch.Size([13, 64]) in the model instantiated
- vision_encoder.layers.11.attn.rel_pos_w: found shape torch.Size([127, 64]) in the checkpoint and torch.Size([13, 64]) in the model instantiated
- vision_encoder.layers.2.attn.rel_pos_h: found shape torch.Size([127, 64]) in the checkpoint and torch.Size([13, 64])

In [10]:
import torch.nn.functional as F

class DiceBCELoss(nn.Module):
    def __init__(self, bce_weight=0.7):
        super(DiceBCELoss, self).__init__()
        self.bce_weight = bce_weight

    def forward(self, logits, targets):
        # Apply sigmoid to logits with clamping for stability
        probs = torch.sigmoid(logits).clamp(min=1e-6, max=1-1e-6)

        # Flatten tensors for Dice and BCE calculations
        probs_flat = probs.view(probs.size(0), -1)
        targets_flat = targets.view(targets.size(0), -1)

        # Weighted BCE loss
        bce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        weight = torch.ones_like(targets) * 0.38  # Background weight
        weight[targets == 1] = 1.0  # Foreground weight
        bce_loss = (bce_loss * weight).mean()

        # Dice loss with epsilon to avoid division by zero
        intersection = (probs_flat * targets_flat).sum(dim=1)
        dice_loss = 1 - (2.0 * intersection / (probs_flat.sum(dim=1) + targets_flat.sum(dim=1) + 1e-6)).mean()

        # Combine BCE and Dice loss
        total_loss = self.bce_weight * bce_loss + (1 - self.bce_weight) * dice_loss

        # Validate loss for NaN
        if not torch.isfinite(total_loss):
            print("NaN detected in loss computation")
            print(f"Logits: {logits}")
            print(f"Targets: {targets}")
            print(f"Probs: {probs}")
            print(f"BCE Loss: {bce_loss}, Dice Loss: {dice_loss}")
            raise ValueError("Loss computation resulted in NaN")

        return total_loss


In [11]:
import os
from torch.optim.lr_scheduler import StepLR

experiment_name = "centroid_prompt"
epochs = 50

# Loss function and optimizer
criterion = DiceBCELoss()  # For binary segmentation masks
optimizer = optim.Adam(model.parameters(), lr=5e-4)

# Scheduler to halve the learning rate every 5 epochs
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}, Current LR: {scheduler.get_last_lr()}")
    
    # Train the model
    train_model(train_loader, model, optimizer, criterion, device)
    
    # Save the model checkpoint
    save_dir = f"models/{experiment_name}"
    os.makedirs(save_dir, exist_ok=True)
    model_path = f"{save_dir}/hyperspectral_sam_epoch_{epoch+1}.pt"
    torch.save(model, model_path)
    
    # Step the scheduler
    scheduler.step()


Epoch 1/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:34<00:00,  1.28s/batch, loss=0.332]


Training completed. Average Loss: 0.4028641738601633
Epoch 2/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:29<00:00,  1.21s/batch, loss=0.191]


Training completed. Average Loss: 0.2254990820546408
Epoch 3/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:31<00:00,  1.23s/batch, loss=0.201]


Training completed. Average Loss: 0.20081608339741425
Epoch 4/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:31<00:00,  1.23s/batch, loss=0.168]


Training completed. Average Loss: 0.19630842192752943
Epoch 5/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:31<00:00,  1.23s/batch, loss=0.184]


Training completed. Average Loss: 0.19491329346154188
Epoch 6/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:29<00:00,  1.20s/batch, loss=0.204]


Training completed. Average Loss: 0.19313984464954687
Epoch 7/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:34<00:00,  1.28s/batch, loss=0.192]


Training completed. Average Loss: 0.19169234242793676
Epoch 8/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:29<00:00,  1.21s/batch, loss=0.185]


Training completed. Average Loss: 0.19084249517402133
Epoch 9/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:29<00:00,  1.21s/batch, loss=0.201]


Training completed. Average Loss: 0.1911328314123927
Epoch 10/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:33<00:00,  1.27s/batch, loss=0.194]


Training completed. Average Loss: 0.18878419052910161
Epoch 11/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:27<00:00,  1.19s/batch, loss=0.195]


Training completed. Average Loss: 0.18509201061081243
Epoch 12/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:27<00:00,  1.18s/batch, loss=0.183]


Training completed. Average Loss: 0.1846822528420268
Epoch 13/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:30<00:00,  1.23s/batch, loss=0.187]


Training completed. Average Loss: 0.1844200985254468
Epoch 14/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:30<00:00,  1.22s/batch, loss=0.185]


Training completed. Average Loss: 0.18409485833064929
Epoch 15/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:30<00:00,  1.22s/batch, loss=0.188]


Training completed. Average Loss: 0.18431227472988335
Epoch 16/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:28<00:00,  1.20s/batch, loss=0.185]


Training completed. Average Loss: 0.1835605520251635
Epoch 17/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:31<00:00,  1.24s/batch, loss=0.196]


Training completed. Average Loss: 0.18315811133062518
Epoch 18/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:31<00:00,  1.23s/batch, loss=0.2]  


Training completed. Average Loss: 0.1831262562725995
Epoch 19/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:33<00:00,  1.26s/batch, loss=0.172]


Training completed. Average Loss: 0.18301720957498294
Epoch 20/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:30<00:00,  1.23s/batch, loss=0.181]


Training completed. Average Loss: 0.18325059699851112
Epoch 21/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:31<00:00,  1.23s/batch, loss=0.175]


Training completed. Average Loss: 0.18058806416150686
Epoch 22/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:28<00:00,  1.20s/batch, loss=0.182]


Training completed. Average Loss: 0.17993737616249034
Epoch 23/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:35<00:00,  1.29s/batch, loss=0.172]


Training completed. Average Loss: 0.17948512269838438
Epoch 24/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:31<00:00,  1.24s/batch, loss=0.199]


Training completed. Average Loss: 0.17908597797960848
Epoch 25/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:32<00:00,  1.25s/batch, loss=0.185]


Training completed. Average Loss: 0.17885931077841166
Epoch 26/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [02:07<00:00,  1.72s/batch, loss=0.183]


Training completed. Average Loss: 0.17847177567514214
Epoch 27/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:29<00:00,  1.22s/batch, loss=0.176]


Training completed. Average Loss: 0.17824003664222923
Epoch 28/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:33<00:00,  1.26s/batch, loss=0.179]


Training completed. Average Loss: 0.17790134392074636
Epoch 29/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:30<00:00,  1.23s/batch, loss=0.178]


Training completed. Average Loss: 0.17731785452043689
Epoch 30/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:34<00:00,  1.28s/batch, loss=0.178]


Training completed. Average Loss: 0.17718866004331693
Epoch 31/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:33<00:00,  1.26s/batch, loss=0.155]


Training completed. Average Loss: 0.17548460855677323
Epoch 32/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:32<00:00,  1.25s/batch, loss=0.173]


Training completed. Average Loss: 0.17420577902246165
Epoch 33/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:29<00:00,  1.21s/batch, loss=0.184]


Training completed. Average Loss: 0.17402407889430588
Epoch 34/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:35<00:00,  1.29s/batch, loss=0.161]


Training completed. Average Loss: 0.1732252276427037
Epoch 35/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:33<00:00,  1.26s/batch, loss=0.171]


Training completed. Average Loss: 0.17273326802092628
Epoch 36/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:32<00:00,  1.25s/batch, loss=0.176]


Training completed. Average Loss: 0.17216884807960406
Epoch 37/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:33<00:00,  1.27s/batch, loss=0.136]


Training completed. Average Loss: 0.1717123167740332
Epoch 38/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:31<00:00,  1.24s/batch, loss=0.159]


Training completed. Average Loss: 0.17098433121636109
Epoch 39/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:33<00:00,  1.26s/batch, loss=0.184]


Training completed. Average Loss: 0.17065903967296756
Epoch 40/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:30<00:00,  1.22s/batch, loss=0.182]


Training completed. Average Loss: 0.17023863039306691
Epoch 41/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:35<00:00,  1.29s/batch, loss=0.158]


Training completed. Average Loss: 0.1683791761060019
Epoch 42/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:30<00:00,  1.22s/batch, loss=0.157]


Training completed. Average Loss: 0.16712380945682526
Epoch 43/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:34<00:00,  1.28s/batch, loss=0.178]


Training completed. Average Loss: 0.1663636289335586
Epoch 44/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:40<00:00,  1.36s/batch, loss=0.181]


Training completed. Average Loss: 0.16586488304105965
Epoch 45/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:28<00:00,  1.20s/batch, loss=0.141]


Training completed. Average Loss: 0.16516890739266937
Epoch 46/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:18<00:00,  1.06s/batch, loss=0.167]


Training completed. Average Loss: 0.16459156613092166
Epoch 47/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:17<00:00,  1.05s/batch, loss=0.184]


Training completed. Average Loss: 0.16425005809680834
Epoch 48/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:16<00:00,  1.03s/batch, loss=0.157]


Training completed. Average Loss: 0.16345207977133827
Epoch 49/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:16<00:00,  1.04s/batch, loss=0.157]


Training completed. Average Loss: 0.1630487486317351
Epoch 50/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:16<00:00,  1.03s/batch, loss=0.161]


Training completed. Average Loss: 0.16237938786680634
