In [1]:
import torch
import torch.nn as nn
from transformers import (
    SamVisionConfig,
    SamPromptEncoderConfig,
    SamMaskDecoderConfig,
    SamModel,
    SamProcessor,
    SamImageProcessor
)
import os
import torch
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_tensor
from PIL import Image
import json
import warnings
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from transformers import SamModel, SamConfig

In [2]:
def normalize_per_image(img):
    """
    Normalize an image tensor by dividing each pixel value by the maximum value in the image (plus 1).

    Args:
        img (torch.Tensor): Image tensor of shape (batch_size, channels, height, width).

    Returns:
        torch.Tensor: Normalized image tensor with values in the range [0, 1].
    """
    # Compute the maximum value per image
    max_val = img.amax(dim=(-1, -2), keepdim=True)  # Max value per channel
    max_val = max_val + 1  # Add 1 to avoid division by zero

    # Normalize by max value
    normalized_img = img / max_val
    return normalized_img


In [3]:
class HyperspectralExpandedDataset(Dataset):
    def __init__(self, root_dir):
        """
        Args:
            root_dir (str): Root directory containing the expanded dataset.
        """
        self.root_dir = root_dir
        self.samples = self._load_samples()

    def _load_samples(self):
        """
        Scans the directory structure to find all saved samples.

        Returns:
            list: List of dictionaries containing file paths for each sample.
        """
        samples = []
        for sample_name in os.listdir(self.root_dir):
            sample_path = os.path.join(self.root_dir, sample_name)
            if not os.path.isdir(sample_path):
                continue

            # Collect file paths for bands, binary mask, and prompt
            bands_path = os.path.join(sample_path, "bands.pt")
            mask_path = os.path.join(sample_path, "binary_mask.tif")
            prompt_path = os.path.join(sample_path, "prompt.json")

            if os.path.exists(bands_path) and os.path.exists(mask_path) and os.path.exists(prompt_path):
                samples.append({
                    "bands": bands_path,
                    "mask": mask_path,
                    "prompt": prompt_path
                })
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        """
        Loads a sample.

        Args:
            idx (int): Index of the sample.

        Returns:
            tuple: (prompt, bands, binary_mask)
        """
        sample = self.samples[idx]

        bands = None
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.load.*")
            bands = torch.load(sample["bands"])

        binary_mask = to_tensor(Image.open(sample["mask"])).squeeze(0)  # Remove channel dimension

        with open(sample["prompt"], "r") as f:
            prompt = json.load(f)

        return prompt, normalize_per_image(bands), binary_mask


In [4]:
root_dir = "./expanded_dataset_output"
dataset = HyperspectralExpandedDataset(root_dir=root_dir)

# Split dataset: 90% training, 10% evaluation
train_size = int(0.9 * len(dataset))
eval_size = len(dataset) - train_size
train_dataset, eval_dataset = random_split(dataset, [train_size, eval_size])

# DataLoaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size)

In [5]:
print(len(dataset))
for (prompt, img, mask) in train_dataset:
    print(prompt)
    print(img.shape)
    print(mask.shape)
    break

10505
{'centroid': [117.45, 2.15], 'random_point': [117, 1]}
torch.Size([12, 120, 120])
torch.Size([120, 120])


In [6]:
for (prompt, img, mask) in train_loader:
    print(prompt)
    print(img.shape)
    print(mask.shape)
    break

