# Functions

In [None]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf

import sys
path_to_software_folder = sys.path[0][:-10] + 'software/'
sys.path.append(path_to_software_folder)
from utils import *
from tqdm import tqdm
from tqdm import tqdm
from pathlib import Path

In [None]:
def FixPredictions(prediction, mode='single'):
    "Enforce physical limits on the prediction"
    
    if mode == 'single':
        alphas = prediction[0]
        Ks = prediction[1]
        Ms = np.round(prediction[2])
    elif mode =='batch':
        alphas = prediction[:,0]
        Ks = prediction[:,1]
        Ms = np.round(prediction[:,2])
    
    
    alphas = np.where(alphas > 0, alphas, 0)
    alphas = np.where(alphas < 1.999, alphas, 1.999)
    
    Ks = np.where(Ks > 1e-12, Ks, 1e-12)
    Ks = np.where(Ks < 1e6, Ks, 1e6)
    
    Ms = np.where(Ms > 0, Ms, 0)    
    Ms = np.where(Ms < 3, Ms, 3)
    Ms = np.where(alphas > 1.9, 3, Ms)    # if alpha is over 1.9, M must be 3
    
    if mode == 'single':
        fixed_prediction = np.array([alphas, Ks, Ms])
    elif mode =='batch':
        fixed_prediction = np.stack([alphas, Ks, Ms], axis=1)
    return fixed_prediction

In [None]:
from sklearn import mixture

def AnalyseEnsembleProperties(Data):
    '''
    Returns optimal number of components, along with means and stds.
    '''
    Data = FixPredictions(Data, mode='batch')
    
    n_components = np.arange(1, 5)[:len(Data)]
    models = [mixture.GaussianMixture(n, covariance_type='diag', tol=0.0001, max_iter=1000, n_init=3).fit(Data)
              for n in n_components]

    BICs = np.zeros(len(n_components))
    OverlapFree = np.zeros(len(n_components))
    GoodWeights = np.zeros(len(n_components))

    for idx, model in enumerate(models):
        BICs[idx] = model.bic(Data) 
        OverlapFree[idx] = CheckGaussianOverlap(model.means_[:,:2], model.covariances_[:,:2], std_scale=0.5)
        GoodWeights[idx] = CheckWeights(model.weights_)
        
    Good_Models = np.logical_and(OverlapFree, GoodWeights)
    best_model_index = np.argmin(BICs[Good_Models!=0])
    best_model = models[best_model_index]
    
    opt_num_components = n_components[best_model_index]
    means = best_model.means_[:,:2]
    stds = best_model.covariances_[:,:2]
    weights = best_model.weights_
    
    return opt_num_components, means, stds, weights, best_model 


def CheckGaussianOverlap(means, covariances, std_scale=1):
    '''
    Checks if any Gaussians mean+- (std_scale) standard deviation(s) contains any other Gaussians' means.
    Returns True is the GMM is good, False if the GMM is bad.
    '''
    standard_deviations = np.sqrt(covariances)

    all_means_min_scaled_std = means - (standard_deviations*std_scale)
    all_means_plu_scaled_std = means + (standard_deviations*std_scale)

    for mean in means:
        others_mask = means != mean
        above_min_mask = all_means_min_scaled_std < mean
        below_max_mask = all_means_plu_scaled_std > mean

        in_std_range_mask = np.logical_and(above_min_mask, below_max_mask)
        other_means_in_std_range = in_std_range_mask[others_mask]

        if np.any(other_means_in_std_range):    # if any of the other Gaussians have means within one std of, this GMM is bad
            return False
    
    return True    # if not, this GMM is good!


def CheckWeights(weights, weight_cutoff=0.05):
    '''
    Checks the values of the weights, removes them if they seem too low
    '''
    return np.all(weights > weight_cutoff)

