In [None]:
from aicspylibczi import CziFile
import czifile
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
import cv2
import os
import imageio
import ffmpeg
import time
import pandas as pd
# from cellpose import io, models
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import json
import glob
import itertools
from PIL import Image
import torch.nn.functional as F
from utils import *
cudnn.benchmark = True
from VideoLoaders import *
plt.ion()   # interactive mode

In [None]:
class VideoDataMIP:
    def __init__(self, files):
        self.data = {
        }
        
        for category, num in files:
            print(f"Loading in MIP {num}")
            assert category == 'mip', "Can't load non Mip file"
            file = {}
            file['video'] = get_file(category, num)
            
            frames, shp = file['video'].read_image(C=0)
            frames = scale_img(frames.squeeze())
            file['frames'] = frames
            print(f"frames {num}: {frames.shape}")
            file['masks'] = binarize_video(frames)           
    
            self.data[num] = file    
    def extract_all_traces(self, file_num, sequence_length, hist_length=2):
        # hist length is how many frames of history
        frames, masks = self.data[file_num]['frames'], self.data[file_num]['masks']
        N = len(frames)
        s = 0
        all_traces = []
        all_videos = []
        for i in range(N // sequence_length):
            print(f"Extracting traces from {s}:{s+sequence_length}")
            data, videos = extract_traces(frames[s:s+sequence_length], masks[s:s+sequence_length], hist=hist_length)
            s += sequence_length
            all_traces = all_traces + data
            all_videos = all_videos + videos
        
        if(N % sequence_length > 0):
            data, videos = extract_traces(frames[-1*sequence_length:], masks[-1*sequence_length:], hist=hist_length)
            all_traces = all_traces + data
            all_videos = all_videos + videos        
        self.data[file_num]['traces'] = all_traces
        self.data[file_num]['trace_videos'] = all_videos
            

In [None]:
from skimage.measure import centroid
import skimage.measure as skm

max_padding =  300

box_shape = (180, 180) #TO DO: find the biggest box and set it to this
X = 10

class CellBoxMaskPatch(torch.utils.data.Dataset):
    #input will be a Directory name, function is TO DO
    def __init__(
        self,
        files, 
        X=X):
        
        self.mips_extractor = VideoDataMIP(files)
        self.proc_extractor = VideoDataProcessed(files)

        for i in files:
            self.video_extractor.extract_all_traces(i[1], X)
        
        
        self.cell_dict = []

        for key in self.video_extractor.data:
            entry = self.video_extractor.data[key]["traces"]
            for cell in entry:
                patches = [np.array(p) for p in cell["patches"]]
                boxes = [np.array(b) for b in cell['boxes']]
                masks = [np.array(m) for m in cell['masks']]
                
                self.cell_dict.append((boxes, masks, patches)) #cell dict is a list of 3 types by sequence

        self.num_cells = len(self.cell_dict) #this is a list of how many sequences we have
              
    def __len__(self):
        return self.num_cells
        

    def get_centroids(self, boxes, masks):
        N = len(masks)
        res = []
        centroids = [skm.centroid(binary.astype(np.uint8)) for binary in masks]
        for i in range(N):
            c = centroids[i]
            ymin, xmin = boxes[i][:2]
            res.append([xmin+c[0], ymin+c[1]])
        return(np.array(res) - res[0]) 
   
    def pad_arrays(self, array, pad_amt=max_padding):
    
        pad_width = ((0, pad_amt - array.shape[0]), (0, pad_amt - array.shape[1]))

        padded_array = np.pad(array, pad_width, mode='constant')
        return padded_array

    def __getitem__(self, idx):
        cell_sequences = self.cell_dict[idx]  #this is the first sequence of 10 cells
        boxes = cell_sequences[0]
        masks = cell_sequences[1]
        patches = cell_sequences[2]


        for cell_mask_num in np.arange(len(masks)): #should be sequence length (10) masks
                
                cell_time = np.array(masks[cell_mask_num], dtype=np.int32)
                cell_time = np.where(cell_time >= 0, cell_time, 1)
                cell_time = self.pad_arrays(cell_time)
                masks[cell_mask_num] = cell_time
                cell_time_patch = np.array(patches[cell_mask_num], dtype=np.int32)

                cell_time_patch = self.pad_arrays(cell_time_patch)

                patches[cell_mask_num] = cell_time_patch


        centroids = self.get_centroids(boxes, masks)
    

        return centroids, masks, patches

In [None]:
# from torch.utils.data import random_split

mip_video_files = [
    ('mip', 3),
    ('mip', 6),
    ('mip', 9)
]

# dataset = CellBoxMaskPatch(mip_video_files, X) # file, S, T

# train, val, test = random_split(dataset, [0.7, 0.2, 0.1])

# input_datasets = {}
# input_datasets["train"] = train
# input_datasets["val"] = val
# input_datasets["test"] = test

In [None]:
from utils import *
import mahotas #: Module("mahotas")


def extract_traces_sparse(frames, masks, hist=2):
    bboxes, num_cells, areas = bounding_boxes(masks[0])
    vid_data = []
    for i in range(num_cells):
        #print("Extracting cell ", i)
        data = track_cells(i, frames, masks, padding=0, history_length=hist, verbose=False)
        vid_data.append(data)
    return(vid_data)

def shape_features(binary, feature_length=20, num_samples=180):
    def radial_distance(binary, theta):
        height, width = binary.shape
        center = [width // 2, height // 2]
        def test_r(r):
            x_test, y_test = center[0] + r*np.cos(theta), center[1] + r*np.sin(theta)
            if(x_test >= width or y_test > height or x_test < 0 or y_test < 0):
                return(False)
            return(binary[int(y_test), int(x_test)])
        # calculate distance to the nearest pixel
        r = max(height, width)
        while(not test_r(r)): # start from edge come inside until hit cell
            r -= 1
        return(r)

    test_angles = np.linspace(0, 2*np.pi, num_samples)
    distances = np.array([radial_distance(binary, angle) for angle in test_angles])
    fft_coefficients = np.fft.rfft(distances)

    features = np.abs(fft_coefficients[:feature_length])
    features = features / np.sum(features)
    return(features, (distances, fft_coefficients))

def featurize(cell_data, index):
    image, binary = cell_data['patches'][index], cell_data['masks'][index].astype(np.uint8)
    zernike = mahotas.features.zernike_moments(binary, max(binary.shape)/2, degree=8)
    #zernike = zernike / zernike.sum()
    haralick = mahotas.features.haralick(image.astype(np.uint16)).mean(axis=0)
    #haralick = haralick / haralick.sum()
    shape, info = shape_features(binary, 20)
    #print(f"Zernike: {zernike.shape}, Haralick: {haralick.shape}, Radial Shape: {shape.shape}")
    return(np.concatenate([zernike, haralick, shape]))

class VideoDataProcessed:
    def __init__(self, files, sequence_length=5, channel=0):
        self.data = {}
        self.all_traces = []
        self.seq_length = sequence_length
        self.channel = channel
        self.videos = {}
        for category, num in files:
            print(f"Loading in processed {num}")
            assert category == 'processed', "Can't load non processed file"
            video = get_file(category, num)
            self.videos[num] = video
        self.num_vids = len(self.data)

    def __len__(self):
        return self.num_vids

    def extract_planes(self, num, zplanes, hist_length):
        for z in zplanes:
            self.extract_slice_traces(num, z, hist_length)
    
    def extract_slice_traces(self, num, zPlane, hist_length=2):
        assert num in self.videos.keys(), f"Video {num} not found"
        
        video = self.videos[num]
        frames, shp = video.read_image(C=self.channel, S=0, Z=zPlane)
        frames = scale_img(frames.squeeze())
        print(f"vid {num} zplane {zPlane} with frames: {frames.shape}")
        masks = binarize_video(frames)
        N = len(frames)
        s = 0
        for i in range(N // self.seq_length):
            print(f"Extracting traces from {s}:{s+self.seq_length}")
            data = extract_traces_sparse(frames[s:s+self.seq_length], masks[s:s+self.seq_length], hist=hist_length)
            s += self.seq_length
            self.all_traces = self.all_traces + data
        
        if(N % self.seq_length > 0):
            data = extract_traces_sparse(frames[-1*self.seq_length:], masks[-1*self.seq_length:], hist=hist_length)
            self.all_traces = self.all_traces + data


class SparseMIPVideo:
    def __init__(self, files, sequence_length, hist_length=2):
        self.data = {}
        self.all_traces = []
        self.N = sequence_length
        for category, num in files:
            print(f"Loading in MIP {num}")
            assert category == 'mip', "Can't load non Mip file"
            video = get_file(category, num)
            frames, shp = video.read_image(C=0)
            frames = scale_img(frames.squeeze())
            print(f"frames {num}: {frames.shape}")
            masks = binarize_video(frames)

            print(f"Finished loading frames and masks for MIP {num}")

            N = len(frames)
            s = 0
        
            for i in range(N // sequence_length):
                print(f"Extracting traces from {s}:{s+sequence_length}")
                data = extract_traces_sparse(frames[s:s+sequence_length], masks[s:s+sequence_length], hist=hist_length)
                s += sequence_length
                self.all_traces = self.all_traces + data
            
            if(N % sequence_length > 0):
                data = extract_traces_sparse(frames[-1*sequence_length:], masks[-1*sequence_length:], hist=hist_length)
                self.all_traces = self.all_traces + data

    def featurize_traces(self):
        self.featurized_frames = []
        for i, trace in enumerate(self.all_traces):
            if(i % 100 == 0):
                print(i)
            trajectory_features = np.array([featurize(trace, index) for index in range(5)])
            self.featurized_frames.append(trajectory_features)

In [None]:
from torch.utils.data import random_split
import VideoLoaders

X = 10
processed_video_files = [
    ('processed', 3),
]
processed_dataset = VideoLoaders.VideoDataProcessed(processed_video_files)
processed_dataset.extract_slice_traces(3, 50)

# train, val, test = random_split(processed_dataset, [0.7, 0.2, 0.1])

# input_datasets = {}
# input_datasets["train"] = train
# input_datasets["val"] = val
# input_datasets["test"] = test

In [None]:
len(processed_dataset)

In [None]:
def collate_fn(batch, mode_box, mode_mask, mode_patch):
    current_centroids = [b[0] for b in batch]
    current_masks = [b[1] for b in batch]
    current_patches = [b[2] for b in batch]

    current_centroids = torch.tensor(np.stack(current_centroids), dtype=torch.float32)
    current_masks = torch.tensor(np.stack(current_masks), dtype=torch.long)
    current_patches = torch.tensor(np.stack(current_patches), dtype=torch.long)

    current_patches = current_patches.reshape([len(batch), 10, max_padding*max_padding])
    current_masks = current_masks.reshape([len(batch), 10, max_padding*max_padding])

    selected_tensors = []
    if mode_box:
        selected_tensors.append(current_centroids)
    if mode_mask:
        selected_tensors.append(current_masks)
    if mode_patch:
        selected_tensors.append(current_patches)

    combined_tensor = torch.cat(selected_tensors, dim=-1)

    # Cast the combined tensor to torch.float32
    combined_tensor = combined_tensor.to(torch.float32)

    return combined_tensor, current_centroids


In [None]:
# box is actually a box surrounding the cell
# mask is the values of the cell
# patch is fluorescence
mode_box = False
mode_mask = False
mode_patch = True

input_size = 0
if mode_box:
    input_size+=2
if mode_mask:
    input_size+=max_padding*max_padding
if mode_patch:
    input_size+=max_padding*max_padding

dataloaders = {}
dataloaders['train'] = torch.utils.data.DataLoader(
    input_datasets['train'],
    batch_size=4,
    shuffle=True,
    num_workers=0,
    collate_fn=lambda batch: collate_fn(batch, mode_box, mode_mask, mode_patch)
)

dataloaders['test'] = torch.utils.data.DataLoader(
    input_datasets['test'],
    batch_size=4,
    shuffle=True,
    num_workers=0,
    collate_fn=lambda batch: collate_fn(batch, mode_box, mode_mask, mode_patch)
)

dataloaders['eval'] = torch.utils.data.DataLoader(
    input_datasets['eval'],
    batch_size=4,
    shuffle=True,
    num_workers=0,
    collate_fn=lambda batch: collate_fn(batch, mode_box, mode_mask, mode_patch)
)

In [None]:
for batch in dataloaders['eval']:

    print("Input:", batch[0].shape, "Centroids", batch[1].shape)
    

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_dim, num_layers=2):
        super(LSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_dim, num_layers, batch_first=True) #stacking 2 LSTMs
        # hidden out output
        #  2 bc x y centroid
        self.fc = nn.Linear(hidden_dim, 2)
    def forward(self, input):
        #h_0: tensor of shape (D∗num_layers,N,Hout​) containing the initial hidden
        h0 = torch.zeros(self.num_layers, input.size(0), self.hidden_dim).to(input.device)
        c0 = torch.zeros(self.num_layers, input.size(0), self.hidden_dim).to(input.device)

       
#         out, _ = self.lstm(input, (h0.detach(), c0.detach()))
        out, _ = self.lstm(input, (h0, c0))

        out = self.fc(out)
        final = out[:,-1,:]
        out = torch.sigmoid(final) * max_padding
        return out

In [None]:
class ImprovedLSTM(nn.Module):
    def __init__(self, input_size, hidden_dim, num_layers=2, dropout=0.2):
        super(ImprovedLSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, 2) 

    def forward(self, input):
        h0 = torch.zeros(self.num_layers * 2, input.size(0), self.hidden_dim).to(input.device)  # Adjusted for bidirectional LSTM
        c0 = torch.zeros(self.num_layers * 2, input.size(0), self.hidden_dim).to(input.device)  # Adjusted for bidirectional LSTM

        out, _ = self.lstm(input, (h0.detach(), c0.detach()))

        out = self.fc(out)
        out = torch.sigmoid(out) * max_padding
        return out


In [None]:
class MultiLayerLSTM(nn.Module):
    def __init__(self, input_size, hidden_dim, num_layers=2, dropout=0.2):
        super(MultiLayerLSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=True)
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)  
        self.fc2 = nn.Linear(hidden_dim, 2)  

    def forward(self, input):
        h0 = torch.zeros(self.num_layers * 2, input.size(0), self.hidden_dim).to(input.device)  
        c0 = torch.zeros(self.num_layers * 2, input.size(0), self.hidden_dim).to(input.device)  

        # out, _ = self.lstm(input, (h0.detach(), c0.detach()))
        out, _ = self.lstm(input, (h0, c0))

        out = self.fc1(out)
        out = torch.relu(out)
        out = self.fc2(out)
        out = torch.sigmoid(out) * max_padding

        out = out.to(torch.float32)
        return out



In [None]:
class MultiLayerGRU(nn.Module):
    def __init__(self, input_size, hidden_dim, num_layers=2, dropout=0.2):
        super(MultiLayerGRU, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.gru = nn.GRU(input_size, hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=True)
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)  
        self.fc2 = nn.Linear(hidden_dim, 2)  

    def forward(self, input):
        h0 = torch.zeros(self.num_layers * 2, input.size(0), self.hidden_dim).to(input.device)  

        out, _ = self.gru(input, h0.detach())

        out = self.fc1(out)
        out = torch.relu(out)
        out = self.fc2(out)
        out = torch.sigmoid(out) * max_padding

        out = out.to(torch.float32)
        return out


In [None]:
class MultiLayerComplexGRU(nn.Module):
    def __init__(self, input_size, hidden_dim, num_layers=2, dropout=0.2):
        super(MultiLayerComplexGRU, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.gru1 = nn.GRU(input_size, hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=True)
        self.gru2 = nn.GRU(hidden_dim * 2, hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=True)
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim * 2)  
        self.fc2 = nn.Linear(hidden_dim * 2, hidden_dim)  
        self.fc3 = nn.Linear(hidden_dim, 2) 

    def forward(self, input):
        # Init
        h0_1 = torch.zeros(self.num_layers * 2, input.size(0), self.hidden_dim).to(input.device)  # Adjusted for bidirectional GRU
        h0_2 = torch.zeros(self.num_layers * 2, input.size(0), self.hidden_dim).to(input.device)  # Adjusted for bidirectional GRU

        out, _ = self.gru1(input, h0_1.detach())
        
        out, _ = self.gru2(out, h0_2.detach())

        # Predict to mad padding
        out = self.fc1(out)
        out = torch.relu(out)
        out = self.fc2(out)
        out = torch.relu(out)
        out = self.fc3(out)
        out = torch.sigmoid(out) * max_padding

        out = out.to(torch.float32)
        return out


In [None]:
class TrivialLSTM(nn.Module):
    def __init__(self):
        super(TrivialLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=2, hidden_size=1000, num_layers=1, batch_first=True)
        self.linear = nn.Linear(1000, 2)  # To ensure the output size matches the input size

    def forward(self, x):
        # LSTM layer
        output, _ = self.lstm(x)
        # Linear layer to match the output size to input size
        output = self.linear(output)
        return output[:,-1,:]

In [None]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        #Flatten the input from (4, 9, 2) to (4, 18)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(810000, 500) 
        self.fc2 = nn.Linear(500,1000)
        self.fc3 = nn.Linear(1000,500)
        self.fc4 = nn.Linear(500, 2)

    def forward(self, x):
        x = self.flatten(x) 
        x = torch.relu(self.fc1(x)) 
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)  
        return x

In [None]:
input_size = 2
hidden_size = 2000
num_layers = 2
epochs = 1000
sequence_length = 10 #how many frames we process per input

# model = LSTM(input_size, hidden_size, num_layers)
model = SimpleNN()
if torch.cuda.is_available():
    model = model.cuda()
# dummy_input_data = torch.randn(batch_size, 10, input_size)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

def train():
    model.train()
    total_loss = 0
    total_correct = 0
    for batch in dataloaders['train']:
        optimizer.zero_grad()
        inputs, outputs = batch[0], batch[1]
#         print("inputs", inputs.shape)
#         print("outputs", outputs.shape)
        # shape is (batch_size * frames * input)
        # print(inputs[:, sequence_length-1:sequence_length, :].shape)
#         print(outputs[:,-1,:].shape)
        inputs, outputs = inputs.to(device), outputs.to(device)
        pred = model(inputs[:, :sequence_length-1, :])
        # print(inputs[:, sequence_length-1:sequence_length, :].shape)
        # print(pred.shape, outputs[:,-1,:].shape)
        # print(f"pred: {pred}")
        # print(f"outputs: {outputs.data[:,-1,:]}")
        # total_correct += torch.sum(torch.eq(pred, outputs[:,-1,:]))
        loss = criterion(pred, outputs[:,-1,:])
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
    # print(f"training loss: {total_loss / len(dataloaders['train'])}, training accuracy: {total_correct / len(dataloaders['eval'])}")
    print(f"training loss: {total_loss / len(dataloaders['train'])}")
    return model

def eval():
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in dataloaders['eval']:
            inputs, outputs = batch[0], batch[1]
            inputs, outputs = inputs.to(device), outputs.to(device)
            pred = model(inputs[:, :sequence_length-1, :])
#             print(pred.shape)
#             print(outputs.shape)
#             print(pred,outputs[:,-1,:])
            
            loss = criterion(pred, outputs[:,-1,:])
            total_loss += loss.item()
        # X = np.array(outputs[0, :, 0])
        # Y = np.array(outputs[0, :, 1])
#         plt.plot(X[:9], Y[:9], marker='o', linestyle='-')
#         plt.scatter(X[-1], Y[-1], color='orange', label='ground truth')
#         plt.scatter(pred[0, -1, 0],pred[0, -1, 1], color='red', label='pred')
#         plt.legend()
#         plt.show()
    print(f"validation loss: {total_loss / len(dataloaders['eval'])}")
    return total_loss / len(dataloaders['eval'])


def train_model():
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    for epoch in range(epochs):
        print("Epoch:", epoch)
        train()
        curr_acc = eval()
        if curr_acc > best_acc:
            best_acc = curr_acc
            best_model_wts = copy.deepcopy(model.state_dict())

    model.load_state_dict(best_model_wts)
    return model

In [None]:
train_model()