In [1]:
%load_ext autoreload
%autoreload 2

In [18]:
# jupyter lab --NotebookApp.iopub_data_rate_limit=1.0e20
import warnings
warnings.filterwarnings('ignore')

## Load data

In [564]:
from pathlib import Path

import matplotlib.pylab as plt
import numpy as np
import pytorch_lightning as pl
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torchinfo import summary
from pytorch_lightning.loggers import TensorBoardLogger

from lstm_model_behavior import RecurrentAutoencoder
from lin_ae_model_behavior import LinearAutoencoder

plt.rcParams['figure.dpi'] = 100

In [4]:
dataset_full = np.load(PATH_DATA / "reaches.npy")
k_mean_labels = np.load(PATH_DATA / "k_mean_labels.npy")
dataset_full.shape

(5984, 2, 75)

## Prepare data

In [475]:
dataset = dataset_full.reshape((dataset_full.shape[0], dataset_full.shape[2], dataset_full.shape[1]))
# X_train, X_test = train_test_split(dataset_full[k_mean_labels == 2], test_size=0.04, random_state=SEED)
# X_train, X_test = train_test_split(dataset_full, test_size=0.04, random_state=SEED)
k_mean_mask = k_mean_labels == 2
# k_mean_mask = (k_mean_labels == 2) | (k_mean_labels == 8) | (k_mean_labels == 6) | (k_mean_labels == 0) 
# k_mean_mask = k_mean_labels > 0
X_train, X_test = train_test_split(dataset_full[k_mean_mask], test_size=0.04, random_state=SEED)

X_train, X_val = train_test_split(X_train, test_size=0.33, random_state=SEED)

X_train = torch.tensor(X_train, device=DEVICE, dtype=DTYPE)
X_test = torch.tensor(X_test, device=DEVICE, dtype=DTYPE)
X_val = torch.tensor(X_val, device=DEVICE, dtype=DTYPE)

# train_data = TensorDataset(X_train)
# test_data = TensorDataset(X_test)
X_train.shape

torch.Size([678, 2, 75])

## Load model

In [503]:
rae_test = RecurrentAutoencoder(75, 2, 32, 10)
summary(rae_test, (1, 75, 2),
        col_names=["input_size", "output_size", "num_params",
                  ]) # "kernel_size", "mult_adds", "trainable"]

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
RecurrentAutoencoder                     [1, 75, 2]                [1, 75, 2]                --
├─Encoder: 1-1                           [1, 75, 2]                [1, 32]                   --
│    └─LSTM: 2-1                         [1, 75, 2]                [1, 75, 64]               316,928
│    └─LSTM: 2-2                         [1, 75, 64]               [1, 75, 32]               12,544
├─Decoder: 1-2                           [1, 32]                   [1, 75, 2]                --
│    └─LSTM: 2-3                         [1, 75, 32]               [1, 75, 32]               8,448
│    └─LSTM: 2-4                         [1, 75, 32]               [1, 75, 64]               324,608
│    └─Linear: 2-5                       [1, 75, 64]               [1, 75, 2]                130
Total params: 662,658
Trainable params: 662,658
Non-trainable params: 0
Total mult-adds (M): 49.69
Input size (MB

## Train model

In [504]:
#@title Init model
pl.seed_everything(42)
_, n_times, n_features = X_train.shape
K = 32
n_layers = 10

# Initialize model and Trainer
rae = RecurrentAutoencoder(n_times, n_features, K, n_layers)
rae.lr = 1e-2

pl.utilities.model_summary.summarize(rae, max_depth=-1)

Global seed set to 42


  | Name                 | Type    | Params
-------------------------------------------------
0 | encoder              | Encoder | 348 K 
1 | encoder.rnn1         | LSTM    | 335 K 
2 | encoder.rnn2         | LSTM    | 12.5 K
3 | decoder              | Decoder | 337 K 
4 | decoder.rnn1         | LSTM    | 8.4 K 
5 | decoder.rnn2         | LSTM    | 324 K 
6 | decoder.output_layer | Linear  | 4.9 K 
-------------------------------------------------
686 K     Trainable params
0         Non-trainable params
686 K     Total params
2.744     Total estimated model params size (MB)

In [505]:
n_epochs = 100
batch_size = 1

logger = TensorBoardLogger(
    "tb_logs",
    name=f"K_{K}_lr_{rae.lr}_bs_{batch_size}_nl_{n_layers}_grad_acc_T_LR_sch_T_class_2")


trainer = pl.Trainer(
    max_epochs=n_epochs, 
    accelerator='mps',
    logger=logger,
    accumulate_grad_batches={0: 8, 4: 4, 8: 1}
    )

# Perform training
trainer.fit(rae,
            DataLoader(
                X_train,
                batch_size=batch_size,
                shuffle=True),
            DataLoader(
                X_val,
                batch_size=batch_size,
                shuffle=False),
            )

# Perform evaluation
trainer.test(rae, DataLoader(
    X_test,  # mock_data,
    shuffle=False))

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: tb_logs/K_32_lr_0.01_bs_1_nl_10_grad_acc_T_LR_sch_T_class_2

  | Name    | Type    | Params
------------------------------------
0 | encoder | Encoder | 348 K 
1 | decoder | Decoder | 337 K 
------------------------------------
686 K     Trainable params
0         Non-trainable params
686 K     Total params
2.744     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Testing: 0it [00:00, ?it/s]

─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           1.7157615423202515
─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 1.7157615423202515}]

