In [11]:
from typing import Tuple, List
from intern import array
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor

# Custom BossDB Dataset classes for Pytorch DataLoaders

In this notebook we provide some example custom Dataset classes for use with Pytorch DataLoaders. Please see https://pytorch.org/tutorials/beginner/basics/data_tutorial.html  and https://pytorch.org/tutorials/beginner/data_loading_tutorial.html for additional information. These classes can easily be adapted for your specific needs.


### Common Definitions:

**boss_uri**: the BossDB path to the project, channel, and experiment you would like to use, e.g. `'bossdb://kuan_phelps2020/drosophila_brain_120nm/drBrain_120nm_rec'`

**boss_config**: if you are accessing a private dataset you can pass your boss config information (see https://github.com/jhuapl-boss/intern/wiki/Boss-Setup-Tutorial) however if you are accessing a public dataset you do not need a config and can leave the default `None` input

**centroid_list_zyx**: A list of coordinate points you would like the images/volumes/arrays to be centered on

**px_radius_*** : A list of the number of pixels on either side of the centroid (i.e. the radius) you would like to include in the image/volume/array. e.g. if I wanted each image to be 128x128 pixels I would set this to `[64, 64]`

**transform**: The transform or group of transforms you want to apply to the data. (see https://pytorch.org/tutorials/beginner/basics/transforms_tutorial.html)

### Dataset class for when you need single image slices from the data 

In [12]:
class BossDBSliceDataset(Dataset):
    
    def __init__(
        self, 
        boss_uri: str, 
        centroid_list_zyx: List[Tuple[int, int, int]],
        px_radius_yx: Tuple[int, int],
        boss_config: dict = None, 
        transform=ToTensor()
        
    ):
        self.config = boss_config
        self.array = array(boss_uri, boss_config=boss_config)
        self.centroid_list = centroid_list_zyx
        rad_y, rad_x = px_radius_yx
        self.px_radius_y = rad_y
        self.px_radius_x = rad_x
        self.transform = transform
    
    def __getitem__(self, key):
        z, y, x = self.centroid_list[key]
        array =  self.array[
            z : z + 1,
            y - self.px_radius_y : y + self.px_radius_y,
            x - self.px_radius_x : x + self.px_radius_x,
        ]
        if self.transform:
            array = self.transform(array)
            
        return array

    def __len__(self):
        return len(self.centroid_list)

### Dataset class for when you need single image slices and corresponding segmentation masks from the data 

In [13]:
class BossDBSliceDataset(Dataset):
    
    def __init__(
        self, 
        image_boss_uri: str, 
        mask_boss_uri: str, 
        boss_config: dict, 
        centroid_list_zyx: List[Tuple[int, int, int]],
        px_radius_yx: Tuple[int, int],
        image_transform=ToTensor(),
        mask_transform=None
    ):
        self.config = boss_config
        self.image_array = array(image_boss_uri, boss_config=boss_config)
        self.mask_array = array(mask_boss_uri, boss_config=boss_config)
        self.centroid_list = centroid_list_zyx
        rad_y, rad_x = px_radius_yx
        self.px_radius_y = rad_y
        self.px_radius_x = rad_x
        self.image_transform = image_transform
        self.mask_transform = mask_transform
    
    def __getitem__(self, key):
        z, y, x = self.centroid_list[key]
        image_array =  self.image_array[
            z : z + 1,
            y - self.px_radius_y : y + self.px_radius_y,
            x - self.px_radius_x : x + self.px_radius_x,
        ]
        mask_array =  self.mask_array[
            z : z + 1,
            y - self.px_radius_y : y + self.px_radius_y,
            x - self.px_radius_x : x + self.px_radius_x,
        ]
        
        
        if self.image_transform:
            image_array= self.image_transform(image_array)
            
        if self.mask_transform:
            mask_array = self.mask_transform(mask_array)
        
        mask_array = torch.from_numpy(mask_array.astype('int64')).long()
            
        return image_array, mask_array

    def __len__(self):
        return len(self.centroid_list)

### Dataset class for when you need 3D image volumes from the data 

In [14]:
class BossDBDataset(Dataset):
    
    def __init__(
        self, 
        boss_uri: str, 
        centroid_list_zyx: List[Tuple[int, int, int]],
        px_radius_zyx: Tuple[int, int, int],
        transform=ToTensor(),
        boss_config: dict = None, 
    ):
        self.config = boss_config
        self.array = array(boss_uri, boss_config=boss_config)
        self.centroid_list = centroid_list_zyx
        rad_z, rad_y, rad_x = px_radius_zyx
        self.px_radius_z = rad_z
        self.px_radius_y = rad_y
        self.px_radius_x = rad_x
        self.transform = transform
    
    def __getitem__(self, key):
        z, y, x = self.centroid_list[key]
        array = self.array[
            z - self.px_radius_z : z + self.px_radius_z,
            y - self.px_radius_y : y + self.px_radius_y,
            x - self.px_radius_x : x + self.px_radius_x,
        ]
        if self.transform:
            array = self.transform(array)
            
        return array

    def __len__(self):
        return len(self.centroid_list)
    

### Dataset class for when you need 3D image volumes and corresponding segmentation masks from the data 

In [15]:
class BossDBDataset(Dataset):
    
    def __init__(
        self, 
        image_boss_uri: str, 
        mask_boss_uri: str, 
        centroid_list_zyx: List[Tuple[int, int, int]],
        px_radius_zyx: Tuple[int, int, int],
        image_transform=ToTensor(),
        mask_transform=None,
        boss_config: dict = None, 
    ):
        self.config = boss_config
        self.image_array = array(image_boss_uri, boss_config=boss_config)
        self.mask_array = array(mask_boss_uri, boss_config=boss_config)
        self.centroid_list = centroid_list_zyx
        rad_z, rad_y, rad_x = px_radius_zyx
        self.px_radius_z = rad_z
        self.px_radius_y = rad_y
        self.px_radius_x = rad_x
        self.image_transform = image_transform
        self.mask_transform = mask_transform
    
    def __getitem__(self, key):
        z, y, x = self.centroid_list[key]
        image_array = self.image_array[
            z - self.px_radius_z : z + self.px_radius_z,
            y - self.px_radius_y : y + self.px_radius_y,
            x - self.px_radius_x : x + self.px_radius_x,
        ]
        
        mask_array = self.mask_array[
            z - self.px_radius_z : z + self.px_radius_z,
            y - self.px_radius_y : y + self.px_radius_y,
            x - self.px_radius_x : x + self.px_radius_x,
        ]
        if self.image_transform:
            image_array = self.transform(image_array)
            
        if self.mask_transform:
            mask_array = self.mask_transform(mask_array)
        
        mask_array = torch.from_numpy(mask_array.astype('int64')).long()
            
        return image_array, mask_array

    def __len__(self):
        return len(self.centroid_list)