In [1]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
import matplotlib.patches as mpatches
from scipy.spatial import procrustes
import __init_paths

In [2]:
from config.st_dynamic_system import config, update_config
from models.st_gcn import STDynamicSystem
from dataset.muscle_sequence import get_muscle_sequences
from utils.visualize import VisualizeMuscle

In [3]:
os.chdir('..')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [11]:
def geo_mean(iterable):
    a = np.array(iterable)
    return a.prod()**(1.0/len(a))

In [6]:
def prepare_input(file_path, in_range=None, reverse=False):
    df = pd.read_excel(file_path, header=[0, 1, 2], index_col=0)
    df = df.interpolate()
    
    if in_range is not None:
        assert in_range[0] > 0, "Start time point must after the 0"
        df = df.iloc[in_range[0]-1:in_range[1]]
    
    # select the x input sequence
    x_columns = [c for c in df.columns if c[2] != 'calcium']
    angle_columns = [c for c in df.columns if c[2] == 'angle']
    for c in angle_columns:
        df[c] = df[c].apply(lambda x: x+180 if x < 0 else x)
    x = df[x_columns].values.reshape(len(df), -1, 5).transpose((2, 0, 1))
    
    # select the u input sequence
    u_columns = [c for c in df.columns if c[2] == 'calcium']
    u = df[u_columns].apply(lambda x: (x - x.min()) / (x.max() - x.min())).values[np.newaxis, :, :]
    if reverse:
        # reverse the time axis and exchange the x and u
        x = np.flip(x, 1)
        u = np.flip(u, 1)
        data_x = u
        data_u = x
    else:
        data_x = x
        data_u = u
    
    return data_x[:, 0], data_x[:, 1:].astype(np.float32), data_u[:, 1:].astype(np.float32)

In [36]:
def inference(model_path, config_file, initial_x, subsequent_u):
    """Make inference base on the provided model, config file and data
    
    :param model_path: The tained model parameters file path.
    :param config_file: The config file of current task.
    :param initial_x: The data for inference.
    :param subsequent_u: The subsequent sequence for translation.
    """
    params = {
        'cfg': config_file
    }
    update_config(config, params)
    _, adjacency, _ = get_muscle_sequences(config, is_train=False)
    
    # load the model
    model = STDynamicSystem(config, adjacency)
    model.cuda()
    model.load_checkpoint(model_path)
    
    # normalize the input
    x_std = np.asarray(config.DATASET.X_STD)
    x_mean = np.asarray(config.DATASET.X_MEAN)
    u_std = np.asarray(config.DATASET.U_STD)
    u_mean = np.asarray(config.DATASET.U_MEAN)
    initial_x = (initial_x - x_mean[:, np.newaxis]) / x_std[:, np.newaxis]
    subsequent_u = (subsequent_u - u_mean[:, np.newaxis, np.newaxis]) / u_std[:, np.newaxis, np.newaxis]
    
    initial_x = torch.from_numpy(initial_x.astype(np.float32))
    subsequent_u = torch.from_numpy(subsequent_u.astype(np.float32))
    
    pred = model.predict(initial_x, subsequent_u)
    pred = pred.detach().cpu().numpy()
    pred = pred * x_std[:, np.newaxis, np.newaxis] + x_mean[:, np.newaxis, np.newaxis]
    return pred

In [None]:
def inference_calcium(model_path, config_file, initial_x, subsequent_u):
    """Make inference base on the provided model, config file and data
    
    :param model_path: The tained model parameters file path.
    :param config_file: The config file of current task.
    :param initial_x: The data for inference.
    :param subsequent_u: The subsequent sequence for translation.
    """
    params = {
        'cfg': config_file
    }
    update_config(config, params)
    _, adjacency, _ = get_muscle_sequences(config, is_train=False)
    
    # load the model
    model = STDynamicSystem(config, adjacency)
    model.cuda()
    model.load_checkpoint(model_path)
    
    # normalize the input
    x_std = np.asarray(config.DATASET.X_STD)
    x_mean = np.asarray(config.DATASET.X_MEAN)
    u_std = np.asarray(config.DATASET.U_STD)
    u_mean = np.asarray(config.DATASET.U_MEAN)
    initial_x = (initial_x - x_mean[:, np.newaxis]) / x_std[:, np.newaxis]
    subsequent_u = (subsequent_u - u_mean[:, np.newaxis, np.newaxis]) / u_std[:, np.newaxis, np.newaxis]
    
    initial_x = torch.from_numpy(initial_x.astype(np.float32))
    subsequent_u = torch.from_numpy(subsequent_u.astype(np.float32))
    
    pred = model.predict(initial_x, subsequent_u)
    pred = pred.detach().cpu().numpy()
    pred = pred * x_std[:, np.newaxis, np.newaxis] + x_mean[:, np.newaxis, np.newaxis]
    return pred