In [None]:
n_plot = 7
plt.figure(figsize=(20, 10))
# dataset_all = X_train.to('cpu')
dataset_all = torch.tensor(dataset[k_mean_mask], device='cpu') # [k_mean_labels == 8]

def plot_reaches_(x, y, duration=0):
    # x = np.sqrt(np.abs(x)) * np.sign(x)
    # y = np.sqrt(np.abs(y)) * np.sign(y)
    plt.plot(x, y, '-', alpha = 0.5)
    plt.scatter(x, y, c=np.arange(75))
    if duration != 0:
        plt.scatter(x[duration], y[duration], c = "r")
    # plt.xlim(-0.15, 0.15)
    # plt.ylim(-0.15, 0.15)

for i in range(n_plot):
    idx = torch.randint(len(dataset_all), size=())
    data_ = dataset_all[idx]
    with torch.no_grad():
      # Get reconstructed movements from autoencoder
      rae_recon = rae(data_.unsqueeze(0).to('cpu').float())[0]

    plt.subplot(2, n_plot, i+1)
    plot_reaches_(data_[:, 0], data_[:, 1])
    if i == 0:
        plt.ylabel('Original\nMovements')


    plt.subplot(2, n_plot, i + 1 + n_plot)
    plot_reaches_(rae_recon[0, :], rae_recon[1, :])
    if i == 0:
        plt.ylabel(f'LSTM AE\n(K={K})')

plt.show()

## Explore embedding space

SEE: https://www.kaggle.com/code/rohitgr/autoencoders-tsne/notebook

In [None]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import pandas as pd

In [None]:
# download more params
PATH_ROI = PATH_ROOT / "data" / "Naturalistic reach ECoG tfrs ROI"
data = pd.read_csv(PATH_ROI / "power-roi-all-patients-metadata.csv", index_col=0)
data.head()

In [None]:
with torch.no_grad():
    z_x_train = rae.encoder(torch.tensor(dataset[k_mean_mask], device='cpu', dtype=DTYPE))
    z_x_train = z_x_train.detach().numpy()

In [None]:
dataset[k_mean_mask].shape

In [None]:
X_TSNE = TSNE(n_components=2, random_state=SEED).fit_transform(z_x_train)
X_PCA = PCA(n_components=2, random_state=SEED).fit_transform(z_x_train)

