In [1]:
import torch
import torch.nn as nn
from transformers import (
    SamVisionConfig,
    SamPromptEncoderConfig,
    SamMaskDecoderConfig,
    SamModel,
    SamProcessor,
    SamImageProcessor
)
import os
import torch
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_tensor
from PIL import Image
import json
import warnings
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from transformers import SamModel, SamConfig

In [2]:
def normalize_per_image(img):
    """
    Normalize an image tensor by dividing each pixel value by the maximum value in the image (plus 1).

    Args:
        img (torch.Tensor): Image tensor of shape (batch_size, channels, height, width).

    Returns:
        torch.Tensor: Normalized image tensor with values in the range [0, 1].
    """
    # Compute the maximum value per image
    max_val = img.amax(dim=(-1, -2), keepdim=True)  # Max value per channel
    max_val = max_val + 1  # Add 1 to avoid division by zero

    # Normalize by max value
    normalized_img = img / max_val
    return normalized_img


In [3]:
class HyperspectralExpandedDataset(Dataset):
    def __init__(self, root_dir):
        """
        Args:
            root_dir (str): Root directory containing the expanded dataset.
        """
        self.root_dir = root_dir
        self.samples = self._load_samples()

    def _load_samples(self):
        """
        Scans the directory structure to find all saved samples.

        Returns:
            list: List of dictionaries containing file paths for each sample.
        """
        samples = []
        for sample_name in os.listdir(self.root_dir):
            sample_path = os.path.join(self.root_dir, sample_name)
            if not os.path.isdir(sample_path):
                continue

            # Collect file paths for bands, binary mask, and prompt
            bands_path = os.path.join(sample_path, "bands.pt")
            mask_path = os.path.join(sample_path, "binary_mask.tif")
            prompt_path = os.path.join(sample_path, "prompt.json")

            if os.path.exists(bands_path) and os.path.exists(mask_path) and os.path.exists(prompt_path):
                samples.append({
                    "bands": bands_path,
                    "mask": mask_path,
                    "prompt": prompt_path
                })
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        """
        Loads a sample.

        Args:
            idx (int): Index of the sample.

        Returns:
            tuple: (prompt, bands, binary_mask)
        """
        sample = self.samples[idx]

        bands = None
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.load.*")
            bands = torch.load(sample["bands"])

        binary_mask = to_tensor(Image.open(sample["mask"])).squeeze(0)  # Remove channel dimension

        with open(sample["prompt"], "r") as f:
            prompt = json.load(f)

        return prompt, normalize_per_image(bands), binary_mask


In [4]:
root_dir = "./expanded_dataset_output"
dataset = HyperspectralExpandedDataset(root_dir=root_dir)

# Split dataset: 90% training, 10% evaluation
train_size = int(0.9 * len(dataset))
eval_size = len(dataset) - train_size
train_dataset, eval_dataset = random_split(dataset, [train_size, eval_size])

# DataLoaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size)

In [5]:
print(len(dataset))
for (prompt, img, mask) in train_dataset:
    print(prompt)
    print(img.shape)
    print(mask.shape)
    break

10505
{'centroid': [28.445252158109952, 59.22353475692867], 'random_point': [61, 76]}
torch.Size([12, 120, 120])
torch.Size([120, 120])


In [6]:
for (prompt, img, mask) in train_loader:
    print(prompt)
    print(img.shape)
    print(mask.shape)
    break

