In [1]:
%load_ext autoreload

%autoreload 2

In [2]:
from pathlib import Path
import json
import re

import torch
from torch import Tensor
import torch.optim as optim
from torch.utils.data import Dataset
from torch import nn
from torch import functional as F

import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from matplotlib import image as mpl_image
from typing import Type, Any, Callable, Union, List, Optional

from hmmlearn import hmm


In [3]:
from centrilyze import CentrioleImageFiles, ImageDataset, CentrioleImageModel, HMM, constants, image_transform, target_transform, annotate, nest_annotation_keys, get_sequence_matrix, get_transition_count_matrix, get_transition_rate_matrix, get_confusion_matrix, reannotate, save_all_state_figs

# Settings

In [4]:
test_folder = Path("/nic/data/high_low/train")
model_file = Path("/nic/models/model_resnet_18_high_low_affine_149.pyt")
annotations_file = Path("/nic/annotations.json")
sequences_file = Path("/nic/sequences.npy")
emission_matrix_path = Path("/nic/emission_matrix.npy")
emission_matrix_path_three_classes = Path("/nic/emission_matrix_three_classes.npy")
output_dir = Path("/nic/output")

n_iter=1000
batch_size = 4

# Data Loading

In [5]:
centriole_image_files = CentrioleImageFiles.from_unannotated_images(test_folder)

In [6]:
# centriole_image_files.images

In [None]:
testset = ImageDataset.from_centriole_image_files(
    centriole_image_files, 
    image_transform, 
    target_transform,
)

In [None]:
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, 
                                         drop_last=False
                                        )

# Define the model

In [None]:
model = CentrioleImageModel()

# Load the trained model params

In [None]:
image_model = CentrioleImageModel()


In [None]:
image_model.load_state_dict(model_file, map_location=torch.device("cpu"))


# Annotate the data with the model

In [None]:
annotations = annotate(image_model, testloader)

In [None]:
nested_annotations = nest_annotation_keys(annotations)

# Output the confusion matrix

In [None]:
confusion_matrix_test_table = get_confusion_matrix(annotations)

In [None]:
confusion_matrix_test_table

# Get the Sequences

In [None]:
sequence_matrix = get_sequence_matrix(nested_annotations, list(range(20)), 7)

In [None]:
sequence_matrix.shape

# All classes: Get the naive transition counts and rates

In [None]:
transition_count_matrix = get_transition_count_matrix(sequence_matrix, 7)

In [None]:
np.set_printoptions(precision=3, suppress=True)


In [None]:
naive_transition_table = pd.DataFrame(
data=transition_count_matrix,
index=list(constants.classes.keys()),
columns=list(constants.classes.keys()))
naive_transition_table

In [None]:
transition_rate_matrix = get_transition_rate_matrix(transition_count_matrix)

In [None]:
classes = {
    "Not_Oriented": 0, 
    "Oriented": 1, 
    "Precieved_Not_Oriented": 2, 
    "Precieved_Oriented": 3, 
    "Slanted": 4, 
    "Unidentified": 5, 
    "No_sample": 6,
}

naive_transition_table = pd.DataFrame(
data=transition_rate_matrix,
index=list(classes.keys()),
columns=list(classes.keys()))
naive_transition_table

# Three classes: Get the counts and transition rates

In [None]:
reannotations = reannotate(nested_annotations, constants.annotation_mapping)

In [None]:
sequence_3_classes_matrix = get_sequence_matrix(reannotations, list(range(20)), 5)

In [None]:
sequence_3_classes_matrix

In [None]:
transition_3_classes_count_matrix = get_transition_count_matrix(sequence_3_classes_matrix, 5)

In [None]:
naive_transition_3_classes_table = pd.DataFrame(
data=transition_3_classes_count_matrix,
index=list(constants.classes_reduced.keys()),
columns=list(constants.classes_reduced.keys()))
naive_transition_3_classes_table

In [None]:
transition_3_classes_rate_matrix = get_transition_rate_matrix(transition_3_classes_count_matrix)

In [None]:

naive_transition_table = pd.DataFrame(
data=transition_3_classes_rate_matrix,
index=list(constants.classes_reduced.keys()),
columns=list(constants.classes_reduced.keys()))
naive_transition_table

# Hidden Markov model

## Load the emission matrix

In [None]:
emission_matrix_np = np.load(emission_matrix_path)

## Define the model

In [None]:
model = hmm.MultinomialHMM(n_components=7, n_iter=n_iter, params="st", init_params="st")

In [None]:
model.emissionprob_ = emission_matrix_np

## Fit the model

In [None]:
model.fit(sequence_matrix.reshape(-1, 1), [sequence_matrix.shape[1]]*sequence_matrix.shape[0])

## Format the model output

In [None]:

transition_table = pd.DataFrame(
data=model.transmat_,
index=list(constants.classes.keys()),
columns=list(constants.classes.keys()))


In [None]:
transition_table

# Hidden Markov model: Three Classes

## Load the emission matrix

In [None]:
emission_matrix_np = np.load(emission_matrix_path_three_classes)

## Define the model

In [None]:
model = hmm.MultinomialHMM(n_components=5, n_iter=n_iter, params="st", init_params="st")

In [None]:
model.emissionprob_ = emission_matrix_np

## Fit the model

In [None]:
model.fit(sequence_3_classes_matrix.reshape(-1, 1), [sequence_3_classes_matrix.shape[1]]*sequence_3_classes_matrix.shape[0])

## Format the model output

In [None]:

transition_table = pd.DataFrame(
data=model.transmat_,
index=list(constants.classes_reduced.keys()),
columns=list(constants.classes_reduced.keys()))


In [None]:
transition_table

# Changing sequences

In [None]:
states = {}
for experiment, particles in reannotations.items():
    for particle, frames in particles.items():
        frame_list = []
        for frame, annotation in frames.items():
            frame_list.append(annotation["assigned"])
            frame_array = np.array(frame_list)
        # print(frame_array)
        # print(tuple(np.unique(frame_array)))
        if tuple(np.unique(frame_array)) not in states:
            states[tuple(np.unique(frame_array))] = set()
        states[tuple(np.unique(frame_array))] = states[tuple(np.unique(frame_array))].union(((experiment, particle),))

In [None]:
save_all_state_figs(states, testset, output_dir, reannotations,)


In [None]:
reannotations.keys()

In [None]:
annotations_by_frame = {}
for experiment, particles in reannotations.items():
    annotations_by_frame[experiment] = {}
    
    for j in range(20):
        annotations_by_frame[experiment][j] = {}
    
    for particle, frames in particles.items():
            for frame_number, frame in frames.items():
                annotations_by_frame[experiment][frame_number][particle] = frame

In [None]:
frames[19]

In [None]:
for experiment, frames in annotations_by_frame.items():
    print(experiment)
    for frame, particles in frames.items():
        print(f"\tFrame: {frame}")
        num_oriented = len([annotation for annotation in particles.values() if annotation['assigned'] == 1])
        
        num_unoriented = len([annotation for annotation in particles.values() if annotation['assigned'] == 0])
        
        num_slanted = len([annotation for annotation in particles.values() if annotation['assigned'] == 2])
        
        num_unidentified = len([annotation for annotation in particles.values() if annotation['assigned'] == 3])
        
        num_particles = num_slanted + num_unoriented+num_oriented+num_unidentified
        print(f"\t\tTotal: {num_particles}")
        print(f"\t\t\tOriented: {num_oriented}")
        print(f"\t\t\tUnoriented: {num_unoriented}")
        print(f"\t\t\tSlanted: {num_slanted}")
        print(f"\t\t\tUnidentified: {num_unidentified}")

        

In [None]:
fig, axs = plt.subplots(nrows=len(annotations_by_frame), figsize=(25, 5*len(annotations_by_frame)))

j = 0
for experiment, frames in annotations_by_frame.items():
    num_orienteds = []
    num_unorienteds = []
    num_slanteds = []
    num_unidentifieds = []
    num_missings = []
    for frame, particles in frames.items():

        num_oriented = len([annotation for annotation in particles.values() if annotation['assigned'] == 1])
        num_unoriented = len([annotation for annotation in particles.values() if annotation['assigned'] == 0])
        num_slanted = len([annotation for annotation in particles.values() if annotation['assigned'] == 2])
        num_unidentified = len([annotation for annotation in particles.values() if annotation['assigned'] == 3])
        num_missing = len([annotation for annotation in particles.values() if annotation['assigned'] == 4])

        total = num_oriented + num_unoriented + num_slanted + num_unidentified + num_missing
        
        num_orienteds.append(num_oriented/total)
        num_unorienteds.append(num_unoriented/total)
        num_slanteds.append(num_slanted/total)
        num_unidentifieds.append(num_unidentified/total)
        num_missings.append(num_missing/total)

        


    axs[j].set_title(f"Experiment: {experiment}")
    # for key, frame in particle_data.items():
    axs[j].plot(num_orienteds, label="Oriented")
    axs[j].plot(num_unorienteds, label="Unoriented")
    axs[j].plot(num_slanteds, label="slanted")
    axs[j].plot(num_unidentifieds, label="Unidentified")
    axs[j].plot(num_missings, label="Missing")
    axs[j].set_xlabel("")
    axs[j].legend()
    
    j = j+1

In [None]:
fig, axs = plt.subplots(nrows=len(annotations_by_frame), figsize=(25, 5*len(annotations_by_frame)))

j = 0
for experiment, frames in annotations_by_frame.items():
    num_orienteds = []
    num_unorienteds = []
    num_slanteds = []
    num_unidentifieds = []
    num_missings = []
    num_fractions = []
    for frame, particles in frames.items():

        num_oriented = len([annotation for annotation in particles.values() if annotation['assigned'] == 1])
        num_unoriented = len([annotation for annotation in particles.values() if annotation['assigned'] == 0])
        num_slanted = len([annotation for annotation in particles.values() if annotation['assigned'] == 2])
        num_unidentified = len([annotation for annotation in particles.values() if annotation['assigned'] == 3])
        num_missing = len([annotation for annotation in particles.values() if annotation['assigned'] == 4])

        total = num_oriented + num_unoriented + num_slanted + num_unidentified + num_missing
        
        num_orienteds.append(num_oriented/total)
        num_unorienteds.append(num_unoriented/total)
        num_slanteds.append(num_slanted/total)
        num_unidentifieds.append(num_unidentified/total)
        num_missings.append(num_missing/total)

        num_fractions.append(num_oriented / num_unoriented)


    axs[j].set_title(f"Experiment: {experiment}")
    # for key, frame in particle_data.items():
    axs[j].plot(num_fractions, label="Oriented/unoriented")
    # axs[j].plot(num_unorienteds, label="Unoriented")
    # axs[j].plot(num_slanteds, label="slanted")
    # axs[j].plot(num_unidentifieds, label="Unidentified")
    # axs[j].plot(num_missings, label="Missing")
    axs[j].set_xlabel("Frame")
    axs[j].legend()
    
    j = j+1