In [None]:
n_clusters = 3
kmeans = KMeans(n_clusters=n_clusters)
Y_TSNE = kmeans.fit_predict(X_TSNE)
plt.scatter(X_TSNE[:, 0], X_TSNE[:, 1], c = Y_TSNE)
plt.colorbar()

In [None]:
for par in ['reach_a', 'onset_velocity', 'reach_r', 'reach_duration']:
    plt.figure(figsize=(20, 10))
    plt.scatter(X_TSNE[:, 0], X_TSNE[:, 1], c = data.loc[k_mean_mask, par])
    plt.title(par)
    plt.colorbar()
    plt.show()

In [None]:
n_plot = 7
n_clust = n_clusters
plt.figure(figsize=(15, 10))

def plot_reaches_(x, y, duration=0):
    plt.plot(x, y, '-', alpha = 0.5)
    plt.scatter(x, y, c=np.arange(75))
    if duration != 0:
        plt.scatter(x[duration], y[duration], c = "r")
    plt.xlim(-0.05, 0.05)
    plt.ylim(-0.05, 0.05)

n = 0
for k in range(n_clust):
    k_data = dataset[k_mean_mask][Y_TSNE == k]
    for i in range(n_plot):
        idx = torch.randint(len(k_data), size=())
        k_data_ex = k_data[idx]
        plt.subplot(n_clust, n_plot, n + 1)
        plot_reaches_(k_data_ex[:, 0], k_data_ex[:, 1])
        if i == 0:
            plt.ylabel(f'Class {k}')
        n += 1
        plt.gca().set_aspect('equal')
        plt.xticks([])
        plt.yticks([])
plt.tight_layout()

plt.show()

In [None]:
n_plot = 3
n_clust = n_clusters
plt.figure(figsize=(20, 6))

def plot_reaches_(x, y, duration=0):
    plt.plot(x, y, '-', alpha = 0.5)
    plt.scatter(x, y, c=np.arange(75))
    if duration != 0:
        plt.scatter(x[duration], y[duration], c = "r")
    plt.xlim(-0.1, 0.1)
    plt.ylim(-0.1, 0.1)

for k in range(n_clust):
    k_data = dataset[k_mean_mask][Y_TSNE == k]
    plt.subplot(1, n_clust, k + 1)
    for i in range(n_plot):
        idx = torch.randint(len(k_data), size=())
        k_data_ex = k_data[idx]
        
        plot_reaches_(k_data_ex[:, 0], k_data_ex[:, 1])
        if i == 0:
            plt.title(f'Class {k}')
    plt.gca().set_aspect('equal')
    plt.xticks([])
    plt.yticks([])
plt.show()

In [None]:
n_plot = 3
n_clust = n_clusters
plt.figure(figsize=(20, 6))

def plot_reaches_(x, y, duration=0):
    plt.plot(x, y, '-', alpha = 0.5)
    plt.scatter(x, y, c=np.arange(75))
    if duration != 0:
        plt.scatter(x[duration], y[duration], c = "r")
    # plt.xlim(-0.1, 0.1)
    # plt.ylim(-0.1, 0.1)

for k in range(n_clust):
    k_data = dataset[k_mean_mask][Y_TSNE == k]
    plt.subplot(1, n_clust, k + 1)
    for i in range(n_plot):
        idx = torch.randint(len(k_data), size=())
        k_data_ex = k_data[idx]
        
        plot_reaches_(k_data_ex[:, 0], k_data_ex[:, 1])
        if i == 0:
            plt.title(f'Class {k}')
    plt.gca().set_aspect('equal')
    plt.xticks([])
    plt.yticks([])
plt.show()

### PCA

In [None]:
kmeans = KMeans(n_clusters=6)
Y_PCA = kmeans.fit_predict(X_PCA)
plt.scatter(X_PCA[:, 0], X_PCA[:, 1], c = Y_PCA)

In [None]:
n_plot = 7
n_clust = 6
plt.figure(figsize=(20, 10))