In [None]:
def PredictAndSplit_M(model, padded_trajs, padding_mask, min_peak_height=0.25):
    '''
    Predict based on all the padded trajectories passed, then split according to paddding mask and the predicted CPs
    '''
    ## Make predictions ##
    Pred_Labs = model.predict(padded_trajs)
    Pred_Labs = np.concatenate(Pred_Labs, axis=2)    # concatenate all label data
    ## /Make predictions ##

    ## Split up the segments, keeping track of where each came from ##
    Label_Segments = []    # we will collect all the segments for each traj into this list
    All_CPs = []
    
    for traj_idx, (traj, pred_lab) in enumerate(zip(padded_trajs, Pred_Labs)):
        ## Undo the padding ##
        pred_lab = pred_lab[padding_mask[traj_idx]]
        ## /Undo the padding ##

        CP_labels = pred_lab[:,0]
        alpha_and_K_and_class_and_model_each_timestep = pred_lab[:,1:]

        ## Get Changepoints ##
        CPs = LabelToCP(CP_labels, min_peak_height=min_peak_height)
        All_CPs = All_CPs + [np.concatenate((CPs, [np.count_nonzero(padding_mask[traj_idx])]))]    # append lenght as a final CP
        ## /Get Changepoints ##

        ## Split according to changepoitns ##
        pred_label_segments = np.split(alpha_and_K_and_class_and_model_each_timestep, CPs)

        ## Save each of these created split segment labels! ##
        Label_Segments = Label_Segments + [pred_label_segments]
        ## /Split according to changepoitns ##
    ## /Split up the segments, keeping track of where each came from ##

    return Label_Segments, All_CPs

In [None]:
def PhaseOnePredictions(model, data_path, max_traj_len=200, min_peak_height=0.25):
    '''
    Load all data and make predictions using a U-Net
    '''
    ### Load data and prepare it for network ###
    All_Trajs = []    # stores all the trajs across all exps and fovs!
    All_Padding_Masks = []    # stores all the padding masks across all exps and fovs!
    All_Traj_Addresses = []    # for each traj, stores what exp and fov its from!

    num_exps = len(os.listdir(data_path + '/track_2/'))
    for exp in range(num_exps):
        all_files = os.listdir(data_path + f'/track_2/exp_{exp}/')
        num_fovs = len([fov for fov in all_files if fov.startswith('trajs_fov')])
        for fov in range(num_fovs):
            FOV_df = pd.read_csv(data_path + f'track_2/exp_{exp}/trajs_fov_{fov}.csv')
            FOV = FOV_df.to_numpy()

            num_trajs = int(FOV[-1,0]) + 1
            all_trajs = np.zeros((num_trajs,max_traj_len,2))   # prepare a container for all the trajs
            traj_idx = -1
            padding_mask = np.full((num_trajs,max_traj_len), True)    # keeps track of what is padded vs authentic data

            _, first_idx = np.unique(FOV[:,0], return_index=True)    # split into diff trajs
            split_trajs = np.split(FOV, first_idx[1:])    

            for traj in split_trajs:
                traj_idx += 1
                first_frame, last_frame = int(traj[0,1]), int(traj[-1,1])

                all_trajs[traj_idx][first_frame:last_frame+1] = traj[:,2:4]    # drop in the traj
                all_trajs[traj_idx][:first_frame] = traj[0,2:4]    # pad the traj!
                all_trajs[traj_idx][last_frame+1:] = traj[-1,2:4]

                padding_mask[traj_idx][:first_frame] = False    # keep track of what values are padding
                padding_mask[traj_idx][last_frame+1:] = False

                All_Traj_Addresses = All_Traj_Addresses + [[exp, fov]]

            All_Trajs = All_Trajs + [all_trajs]
            All_Padding_Masks = All_Padding_Masks + [padding_mask]
    All_Trajs = np.concatenate(All_Trajs, axis=0) 
    All_Trajs = DiffTrajs(All_Trajs)
    All_Padding_Masks = np.concatenate(All_Padding_Masks, axis=0)        
    All_Traj_Addresses = np.array(All_Traj_Addresses)
    ### /Load data and prepare it for network ###

    # ## Make predictions ###        
    Label_Segments, All_CPs = PredictAndSplit_M(model, All_Trajs, All_Padding_Masks, min_peak_height=min_peak_height)
    
    return All_Traj_Addresses, Label_Segments, All_CPs, num_exps, num_fovs

