In [1]:
import zarr
import numpy as np
import json
import os
import torch
from torch.utils.data import Dataset, DataLoader

In [2]:
def print_data_info(data, indent=''):
    if isinstance(data, dict):
        for key, value in data.items():
            print(f"{indent}{key}:")
            if isinstance(value, dict):
                if 'data' in value and isinstance(value['data'], torch.Tensor):
                    print(f"{indent}  Data shape: {value['data'].shape}")
                    print(f"{indent}  Data type: {value['data'].dtype}")
                if 'metadata' in value:
                    print(f"{indent}  Metadata:")
                    for meta_key, meta_value in value['metadata'].items():
                        print(f"{indent}    {meta_key}: {meta_value}")
            else:
                print_data_info(value, indent + '  ')

In [3]:
# define a custom dataset class
def read_zarr_data(
        parent_dir,
        gene_name,
        barcode_name,
        channels = [0, 1, 2, 3],
        cell_cycle_stages = "interphase",
        transform = None,
        crop_size = None,
        crop_position = None, 
        print_tree = False):
    
    # Look for zattrs file in the path
    def read_zattrs(path):
        zattrs = {}
        zattrs_path = os.path.join(path, ".zattrs")
        if os.path.exists(zattrs_path):
            with open(zattrs_path, "r") as f:
                zattrs = json.load(f)
        return zattrs
    
    # Check if parent_dir is a valid directory
    if not os.path.isdir(parent_dir):
        raise ValueError(f"Directory {parent_dir} does not exist")
    
    # Check if gene_name exists
    zarr_file_gene = os.path.join(parent_dir, gene_name + ".zarr")
    if not os.path.isdir(zarr_file_gene):
        raise ValueError(f"Gene {zarr_file_gene} does not exist")
    
    # Load the Zarr file
    zarr_data = zarr.open(zarr_file_gene, mode='r')

    # Print the tree structure of the Zarr file
    if print_tree:
        print(zarr_data.tree())
    
    # Check if barcode_name exists
    zarr_file_barcode = os.path.join(parent_dir, gene_name + ".zarr", barcode_name)
    if not os.path.isdir(zarr_file_barcode):
        raise ValueError(f"Barcode {zarr_file_barcode} does not exist")

    # Check if stage exists
    zarr_file_stage = os.path.join(parent_dir, gene_name + ".zarr", barcode_name, cell_cycle_stages)
    if not os.path.isdir(zarr_file_stage):
        raise ValueError(f"Stage {zarr_file_stage} does not exist")
    
    # Check for zattrs file at zarr_file_stage
    zattrs = read_zattrs(zarr_file_stage)

    # Load images (selecting channels), cells, and nuclei
    images = zarr_data[barcode_name][cell_cycle_stages]['images'][:, channels, :, :]
    cells = zarr_data[barcode_name][cell_cycle_stages]['cells']
    nuclei = zarr_data[barcode_name][cell_cycle_stages]['nuclei']

    # Check if the number of images, cells, and nuclei are the same
    if len(images) != len(cells) or len(images) != len(nuclei):
        raise ValueError("Number of images, cells, and nuclei are not the same")

    # Conver the images, cells, and nuclei to torch tensors
    images = torch.tensor(images)
    cells = torch.tensor(cells)
    nuclei = torch.tensor(nuclei)

    # Print the shape of the images, cells, and nuclei
    print("Images shape:", images.shape)
    print("Cells shape:", cells.shape)
    print("Nuclei shape:", nuclei.shape)

    # Apply the transform to the images
    if transform == "masks":
        cell_images = images * cells
        nuclei_images = images * nuclei

    return cell_images, nuclei_images


    
    

    

In [4]:
# Usage example:
parent_dir = '/mnt/efs/dlmbl/S-md/'
gene_name = 'AAAS'  # Replace with the gene you want to analyze
barcode_name = 'ATATGAGCACAATAACGAGC'  # Replace with the barcode you want, or set to None for all
channels = [0, 1, 2]  # Replace with the channels you want, or set to None for all
cell_cycle_stages = 'mitotic'  # Replace with the stages you want, or set to None for all
transform = "masks"  # Replace with the transform you want to apply
crop_size = None  # Replace with the crop size you want
crop_position = None  # Replace with the crop position you want
print_tree = True  # Set to True to print the tree structure of the Zarr file

