In [7]:
import os
from zipfile import ZipFile
import polars as pl
import torch
from torchvision.io import decode_image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import torch.multiprocessing as mp
import pandas as pd

In [8]:
studies = pl.read_csv('./config.csv')['study'].unique().to_list()

In [9]:
len(studies)

62

In [3]:
OVERRIDES = {'experiment.well':pl.String, 
             'experiment.plate':pl.String, 
             'microscopy.fov': pl.String, 
             'microscopy.magnification': pl.String, 
             'geometry.depth': pl.String,
             'geometry.z_slice': pl.String
             }

In [25]:
pl.Config.set_tbl_rows(-1)
pl.Config.set_tbl_width_chars(-1)
pl.Config.set_tbl_cols(-1)
pl.Config.set_fmt_table_cell_list_len(-1)

polars.config.Config

In [7]:
meta = pl.read_csv('~/dataset/sampling/75ds_small_meta.csv', schema_overrides=OVERRIDES)
meta.shape

(1541257, 22)

In [8]:
substudies = meta.filter(pl.col('experiment.study').is_in(studies))
substudies.shape

(1430983, 22)

In [36]:
substudies.filter(pl.col('experiment.study') == 'idr0089')['storage.path'][:10].to_list()

['75ds_small/idr0089/idr0089-plate_1A-converted/AC16_T18d24h_Rep1_PoorProbeChan1_Cry2Chan2_DAPIChan3_13_series-0_z-7_t-0_channel-0.png',
 '75ds_small/idr0089/idr0089-plate_1A-converted/AC16_T18d24h_Rep1_PoorProbeChan1_Cry2Chan2_DAPIChan3_14_series-0_z-7_t-0_channel-0.png',
 '75ds_small/idr0089/idr0089-plate_1A-converted/AC16_T18d24h_Rep1_PoorProbeChan1_Cry2Chan2_DAPIChan3_15_series-0_z-7_t-0_channel-0.png',
 '75ds_small/idr0089/idr0089-plate_1A-converted/AC16_T18d24h_Rep1_PoorProbeChan1_Cry2Chan2_DAPIChan3_2_series-0_z-7_t-0_channel-0.png',
 '75ds_small/idr0089/idr0089-plate_1A-converted/AC16_T18d24h_Rep1_PoorProbeChan1_Cry2Chan2_DAPIChan3_8_series-0_z-7_t-0_channel-0.png',
 '75ds_small/idr0089/idr0089-plate_1A-converted/AC16_T18d24h_Rep2_PoorProbeChan1_Cry2Chan2_DAPIChan3_1_series-0_z-7_t-0_channel-0.png',
 '75ds_small/idr0089/idr0089-plate_1A-converted/AC16_T18d24h_Rep2_PoorProbeChan1_Cry2Chan2_DAPIChan3_10_series-0_z-7_t-0_channel-0.png',
 '75ds_small/idr0089/idr0089-plate_1A-conver

In [37]:
paths = []
base_path = os.path.expanduser('~/scratch/data/75ds_small_segmentations/')
for path, _, files in os.walk(base_path):
    paths = []
    for file in files:
        file_path = os.path.join(path.replace(base_path, ''), file)
        paths.append(file_path.replace('.safetensors', '.png'))

In [5]:
class UnZippedImageArchive(Dataset):
    """Basic unzipped image arch. This will no longer be used. 
       Remove when unzipped support is added to the IterableImageArchive
    """
    def __init__(self, output_path: str, overwrite: bool = False) -> None:
        super().__init__()
        self.overwrite = overwrite
        self.output_path = os.path.expanduser(output_path)
        self.configs: pl.DataFrame = pl.read_csv('./config.csv')
        self.ds10 = self.configs['study'].unique().to_list()
        self.configs = self.configs.to_pandas()
        self.imgs_base = os.path.expanduser('/scr/data/75ds_small')
        self.meta_path = os.path.expanduser("~/dataset/sampling/75ds_small_meta.csv")
        self.data = self.get_dataset()
        self.size = self.data['imaging.multi_channel_id'].unique().len()
        self.data = self.data.to_pandas().groupby('imaging.multi_channel_id')
        self.data = [data for _, data in self.data]
        
    def __len__(self):
        return self.size
        
    def get_dataset(self):
        meta = pl.read_csv(self.meta_path, schema_overrides=OVERRIDES)
        meta = meta.sort('imaging.multi_channel_id').filter(pl.col('experiment.study').is_in(self.ds10))
        if not self.overwrite:
            base_path = self.output_path
            paths = []
            for path, _, files in os.walk(base_path):
                for file in files:
                    file_path = os.path.join(path.replace(base_path, ''), file)
                    paths.append(file_path.replace('.safetensors', '.png'))
                    
        paths = set(paths)
        meta = meta.filter(pl.col('storage.path').is_in(paths))
        print(meta.shape)
        return meta
    
    def __getitem__(self, idx):
        data: pd.DataFrame = self.data[idx]
        id = data['imaging.multi_channel_id'].iloc[0]
        
        data = data.sort_values('imaging.channel')
        images_paths = [os.path.join(self.imgs_base, path) for path in data['storage.path'].to_list()]
        
        study = data['experiment.study'].iloc[0]       
        channel_type = ','.join(data.sort_values('imaging.channel_type')['imaging.channel_type'].to_list())
        channel_settings = self.configs[(self.configs['study'] == study) & (self.configs['config'] == channel_type)]
        col_eq = channel_settings['seg_cfg'].iloc[0]
        diameter = channel_settings['diameter'].iloc[0]
        images = [decode_image(image)[0] for image in images_paths]
    
        image_data = {'id': id, 'study': [path for path in data['storage.path'].to_list()]}
        if col_eq == "classical":
            return images, image_data
        elif col_eq == 'nucleus':
            col_eq = int(data[data['imaging.channel_type'] == 'nucleus']['imaging.channel'].item())
        elif col_eq == "skip":
            image_data['config'] = col_eq
            return images, image_data

        col_eq = [col_eq] if isinstance(col_eq, int) else col_eq.split(',')
        col_eq = [int(col) for col in col_eq]
        
        channel_axis = 1
        if len(col_eq) == 2:
                image_data['axis'] = channel_axis
                channels_config = [1, 2]
        else:
            channel_axis = -1
            channels_config = [0, 0]
            
        image_data['config'] = channels_config
        image_data['diameter'] = diameter

        if col_eq[0] != 0:
            cellpose_images = [images[idx-1] for idx in col_eq]
        else:
            cellpose_images = images
        
        return cellpose_images, image_data
    
dataset = UnZippedImageArchive(os.path.expanduser('/scr/data/75ds_small_segmentations/'))
len(dataset)

(680483, 22)


262186