{'centroid': [tensor([100.1844,  77.7186,  48.0732,  60.1692,  93.5194,  67.8435,  18.2421,
         56.2103,  43.7935, 105.6491,  89.0308,  71.0735, 100.2617,   2.0000,
         54.6781,  48.0000,  40.7679,  93.3576,  64.7331, 116.4773,  63.0625,
         48.4096,  21.5341, 109.8931,  88.9138,  37.2393,   3.7972, 101.4298,
         17.7072,  58.7567,  54.4813, 100.5000, 117.7260, 118.1600, 118.4118,
         82.7295,  51.1835,  73.1219,  63.0000, 103.7015,  60.5992,  44.5043,
        104.6299,   5.7179,   3.5556, 113.7377, 112.3469,   5.5250,  81.3037,
        114.5000,  37.9047, 108.9514,  99.9271,   8.7778, 100.3565,  26.5843,
         58.0617, 118.3750, 117.0000,  81.1758,   6.4023, 115.0000,   8.0877,
         14.1176,  25.3725,   9.8534, 109.3719,  38.4598,  85.9409,  11.2918,
        116.6875,  30.5050,  12.6373,   0.7895,  24.0988,  58.0644,  81.0260,
        117.0000,  59.5000,  54.3667,  22.3966,  85.3821,  36.3410,  37.5990,
         97.4845,   2.1176,  92.5618,  56.8887, 10

In [7]:
class HyperspectralSAM(nn.Module):
    def __init__(self, sam_checkpoint="facebook/sam-vit-base", num_input_channels=12):
        """
        Adapt SAM for hyperspectral data by modifying the input layer to handle more channels
        and adding a final layer for binary segmentation.

        Args:
            sam_checkpoint (str): Hugging Face SAM model checkpoint.
            num_input_channels (int): Number of input channels for hyperspectral data.
        """
        super(HyperspectralSAM, self).__init__()

        vision_config = SamVisionConfig(num_channels=12, image_size=120)
        decoder_config = SamMaskDecoderConfig(num_multimask_outputs = 1)
        prompt_config  = SamPromptEncoderConfig(image_size=120)
        
        config = SamConfig(vision_config = vision_config, 
                           prompt_encoder_config = prompt_config, 
                           mask_decoder_config = decoder_config, 
                           name_or_path=sam_checkpoint
                          )

        
        # self.processor = SamProcessor(img_processor)
        self.sam_model = SamModel.from_pretrained(sam_checkpoint, config=config, ignore_mismatched_sizes=True)
        self.sam_model.train()
    def forward(self, pixel_values, input_points=None):
        """
        Forward pass for the adapted SAM model.

        Args:
            pixel_values (torch.Tensor): Input tensor of shape (batch_size, num_channels, height, width).
            input_points (torch.Tensor, optional): Points as input prompts, of shape (batch_size, num_points, 2).
            input_boxes (torch.Tensor, optional): Boxes as input prompts, of shape (batch_size, num_boxes, 4).
            input_masks (torch.Tensor, optional): Masks as input prompts, of shape (batch_size, height, width).

        Returns:
            torch.Tensor: Binary segmentation logits of shape (batch_size, 1, height, width).
        """

        outputs = self.sam_model(
            pixel_values=pixel_values,
            input_points=input_points
        )
        return outputs
        # outputs["iou_scores"]
        # logits = self.final_conv(outputs["pred_masks"][:, 0, :, :, :])
        # return {"pred_masks": logits}


In [8]:
from torchvision.transforms import Resize, Normalize
from tqdm import tqdm

def train_model(train_loader, model, optimizer, criterion, device):
    """
    Training loop for a model that processes image, mask, and prompt data from a train loader.

    Args:
        train_loader (DataLoader): DataLoader for the training dataset.
        model (torch.nn.Module): Model to train.
        optimizer (torch.optim.Optimizer): Optimizer for training.
        criterion (torch.nn.Module): Loss function.
        device (torch.device): Device to use for training ('cuda' or 'cpu').

    Returns:
        None
    """
    model.train()  # Set the model to training mode
    total_loss = 0.0

    # Resize and Normalize transformations
    # Explicit normalization for RGB and hyperspectral bands
    # normalize_img = Normalize(
    #     mean=[0.5, 0.485, 0.456, 0.406] + [0.5] * 8,  # RGB + Hyperspectral
    #     std=[0.5, 0.229, 0.224, 0.225] + [0.5] * 8   # RGB + Hyperspectral
    # )

    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc="Training", unit="batch")

    for batch_idx, (prompt, img, mask) in progress_bar:
        # Preprocess the image
        # img = normalize_img(img)  # Normalize input images
        img = img.to(device)  # Move to the correct device
        # print(torch.min(img), torch.max(img))
        # Preprocess the mask
        mask = mask.to(device)

        # Preprocess input points
        random_point_x, random_point_y = prompt['centroid']
        random_point = torch.stack((random_point_x, random_point_y), dim=-1).to(device)  # Combine and move to device
        random_point = random_point.unsqueeze(1).unsqueeze(2).to(device)  # Shape: (batch_size, 1, 1, 2)
        # Forward pass
        optimizer.zero_grad()
        predictions = model(pixel_values=img, input_points=random_point)

        # Resize mask to match the predictions' spatial dimensions
        predictions_shape = predictions["pred_masks"].shape[-2:]  # (height, width)
        resize_mask = Resize(predictions_shape, antialias=True)  # Dynamically adjust mask size
        mask = resize_mask(mask)
        mask = (mask > 0.5).float()
        # print(mask)
        # Calculate loss
        loss = criterion(predictions["pred_masks"], mask.unsqueeze(1).unsqueeze(1).float())

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    # Average loss over all batches
    avg_loss = total_loss / len(train_loader)
    print(f"Training completed. Average Loss: {avg_loss}")


In [9]:
# Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HyperspectralSAM(num_input_channels=12).to(device)

Some weights of SamModel were not initialized from the model checkpoint at facebook/sam-vit-base and are newly initialized because the shapes did not match:
- mask_decoder.iou_prediction_head.proj_out.bias: found shape torch.Size([4]) in the checkpoint and torch.Size([2]) in the model instantiated
- mask_decoder.iou_prediction_head.proj_out.weight: found shape torch.Size([4, 256]) in the checkpoint and torch.Size([2, 256]) in the model instantiated
- mask_decoder.mask_tokens.weight: found shape torch.Size([4, 256]) in the checkpoint and torch.Size([2, 256]) in the model instantiated
- vision_encoder.layers.11.attn.rel_pos_h: found shape torch.Size([127, 64]) in the checkpoint and torch.Size([13, 64]) in the model instantiated
- vision_encoder.layers.11.attn.rel_pos_w: found shape torch.Size([127, 64]) in the checkpoint and torch.Size([13, 64]) in the model instantiated
- vision_encoder.layers.2.attn.rel_pos_h: found shape torch.Size([127, 64]) in the checkpoint and torch.Size([13, 64])

In [10]:
import torch.nn.functional as F

class DiceBCELoss(nn.Module):
    def __init__(self, bce_weight=0.7):
        super(DiceBCELoss, self).__init__()
        self.bce_weight = bce_weight

    def forward(self, logits, targets):
        # Apply sigmoid to logits with clamping for stability
        probs = torch.sigmoid(logits).clamp(min=1e-6, max=1-1e-6)

        # Flatten tensors for Dice and BCE calculations
        probs_flat = probs.view(probs.size(0), -1)
        targets_flat = targets.view(targets.size(0), -1)

        # Dice loss with epsilon to avoid division by zero
        intersection = (probs_flat * targets_flat).sum(dim=1)
        dice_loss = 1 - (2.0 * intersection / (probs_flat.sum(dim=1) + targets_flat.sum(dim=1) + 1e-6)).mean()

        # Combine BCE and Dice loss
        total_loss = dice_loss

        # Validate loss for NaN
        if not torch.isfinite(total_loss):
            print("NaN detected in loss computation")
            print(f"Logits: {logits}")
            print(f"Targets: {targets}")
            print(f"Probs: {probs}")
            print(f"BCE Loss: {bce_loss}, Dice Loss: {dice_loss}")
            raise ValueError("Loss computation resulted in NaN")

        return total_loss


In [11]:
import os
from torch.optim.lr_scheduler import StepLR

experiment_name = "centroid_prompt_dice_only"
epochs = 50

# Loss function and optimizer
criterion = DiceBCELoss()  # For binary segmentation masks
optimizer = optim.Adam(model.parameters(), lr=5e-4)

# Scheduler to halve the learning rate every 5 epochs
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}, Current LR: {scheduler.get_last_lr()}")
    
    # Train the model
    train_model(train_loader, model, optimizer, criterion, device)
    
    # Save the model checkpoint
    save_dir = f"models/{experiment_name}"
    os.makedirs(save_dir, exist_ok=True)
    model_path = f"{save_dir}/hyperspectral_sam_epoch_{epoch+1}.pt"
    torch.save(model, model_path)
    
    # Step the scheduler
    scheduler.step()


Epoch 1/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [03:08<00:00,  2.54s/batch, loss=0.792]


Training completed. Average Loss: 0.7863446071341231
Epoch 2/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:15<00:00,  1.02s/batch, loss=0.776]


Training completed. Average Loss: 0.7834294881369617
Epoch 3/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:14<00:00,  1.01s/batch, loss=0.764]


