# Defining dataset

In [1]:
#| default_exp nb_03_dataset

In [2]:
#|export
from pathlib import Path 

In [3]:
#|export
import pandas as pd

In [4]:
#| export
import matplotlib.pyplot as plt

In [5]:
#|export
from ml.nb_02_patching import *

## Data

In [6]:
df = pd.read_pickle("data/df_all_2022_10_06.pkl")
df.shape

(1180, 95)

In [7]:
#| export
import torch.utils.data as data_utils

In [8]:
#| export
from PIL import Image

In [9]:
#| export
import torchvision.transforms as transforms

In [10]:
#| export
import numpy as np

In [11]:
#| export
import torch

In [12]:
#| export
import torchvision

In [13]:
#| export
import tqdm

## Dataset

In [14]:
#| export
class PatchedDataSet(data_utils.Dataset):

    """Returns a batch of N patches with the specified target in dataframe"""
    
    def __init__(self, 
                 img_path, # Path with images   
                 df, # pandas dataframe
                 y_col, # df column for target
                 stime_col, #df column with survival time
                 N, # number of patches
                 mean, #mean for normalization
                 std,  #std for normalization
                 trfms=None #list of transforms
                ):
        self.img_path = img_path
        self.df = df
        self.df.reset_index()
        self.y_col = y_col
        self.stime_col = stime_col
        self.N = N
        self.mean, self.std = mean, std
        self.trfm = trfms
        self.img_ids = self.get_img_ids()
        
    def get_img_ids(self):
        col = self.img_path.str()+"/"+self.df.TMA_ID+"_"+self.df.TMASpot
        return col.tolist()
                
        
    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        patch_paths = [img_id+"_"+str(i)+".png" for i in range(self.N)]
        patches = []
        
        transform = transforms.Compose([
            transforms.ToTensor()
            ])
        
        for p in patch_paths:
            img = Image.open(p)
            img = transform(img)
            
            if self.trfm is not None: 
                img = self.trfm(img)
            
            img = transforms.Normalize(self.mean, self.std)(img)
            patches.append(img)
            
        patches = torch.stack(patches, dim=0)
        
        return patches.clone().detach(), self.df.iloc[idx][self.y_col]
    
    def __len__(self):
        return len(self.img_ids)

In [15]:
#| export
p_outx = Path("/media/dimi/TOSHIBA EXT/patched_images")
p_outy = Path("/media/dimi/TOSHIBA EXT/patched_masks")

In [16]:
#| export
mean_img, std_img = [0.8868493, 0.7803772, 0.87521], [0.07292725, 0.09504553, 0.05757239]
mean_mask, std_mask = [0.04432359, 0.04432359, 0.04432359], [0.02483896, 0.02483896, 0.02483896]

In [17]:
#| hide
import nbdev; nbdev.nbdev_export()