def plot_reaches_(x, y, duration=0):
    plt.plot(x, y, '-', alpha = 0.5)
    plt.scatter(x, y, c=np.arange(75))
    if duration != 0:
        plt.scatter(x[duration], y[duration], c = "r")
    plt.xlim(-0.05, 0.05)
    plt.ylim(-0.05, 0.05)

n = 0
for k in range(n_clust):
    k_data = X_train.to('cpu')[Y_PCA == k]
    for i in range(n_plot):
        idx = torch.randint(len(k_data), size=())
        k_data_ex = k_data[idx]
        plt.subplot(n_clust, n_plot, n + 1)
        plot_reaches_(k_data_ex[0, :], k_data_ex[1, :])
        if i == 0:
            plt.ylabel(f'Class {k}')
        n += 1
plt.show()

## Custom loss 

In [None]:
plt.imshow(torch.cdist(data_.double(), rae_recon.T.double()))
plt.colorbar()

In [None]:
rae_recon.T.shape

In [None]:
for (d1, d2) in zip(data_.unsqueeze(0), rae_recon.T.unsqueeze(0)):
    cs = torch.nn.functional.cosine_similarity(d1, d2)
l2 = torch.nn.functional.mse_loss(data_, rae_recon.T, reduction='none')
dist = torch.norm(data_ - rae_recon.T)
# plt.axhline(y=dist, color='r', linestyle='-')
# plt.axhline(y=torch.nn.functional.mse_loss(data_, rae_recon.T, reduction='sum'), color='g', linestyle='-')
# plt.axhline(y=torch.nn.functional.l1_loss(data_, rae_recon.T, reduction='sum'), color='yellow', linestyle='-')
plt.plot(cs)
plt.plot(l2)

In [None]:
torch.nn.functional.l1_loss(data_, rae_recon.T, reduction='none').max()

In [None]:
torch.nn.functional.l1_loss(data_, rae_recon.T, reduction='none').max(dim=0)

In [None]:
torch.nn.functional.l1_loss(data_, rae_recon.T, reduction='mean')

In [None]:
torch.nn.functional.l1_loss?

In [None]:
1 - cs.mean()

In [None]:
data_.unsqueeze(0).shape

In [None]:
for (d1, d2) in zip(data_.unsqueeze(0), rae_recon.T.unsqueeze(0)):
    cs = torch.nn.functional.cosine_similarity(d1, d2).max()
    print(cs)

In [None]:
for (d1, d2) in zip(data_.unsqueeze(0), rae_recon.T.unsqueeze(0)):
    cs = torch.nn.functional.cosine_similarity(d1, d1).mean()
    print(cs)

In [None]:
torch.nn.functional.cosine_similarity(data_, -data_)

In [None]:
data_.shape

In [None]:
# SEE: https://stackoverflow.com/questions/66139651/which-loss-function-calculates-the-distance-between-two-contours
def contour_divergence(c1, c2, func = lambda x: x**2):
    c1 = torch.atleast_3d(c1);
    c2 = torch.atleast_3d(c2);
    f = func(torch.amin(torch.cdist(c1, c2), dim=2));
    # this computes the length of each segment connecting two consecutive points
    df = torch.sum((c1[:, 1:, :] - c1[:, :-1, :])**2, axis=2)**0.5;
    # here is the trapesoid rule
    return torch.sum((f[:, :-1] + f[:, 1:]) * df[:, :], axis=1) / 4.0;

def contour_dist(c1, c2, func = lambda x: x**2):
    return contour_divergence(c1, c2, func) + contour_divergence(c2, c1, func)


In [None]:
contour_dist(data_.double(), rae_recon.T.double()).shape

In [None]:
for (d1, d2) in zip(data_.unsqueeze(0), rae_recon.T.unsqueeze(0)):
    cs = torch.nn.functional.cosine_similarity(d1, d2)
l2 = torch.nn.functional.mse_loss(data_, rae_recon.T, reduction='none')
cd = contour_dist(data_.double(), rae_recon.T.double(), lambda x: x * 1000)
# plt.plot(cs)
plt.plot(cd)
# plt.plot(l2)