{'centroid': [tensor([ 73.6996,  20.8202,  72.0199,  24.2275, 118.0667,  76.6011,  41.8462,
         97.5854, 114.5000,  18.2540,  64.5663, 101.1000,  66.0522,  38.6026,
         51.6794,  97.4505, 117.8684,  39.6109,   6.2841, 113.7049,   3.7132,
         54.9539,  37.4452,  14.4762,  72.3264, 112.5969, 109.1109,  81.0731,
         94.3195,  86.0193, 105.7262,  84.5280,  59.5000,  63.3437,  15.8971,
         11.1844,   8.3231,   8.6261,  93.1406,  74.0059,  19.2321,  59.9540,
        111.2772, 116.9841,  38.7669,   2.4626,  91.5000,  20.2246,  15.6869,
          0.2857,   2.8017,   2.5532,  61.0629, 117.6667,   3.7077, 112.5752,
         32.0000,  14.4755,  72.5252,  99.6927, 112.7690,  30.4091,   5.6750,
        114.3601,  75.6715,  47.9559, 103.0000,  13.7155, 114.9437,  15.1524,
        112.2319,  91.0000,  53.0374,   1.4857,  63.8888, 116.4151,   3.7200,
         17.5373,  65.1592,  94.3737, 113.3171,   6.8251, 119.0000,  14.1250,
        104.7735, 119.0000,  77.1954,  54.1170,  6

In [7]:
class HyperspectralSAM(nn.Module):
    def __init__(self, sam_checkpoint="facebook/sam-vit-base", num_input_channels=12):
        """
        Adapt SAM for hyperspectral data by modifying the input layer to handle more channels
        and adding a final layer for binary segmentation.

        Args:
            sam_checkpoint (str): Hugging Face SAM model checkpoint.
            num_input_channels (int): Number of input channels for hyperspectral data.
        """
        super(HyperspectralSAM, self).__init__()

        vision_config = SamVisionConfig(num_channels=12, image_size=120)
        decoder_config = SamMaskDecoderConfig(num_multimask_outputs = 1)
        prompt_config  = SamPromptEncoderConfig(image_size=120)
        
        config = SamConfig(vision_config = vision_config, 
                           prompt_encoder_config = prompt_config, 
                           mask_decoder_config = decoder_config, 
                           name_or_path=sam_checkpoint
                          )

        
        # self.processor = SamProcessor(img_processor)
        self.sam_model = SamModel.from_pretrained(sam_checkpoint, config=config, ignore_mismatched_sizes=True)
        self.sam_model.train()
    def forward(self, pixel_values, input_points=None):
        """
        Forward pass for the adapted SAM model.

        Args:
            pixel_values (torch.Tensor): Input tensor of shape (batch_size, num_channels, height, width).
            input_points (torch.Tensor, optional): Points as input prompts, of shape (batch_size, num_points, 2).
            input_boxes (torch.Tensor, optional): Boxes as input prompts, of shape (batch_size, num_boxes, 4).
            input_masks (torch.Tensor, optional): Masks as input prompts, of shape (batch_size, height, width).

        Returns:
            torch.Tensor: Binary segmentation logits of shape (batch_size, 1, height, width).
        """

        outputs = self.sam_model(
            pixel_values=pixel_values,
            input_points=input_points
        )
        return outputs
        # outputs["iou_scores"]
        # logits = self.final_conv(outputs["pred_masks"][:, 0, :, :, :])
        # return {"pred_masks": logits}


In [8]:
from torchvision.transforms import Resize, Normalize
from tqdm import tqdm

def train_model(train_loader, model, optimizer, criterion, device):
    """
    Training loop for a model that processes image, mask, and prompt data from a train loader.

    Args:
        train_loader (DataLoader): DataLoader for the training dataset.
        model (torch.nn.Module): Model to train.
        optimizer (torch.optim.Optimizer): Optimizer for training.
        criterion (torch.nn.Module): Loss function.
        device (torch.device): Device to use for training ('cuda' or 'cpu').

    Returns:
        None
    """
    model.train()  # Set the model to training mode
    total_loss = 0.0

    # Resize and Normalize transformations
    # Explicit normalization for RGB and hyperspectral bands
    # normalize_img = Normalize(
    #     mean=[0.5, 0.485, 0.456, 0.406] + [0.5] * 8,  # RGB + Hyperspectral
    #     std=[0.5, 0.229, 0.224, 0.225] + [0.5] * 8   # RGB + Hyperspectral
    # )

    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc="Training", unit="batch")

    for batch_idx, (prompt, img, mask) in progress_bar:
        # Preprocess the image
        # img = normalize_img(img)  # Normalize input images
        img = img.to(device)  # Move to the correct device
        # print(torch.min(img), torch.max(img))
        # Preprocess the mask
        mask = mask.to(device)

        # Preprocess input points
        random_point_x, random_point_y = prompt['random_point']
        random_point = torch.stack((random_point_x, random_point_y), dim=-1).to(device)  # Combine and move to device
        random_point = random_point.unsqueeze(1).unsqueeze(2).to(device)  # Shape: (batch_size, 1, 1, 2)
        # Forward pass
        optimizer.zero_grad()
        predictions = model(pixel_values=img, input_points=random_point)

        # Resize mask to match the predictions' spatial dimensions
        predictions_shape = predictions["pred_masks"].shape[-2:]  # (height, width)
        resize_mask = Resize(predictions_shape, antialias=True)  # Dynamically adjust mask size
        mask = resize_mask(mask)
        mask = (mask > 0.5).float()
        # print(mask)
        # Calculate loss
        loss = criterion(predictions["pred_masks"], mask.unsqueeze(1).unsqueeze(1).float())

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    # Average loss over all batches
    avg_loss = total_loss / len(train_loader)
    print(f"Training completed. Average Loss: {avg_loss}")


In [9]:
# Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HyperspectralSAM(num_input_channels=12).to(device)

Some weights of SamModel were not initialized from the model checkpoint at facebook/sam-vit-base and are newly initialized because the shapes did not match:
- mask_decoder.iou_prediction_head.proj_out.bias: found shape torch.Size([4]) in the checkpoint and torch.Size([2]) in the model instantiated
- mask_decoder.iou_prediction_head.proj_out.weight: found shape torch.Size([4, 256]) in the checkpoint and torch.Size([2, 256]) in the model instantiated
- mask_decoder.mask_tokens.weight: found shape torch.Size([4, 256]) in the checkpoint and torch.Size([2, 256]) in the model instantiated
- vision_encoder.layers.11.attn.rel_pos_h: found shape torch.Size([127, 64]) in the checkpoint and torch.Size([13, 64]) in the model instantiated
- vision_encoder.layers.11.attn.rel_pos_w: found shape torch.Size([127, 64]) in the checkpoint and torch.Size([13, 64]) in the model instantiated
- vision_encoder.layers.2.attn.rel_pos_h: found shape torch.Size([127, 64]) in the checkpoint and torch.Size([13, 64])

In [10]:
import torch.nn.functional as F

class DiceBCELoss(nn.Module):
    def __init__(self, bce_weight=0.7):
        super(DiceBCELoss, self).__init__()
        self.bce_weight = bce_weight

    def forward(self, logits, targets):
        # Apply sigmoid to logits with clamping for stability
        probs = torch.sigmoid(logits).clamp(min=1e-6, max=1-1e-6)

        # Flatten tensors for Dice and BCE calculations
        probs_flat = probs.view(probs.size(0), -1)
        targets_flat = targets.view(targets.size(0), -1)

        # Weighted BCE loss
        bce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        weight = torch.ones_like(targets) * 0.38  # Background weight
        weight[targets == 1] = 1.0  # Foreground weight
        bce_loss = (bce_loss * weight).mean()

        # Dice loss with epsilon to avoid division by zero
        intersection = (probs_flat * targets_flat).sum(dim=1)
        dice_loss = 1 - (2.0 * intersection / (probs_flat.sum(dim=1) + targets_flat.sum(dim=1) + 1e-6)).mean()

        # Combine BCE and Dice loss
        total_loss = self.bce_weight * bce_loss + (1 - self.bce_weight) * dice_loss

        # Validate loss for NaN
        if not torch.isfinite(total_loss):
            print("NaN detected in loss computation")
            print(f"Logits: {logits}")
            print(f"Targets: {targets}")
            print(f"Probs: {probs}")
            print(f"BCE Loss: {bce_loss}, Dice Loss: {dice_loss}")
            raise ValueError("Loss computation resulted in NaN")

        return total_loss


In [11]:
import os
from torch.optim.lr_scheduler import StepLR

experiment_name = "random_prompt"
epochs = 50

# Loss function and optimizer
criterion = DiceBCELoss()  # For binary segmentation masks
optimizer = optim.Adam(model.parameters(), lr=5e-4)

# Scheduler to halve the learning rate every 5 epochs
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}, Current LR: {scheduler.get_last_lr()}")
    
    # Train the model
    train_model(train_loader, model, optimizer, criterion, device)
    
    # Save the model checkpoint
    save_dir = f"models/{experiment_name}"
    os.makedirs(save_dir, exist_ok=True)
    model_path = f"{save_dir}/hyperspectral_sam_epoch_{epoch+1}.pt"
    torch.save(model, model_path)
    
    # Step the scheduler
    scheduler.step()


Epoch 1/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:16<00:00,  1.03s/batch, loss=0.451]


Training completed. Average Loss: 0.4748010772305566
Epoch 2/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:16<00:00,  1.03s/batch, loss=0.43] 