In [8]:
def plot_results(pred_all, real_all, save_file, step=1, dpi=600):
    """Plot the result sequences of prediction and groundtruth
    
    :param pred_all: The predicted sequence [C, T, V]
    :param real_all: The groundtruth sequence [C, T, V]
    """
    time_steps = pred_all.shape[1]
    fig, axes = plt.subplots(nrows=2, ncols=time_steps,
                             figsize=(time_steps, 2), sharex=True)
    axes[0, 0].set_ylabel('Prediction')
    axes[1, 0].set_ylabel('Real')
    for i in range(0, time_steps):
        axes[0, i].set_xticks([])
        axes[0, i].set_yticks([])
        axes[1, i].set_xticks([])
        axes[1, i].set_yticks([])
    for i in range(0, time_steps, step):
        vis = VisualizeMuscle(pred_all[:, i].T)
        vis.show_motion(axes[0, i // step])
        vis_gt = VisualizeMuscle(real_all[:, i].T)
        vis_gt.show_motion(axes[1, i // step])
    fig.tight_layout()
    if save_file is not None:
        fig.savefig(save_file, dpi=dpi)

In [9]:
def to_cartesian_coords(data):
    """Transform the [length, width, angle, cx, cy] representation to cartesian coordinates
    
    :param data: [V, C]
    """
    polygons = data
    pose = []
    for i, p in enumerate(polygons):
        affine_transform = mtransforms.Affine2D()
        affine_transform.rotate_deg_around(x=p[3], y=p[4], degrees=p[2])
        xy = (p[3] - p[0]/2, p[4] - p[1]/2)
        width = p[0]
        length = p[1]
        rect = mpatches.Rectangle(xy, width, length, transform=affine_transform)
        coords = np.array([rect.get_xy(),
                           [rect.get_x() + rect.get_width(), rect.get_y()],
                           [rect.get_x() + rect.get_width(), rect.get_y() + rect.get_height()],
                           [rect.get_x(), rect.get_y() + rect.get_height()]])
        pose.append(coords)
    return np.array(pose)

def mae(p1, p2):
    """Calculate the pose mse based on their cartesian coordinates
    
    :param p1: [C, V]
    :param p2: [C, V]
    """
    p1_cartesian = to_cartesian_coords(p1)
    p2_cartesian = to_cartesian_coords(p2)
    # mean_p1 = np.mean(p1_cartesian, axis=(0, 1))
    # mean_p2 = np.mean(p2_cartesian, axis=(0, 1))
    # p1_cartesian -= mean_p1[np.newaxis, np.newaxis, :]
    # p2_cartesian -= mean_p2[np.newaxis, np.newaxis, :]
    mtx1, mtx2, disparity = procrustes(p1_cartesian.reshape(-1, 2), p2_cartesian.reshape(-1, 2))
    # disparity = np.mean(np.linalg.norm(p1_cartesian - p2_cartesian, axis=-1))
    return disparity

In [10]:
def plot_mae(pred_all, real_all, save_file):
    """Plot the mean average error between prediction and groundtruth behaviors
    
    :param pred_all: [N, T, V]
    :prarm real_all: [N, T, V]
    """
    sequential_errs = [mae(p1, p2) for (p1, p2) in zip(pred_all.transpose(1, 2, 0), real_all.transpose(1, 2, 0))]
    fig, ax = plt.subplots(1, 1, figsize=(10, 2))
    ax.plot(sequential_errs)
    ax.set_xlabel('Time Step')
    ax.set_ylabel('Pose Disparity')
    # fig.tight_layout()
    fig.savefig(save_file, bbox_inches="tight")
    return sequential_errs

In [12]:
def calculate_mae(pred_all, real_all, time_steps):
    """
    
    :param pred_all: [N, T, V]
    :param real_all: [N, T, V]
    """
    sequential_errs = np.asarray([mae(p1, p2) for (p1, p2) in zip(pred_all.transpose(1, 2, 0), real_all.transpose(1, 2, 0))])
    grouped_errs = sequential_errs.reshape(-1, time_steps)
    overall_errs = [geo_mean(a) for a in grouped_errs]
    return np.asarray(overall_errs)

In [17]:
# load data
x0, x, u = prepare_input('data/larva/muscle_sequenece/dorsal_single/dorsal-9/dorsal-9.xlsx',
                         in_range=(66, 130), reverse=True)
print(x0.shape, x.shape, u.shape)

(1, 38) (1, 64, 38) (5, 64, 38)


In [33]:
# perform inference
pred = inference(
    'outputs/STDynamicSystem/Muscle2Calcium/dorsal_64steps/checkpoint_latest.pth',
    'experiments/st_dynamic_system/muscle2calcium/dorsal_64steps.yaml',
    x0,
    u
)

In [34]:
def save_calcium_seq(pred_seq, real_seq, save_file=None):
    num_channels = pred_seq.shape[1]
    multi_index = pd.MultiIndex.from_product([['ground truth', 'prediction'],
                                              [str(i) for i in range(num_channels)]],
                                             names=['type', 'Muscle ID'])
    df = pd.DataFrame(data=np.concatenate((real_seq, pred_seq), axis=1),
                      columns=multi_index)
    if save_file is not None:
        df.to_excel(save_file)

In [35]:
save_calcium_seq(pred[0], x[0], 'dorsal_turn_calcium.xlsx')