In [1]:
import torch
import torch.nn as nn
from transformers import (
    SamVisionConfig,
    SamPromptEncoderConfig,
    SamMaskDecoderConfig,
    SamModel,
    SamProcessor,
    SamImageProcessor
)
import os
import torch
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_tensor
from PIL import Image
import json
import warnings
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from transformers import SamModel, SamConfig

In [2]:
def normalize_per_image(img):
    """
    Normalize an image tensor by dividing each pixel value by the maximum value in the image (plus 1).

    Args:
        img (torch.Tensor): Image tensor of shape (batch_size, channels, height, width).

    Returns:
        torch.Tensor: Normalized image tensor with values in the range [0, 1].
    """
    # Compute the maximum value per image
    max_val = img.amax(dim=(-1, -2), keepdim=True)  # Max value per channel
    max_val = max_val + 1  # Add 1 to avoid division by zero

    # Normalize by max value
    normalized_img = img / max_val
    return normalized_img


In [3]:
class HyperspectralExpandedDataset(Dataset):
    def __init__(self, root_dir):
        """
        Args:
            root_dir (str): Root directory containing the expanded dataset.
        """
        self.root_dir = root_dir
        self.samples = self._load_samples()

    def _load_samples(self):
        """
        Scans the directory structure to find all saved samples.

        Returns:
            list: List of dictionaries containing file paths for each sample.
        """
        samples = []
        for sample_name in os.listdir(self.root_dir):
            sample_path = os.path.join(self.root_dir, sample_name)
            if not os.path.isdir(sample_path):
                continue

            # Collect file paths for bands, binary mask, and prompt
            bands_path = os.path.join(sample_path, "bands.pt")
            mask_path = os.path.join(sample_path, "binary_mask.tif")
            prompt_path = os.path.join(sample_path, "prompt.json")

            if os.path.exists(bands_path) and os.path.exists(mask_path) and os.path.exists(prompt_path):
                samples.append({
                    "bands": bands_path,
                    "mask": mask_path,
                    "prompt": prompt_path
                })
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        """
        Loads a sample.

        Args:
            idx (int): Index of the sample.

        Returns:
            tuple: (prompt, bands, binary_mask)
        """
        sample = self.samples[idx]

        bands = None
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.load.*")
            bands = torch.load(sample["bands"])

        binary_mask = to_tensor(Image.open(sample["mask"])).squeeze(0)  # Remove channel dimension

        with open(sample["prompt"], "r") as f:
            prompt = json.load(f)

        return prompt, normalize_per_image(bands), binary_mask


In [4]:
root_dir = "./expanded_dataset_output"
dataset = HyperspectralExpandedDataset(root_dir=root_dir)

# Split dataset: 90% training, 10% evaluation
train_size = int(0.9 * len(dataset))
eval_size = len(dataset) - train_size
train_dataset, eval_dataset = random_split(dataset, [train_size, eval_size])

# DataLoaders
batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size)

In [5]:
print(len(dataset))
for (prompt, img, mask) in train_dataset:
    print(prompt)
    print(img.shape)
    print(mask.shape)
    break

10505
{'centroid': [56.394591887831744, 92.83508596227675], 'random_point': [3, 76]}
torch.Size([12, 120, 120])
torch.Size([120, 120])


In [6]:
for (prompt, img, mask) in train_loader:
    print(prompt)
    print(img.shape)
    print(mask.shape)
    break

{'centroid': [tensor([ 80.1880,  88.9929,  37.1098, 114.1507], dtype=torch.float64), tensor([ 16.4165,   6.5786,  14.8592, 116.4247], dtype=torch.float64)], 'random_point': [tensor([ 84,  82,  40, 117]), tensor([ 26,   3,   2, 114])]}
torch.Size([4, 12, 120, 120])
torch.Size([4, 120, 120])