Training completed. Average Loss: 0.4410170288504781
Epoch 3/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:17<00:00,  1.05s/batch, loss=0.427]


Training completed. Average Loss: 0.4366259160074028
Epoch 4/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:28<00:00,  1.19s/batch, loss=0.432]


Training completed. Average Loss: 0.43109300571518977
Epoch 5/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:25<00:00,  1.16s/batch, loss=0.437]


Training completed. Average Loss: 0.41959352026114594
Epoch 6/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:20<00:00,  1.09s/batch, loss=0.432]


Training completed. Average Loss: 0.4088892840050362
Epoch 7/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:30<00:00,  1.22s/batch, loss=0.384]


Training completed. Average Loss: 0.401325719582068
Epoch 8/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:30<00:00,  1.22s/batch, loss=0.405]


Training completed. Average Loss: 0.3950336893668046
Epoch 9/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:32<00:00,  1.25s/batch, loss=0.372]


Training completed. Average Loss: 0.3884747020296148
Epoch 10/50, Current LR: [0.0005]


Training: 100%|██████████| 74/74 [01:30<00:00,  1.23s/batch, loss=0.391]


Training completed. Average Loss: 0.38807787363593643
Epoch 11/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:29<00:00,  1.21s/batch, loss=0.381]


Training completed. Average Loss: 0.38283543006793874
Epoch 12/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:31<00:00,  1.23s/batch, loss=0.358]


