In [None]:
# importing modules

import numpy as np
import os
import skimage.io as io
import skimage
from tqdm import tqdm
import scipy as sp
import time
import scipy.io as sio
import hdf5storage
from multiprocessing import TimeoutError
from multiprocessing.pool import ThreadPool as Pool
from functools import partial
import matplotlib.pyplot as plt
from aicspylibczi import CziFile
from natsort import natsorted
import warnings
warnings.filterwarnings("ignore")

from utils import *

In [None]:
#optical flow vectors, 10 frames for each cell
from skimage.measure import centroid
import skimage.measure as skm

X = 10

class CellBoxMaskPatch(torch.utils.data.Dataset):
    #input will be a Directory name, function is TO DO
    def __init__(
        self,
        files, 
        X=10):
        
        #take list of files and get data
        
        self.opticalflows = [] #list of dlows for the same cells
        self.centroids = [] #list of centroids 
    
        for (vid_type, num) in files:
            of_address = os.path.join('opticalflow/', f'{vid_type}{num}')   
            file_list = natsorted(os.listdir(of_address))    # sort the image files numerically by frame-index
            n_frames = len(file_list)
            ten_list = []
            ten_list_mask = []
            for k in range(n_frames):
                of = np.load(of_address+'/{}.npz'.format(k)) #dictionary --> keys are vx, vy  [3, w, h]
                vx = of['vx'][0]
                vy = of['vy'][0]
                #get cells --> these are 2 frames, we want boxes on all the cells in the frames in the future
                #TO DO: Bounding boxes ???
                ten_list.append([vx,vy])
                

                img = get_file(vid_type, num)
                img_temp = img.read_image(C=0, Z=50)
                img_temp = scale_img(img_temp[0].squeeze())
                mask = binarize_video(img_temp)
                ten_list_mask.append(mask[k])


                if (k%X ==0):
                    #here (or can change code, remove optical flow, use get cells to isolate cells and retun their optical flow)
                    self.opticalflows.append(ten_list)
                    ten_list = []
                    self.frames.append(ten_list_mask)
                    ten_list_mask = []
                    #goal: 2 lists - one if a list of 10 centroids, and one is a list of 10 cells with opical flow
        
              
    def __len__(self):
        return len(self.opticalflows) #this should get all the lists of 
        


    def get_cells(img, masked):
        bboxes, num_cells, cell_areas = bounding_boxes(masked)
        zoomed_cells = []
        relative_centroids = []
        for box in bboxes:
            zoomed_cells.append([img[box[0]:box[2],box[1]:box[3]]])
            relative_centroids.append(skimage.measure.centroid(masked[box[0]:box[2],box[1]:box[3]]))
            
        return zoomed_cells, bboxes, num_cells, cell_areas, relative_centroids, masked


    def __getitem__(self, idx):
        flows = self.opticalflows[idx]
        centroids = self.centroids[idx]
    

        return flows, centroids

In [None]:
from torch.utils.data import random_split

mip_video_files = [
    ('processed', 3),
    ('processed', 6),
    ('processed', 9)
]

dataset = CellBoxMaskPatch(mip_video_files, X) # file, S, T

train, eval, test = random_split(dataset, [0.7, 0.2, 0.1])

input_datasets = {}
input_datasets["train"] = train
input_datasets["eval"] = eval
input_datasets["test"] = test

In [None]:
def collate_fn(batch, mode_box, mode_mask, mode_patch):
    current_centroids = [b[1] for b in batch]
    current_flows = [b[0] for b in batch]


    return current_flows, current_centroids


In [None]:
dataloaders = {}
dataloaders['train'] = torch.utils.data.DataLoader(
    input_datasets['train'],
    batch_size=4,
    shuffle=True,
    num_workers=0,
    collate_fn=lambda batch: collate_fn(batch)
)

dataloaders['test'] = torch.utils.data.DataLoader(
    input_datasets['test'],
    batch_size=4,
    shuffle=True,
    num_workers=0,
    collate_fn=lambda batch: collate_fn(batch)
)

dataloaders['eval'] = torch.utils.data.DataLoader(
    input_datasets['eval'],
    batch_size=4,
    shuffle=True,
    num_workers=0,
    collate_fn=lambda batch: collate_fn(batch)
)

In [None]:
for batch in dataloaders['eval']:
    print("Input:", batch[0].shape, "Centroids", batch[1].shape)