In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from einops import rearrange
from scipy.signal import find_peaks
from entmax import entmax15

from models import LSTS

plt.rcParams.update({'font.size': 20, 'axes.linewidth': 3})

In [None]:
# Min-max normalization
def norm(bvp):
    return -1 + 2 * (bvp - bvp.min()) / (bvp.max() - bvp.min())

# Find the last crest as the reference point
def find_ref(bvp):
    bvp = (bvp-np.mean(bvp))/np.std(bvp)
    prominence = (1.5*np.std(bvp), None)
    peak = find_peaks(bvp, prominence=prominence)[0]
    return peak[-1]

@torch.no_grad()
def similarity(model, fn, data):    
    x = data['frames'].unsqueeze(0) / 255.
    y = data['waves'].numpy()
    ref = find_ref(y)
        
    rep = fn(model, x)
    sim = (F.cosine_similarity(rep[ref:ref+1], rep, dim=-1))
    pred = model.predict(x).reshape(-1)
    
    return sim, pred

def visualize(sim, y, pred):
    ref = find_ref(y)
    
    norm_y = norm(y)
    
    data = np.stack([sim, norm_y, norm(pred)], axis=-1)
    t = np.linspace(0, 6, 180, endpoint=False)

    df = pd.DataFrame(data, t, columns=['Similarity', 'Ground Truth', 'Prediction'])
    fig = plt.figure(figsize=(6, 4), dpi=200)
    ax = sns.lineplot(df, palette="tab10", linewidth=4, legend='brief')
    plt.xlabel('t(s)')
    plt.ylim([-1.5, 1.5])
    plt.xlim(-0.5, 8)
    plt.xticks([0, 2, 4, 6])
    sns.move_legend(ax, "lower right")
    plt.tight_layout()
    plt.show()

# Get the hidden representation before the final layer
def lsts_fn(lsts, x):
    x = lsts.preprocess(x)     
    x = lsts.patch_embed(x)                  
    x = lsts.pos_drop(lsts.pos_embed(x))
    x = lsts.layers(x)
    x = rearrange(x, 'n c d h w -> n d h w c')
    x = lsts.norm(x)
    x = rearrange(x, 'n d h w c -> n d (h w) c')
    
    attn_score = lsts.out_pooling(x)
    attn_score = entmax15(attn_score, dim=2)
    x = (x * attn_score).sum(dim=2)
    x = x.squeeze(0)
    return x

def show_data(data):
    frames = data['frames'].permute(0, 2, 3, 1).numpy()
    fig = plt.figure(figsize=(6, 4), dpi=800)
    for i in range(6):
        ax = fig.add_subplot(1, 6, i+1)
        img = frames[i*30]
        ax.imshow(img)
        ax.set_axis_off()
        ax.set_title(f't={i}s')
    plt.tight_layout()
    plt.show()

In [None]:
lsts = LSTS()
lsts.load_state_dict(torch.load('weights.pt'))
lsts.eval()
data = torch.load(R'data.pt')
sim, pred = similarity(lsts, lsts_fn, data)
y = data['waves'].numpy()

In [None]:
show_data(data)

In [None]:
visualize(sim, y, pred)