dataset = read_zarr_data(parent_dir, gene_name, barcode_name, channels, cell_cycle_stages, transform, crop_size, crop_position, print_tree)

/
 ├── ATATGAGCACAATAACGAGC
 │   ├── interphase
 │   │   ├── cells (671, 256, 256) bool
 │   │   ├── images (671, 4, 256, 256) uint16
 │   │   └── nuclei (671, 256, 256) bool
 │   └── mitotic
 │       ├── cells (16, 256, 256) bool
 │       ├── images (16, 4, 256, 256) uint16
 │       └── nuclei (16, 256, 256) bool
 ├── CCACACCAACAAGTTTGCAG
 │   ├── interphase
 │   │   ├── cells (1403, 256, 256) bool
 │   │   ├── images (1403, 4, 256, 256) uint16
 │   │   └── nuclei (1403, 256, 256) bool
 │   └── mitotic
 │       ├── cells (38, 256, 256) bool
 │       ├── images (38, 4, 256, 256) uint16
 │       └── nuclei (38, 256, 256) bool
 ├── CCCAGGGTGAGACAGCACTT
 │   ├── interphase
 │   │   ├── cells (921, 256, 256) bool
 │   │   ├── images (921, 4, 256, 256) uint16
 │   │   └── nuclei (921, 256, 256) bool
 │   └── mitotic
 │       ├── cells (15, 256, 256) bool
 │       ├── images (15, 4, 256, 256) uint16
 │       └── nuclei (15, 256, 256) bool
 └── TGGGCATTGACAAAGTACAG
     ├── interphase
     │ 

  cells = torch.tensor(cells)
  cells = torch.tensor(cells)
  nuclei = torch.tensor(nuclei)


Images shape: torch.Size([16, 3, 256, 256])
Cells shape: torch.Size([16, 256, 256])
Nuclei shape: torch.Size([16, 256, 256])


RuntimeError: The size of tensor a (3) must match the size of tensor b (16) at non-singleton dimension 1

In [67]:
dataset[2]

tensor([[[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]],

        [[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]],

        [[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [

In [None]:
class OPSZarrDataset(Dataset):
    def __init__(self, zarr_file, sequence, phase='interphase', transform=None):
        self.zarr_home = zarr_file
        self.zarr_data = zarr.open(zarr_file, mode='r')
        self.sequence = sequence
        self.phase = phase
        self.transform = transform
        
        # Assuming 'images', 'cells', and 'nuclei' are the keys in your Zarr structure
        self.images = self.zarr_data[sequence][phase]['images']
        self.cells = self.zarr_data[sequence][phase]['cells']
        self.nuclei = self.zarr_data[sequence][phase]['nuclei']
        
    def __len__(self):
        return self.images.shape[0]
    
    def __getitem__(self, idx):
        image = self.images[idx]
        cell = self.cells[idx]
        nucleus = self.nuclei[idx]
        
        # Convert to torch tensors
        image = torch.from_numpy(image).float()
        cell = torch.from_numpy(cell).float()
        nucleus = torch.from_numpy(nucleus).float()
        
        if self.transform:
            image = self.transform(image)
            cell = self.transform(cell)
            nucleus = self.transform(nucleus)
        
        return {'image': image, 'cell': cell, 'nucleus': nucleus}

# Usage example:
zarr_file = '/mnt/efs/dlmbl/S-md/AAAS.zarr'
sequence = 'ATATGAGCACAATAACGAGC'  # Replace with an actual sequence from your data

# Create dataset
dataset = ZarrDataset(zarr_file, sequence)

# Create dataloader
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Example of iterating through the data
for batch in dataloader:
    images = batch['image']
    cells = batch['cell']
    nuclei = batch['nucleus']
    
    print(f"Batch shapes: Images {images.shape}, Cells {cells.shape}, Nuclei {nuclei.shape}")
    
    # Your training or processing code here
    break  # Remove this to process all batches