In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
from PIL import Image
import numpy as np
import os
from PIL import Image

In [2]:
def check_sample_shapes(input_data_dir):
    """
    Check and print the shapes of all bands in a single sample folder and the reference map.

    Args:
        input_data_dir (str): Path to the folder containing band images for a sample.
    """
    print(f"Checking sample folder: {input_data_dir}")
    
    # Check band shapes
    band_files = sorted([
        os.path.join(input_data_dir, band) 
        for band in os.listdir(input_data_dir) 
        if band.endswith('.tif')
    ])
    
    band_shapes = []
    for band_file in band_files:
        with Image.open(band_file) as img:
            shape = img.size  # (width, height)
            band_shapes.append(shape)
            print(f"  Band: {os.path.basename(band_file)}, Shape: {shape}")
    
    if len(set(band_shapes)) > 1:
        print(f"  ⚠️ Mismatch in band shapes detected in this sample folder!")
    else:
        print(f"  ✅ All bands have the same shape in this sample folder.")
    
    # Determine reference map path
    reference_map_dir = input_data_dir.replace("SmallEarthNet-S2", "Reference_Maps")
    reference_map_file = next(
        (os.path.join(reference_map_dir, f) for f in os.listdir(reference_map_dir) if f.endswith("_reference_map.tif")),
        None
    )
    
    # Check reference map shape
    if reference_map_file and os.path.exists(reference_map_file):
        with Image.open(reference_map_file) as ref_map:
            ref_map_shape = ref_map.size  # (width, height)
            print(f"Reference Map Shape: {ref_map_shape}")
    else:
        print(f"  ❌ Reference map not found in directory: {reference_map_dir}")
    
    # Compare shapes of bands and reference map
    if band_shapes and reference_map_file and os.path.exists(reference_map_file):
        if band_shapes[0] == ref_map_shape:
            print("  ✅ Reference map shape matches band shapes.")
        else:
            print("  ⚠️ Reference map shape does NOT match band shapes!")
# Define the input data directory
input_data_dir = "./SmallEarthNet-S2/S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP/S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP_26_57/"

# Check shapes for the sample
check_sample_shapes(input_data_dir)

Checking sample folder: ./SmallEarthNet-S2/S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP/S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP_26_57/
  Band: S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP_26_57_B01.tif, Shape: (20, 20)
  Band: S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP_26_57_B02.tif, Shape: (120, 120)
  Band: S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP_26_57_B03.tif, Shape: (120, 120)
  Band: S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP_26_57_B04.tif, Shape: (120, 120)
  Band: S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP_26_57_B05.tif, Shape: (60, 60)
  Band: S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP_26_57_B06.tif, Shape: (60, 60)
  Band: S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP_26_57_B07.tif, Shape: (60, 60)
  Band: S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP_26_57_B08.tif, Shape: (120, 120)
  Band: S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP_26_57_B09.tif, Shape: (20, 20)
  Band: S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP_26_57_B11.tif, Shape: (60, 60)
  Band: S2A_

In [4]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Resize
from torchvision.transforms.functional import to_tensor
from PIL import Image

class HyperspectralDataset(Dataset):
    def __init__(self, input_base_dir, reference_base_dir, target_size=(120, 120), transform=None):
        """
        Args:
            input_base_dir (str): Base directory containing all input image directories.
            reference_base_dir (str): Base directory containing all reference map directories.
            target_size (tuple): Target size (height, width) for resizing images.
            transform (callable, optional): Optional transform to apply to the images.
        """
        self.input_base_dir = input_base_dir
        self.reference_base_dir = reference_base_dir
        self.target_size = target_size
        self.transform = transform

        # Gather all tile/subtile pairs
        self.samples = self._get_samples()

    def _get_samples(self):
        """
        Collect all tile/subtile pairs from input and reference directories.
        Assumes all samples exist and are valid.
        """
        samples = []
        for tile in os.listdir(self.input_base_dir):
            tile_input_dir = os.path.join(self.input_base_dir, tile)
            tile_reference_dir = os.path.join(self.reference_base_dir, tile)

            for subtile in os.listdir(tile_input_dir):
                subtile_input_dir = os.path.join(tile_input_dir, subtile)
                subtile_reference_dir = os.path.join(tile_reference_dir, subtile)

                # Gather all band files and the reference map
                band_files = sorted([
                    os.path.join(subtile_input_dir, f)
                    for f in os.listdir(subtile_input_dir)
                    if f.endswith(".tif") and "_reference_map" not in f
                ])
                reference_map = os.path.join(
                    subtile_reference_dir,
                    f"{os.path.basename(subtile_reference_dir)}_reference_map.tif"
                )
                samples.append((band_files, reference_map))
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Get band files and reference map path
        band_files, reference_map_path = self.samples[idx]

        # Load and resize bands
        resize_transform = Resize(self.target_size)
        input_image = torch.stack([
            resize_transform(to_tensor(Image.open(band_file))).squeeze(0)  # Remove singleton channel dimension
            for band_file in band_files
        ])

        # Load and resize reference map
        label_image = Image.open(reference_map_path)
        label_tensor = to_tensor(label_image).squeeze(0)  # Remove channel dimension for label
        label_resized = resize_transform(label_tensor.unsqueeze(0)).squeeze(0).float()

        # Apply any additional transforms if provided
        if self.transform:
            input_image = self.transform(input_image)
            label_resized = self.transform(label_resized)

        return input_image.float(), label_resized


