In [1]:
from aicspylibczi import CziFile
import czifile
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
import cv2
import os
import imageio
import ffmpeg
import time
import pandas as pd
from cellpose import io, models
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import json
import glob
import itertools
from PIL import Image
import torch.nn.functional as F
from utils import *
cudnn.benchmark = True
plt.ion()   # interactive mode

<contextlib.ExitStack at 0x7ff46dc88ca0>

In [2]:

class VideoDataMIP:
    def __init__(self, files):
        self.data = {
        }
        
        for category, num in files:
            print(f"Loading in MIP {num}")
            assert category == 'mip', "Can't load non Mip file"
            file = {}
            file['video'] = get_file(category, num)
            
            frames, shp = file['video'].read_image(C=0)
            frames = scale_img(frames.squeeze())
            file['frames'] = frames
            print(f"frames {num}: {frames.shape}")
            file['masks'] = binarize_video(frames)           
    
            self.data[num] = file    
    def extract_all_traces(self, file_num, sequence_length, hist_length=2):
        # hist length is how many frames of history
        frames, masks = self.data[file_num]['frames'], self.data[file_num]['masks']
        N = len(frames)
        s = 0
        all_traces = []
        all_videos = []
        for i in range(N // sequence_length):
            print(f"Extracting traces from {s}:{s+sequence_length}")
            data, videos = extract_traces(frames[s:s+sequence_length], masks[s:s+sequence_length], hist=hist_length)
            s += sequence_length
            all_traces = all_traces + data
            all_videos = all_videos + videos
        
        if(N % sequence_length > 0):
            data, videos = extract_traces(frames[-1*sequence_length:], masks[-1*sequence_length:], hist=hist_length)
            all_traces = all_traces + data
            all_videos = all_videos + videos        
        self.data[file_num]['traces'] = all_traces
        self.data[file_num]['trace_videos'] = all_videos
            

In [3]:
from skimage.measure import centroid
import skimage.measure as skm



box_shape = (180, 180) #TO DO: find the biggest box and set it to this
X = 10

class CellBoxMaskPatch(torch.utils.data.Dataset):
    #input will be a Directory name, function is TO DO
    def __init__(
        self,
        files, 
        X=X):
        
        self.video_extractor = VideoDataMIP(files)

        for i in files:
            self.video_extractor.extract_all_traces(i[1], X)
        
        
        self.cell_dict = []

        for key in self.video_extractor.data:
            entry = self.video_extractor.data[key]["traces"]
            for cell in entry:
                patches = [np.array(p) for p in cell["patches"]]
                boxes = [np.array(b) for b in cell['boxes']]
                masks = [np.array(m) for m in cell['masks']]
                
                self.cell_dict.append((boxes, masks, patches)) #cell dict is a list of 3 types by sequence

        self.num_cells = len(self.cell_dict) #this is a list of how many sequences we have
              
    def __len__(self):
        return self.num_cells
        

    def get_centroids(self, boxes, masks):
        N = len(masks)
        res = []
        centroids = [skm.centroid(binary.astype(np.uint8)) for binary in masks]
        for i in range(N):
            c = centroids[i]
            ymin, xmin = boxes[i][:2]
            res.append([xmin+c[0], ymin+c[1]])
        return(np.array(res) - res[0]) 
   
    def pad_arrays(self, array, pad_amt=200):
    
        pad_width = ((0, pad_amt - array.shape[0]), (0, pad_amt - array.shape[1]))

        padded_array = np.pad(array, pad_width, mode='constant')
        return padded_array

    def __getitem__(self, idx):
        cell_sequences = self.cell_dict[idx]  #this is the first sequence of 10 cells
        boxes = cell_sequences[0]
        masks = cell_sequences[1]
        patches = cell_sequences[2]


        for cell_mask_num in np.arange(len(masks)): #should be sequence length (10) masks
                
                cell_time = np.array(masks[cell_mask_num], dtype=np.int32)
                cell_time = np.where(cell_time >= 0, cell_time, 1)
                cell_time = self.pad_arrays(cell_time)
                masks[cell_mask_num] = cell_time
                cell_time_patch = np.array(patches[cell_mask_num], dtype=np.int32)

                cell_time_patch = self.pad_arrays(cell_time_patch)

                patches[cell_mask_num] = cell_time_patch


        centroids = self.get_centroids(boxes, masks)
    

        return centroids, masks, patches

In [4]:
from torch.utils.data import random_split

mip_video_files = [
    ('mip', 3)
]

dataset = CellBoxMaskPatch(mip_video_files, 10) # file, S, T

train, eval, test = random_split(dataset, [0.4, 0.2, 0.4])

input_datasets = {}
input_datasets["train"] = train
input_datasets["eval"] = eval
input_datasets["test"] = test

Loading in MIP 3
Loading dicty_factin_pip3-03_MIP.czi with dims [{'X': (0, 474), 'Y': (0, 2048), 'C': (0, 2), 'T': (0, 90)}]
frames 3: (90, 2048, 474)
Extracting traces from 0:10
Extracting cell  0
Extracting cell  1
Extracting cell  2
Extracting cell  3
Extracting cell  4
Extracting cell  5
Extracting cell  6
Extracting traces from 10:20
Extracting cell  0
Extracting cell  1
Extracting cell  2
Extracting cell  3
Extracting cell  4
Extracting cell  5
Extracting cell  6
Extracting cell  7
Extracting traces from 20:30
Extracting cell  0
Extracting cell  1
Extracting cell  2
Extracting cell  3
Extracting cell  4
Extracting cell  5
Extracting cell  6
Extracting cell  7
Extracting cell  8
Extracting traces from 30:40
Extracting cell  0
Extracting cell  1
Extracting cell  2
Extracting cell  3
Extracting cell  4
Extracting cell  5
Extracting cell  6
Extracting cell  7
Extracting traces from 40:50
Extracting cell  0
Extracting cell  1
Extracting cell  2
Extracting cell  3
Extracting cell  4
Ex

In [5]:
def collate_fn(batch, mode_box, mode_mask, mode_patch):
    

    # centroids = torch.stack([torch.tensor(b[0], dtype=torch.int) for b in batch], dim=0)
    # print(centroids.shape)
    # current_masks = torch.stack([torch.tensor(b[1], dtype=torch.long) for b in batch], dim=0)
    # current_masks = current_masks.reshape([4, 10, 40000])

    current_centroids = [b[0] for b in batch]
    current_masks = [b[1] for b in batch]
    current_patches = [b[2] for b in batch]

    current_centroids = torch.tensor(np.stack(current_centroids), dtype=torch.float32)
    current_masks = torch.tensor(np.stack(current_masks), dtype=torch.long)
    current_patches = torch.tensor(np.stack(current_patches), dtype=torch.long)

    
    current_patches = current_patches.reshape([4, 10, 40000])
    current_masks = current_masks.reshape([4, 10, 40000])

    print(current_masks.shape)
    print(current_patches.shape)

    selected_tensors = []
    if mode_box:
        selected_tensors.append(current_centroids)
    if mode_mask:
        selected_tensors.append(current_masks)
    if mode_patch:
        selected_tensors.append(current_patches)

    # Concatenate selected tensors along the last dimension
    combined_tensor = torch.cat(selected_tensors, dim=-1)

    # Reshape to add the singleton dimension
    combined_tensor = combined_tensor #.unsqueeze(1)

    return combined_tensor, current_centroids
    


In [6]:

mode_box = True
mode_mask = True
mode_patch = False

dataloaders = {}
dataloaders['train'] = torch.utils.data.DataLoader(
    input_datasets['train'],
    batch_size=4,
    shuffle=True,
    num_workers=0,
    collate_fn=lambda batch: collate_fn(batch, mode_box, mode_mask, mode_patch)
)

dataloaders['test'] = torch.utils.data.DataLoader(
    input_datasets['test'],
    batch_size=4,
    shuffle=True,
    num_workers=0,
    collate_fn=lambda batch: collate_fn(batch, mode_box, mode_mask, mode_patch)
)

dataloaders['eval'] = torch.utils.data.DataLoader(
    input_datasets['eval'],
    batch_size=4,
    shuffle=True,
    num_workers=0,
    collate_fn=lambda batch: collate_fn(batch, mode_box, mode_mask, mode_patch)
)

In [7]:
for batch in dataloaders['train']:

    print("Input:", batch[0].shape, "Centroids", batch[1].shape)
    
    break  

torch.Size([4, 10, 40000])
torch.Size([4, 10, 40000])
Input: torch.Size([4, 10, 40002]) Centroids torch.Size([4, 10, 2])


In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [20]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_dim, num_layers=2):
        super(LSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_dim, num_layers=2, batch_first=True) #stacking 2 LSTMs
        # The linear layer that maps from hidden state space to output space
        # Currently set equal to 2 for the x & y dimension of the centroid
        self.fc = nn.Linear(hidden_dim, 2)
    def forward(self, input):
        #h_0: tensor of shape (D∗num_layers,N,Hout​) containing the initial hidden
        # state for each element in the input sequence. Defaults to zeros if (h_0, c_0) is not provided.
        h0 = torch.zeros(self.num_layers, input.size(0), self.hidden_dim).to(input.device)
        c0 = torch.zeros(self.num_layers, input.size(0), self.hidden_dim).to(input.device)

        # We need to detach as we are doing truncated backpropagation through time (BPTT)
        # If we don't, we'll backprop all the way to the start even after going through another batch
        out, _ = self.lstm(input, (h0.detach(), c0.detach()))

        # Predict coordinate + bound to 0-200.
        out = self.fc(out)
        final = out[:,-1,:]
        out = torch.sigmoid(out) * 200
        return out