Training completed. Average Loss: 0.38121634319021896
Epoch 13/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:32<00:00,  1.26s/batch, loss=0.388]


Training completed. Average Loss: 0.38049358048954524
Epoch 14/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:28<00:00,  1.19s/batch, loss=0.379]


Training completed. Average Loss: 0.37940340630106023
Epoch 15/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:34<00:00,  1.28s/batch, loss=0.349]


Training completed. Average Loss: 0.3792246406948244
Epoch 16/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:26<00:00,  1.17s/batch, loss=0.371]


Training completed. Average Loss: 0.3793349068712544
Epoch 17/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:27<00:00,  1.18s/batch, loss=0.365]


Training completed. Average Loss: 0.37890425203619776
Epoch 18/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:30<00:00,  1.23s/batch, loss=0.359]


Training completed. Average Loss: 0.37854889840693084
Epoch 19/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:31<00:00,  1.24s/batch, loss=0.405]


Training completed. Average Loss: 0.37823318065823736
Epoch 20/50, Current LR: [0.00025]


Training: 100%|██████████| 74/74 [01:31<00:00,  1.23s/batch, loss=0.364]


Training completed. Average Loss: 0.37779253600416957
Epoch 21/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:31<00:00,  1.24s/batch, loss=0.383]


Training completed. Average Loss: 0.3753152175529583
Epoch 22/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:31<00:00,  1.24s/batch, loss=0.372]