In [None]:
def GetModel(exp_model):
    '''
    Convert one hot M prediction to model label
    '''
    models_possible = ['single_state', 'multi_state', 'dimerization', 'confinement', 'immobile_traps']
    model_idx = np.argmax(exp_model)
    return models_possible[model_idx]

In [None]:
def PhaseTwoPredictions(All_Traj_Addresses, Label_Segments, CPs, 
                        num_exps, num_fovs,
                        output_name,
                        predict_ensemble_properties=True):
    '''
    Take the outputs from phase one and use them to make phase two predictions.
    '''    
    if predict_ensemble_properties:
        ALL_EXP_LABELS = []
    for exp in tqdm(range(num_exps)):
        ### Collect all segments for this experiment ###
        exp_mask = All_Traj_Addresses[:,0] == exp
        exp_Traj_Addresses = All_Traj_Addresses[exp_mask]
        exp_Labels = [seg_lab for e_mask, seg_lab in zip(exp_mask, Label_Segments) if e_mask]
        exp_CPs = [cp for e_mask, cp in zip(exp_mask, CPs) if e_mask]
        ### Collect all segments for this experiment ###

        ### create the correct file structure ###
        results_dir_path = os.getcwd() + f'/{output_name}/track_2/exp_{exp}/'
        ### /create the correct file structure ###    

        ### If needed, do ensemble level analysis ### 
        if predict_ensemble_properties:
            ### Ensemble Level Analysis! ###
            flat_traj_labels = [np.concatenate(exp_lab, axis=0) for exp_lab in exp_Labels]
            flat_exp_labels = np.concatenate(flat_traj_labels, axis=0)
            ALL_EXP_LABELS = ALL_EXP_LABELS + [flat_exp_labels]
            
            num_components, means, stds, weights, GMM_model = AnalyseEnsembleProperties(flat_exp_labels[:,:3])    # only consider the alpha K and diff type
            flat_exp_model = np.mean(flat_exp_labels[3:])
            exp_model = GetModel(flat_exp_model)

            ### write to file ###
            Path(results_dir_path).mkdir(parents=True, exist_ok=True)
            file = open(results_dir_path + f'ensemble_labels.txt', 'w')
            prediction_string = f'model: {exp_model}; num_state: {num_components} \n'
            prediction_string += "; ".join(means[:,0].astype('str')) + '\n'    # all alpha means
            prediction_string += "; ".join(stds[:,0].astype('str')) + '\n'    # all alpha stds
            prediction_string += "; ".join(means[:,1].astype('str')) + '\n'    # all K means
            prediction_string += "; ".join(stds[:,1].astype('str')) + '\n'    # all K stds
            prediction_string += "; ".join(weights.astype('str'))    # weights
            file.write(prediction_string)
            file.close()
            ### write to file ###  
        ## If needed, do ensemble level analysis ###


        ## loop over all fovs in this experiment and write their info to different files ###
        fovs = np.unique(exp_Traj_Addresses[:,1])
        for fov in fovs:
            ### collect all info for this FOV ###
            fov_mask = exp_Traj_Addresses[:,1] == fov
            fov_Labels = [seg_lab for f_mask, seg_lab in zip(fov_mask, exp_Labels) if f_mask]
            fov_CPs = [cp for f_mask, cp in zip(fov_mask, exp_CPs) if f_mask]
            ### /collect all info for this FOV ###

            ### write to file ###
            Path(results_dir_path).mkdir(parents=True, exist_ok=True)    # make parent dir if needed
            file = open(results_dir_path + f'fov_{fov}.txt', 'w')
            for traj_idx, (Traj_Labels, Traj_CPs) in enumerate(zip(fov_Labels, fov_CPs)):
                prediction_string = str(traj_idx)
                for seg_label, cp in zip(Traj_Labels, Traj_CPs):
                    seg_label =  FixPredictions(np.mean(seg_label, axis=0))    # convert TS wise prediction to single values
                    prediction_string = (prediction_string + ','  
                                        +str(seg_label[1]) + ','    # Ks 
                                        +str(seg_label[0]) + ','    # alphas
                                        +str(seg_label[2]) + ','    # Ms
                                        +str(cp))
                prediction_string = prediction_string + '\n'
                file.write(prediction_string)
            file.close()
            # /write to file ###
        ## loop over all fovs in this experiment and write their info to diff files ###     

    if predict_ensemble_properties:
        return ALL_EXP_LABELS#, ALL_EXP_MODELS

