In [1]:
import torch
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.transforms import functional as F
import numpy as np
import cv2
from torch.utils.data import DataLoader, Dataset
import os
from statistics import median, mean, mode, stdev


In [None]:
class PalmCanopyDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms

    def __getitem__(self, idx):
        # Load image (assume it's a multispectral image with 4 channels)
        image = cv2.imread(self.image_paths[idx], cv2.IMREAD_UNCHANGED)
        
        # Load mask (binary mask for palm canopies)
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)
        
        # Convert mask into 0-1 format for training
        mask = mask / 255.0
        
        # Convert mask to tensor format
        target = {
            "masks": torch.as_tensor(mask, dtype=torch.uint8).unsqueeze(0),  # add batch dimension
            "labels": torch.ones((1,), dtype=torch.int64)  # only one class: palm canopy
        }

        if self.transforms:
            image = self.transforms(image)
        
        return image, target

    def __len__(self):
        return len(self.image_paths)