Training completed. Average Loss: 0.3749855849388483
Epoch 23/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:30<00:00,  1.22s/batch, loss=0.361]


Training completed. Average Loss: 0.37398003283384684
Epoch 24/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:36<00:00,  1.30s/batch, loss=0.407]


Training completed. Average Loss: 0.3746216240766886
Epoch 25/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:29<00:00,  1.21s/batch, loss=0.383]


Training completed. Average Loss: 0.3738975488656276
Epoch 26/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:29<00:00,  1.21s/batch, loss=0.378]


Training completed. Average Loss: 0.3734238937094405
Epoch 27/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:30<00:00,  1.23s/batch, loss=0.389]


Training completed. Average Loss: 0.37358608680802424
Epoch 28/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:30<00:00,  1.22s/batch, loss=0.364]


Training completed. Average Loss: 0.3730669992195593
Epoch 29/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:35<00:00,  1.29s/batch, loss=0.396]


Training completed. Average Loss: 0.373314478912869
Epoch 30/50, Current LR: [0.000125]


Training: 100%|██████████| 74/74 [01:32<00:00,  1.24s/batch, loss=0.365]


Training completed. Average Loss: 0.3727270517800305
Epoch 31/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:58<00:00,  1.60s/batch, loss=0.429]


Training completed. Average Loss: 0.37044317577336283
Epoch 32/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:41<00:00,  1.37s/batch, loss=0.374]


Training completed. Average Loss: 0.36900631036307363
Epoch 33/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:34<00:00,  1.28s/batch, loss=0.358]


Training completed. Average Loss: 0.36906491058903773
Epoch 34/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:28<00:00,  1.19s/batch, loss=0.381]


Training completed. Average Loss: 0.36839548719895854
Epoch 35/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:35<00:00,  1.29s/batch, loss=0.36] 


Training completed. Average Loss: 0.367613024405531
Epoch 36/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:34<00:00,  1.27s/batch, loss=0.393]


Training completed. Average Loss: 0.3679727420613572
Epoch 37/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:28<00:00,  1.20s/batch, loss=0.382]


Training completed. Average Loss: 0.3669643273224702
Epoch 38/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:31<00:00,  1.24s/batch, loss=0.344]


Training completed. Average Loss: 0.3661971325809891
Epoch 39/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:33<00:00,  1.26s/batch, loss=0.399]


Training completed. Average Loss: 0.36744203599723607
Epoch 40/50, Current LR: [6.25e-05]


Training: 100%|██████████| 74/74 [01:30<00:00,  1.23s/batch, loss=0.38] 


Training completed. Average Loss: 0.3662784703679987
Epoch 41/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:34<00:00,  1.28s/batch, loss=0.372]


Training completed. Average Loss: 0.36385917864941264
Epoch 42/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:32<00:00,  1.25s/batch, loss=0.42] 


Training completed. Average Loss: 0.36279071786919154
Epoch 43/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:36<00:00,  1.30s/batch, loss=0.361]


Training completed. Average Loss: 0.3621385987545993
Epoch 44/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:32<00:00,  1.25s/batch, loss=0.371]


Training completed. Average Loss: 0.36204983817564473
Epoch 45/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:30<00:00,  1.22s/batch, loss=0.361]


Training completed. Average Loss: 0.36134609700860204
Epoch 46/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:30<00:00,  1.22s/batch, loss=0.336]


Training completed. Average Loss: 0.3607161463917913
Epoch 47/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:35<00:00,  1.29s/batch, loss=0.338]


Training completed. Average Loss: 0.3604062831885106
Epoch 48/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:34<00:00,  1.27s/batch, loss=0.402]


Training completed. Average Loss: 0.3605216342855144
Epoch 49/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:34<00:00,  1.27s/batch, loss=0.403]


Training completed. Average Loss: 0.35968271501966426
Epoch 50/50, Current LR: [3.125e-05]


Training: 100%|██████████| 74/74 [01:39<00:00,  1.34s/batch, loss=0.387]


Training completed. Average Loss: 0.35973068266301544