## Linear AE

In [None]:
lae_test = LinearAutoencoder(75 * 2, 16)
summary(lae_test, (1, 75 * 2),
        col_names=["input_size", "output_size", "num_params",]) 

In [None]:
pl.seed_everything(42)
_, n_times, n_features = X_train.shape
n_embedding = 48

# Initialize model and Trainer
lae = LinearAutoencoder(75 * 2, n_embedding)
lae.lr = 1e-2

pl.utilities.model_summary.summarize(lae, max_depth=-1)

In [None]:
# def plot_reaches_(x, y, duration=0):
#     plt.plot(x, y, '-', alpha = 0.5)
#     plt.scatter(x, y, c=np.arange(75))
#     if duration != 0:
#         plt.scatter(x[duration], y[duration], c = "r")
#     # plt.xlim(-0.15, 0.15)
#     # plt.ylim(-0.15, 0.15)
# i = 59
# one_sample_data = X_train[i:i+1, :, :]
# plot_reaches_(one_sample_data[0, 0, :].to('cpu'), one_sample_data[0, 1, :].to('cpu'))

In [None]:
n_epochs = 100
batch_size = 1

logger = TensorBoardLogger(
    "tb_logs",
    name=f"LAE_lr_{lae.lr}_bs_{batch_size}_n_emb_{n_embedding}_grad_acc_T_LR_sch_T_class_2")


trainer = pl.Trainer(
    max_epochs=n_epochs, 
    accelerator='mps', 
    logger=logger,
    accumulate_grad_batches={0: 8, 4: 4, 8: 1}
    )

# Perform training
trainer.fit(lae,
            DataLoader(
                X_train.view(X_train.size(0), -1), # one_sample_data.view(one_sample_data.size(0), -1), 
                batch_size=batch_size,
                shuffle=True),
            DataLoader(
                X_val.view(X_val.size(0), -1), # one_sample_data.view(one_sample_data.size(0), -1), 
                batch_size=batch_size,
                shuffle=False),
            )

# Perform evaluation
trainer.test(lae, DataLoader(
    X_test.view(X_test.size(0), -1), # one_sample_data.view(one_sample_data.size(0), -1),
    shuffle=False))

In [None]:
n_plot = 7
plt.figure(figsize=(20, 10))
# dataset_all = X_train.to('cpu')
# dataset_all = torch.tensor(dataset[k_mean_labels == 8], device='cpu') # [k_mean_labels == 8]
# dataset_all = one_sample_data.to('cpu')
dataset_all = torch.tensor(dataset[k_mean_mask], device='cpu')

def plot_reaches_(x, y, duration=0):
    # x = np.sqrt(np.abs(x)) * np.sign(x)
    # y = np.sqrt(np.abs(y)) * np.sign(y)
    plt.plot(x, y, '-', alpha = 0.5)
    plt.scatter(x, y, c=np.arange(75))
    if duration != 0:
        plt.scatter(x[duration], y[duration], c = "r")
    # plt.xlim(-0.15, 0.15)
    # plt.ylim(-0.15, 0.15)

for i in range(n_plot):
    idx = torch.randint(len(dataset_all), size=())
    data_ = dataset_all[idx]
    with torch.no_grad():
      # Get reconstructed movements from autoencoder
      lae_recon = lae(data_.unsqueeze(0).view(1, -1).to('cpu').float())[0]

    plt.subplot(2, n_plot, i+1)
    plot_reaches_(data_.T[:, 0], data_.T[:, 1])
    if i == 0:
        plt.ylabel('Original\nMovements')


    plt.subplot(2, n_plot, i + 1 + n_plot)
    plot_reaches_(lae_recon.reshape((2, 75))[0, :], lae_recon.reshape((2, 75))[1, :])
    if i == 0:
        plt.ylabel(f'Linear AE\n(K={K})')

plt.show()