In [6]:
# Define the base directories
input_base_dir = "./SmallEarthNet-S2/"
reference_base_dir = "./Reference_Maps/"

# Instantiate the dataset
original_dataset = HyperspectralDataset(
    input_base_dir=input_base_dir,
    reference_base_dir=reference_base_dir,
    target_size=(120, 120)
)

In [7]:
# Create DataLoader
dataloader = DataLoader(original_dataset, batch_size=4, shuffle=True, num_workers=2)

# Inspect a batch
for inputs, labels in dataloader:
    print("Inputs shape:", inputs.shape)  # Expected: [batch_size, channels, height, width]
    print("Labels shape:", labels.shape)  # Expected: [batch_size, height, width]
    break


Inputs shape: torch.Size([4, 12, 120, 120])
Labels shape: torch.Size([4, 120, 120])


In [12]:
!rm -r ./expanded_dataset_output

In [13]:
import os
import torch
from torchvision.transforms.functional import to_pil_image
from PIL import Image
import json
from scipy.ndimage import label
import numpy as np
from tqdm import tqdm

class SaveAndExpandHyperspectralDataset:
    def __init__(self, base_dataset, output_dir):
        self.base_dataset = base_dataset
        self.output_dir = output_dir
        os.makedirs(self.output_dir, exist_ok=True)

        # Process each sample in the base dataset to generate and save expanded entries
        self._expand_and_save_dataset()

    def _expand_and_save_dataset(self):
        with tqdm(total=len(self.base_dataset), desc="Processing dataset") as pbar:
            for sample_idx, (bands, reference_map) in enumerate(self.base_dataset):
                # Process the reference map to extract class objects and prompts
                class_dict, prompts = process_ref_maps(reference_map)
    
                # Expand entries for each class and object and save to disk
                for class_value, obj_masks in class_dict.items():
                    for obj_idx, obj_mask in enumerate(obj_masks):
                        self._save_sample(
                            sample_idx=sample_idx,
                            class_value=class_value,
                            obj_idx=obj_idx,
                            bands=bands,
                            prompt=prompts[class_value][obj_idx],
                            binary_mask=obj_mask
                        )
                pbar.update(1)

    def _save_sample(self, sample_idx, class_value, obj_idx, bands, prompt, binary_mask):
        sample_dir = os.path.join(
            self.output_dir,
            f"sample_{sample_idx:06}_class_{class_value}_obj_{obj_idx:03}"
        )
        os.makedirs(sample_dir, exist_ok=True)
    
        # Save bands as tensor
        bands_path = os.path.join(sample_dir, "bands.pt")
        torch.save(bands, bands_path)  # Save tensor directly
    
        # # Alternative
        # for band_idx in range(bands.shape[0]):
        #     band_path = os.path.join(sample_dir, f"band_{band_idx:02}.tif")
        #     band_pil = to_pil_image(bands[band_idx].unsqueeze(0))
        #     band_pil.save(band_path)
    
        mask_path = os.path.join(sample_dir, "binary_mask.tif")
        mask_pil = to_pil_image(binary_mask.unsqueeze(0))  # Add channel dimension for saving
        mask_pil.save(mask_path)
    
        # Save prompt as JSON
        prompt_path = os.path.join(sample_dir, "prompt.json")
        with open(prompt_path, "w") as f:
            json.dump(prompt, f, indent=4)
    
        # print(f"Saved sample to {sample_dir}")



def process_ref_maps(mask_tensor):
    # Ensure the mask is in integer format and convert to numpy
    mask_numpy = mask_tensor.cpu().numpy().astype(int)

    # Get unique class labels in the mask
    unique_values = np.unique(mask_numpy)

    class_dict = {}
    prompts = {}

    # Define the connectivity structure for 8-connectivity
    structure_8 = np.ones((3, 3), dtype=int)

    for each_value in unique_values:
        if each_value == 999:  # Skip unlabelled class
            continue

        class_mask = (mask_numpy == each_value).astype(int)

        # Connected component labelling
        labeled_array, num_features = label(class_mask, structure=structure_8)

        # Extract individual connected component masks
        obj_arr = []
        obj_prompts = []
        for i in range(1, num_features + 1):
            obj_mask = (labeled_array == i).astype(int)
            obj_arr.append(torch.tensor(obj_mask, dtype=torch.float32))

            # Calculate centroid
            coords = np.argwhere(obj_mask)
            centroid = coords.mean(axis=0).tolist()

            # Get a random point
            random_point_idx = np.random.randint(0, len(coords))
            random_point = coords[random_point_idx].tolist()

            obj_prompts.append({"centroid": centroid, "random_point": random_point})

        class_dict[int(each_value)] = obj_arr
        prompts[int(each_value)] = obj_prompts

    return class_dict, prompts


