In this notebook we use classic CV features in combination with Deep Learning to predict cell movements. 

In [2]:
import mahotas
import numpy as np
import torch 
import torch.nn as nn
from utils import *
from skimage.measure import centroid
import skimage.measure as skm
from torch.utils.data import random_split
import copy


In [3]:
def shape_features(binary, feature_length=20, num_samples=180):
    def radial_distance(binary, theta):
        height, width = binary.shape
        center = [width // 2, height // 2]
        def test_r(r):
            x_test, y_test = center[0] + r*np.cos(theta), center[1] + r*np.sin(theta)
            if(x_test >= width or y_test > height or x_test < 0 or y_test < 0):
                return(False)
            return(binary[int(y_test), int(x_test)])
        # calculate distance to the nearest pixel
        r = max(height, width)
        while(not test_r(r)): # start from edge come inside until hit cell
            r -= 1
        return(r)

    test_angles = np.linspace(0, 2*np.pi, num_samples)
    distances = np.array([radial_distance(binary, angle) for angle in test_angles])
    fft_coefficients = np.fft.rfft(distances)

    features = np.abs(fft_coefficients[:feature_length])
    features = features / np.sum(features)
    return(features, (distances, fft_coefficients))

In [4]:
def featurize(patches, masks):
    image, binary = patches, masks.astype(np.uint8)
    zernike = mahotas.features.zernike_moments(binary, max(binary.shape)/2, degree=8)
    #zernike = zernike / zernike.sum()
    haralick = mahotas.features.haralick(image.astype(np.uint16)).mean(axis=0)
    #haralick = haralick / haralick.sum()
    shape, info = shape_features(binary, 20)
    #print(f"Zernike: {zernike.shape}, Haralick: {haralick.shape}, Radial Shape: {shape.shape}")
    return(np.concatenate([zernike, haralick, shape]))

In [5]:
def get_centroids(boxes, masks):
        N = len(masks)
        res = []
        centroids = [skm.centroid(binary.astype(np.uint8)) for binary in masks]
        for i in range(N):
            c = centroids[i]
            ymin, xmin = boxes[i][:2]
            res.append([xmin+c[0], ymin+c[1]])
        return(np.array(res) - res[0]) 

In [6]:
def get_velocities(centroids):
    vels = [0] * len(centroids)
    vels[0] = np.array([0,0])
    for i in range(1,len(centroids)):
        vels[i] = centroids[i] - centroids[i-1]
    return vels

In [7]:
max_padding =  300
box_shape = (180, 180)
X = 10

class VelocitiesClassicDataset(torch.utils.data.Dataset):
    #input will be a Directory name, function is TO DO
    def __init__(self,files, X=X):
        self.video_extractor = VideoDataMIP(files)
        self.cell_dict = []

        for i in files:
            self.video_extractor.extract_all_traces(i[1], X)

        for key in self.video_extractor.data:
            entry = self.video_extractor.data[key]["traces"]
            for cell in entry:
                features = [featurize(p,m) for p,m in zip(cell["patches"], cell["masks"])]
                velocities = [np.array(p) for p in get_velocities(get_centroids(cell["boxes"], cell["masks"]))]
                self.cell_dict.append((features, velocities)) #cell dict is a list of 3 types by sequence


    def __len__(self):
        return len(self.cell_dict)

    def __getitem__(self, idx):
        # tuple of features, velocities
        return  self.cell_dict[idx]

In [8]:
mip_video_files = [
    ('mip', 3),
    ('mip', 6),
    ('mip', 9)
]
dataset = VelocitiesClassicDataset(mip_video_files, X) # file, S, T

Loading in MIP 3
Loading dicty_factin_pip3-03_MIP.czi with dims [{'X': (0, 474), 'Y': (0, 2048), 'C': (0, 2), 'T': (0, 90)}]
frames 3: (90, 2048, 474)
Loading in MIP 6
Loading dicty_factin_pip3-06_MIP.czi with dims [{'X': (0, 474), 'Y': (0, 2048), 'C': (0, 2), 'T': (0, 241)}]
frames 6: (241, 2048, 474)
Loading in MIP 9
Loading dicty_factin_pip3-09_MIP.czi with dims [{'X': (0, 474), 'Y': (0, 2048), 'C': (0, 2), 'T': (0, 241)}]
frames 9: (241, 2048, 474)
Extracting traces from 0:10
Extracting traces from 10:20
Extracting traces from 20:30
Extracting traces from 30:40
Extracting traces from 40:50
Extracting traces from 50:60
Extracting traces from 60:70
Extracting traces from 70:80
Extracting traces from 80:90
Extracting traces from 0:10
Extracting traces from 10:20
Extracting traces from 20:30
Extracting traces from 30:40
Extracting traces from 40:50
Extracting traces from 50:60
Extracting traces from 60:70
Extracting traces from 70:80
Extracting traces from 80:90
Extracting traces from 

In [18]:
train, eval, test = random_split(dataset, [0.7, 0.2, 0.1])

input_datasets = {}
input_datasets["train"] = train
input_datasets["eval"] = eval
input_datasets["test"] = test

In [19]:
def collate_fn(batch):
    current_features = [b[0] for b in batch]
    current_offset = [b[1] for b in batch] 
    return torch.tensor(np.stack(current_features)).to(torch.float32), torch.tensor(np.stack(current_offset)).to(torch.float32)


In [20]:
dataloaders = {}
dataloaders['train'] = torch.utils.data.DataLoader(
    input_datasets['train'],
    batch_size=4,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn
)

dataloaders['test'] = torch.utils.data.DataLoader(
    input_datasets['test'],
    batch_size=4,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn
)

dataloaders['eval'] = torch.utils.data.DataLoader(
    input_datasets['eval'],
    batch_size=4,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn
)

In [21]:
for batch in dataloaders['eval']:
    print("Input:", batch[0].shape, "Velocities", batch[1].shape)
    

Input: torch.Size([4, 10, 58]) Velocities torch.Size([4, 10, 2])
Input: torch.Size([4, 10, 58]) Velocities torch.Size([4, 10, 2])
Input: torch.Size([4, 10, 58]) Velocities torch.Size([4, 10, 2])
Input: torch.Size([4, 10, 58]) Velocities torch.Size([4, 10, 2])
Input: torch.Size([4, 10, 58]) Velocities torch.Size([4, 10, 2])
Input: torch.Size([4, 10, 58]) Velocities torch.Size([4, 10, 2])
Input: torch.Size([4, 10, 58]) Velocities torch.Size([4, 10, 2])
Input: torch.Size([4, 10, 58]) Velocities torch.Size([4, 10, 2])
Input: torch.Size([4, 10, 58]) Velocities torch.Size([4, 10, 2])
Input: torch.Size([4, 10, 58]) Velocities torch.Size([4, 10, 2])
Input: torch.Size([4, 10, 58]) Velocities torch.Size([4, 10, 2])
Input: torch.Size([4, 10, 58]) Velocities torch.Size([4, 10, 2])
Input: torch.Size([4, 10, 58]) Velocities torch.Size([4, 10, 2])
Input: torch.Size([4, 10, 58]) Velocities torch.Size([4, 10, 2])
Input: torch.Size([4, 10, 58]) Velocities torch.Size([4, 10, 2])
Input: torch.Size([4, 10,

In [22]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        #Flatten the input from (4, 9, 2) to (4, 18)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(58, 2) 
        # self.fc2 = nn.Linear(32,16)
        # self.fc3 = nn.Linear(16,8)torch loss functions
        # self.fc4 = nn.Linear(8, 2)

    def forward(self, x):
        x = self.flatten(x) 
        # x = torch.relu(self.fc1(x)) 
        # x = torch.relu(self.fc2(x))
        # x = torch.relu(self.fc3(x))
        x = self.fc1(x)  
        return x

In [37]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_dim, num_layers=2):
        super(LSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_dim, num_layers, batch_first=True) #stacking 2 LSTMs
        # hidden out output
        #  2 bc x y centroid
        self.fc = nn.Linear(hidden_dim, 2)
    def forward(self, input):
        h0 = torch.zeros(self.num_layers, input.size(0), self.hidden_dim).to(input.device)
        c0 = torch.zeros(self.num_layers, input.size(0), self.hidden_dim).to(input.device)
        out, _ = self.lstm(input, (h0, c0))

        out = self.fc(out)
        final = out[:,-1,:]
        out = torch.sigmoid(final) * max_padding
        return out

In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [42]:
epochs = 1000
sequence_length = 10 #how many frames we process per input

# model = LSTM(input_size, hidden_size, num_layers)
model = LSTM(input_size=58, hidden_dim = 58)
if torch.cuda.is_available():
    model = model.cuda()
# dummy_input_data = torch.randn(batch_size, 10, input_size)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train():
    model.train()
    total_loss = 0
    total_correct = 0
    for batch in dataloaders['train']:
        optimizer.zero_grad()
        inputs, outputs = batch[0], batch[1]
        inputs, outputs = inputs.to(device), outputs.to(device)
        inputs, outputs = inputs.to(device), outputs.to(device)
        pred = model(inputs[:, :sequence_length-1, :])
        # print(inputs[:, sequence_length-1:sequence_length, :].shape)
        # print(pred.shape, outputs[:,-1,:].shape)
        # print(f"pred: {pred}")
        # print(f"outputs: {outputs.data[:,-1,:]}")
        # total_correct += torch.sum(torch.eq(pred, outputs[:,-1,:]))
        loss = criterion(pred, outputs[:,-1,:])
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
    # print(f"training loss: {total_loss / len(dataloaders['train'])}, training accuracy: {total_correct / len(dataloaders['eval'])}")
    print(f"training loss: {total_loss / len(dataloaders['train'])}")
    return model

def eval():
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in dataloaders['eval']:
            inputs, outputs = batch[0], batch[1]
            inputs, outputs = inputs.to(device), outputs.to(device)
            pred = model(inputs[:, :sequence_length-1, :])     
            loss = criterion(pred, outputs[:,-1,:])
            total_loss += loss.item()
    print(f"validation loss: {total_loss / len(dataloaders['eval'])}")
    return total_loss / len(dataloaders['eval'])


def train_model():
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    for epoch in range(epochs):
        print("Epoch:", epoch)
        train()
        curr_acc = eval()
        if curr_acc > best_acc:
            best_acc = curr_acc
            best_model_wts = copy.deepcopy(model.state_dict())

    model.load_state_dict(best_model_wts)
    return model

In [43]:
best_model = train_model()

Epoch: 0
training loss: 2663.994289452689
validation loss: 841.1459820866585
Epoch: 1
training loss: 672.3464625494821
validation loss: 818.6365978240967
Epoch: 2
training loss: 666.533123588562
validation loss: 810.9657469511033
Epoch: 3
training loss: 665.2561683041708
validation loss: 805.1799789190293
Epoch: 4
training loss: 662.3853049210139
validation loss: 803.1874762773514
Epoch: 5
training loss: 661.9053534439632
validation loss: 801.6749158978462
Epoch: 6
training loss: 661.5863018700055
validation loss: 801.2095357179642
Epoch: 7
training loss: 661.3778661676815
validation loss: 800.1282931089402
Epoch: 8
training loss: 661.2180553981236
validation loss: 798.7664747834206
Epoch: 9
training loss: 661.5320546797344
validation loss: 817.8248416304589
Epoch: 10
training loss: 660.8920964547566
validation loss: 797.0986922502518
Epoch: 11
training loss: 660.779383967604
validation loss: 796.4595306634903
Epoch: 12
training loss: 661.6542941434043
validation loss: 972.401052856445

KeyboardInterrupt: 