In [None]:
def GetModels(exp_model, threshold=0.0):
    '''
    Convert one hot M prediction to model label
    '''   
    models_possible = ['single_state', 'multi_state', 'dimerization', 'confinement', 'immobile_traps']
    sorted_model_idxs = np.argsort(exp_model)[::-1]
    sorted_model_exps = np.array([exp_model[s_m_idx] for s_m_idx in sorted_model_idxs])
    diff_sorted_models = np.diff(-sorted_model_exps)
    diff_sorted_models_above_threshold = diff_sorted_models > threshold
    diff_sorted_models_above_threshold = np.concatenate([diff_sorted_models_above_threshold, [True]])
    diff_sorted_models_above_threshold_idx = np.argwhere(diff_sorted_models_above_threshold)[0][0]+1
    model_idxs = sorted_model_idxs[:diff_sorted_models_above_threshold_idx]

    return [models_possible[m_idx] for m_idx in model_idxs]

# Main

### Build the Network 

In [None]:
import numpy as np
import sys
import tensorflow as tf
from UNet3P_var_M2 import *
from UNet_Blocks import *
from utils import DiffTrajs

# Make the model
max_traj_len = 224
filters = [16, 32, 64, 64, 128, 128]

ConvBlockParams = {'num_filters': 128,
                   'kernel_size': 3,
                   'strides': 1,
                   'padding': 'same'}

SkipBlockParams = {'num_filters': 512,
                   'kernel_size': 3,
                   'strides': 1,
                   'padding': 'same'}

DecoderBlockParams = {'num_filters': 512,
                      'kernel_size' :3,
                      'strides': 1,
                      'padding': "same"}

model = UNet3P_var_M(filters, ConvBlockSimple, ConvBlockParams, SkipBlockParams, DecoderBlockParams, input_len=max_traj_len)

# Set file to save to
path = sys.path[0][:-10] + 'ChallengeNets/GeneralistNet/Model.weights.h5'
model.load_weights(path)
print('Network weights loaded')

### Apply Network to Local Challenge Data

In [None]:
data_path = '/home/cs-solomon.asghar/AnDi_2024/public_data_challenge_v0/'
output_name_template = "/GeneralistNet_Predictions/"

In [None]:
All_Traj_Addresses, Label_Segments, All_CPs, num_exps, num_fovs = PhaseOnePredictions(model, data_path, max_traj_len=max_traj_len)

In [None]:
output_name_nosegnet = output_name_template

All_Traj_Addresses_a = All_Traj_Addresses.copy()
Label_Segments_a = Label_Segments.copy()
All_CPs_a = All_CPs.copy()

labs = PhaseTwoPredictions(All_Traj_Addresses_a, Label_Segments_a, All_CPs_a, num_exps, num_fovs,
                           output_name=output_name_nosegnet)