Training completed. Average Loss: 0.7834054336354539
Epoch 4/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:17<00:00,  1.04s/batch, loss=0.802]


Training completed. Average Loss: 0.7834778476405788
Epoch 5/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:15<00:00,  1.01s/batch, loss=0.804]


Training completed. Average Loss: 0.7834811299233824
Epoch 6/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:15<00:00,  1.03s/batch, loss=0.805]


Training completed. Average Loss: 0.7834840851861078
Epoch 7/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:14<00:00,  1.01s/batch, loss=0.778]


Training completed. Average Loss: 0.7834325811347446
Epoch 8/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:14<00:00,  1.01s/batch, loss=0.782]


Training completed. Average Loss: 0.7834406599805162
Epoch 9/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:13<00:00,  1.00batch/s, loss=0.78] 


Training completed. Average Loss: 0.7834365456490904
Epoch 10/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:15<00:00,  1.02s/batch, loss=0.786]


Training completed. Average Loss: 0.7834481210321993
Epoch 11/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:15<00:00,  1.03s/batch, loss=0.8]  


Training completed. Average Loss: 0.7834748835177034
Epoch 12/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:16<00:00,  1.03s/batch, loss=0.808]


Training completed. Average Loss: 0.7834889365209116
Epoch 13/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:15<00:00,  1.02s/batch, loss=0.761]


Training completed. Average Loss: 0.7834010277245496
Epoch 14/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:17<00:00,  1.04s/batch, loss=0.779]


