In [2]:
%load_ext autoreload
%autoreload 2

In [1]:
from pathlib import Path
import random

import torch
import numpy as np
import pandas as pd
import matplotlib.pylab as plt
import pytorch_lightning as pl
from torchinfo import summary

from utils import plot_reach, plot_reconstruction_examples, plot_grid_z
from lin_ae_model_behavior import LinearVariationalAutoencoder

plt.rcParams['figure.dpi'] = 100

In [3]:
PATH_ROOT = Path('/Volumes/GoogleDrive/My Drive/NMA-22/naturalistic_arm_movements_ecog')
PATH_DATA = PATH_ROOT / 'data' / 'behavior_data'

SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)

DTYPE = torch.float
DEVICE = torch.device("mps")

DATASET = np.load(PATH_DATA / "reaches_scales.npy")
DATASET = np.swapaxes(DATASET, 2, 1)[:, :75, :]

PATH_ROI = PATH_ROOT / "data" / "Naturalistic reach ECoG tfrs ROI"
METADATA = pd.read_csv(PATH_ROI / "power-roi-all-patients-metadata.csv", index_col=0)

# check the dataset shape
assert DATASET.shape == (5984, 75, 2)

### Load trained model

In [10]:
chk_path = Path.cwd() / 'tb_logs' / 'LinearVariationalAutoencoder_n_latent=4_lr=0.005' / 'version_3' /'checkpoints' / 'last.ckpt'
model = LinearVariationalAutoencoder.load_from_checkpoint(chk_path, n_input=150, n_latent=4)


In [19]:
z_test = torch.tensor(np.array([[-1, 0, 0, 0]]), dtype=torch.float)

with torch.no_grad():
    sample = model.decoder(model.decoder_inp(z_test))

## Visualize latent space = 2

In [5]:
chk_path_2 = Path.cwd() / 'tb_logs' / 'LinearVariationalAutoencoder_n_latent=2_lr=0.005' / 'version_9' /'checkpoints' / 'last.ckpt'
model_2 = LinearVariationalAutoencoder.load_from_checkpoint(chk_path_2, n_input=150, n_latent=2)

In [None]:
plot_grid_z(model_2, n_latent=2, z_ids=(0, 1), n_ex=7, max_z=10)