In [7]:
class HyperspectralSAM(nn.Module):
    def __init__(self, sam_checkpoint="facebook/sam-vit-base", num_input_channels=12):
        """
        Adapt SAM for hyperspectral data by modifying the input layer to handle more channels
        and adding a final layer for binary segmentation.

        Args:
            sam_checkpoint (str): Hugging Face SAM model checkpoint.
            num_input_channels (int): Number of input channels for hyperspectral data.
        """
        super(HyperspectralSAM, self).__init__()

        vision_config = SamVisionConfig(num_channels=12, image_size=1024)
        decoder_config = SamMaskDecoderConfig(num_multimask_outputs = 1)
        prompt_config  = SamPromptEncoderConfig(image_size=1024)
        
        config = SamConfig(vision_config = vision_config, 
                           prompt_encoder_config = prompt_config, 
                           mask_decoder_config = decoder_config, 
                           name_or_path=sam_checkpoint
                          )

        
        # self.processor = SamProcessor(img_processor)
        self.sam_model = SamModel.from_pretrained(sam_checkpoint, config=config, ignore_mismatched_sizes=True)
        self.sam_model.train()
    def forward(self, pixel_values, input_points=None):
        """
        Forward pass for the adapted SAM model.

        Args:
            pixel_values (torch.Tensor): Input tensor of shape (batch_size, num_channels, height, width).
            input_points (torch.Tensor, optional): Points as input prompts, of shape (batch_size, num_points, 2).
            input_boxes (torch.Tensor, optional): Boxes as input prompts, of shape (batch_size, num_boxes, 4).
            input_masks (torch.Tensor, optional): Masks as input prompts, of shape (batch_size, height, width).

        Returns:
            torch.Tensor: Binary segmentation logits of shape (batch_size, 1, height, width).
        """

        outputs = self.sam_model(
            pixel_values=pixel_values,
            input_points=input_points
        )
        return outputs
        # outputs["iou_scores"]
        # logits = self.final_conv(outputs["pred_masks"][:, 0, :, :, :])
        # return {"pred_masks": logits}


In [8]:
from torchvision.transforms import Resize, Normalize
from tqdm import tqdm

def train_model(train_loader, model, optimizer, criterion, device):
    """
    Training loop for a model that processes image, mask, and prompt data from a train loader.

    Args:
        train_loader (DataLoader): DataLoader for the training dataset.
        model (torch.nn.Module): Model to train.
        optimizer (torch.optim.Optimizer): Optimizer for training.
        criterion (torch.nn.Module): Loss function.
        device (torch.device): Device to use for training ('cuda' or 'cpu').

    Returns:
        None
    """
    model.train()  # Set the model to training mode
    total_loss = 0.0

    # Resize and Normalize transformations
    # Explicit normalization for RGB and hyperspectral bands
    # normalize_img = Normalize(
    #     mean=[0.5, 0.485, 0.456, 0.406] + [0.5] * 8,  # RGB + Hyperspectral
    #     std=[0.5, 0.229, 0.224, 0.225] + [0.5] * 8   # RGB + Hyperspectral
    # )

    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc="Training", unit="batch")
    resize_img = Resize((1024, 1024), antialias=True)  

    for batch_idx, (prompt, img, mask) in progress_bar:
        # Preprocess the image
        # img = normalize_img(img)  # Normalize input images
        img = resize_img(img).to(device)  # Move to the correct device
        # print(torch.min(img), torch.max(img))
        # Preprocess the mask
        mask = mask.to(device)

        # Preprocess input points
        random_point_x, random_point_y = prompt['centroid']
        random_point = torch.stack((random_point_x, random_point_y), dim=-1).to(device)  # Combine and move to device
        random_point = random_point.unsqueeze(1).unsqueeze(2).to(device)  # Shape: (batch_size, 1, 1, 2)
        # Forward pass
        optimizer.zero_grad()
        predictions = model(pixel_values=img, input_points=random_point)

        # Resize mask to match the predictions' spatial dimensions
        predictions_shape = predictions["pred_masks"].shape[-2:]  # (height, width)
        resize_mask = Resize(predictions_shape, antialias=True)  # Dynamically adjust mask size
        mask = resize_mask(mask)
        mask = (mask > 0.5).float()
        # print(mask)
        # Calculate loss
        loss = criterion(predictions["pred_masks"], mask.unsqueeze(1).unsqueeze(1).float())

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    # Average loss over all batches
    avg_loss = total_loss / len(train_loader)
    print(f"Training completed. Average Loss: {avg_loss}")


In [9]:
# Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HyperspectralSAM(num_input_channels=12).to(device)

