In [1]:
from aicspylibczi import CziFile
import czifile
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
import cv2
import os
import imageio
import ffmpeg
import time
import pandas as pd
# from cellpose import io, models
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import json
import glob
import itertools
from PIL import Image
import torch.nn.functional as F
from utils import *
cudnn.benchmark = True
from VideoLoaders import *
plt.ion()   # interactive mode

<contextlib.ExitStack at 0x7f2c8fb31410>

In [2]:
from utils import *
import mahotas #: Module("mahotas")


def extract_traces_sparse(frames, masks, hist=2):
    bboxes, num_cells, areas = bounding_boxes(masks[0])
    vid_data = []
    for i in range(num_cells):
        #print("Extracting cell ", i)
        data = track_cells(i, frames, masks, padding=0, history_length=hist, verbose=False)
        vid_data.append(data)
    return(vid_data)

def shape_features(binary, feature_length=20, num_samples=180):
    def radial_distance(binary, theta):
        height, width = binary.shape
        center = [width // 2, height // 2]
        def test_r(r):
            x_test, y_test = center[0] + r*np.cos(theta), center[1] + r*np.sin(theta)
            if(x_test >= width or y_test > height or x_test < 0 or y_test < 0):
                return(False)
            return(binary[int(y_test), int(x_test)])
        # calculate distance to the nearest pixel
        r = max(height, width)
        while(not test_r(r)): # start from edge come inside until hit cell
            r -= 1
        return(r)

    test_angles = np.linspace(0, 2*np.pi, num_samples)
    distances = np.array([radial_distance(binary, angle) for angle in test_angles])
    fft_coefficients = np.fft.rfft(distances)

    features = np.abs(fft_coefficients[:feature_length])
    features = features / np.sum(features)
    return(features, (distances, fft_coefficients))

def featurize(cell_data, index):
    image, binary = cell_data['patches'][index], cell_data['masks'][index].astype(np.uint8)
    zernike = mahotas.features.zernike_moments(binary, max(binary.shape)/2, degree=8)
    #zernike = zernike / zernike.sum()
    haralick = mahotas.features.haralick(image.astype(np.uint16)).mean(axis=0)
    #haralick = haralick / haralick.sum()
    shape, info = shape_features(binary, 20)
    #print(f"Zernike: {zernike.shape}, Haralick: {haralick.shape}, Radial Shape: {shape.shape}")
    return(np.concatenate([zernike, haralick, shape]))

class VideoDataProcessed:
    def __init__(self, files, sequence_length=5, channel=0):
        self.data = {}
        self.all_traces = []
        self.seq_length = sequence_length
        self.channel = channel
        self.videos = {}
        for category, num in files:
            print(f"Loading in processed {num}")
            assert category == 'processed', "Can't load non processed file"
            video = get_file(category, num)
            self.videos[num] = video
        self.num_vids = len(self.data)

    def __len__(self):
        return self.num_vids

    def extract_planes(self, num, zplanes, hist_length):
        for z in zplanes:
            self.extract_slice_traces(num, z, hist_length)
    
    def extract_slice_traces(self, num, zPlane, hist_length=2):
        assert num in self.videos.keys(), f"Video {num} not found"
        
        video = self.videos[num]
        frames, shp = video.read_image(C=self.channel, S=0, Z=zPlane)
        frames = scale_img(frames.squeeze())
        print(f"vid {num} zplane {zPlane} with frames: {frames.shape}")
        masks = binarize_video(frames)
        N = len(frames)
        s = 0
        for i in range(N // self.seq_length):
            print(f"Extracting traces from {s}:{s+self.seq_length}")
            data = extract_traces_sparse(frames[s:s+self.seq_length], masks[s:s+self.seq_length], hist=hist_length)
            s += self.seq_length
            self.all_traces = self.all_traces + data
        
        if(N % self.seq_length > 0):
            data = extract_traces_sparse(frames[-1*self.seq_length:], masks[-1*self.seq_length:], hist=hist_length)
            self.all_traces = self.all_traces + data


class SparseMIPVideo:
    def __init__(self, files, sequence_length, hist_length=2):
        self.data = {}
        self.all_traces = []
        self.N = sequence_length
        for category, num in files:
            print(f"Loading in MIP {num}")
            assert category == 'mip', "Can't load non Mip file"
            video = get_file(category, num)
            frames, shp = video.read_image(C=0)
            frames = scale_img(frames.squeeze())
            print(f"frames {num}: {frames.shape}")
            masks = binarize_video(frames)

            print(f"Finished loading frames and masks for MIP {num}")

            N = len(frames)
            s = 0
        
            for i in range(N // sequence_length):
                print(f"Extracting traces from {s}:{s+sequence_length}")
                data = extract_traces_sparse(frames[s:s+sequence_length], masks[s:s+sequence_length], hist=hist_length)
                s += sequence_length
                self.all_traces = self.all_traces + data
            
            if(N % sequence_length > 0):
                data = extract_traces_sparse(frames[-1*sequence_length:], masks[-1*sequence_length:], hist=hist_length)
                self.all_traces = self.all_traces + data

    def featurize_traces(self):
        self.featurized_frames = []
        for i, trace in enumerate(self.all_traces):
            if(i % 100 == 0):
                print(i)
            trajectory_features = np.array([featurize(trace, index) for index in range(5)])
            self.featurized_frames.append(trajectory_features)

In [None]:
from skimage.measure import centroid
import skimage.measure as skm

max_padding =  250

box_shape = (180, 180) #TO DO: find the biggest box and set it to this
X = 10

class CellBoxMaskPatch(torch.utils.data.Dataset):
    #input will be a Directory name, function is TO DO
    def __init__(
        self,
        files, 
        X=X):
        
        self.mips_extractor = SparseMIPVideo(files)

        for i in files:
            self.video_extractor.extract_all_traces(i[1], X)
        
        
        self.cell_dict = []

        for key in self.video_extractor.data:
            entry = self.video_extractor.data[key]["traces"]
            for cell in entry:
                patches = [np.array(p) for p in cell["patches"]]
                boxes = [np.array(b) for b in cell['boxes']]
                masks = [np.array(m) for m in cell['masks']]
                
                self.cell_dict.append((boxes, masks, patches)) #cell dict is a list of 3 types by sequence

        self.num_cells = len(self.cell_dict) #this is a list of how many sequences we have
              
    def __len__(self):
        return self.num_cells
        

    def get_centroids(self, boxes, masks):
        N = len(masks)
        res = []
        centroids = [skm.centroid(binary.astype(np.uint8)) for binary in masks]
        for i in range(N):
            c = centroids[i]
            ymin, xmin = boxes[i][:2]
            res.append([xmin+c[0], ymin+c[1]])
        return(np.array(res) - res[0]) 
   
    def pad_arrays(self, array, pad_amt=max_padding):
    
        pad_width = ((0, pad_amt - array.shape[0]), (0, pad_amt - array.shape[1]))

        padded_array = np.pad(array, pad_width, mode='constant')
        return padded_array

    def __getitem__(self, idx):
        cell_sequences = self.cell_dict[idx]  #this is the first sequence of 10 cells
        boxes = cell_sequences[0]
        masks = cell_sequences[1]
        patches = cell_sequences[2]


        for cell_mask_num in np.arange(len(masks)): #should be sequence length (10) masks
                
                cell_time = np.array(masks[cell_mask_num], dtype=np.int32)
                cell_time = np.where(cell_time >= 0, cell_time, 1)
                cell_time = self.pad_arrays(cell_time)
                masks[cell_mask_num] = cell_time
                cell_time_patch = np.array(patches[cell_mask_num], dtype=np.int32)

                cell_time_patch = self.pad_arrays(cell_time_patch)

                patches[cell_mask_num] = cell_time_patch


        centroids = self.get_centroids(boxes, masks)
    

        return centroids, masks, patches

In [None]:
from torch.utils.data import random_split
import VideoLoaders

X = 10
processed_video_files = [
    ('processed', 3),
]
processed_dataset = VideoLoaders.VideoDataProcessed(processed_video_files)
processed_dataset.extract_slice_traces(3, 50)

# train, val, test = random_split(processed_dataset, [0.7, 0.2, 0.1])

# input_datasets = {}
# input_datasets["train"] = train
# input_datasets["val"] = val
# input_datasets["test"] = test

In [None]:
len(processed_dataset)

In [3]:
import VideoLoaders

X = 10
processed_video_files = [
    ('processed', 3),
]
processed_dataset = VideoLoaders.VideoDataProcessed(processed_video_files)
processed_dataset.extract_slice_traces(3, 50)

Loading in processed 3
Loading dicty_factin_pip3-03_processed.czi with dims [{'X': (0, 475), 'Y': (0, 2048), 'Z': (0, 114), 'C': (0, 2), 'T': (0, 90), 'S': (0, 1)}]
vid 3 zplane 50 with frames: (90, 2048, 475)
Extracting traces from 0:5
Extracting traces from 5:10
Extracting traces from 10:15
Extracting traces from 15:20
Extracting traces from 20:25
Extracting traces from 25:30
Extracting traces from 30:35
Extracting traces from 35:40
Extracting traces from 40:45
Extracting traces from 45:50
Extracting traces from 50:55
Extracting traces from 55:60
Extracting traces from 60:65
Extracting traces from 65:70
Extracting traces from 70:75
Extracting traces from 75:80
Extracting traces from 80:85
Extracting traces from 85:90


In [9]:
processed_dataset

<VideoLoaders.VideoDataProcessed at 0x7f2c8fbbcd90>