input_base_dir = "./SmallEarthNet-S2/"
reference_base_dir = "./Reference_Maps/"
output_dir = "./expanded_dataset_output"

original_dataset = HyperspectralDataset(
    input_base_dir=input_base_dir,
    reference_base_dir=reference_base_dir,
    target_size=(120, 120)
)

expanded_saver = SaveAndExpandHyperspectralDataset(original_dataset, output_dir)
print(f"Expanded dataset saved to {output_dir}")


Processing dataset: 100%|██████████| 1581/1581 [06:04<00:00,  4.34it/s]

Expanded dataset saved to ./expanded_dataset_output





In [14]:
!zip -r final_dataset.zip expanded_dataset_output

Scanning files ........... ......
  adding: expanded_dataset_output/ (stored 0%)
  adding: expanded_dataset_output/sample_001266_class_242_obj_000/ (stored 0%)
  adding: expanded_dataset_output/sample_001266_class_242_obj_000/prompt.json (deflated 32%)
  adding: expanded_dataset_output/sample_001266_class_242_obj_000/binary_mask.tif (deflated 98%)
  adding: expanded_dataset_output/sample_001266_class_242_obj_000/bands.pt (deflated 55%)
  adding: expanded_dataset_output/sample_000613_class_112_obj_003/ (stored 0%)
  adding: expanded_dataset_output/sample_000613_class_112_obj_003/prompt.json (deflated 39%)
  adding: expanded_dataset_output/sample_000613_class_112_obj_003/binary_mask.tif (deflated 99%)
  adding: expanded_dataset_output/sample_000613_class_112_obj_003/bands.pt (deflated 56%)
  adding: expanded_dataset_output/sample_000454_class_313_obj_002/ (stored 0%)
  adding: expanded_dataset_output/sample_000454_class_313_obj_002/prompt.json (deflated 34%)
  adding: expanded_dataset_ou

In [21]:
import os
import torch
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_tensor
from PIL import Image
import json
import warnings

class HyperspectralExpandedDataset(Dataset):
    def __init__(self, root_dir):
        """
        Args:
            root_dir (str): Root directory containing the expanded dataset.
        """
        self.root_dir = root_dir
        self.samples = self._load_samples()

    def _load_samples(self):
        """
        Scans the directory structure to find all saved samples.

        Returns:
            list: List of dictionaries containing file paths for each sample.
        """
        samples = []
        for sample_name in os.listdir(self.root_dir):
            sample_path = os.path.join(self.root_dir, sample_name)
            if not os.path.isdir(sample_path):
                continue

            # Collect file paths for bands, binary mask, and prompt
            bands_path = os.path.join(sample_path, "bands.pt")
            mask_path = os.path.join(sample_path, "binary_mask.tif")
            prompt_path = os.path.join(sample_path, "prompt.json")

            if os.path.exists(bands_path) and os.path.exists(mask_path) and os.path.exists(prompt_path):
                samples.append({
                    "bands": bands_path,
                    "mask": mask_path,
                    "prompt": prompt_path
                })
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        """
        Loads a sample.

        Args:
            idx (int): Index of the sample.

        Returns:
            tuple: (prompt, bands, binary_mask)
        """
        sample = self.samples[idx]

        # Load bands tensor
        bands = None
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.load.*")
            bands = torch.load(sample["bands"])

        # Load binary mask as tensor
        binary_mask = to_tensor(Image.open(sample["mask"])).squeeze(0)  # Remove channel dimension

        # Load prompt as a dictionary
        with open(sample["prompt"], "r") as f:
            prompt = json.load(f)

        return prompt, bands, binary_mask


In [22]:
# Define the directory containing the expanded dataset
root_dir = "./expanded_dataset_output"

# Initialize the dataset
dataset = HyperspectralExpandedDataset(root_dir)

# Sanity check: Load and display one sample
if len(dataset) > 0:
    # Load the first sample
    prompt, bands, binary_mask = dataset[0]

    # Print the details
    print("Prompt:", prompt)
    print("Bands shape:", bands.shape)  # (C, H, W)
    print("Binary mask shape:", binary_mask.shape)  # (H, W)
else:
    print("Dataset is empty!")

Prompt: {'centroid': [41.41887417218543, 23.48112582781457], 'random_point': [24, 19]}
Bands shape: torch.Size([12, 120, 120])
Binary mask shape: torch.Size([120, 120])