Some weights of SamModel were not initialized from the model checkpoint at facebook/sam-vit-base and are newly initialized because the shapes did not match:
- mask_decoder.iou_prediction_head.proj_out.bias: found shape torch.Size([4]) in the checkpoint and torch.Size([2]) in the model instantiated
- mask_decoder.iou_prediction_head.proj_out.weight: found shape torch.Size([4, 256]) in the checkpoint and torch.Size([2, 256]) in the model instantiated
- mask_decoder.mask_tokens.weight: found shape torch.Size([4, 256]) in the checkpoint and torch.Size([2, 256]) in the model instantiated
- vision_encoder.patch_embed.projection.weight: found shape torch.Size([768, 3, 16, 16]) in the checkpoint and torch.Size([768, 12, 16, 16]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
import torch.nn.functional as F

class DiceBCELoss(nn.Module):
    def __init__(self, bce_weight=0.7):
        super(DiceBCELoss, self).__init__()
        self.bce_weight = bce_weight

    def forward(self, logits, targets):
        # Apply sigmoid to logits with clamping for stability
        probs = torch.sigmoid(logits).clamp(min=1e-6, max=1-1e-6)

        # Flatten tensors for Dice and BCE calculations
        probs_flat = probs.view(probs.size(0), -1)
        targets_flat = targets.view(targets.size(0), -1)

        # Weighted BCE loss
        bce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        weight = torch.ones_like(targets) * 0.38  # Background weight
        weight[targets == 1] = 1.0  # Foreground weight
        bce_loss = (bce_loss * weight).mean()

        # Dice loss with epsilon to avoid division by zero
        intersection = (probs_flat * targets_flat).sum(dim=1)
        dice_loss = 1 - (2.0 * intersection / (probs_flat.sum(dim=1) + targets_flat.sum(dim=1) + 1e-6)).mean()

        # Combine BCE and Dice loss
        total_loss = self.bce_weight * bce_loss + (1 - self.bce_weight) * dice_loss

        # Validate loss for NaN
        if not torch.isfinite(total_loss):
            print("NaN detected in loss computation")
            print(f"Logits: {logits}")
            print(f"Targets: {targets}")
            print(f"Probs: {probs}")
            print(f"BCE Loss: {bce_loss}, Dice Loss: {dice_loss}")
            raise ValueError("Loss computation resulted in NaN")

        return total_loss


In [None]:
import os
from torch.optim.lr_scheduler import StepLR

experiment_name = "centroid_prompt_1024"
epochs = 15

# Loss function and optimizer
criterion = DiceBCELoss()  # For binary segmentation masks
optimizer = optim.Adam(model.parameters(), lr=5e-4)

# Scheduler to halve the learning rate every 5 epochs
scheduler = StepLR(optimizer, step_size=4, gamma=0.8)

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}, Current LR: {scheduler.get_last_lr()}")
    
    # Train the model
    train_model(train_loader, model, optimizer, criterion, device)
    
    # Save the model checkpoint
    save_dir = f"models/{experiment_name}"
    os.makedirs(save_dir, exist_ok=True)
    model_path = f"{save_dir}/hyperspectral_sam_epoch_{epoch+1}.pt"
    torch.save(model, model_path)
    
    # Step the scheduler
    scheduler.step()


Epoch 1/15, Current LR: [0.0005]


Training: 100%|██████████| 2364/2364 [49:57<00:00,  1.27s/batch, loss=0.399]


Training completed. Average Loss: 0.2990678672547974
Epoch 2/15, Current LR: [0.0005]


Training: 100%|██████████| 2364/2364 [47:12<00:00,  1.20s/batch, loss=0.0856]


Training completed. Average Loss: 0.2171702542406855
Epoch 3/15, Current LR: [0.0005]


Training: 100%|██████████| 2364/2364 [47:12<00:00,  1.20s/batch, loss=0.226] 


Training completed. Average Loss: 0.2091868589592817
Epoch 4/15, Current LR: [0.0005]


Training: 100%|██████████| 2364/2364 [47:14<00:00,  1.20s/batch, loss=0.189] 


Training completed. Average Loss: 0.2031842088153469
Epoch 5/15, Current LR: [0.0004]


Training: 100%|██████████| 2364/2364 [47:11<00:00,  1.20s/batch, loss=0.164] 


Training completed. Average Loss: 0.1993908326294393
Epoch 6/15, Current LR: [0.0004]


Training: 100%|██████████| 2364/2364 [47:07<00:00,  1.20s/batch, loss=0.38]  


Training completed. Average Loss: 0.19399780560719784
Epoch 7/15, Current LR: [0.0004]


Training: 100%|██████████| 2364/2364 [47:11<00:00,  1.20s/batch, loss=0.156] 


Training completed. Average Loss: 0.1945723706727151
Epoch 8/15, Current LR: [0.0004]


Training: 100%|██████████| 2364/2364 [47:12<00:00,  1.20s/batch, loss=0.213] 


Training completed. Average Loss: 0.19181925751892898
Epoch 9/15, Current LR: [0.00032]


Training: 100%|██████████| 2364/2364 [47:09<00:00,  1.20s/batch, loss=0.168] 


Training completed. Average Loss: 0.186587841243699
Epoch 10/15, Current LR: [0.00032]


Training:  77%|███████▋  | 1810/2364 [36:06<11:00,  1.19s/batch, loss=0.222] 