Training completed. Average Loss: 0.7834340543360323
Epoch 15/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:15<00:00,  1.03s/batch, loss=0.783]


Training completed. Average Loss: 0.7834415105549065
Epoch 16/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:15<00:00,  1.02s/batch, loss=0.753]


Training completed. Average Loss: 0.7833840444281295
Epoch 17/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:16<00:00,  1.03s/batch, loss=0.812]


Training completed. Average Loss: 0.7834969686495291
Epoch 18/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:16<00:00,  1.03s/batch, loss=0.797]


Training completed. Average Loss: 0.7834694361364519
Epoch 19/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:14<00:00,  1.00s/batch, loss=0.731]


Training completed. Average Loss: 0.7833422704322918
Epoch 20/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:17<00:00,  1.04s/batch, loss=0.786]


Training completed. Average Loss: 0.7834468443651457
Epoch 21/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:14<00:00,  1.00s/batch, loss=0.785]


Training completed. Average Loss: 0.783445623275396
Epoch 22/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:13<00:00,  1.01batch/s, loss=0.798]


Training completed. Average Loss: 0.7834711453399142
Epoch 23/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:12<00:00,  1.01batch/s, loss=0.766]


Training completed. Average Loss: 0.7834089180907687
Epoch 24/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:13<00:00,  1.00batch/s, loss=0.811]


Training completed. Average Loss: 0.783494966255652
Epoch 25/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:14<00:00,  1.01s/batch, loss=0.781]


Training completed. Average Loss: 0.7834383007642385
Epoch 26/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:15<00:00,  1.02s/batch, loss=0.791]


Training completed. Average Loss: 0.7834565099832174
Epoch 27/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:16<00:00,  1.03s/batch, loss=0.744]


Training completed. Average Loss: 0.7833684247893256
Epoch 28/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:13<00:00,  1.01batch/s, loss=0.746]


Training completed. Average Loss: 0.7833720163719075
Epoch 29/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:13<00:00,  1.01batch/s, loss=0.797]


Training completed. Average Loss: 0.7834694594950289
Epoch 30/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:13<00:00,  1.00batch/s, loss=0.794]


Training completed. Average Loss: 0.7834624905843992
Epoch 31/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:14<00:00,  1.01s/batch, loss=0.732]


Training completed. Average Loss: 0.7833459143703049
Epoch 32/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:15<00:00,  1.02s/batch, loss=0.789]


Training completed. Average Loss: 0.783453102047379
Epoch 33/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:13<00:00,  1.00batch/s, loss=0.786]


Training completed. Average Loss: 0.7834468040917371
Epoch 34/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:13<00:00,  1.01batch/s, loss=0.775]


Training completed. Average Loss: 0.7834266923569344
Epoch 35/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:13<00:00,  1.00batch/s, loss=0.848]


Training completed. Average Loss: 0.7835655719847292
Epoch 36/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:14<00:00,  1.00s/batch, loss=0.773]


Training completed. Average Loss: 0.7834237999207264
Epoch 37/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:13<00:00,  1.01batch/s, loss=0.729]


Training completed. Average Loss: 0.7833387505363774
Epoch 38/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:13<00:00,  1.00batch/s, loss=0.767]


Training completed. Average Loss: 0.7834123783820385
Epoch 39/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:14<00:00,  1.01s/batch, loss=0.739]


Training completed. Average Loss: 0.7833580439155167
Epoch 40/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:14<00:00,  1.00s/batch, loss=0.803]


Training completed. Average Loss: 0.7834808608970126
Epoch 41/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:13<00:00,  1.01batch/s, loss=0.83] 


Training completed. Average Loss: 0.7835309263822194
Epoch 42/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:13<00:00,  1.00batch/s, loss=0.786]


Training completed. Average Loss: 0.7834476256692732
Epoch 43/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:14<00:00,  1.01s/batch, loss=0.793]


Training completed. Average Loss: 0.7834617551919576
Epoch 44/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:17<00:00,  1.05s/batch, loss=0.781]


Training completed. Average Loss: 0.783437393001608
Epoch 45/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:19<00:00,  1.08s/batch, loss=0.768]


Training completed. Average Loss: 0.7834137307631003
Epoch 46/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:18<00:00,  1.06s/batch, loss=0.762]


Training completed. Average Loss: 0.7834012460064244
Epoch 47/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:17<00:00,  1.04s/batch, loss=0.761]


Training completed. Average Loss: 0.7834003616023708
Epoch 48/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:16<00:00,  1.03s/batch, loss=0.793]


Training completed. Average Loss: 0.7834614918038652
Epoch 49/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:16<00:00,  1.03s/batch, loss=0.743]


Training completed. Average Loss: 0.7833653559555879
Epoch 50/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:14<00:00,  1.01s/batch, loss=0.796]


Training completed. Average Loss: 0.783465759979712
