# dataset sketch (untested)

In [1]:
import os
from collections import Iterable

import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from torch.utils.data.dataloader import default_collate as collate

In [2]:
def load_greyscale(img_path, *args, **kwargs):
    with open(img_path, 'rb') as f:
        img = Image.open(f)
        return img.convert('L')

In [3]:
class dataframeDatasetHPA(Dataset):
    """dataframe Dataset class for HPA data."""

    def __init__(self, df,
                 root_dir='/root/aics/modeling/data/HPAscrape',
                 image_cols=('nucleus', 'microtublues', 'antibody'),
                 image_transform=transforms.Compose([transforms.ToTensor()]),
                 metadata_cols={'sequence', 'antibody', 'cellLine'}):
        """
        Args:
            df (pandas.DataFrame): dataframe containing the image locations (relative to root_dir) and target data
            root_dir (string): full path to the directory containing all the images
            image_cols (tuple of strings): column names in dataframe containing the paths to the single channel 
                                           images to be used as input data channels; channels will be stacked in order
            image_transform (callable): torchvision transform to be applied on a sample;
                                        default is transforms.Compose([transforms.ToTensor()])
            metadata_cols (tuple of strings): columns whose entries will be returned for each data point;
                                              col names will be used as dict keys in returned sample
        """
        self.opts = locals()
        self.opts.pop('self')
        self.opts.pop('df')

        self.df = df.reset_index(drop=True)
        self._image_transform = self.opts['image_transform']

    def __len__(self):
        return len(self.df)

    def _get_item(self, idx):
        image_paths = [os.path.join(self.opts['root_dir'], self.df[col][idx]) for col in self.opts['image_cols']]
        image = torch.stack([self._image_transform(load_greyscale(image_path)) for image_path in image_paths])
        metadata = {md_col:self.df[md_col][idx] for md_col in self.opts['metadata_cols']}
        return {'image':image, **metadata}

    def __getitem__(self, idx):
        return collate([self._get_item(i) for i in idx]) if isinstance(idx,Iterable) else self._get_item(idx)

In [4]:
# need to actually get the images this df references
df = pd.read_csv('test.csv')

In [5]:
dset = dataframeDatasetHPA(df, root_dir='./',
                           image_cols=('nuclear_channel', 'microtubule_channel', 'antibody_channel'),
                           image_transform=transforms.Compose([transforms.ToTensor()]),
                           metadata_cols=('antibody_name','cellLine','protein_id','protein_AA_sequence'))

In [6]:
dset.opts

{'image_cols': ('nuclear_channel', 'microtubule_channel', 'antibody_channel'),
 'image_transform': <torchvision.transforms.transforms.Compose at 0x1138d64e0>,
 'metadata_cols': ('antibody_name',
  'cellLine',
  'protein_id',
  'protein_AA_sequence'),
 'root_dir': './'}