In [43]:
input_size = 40002 
hidden_size = 100
num_layers = 2
epochs = 1
sequence_length = 10 #how many frames we process per input
batch_size = 4

model = LSTM(input_size, hidden_size, num_layers)
if torch.cuda.is_available():
    model = model.cuda()
dummy_input_data = torch.randn(batch_size, 10, input_size)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train():
    model.train()
    total_loss = 0
    total_correct = 0
    for batch in dataloaders['train']:
        optimizer.zero_grad()
        inputs, outputs = batch[0], batch[1]
        inputs, outputs = inputs.to(device), outputs.to(device)
        pred = model(inputs[:, :sequence_length-1, :])
        # print(f"pred: {pred}")
        # print(f"outputs: {outputs.data[:,-1,:]}")
        total_correct += torch.sum(pred[:,-1,:] == outputs.data[:,-1,:])
        loss = criterion(pred[:,-1,:], outputs[:,-1,:])
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
    print(f"training loss: {total_loss / len(dataloaders['train'])}, training accuracy: {total_correct / len(dataloaders['train'])}")
    return model

def eval():
    model.eval()
    total_loss = 0
    total_correct = 0
    with torch.no_grad():
        for batch in dataloaders['dev']:
            inputs, outputs = batch[0], batch[1]
            inputs, outputs = inputs.to(device), outputs.to(device)
            pred = model(inputs[:, :sequence_length-1, :])
            total_correct += torch.sum(pred[:,-1,:] == outputs.data[:,-1,:])
            loss = criterion(pred[:,-1,:], outputs[:,-1,:])
            total_loss += loss.item()
    print(f"validation loss: {total_loss / len(dataloaders['train'])}, validation accuracy: {total_correct / len(dataloaders['train'])}")
    return total_correct / len(dataloaders['train'])

def train_model():
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    for epoch in range(epochs):
        train()
        curr_acc = eval()
        if curr_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = copy.deepcopy(model.state_dict())

    model.load_state_dict(best_model_wts)
    return model

In [44]:
train_model()

torch.Size([4, 10, 40000])
torch.Size([4, 10, 40000])
torch.Size([4, 10, 40000])
torch.Size([4, 10, 40000])
torch.Size([4, 10, 40000])
torch.Size([4, 10, 40000])
torch.Size([4, 10, 40000])
torch.Size([4, 10, 40000])
torch.Size([4, 10, 40000])
torch.Size([4, 10, 40000])
torch.Size([4, 10, 40000])
torch.Size([4, 10, 40000])


RuntimeError: shape '[4, 10, 40000]' is invalid for input of size 1200000