In [None]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

from moment.utils.config import Config
from moment.utils.utils import parse_config
from moment.data.generate_synthetic_data import SyntheticDataset
from moment.models.base import BaseModel
from moment.models.moment import MOMENT

In [None]:
DIMENSION_REDUCTION_METHOD = 'tsne' # 'tsne' 'pca'

In [None]:
def embed_timeseries_in_manifold(model: torch.nn.Module, 
                                 y: npt.NDArray, 
                                 device: torch.device, 
                                 input_mask: npt.NDArray = None,
                                 dimension_reduction_method: str = 'tsne'):
    y = y.to(device)
    n_samples, _, seq_len = y.shape
    model = model.to(device)
    
    if input_mask is None:
        input_mask = torch.ones((n_samples, seq_len)).to(device)
   
    model.eval()
    embeddings_manifold = []
    
    with torch.no_grad():
        outputs = model.embed(x_enc=y, input_mask=input_mask, reduction='mean')
    embeddings = outputs.embeddings.detach().cpu().numpy()

    if dimension_reduction_method == 'tsne':
        embeddings_manifold = TSNE(n_components=2, n_jobs=5).fit_transform(embeddings)
    elif dimension_reduction_method == 'pca':
        embeddings_manifold = PCA(n_components=2).fit_transform(embeddings)

    # Move tensors and models back to CPU
    y = y.detach().cpu().numpy()
    model = model.cpu()
    input_mask = input_mask.detach().cpu().numpy()

    return embeddings, embeddings_manifold

def save_experiment_artifacts(filename : str, 
                              embeddings : npt.NDArray, 
                              y : npt.NDArray, 
                              c : npt.NDArray, 
                              embeddings_manifold : npt.NDArray):
    # Save the data and embeddings
    # "../../assets/results/interpretability/frequency_artifacts.npz"
    with open(filename, "wb") as f:
        np.savez(f, embeddings=embeddings, y=y, c=c, 
                 embeddings_manifold=embeddings_manifold)

### Defaults

In [None]:
DEFAULT_CONFIG_PATH = "../../configs/default.yaml"
GPU_ID = 0
run_name = "fearless-planet-52" 

### Parse config, build model and load pre-trained weights

In [None]:
# with open('/home/extra_scratch/XXXX-2/moment_checkpoints/avid-moon-55/MOMENT_checkpoint_5000.pth', 'rb') as f:
#     checkpoint = torch.load(f)
checkpoint = BaseModel.load_pretrained_weights(run_name=run_name, opt_steps=None)

config = Config(config_file_path=DEFAULT_CONFIG_PATH, default_config_file_path=DEFAULT_CONFIG_PATH).parse()
config['device'] = GPU_ID if torch.cuda.is_available() else 'cpu'

args = parse_config(config)
model = MOMENT(configs=args)
model.load_state_dict(checkpoint["model_state_dict"])

### Frequency

In [None]:
synthetic_dataset = SyntheticDataset(n_samples=1024, freq=1, freq_range=(1, 32), 
                                     noise_mean=0., noise_std=0.1, random_seed=13)

y, c = synthetic_dataset.gen_sinusoids_with_varying_freq()
n_samples = synthetic_dataset.n_samples
seq_len = synthetic_dataset.seq_len

In [None]:
# Visualize the data
fig, axs = plt.subplots(1, 5, figsize=(30, 6), sharey=True)
axs.flatten()
for i, idx in enumerate(np.arange(0, n_samples+1, n_samples//4-1)):
    axs[i].plot(y[idx].squeeze().numpy())
    axs[i].set_xticks(
        ticks=np.arange(0, seq_len+1, 128), 
        labels=np.arange(0, seq_len+1, 128), 
        fontdict={"fontsize" : 16})
    axs[i].set_title("Frequency: {:.2f}".format(c[:, 0][idx].squeeze().numpy(), ), fontsize=16)
axs[0].set_yticks(
        ticks=np.arange(-1.5, 1.5, 0.5), 
        labels=np.arange(-1.5, 1.5, 0.5),
        fontdict={"fontsize" : 16})
plt.savefig("../../assets/figures/interpretability/frequency_timeseries.pdf", bbox_inches='tight') 
plt.show()

In [None]:
embeddings, embeddings_manifold = embed_timeseries_in_manifold(
    model=model, y=y, device=args.device, dimension_reduction_method=DIMENSION_REDUCTION_METHOD)

# save_experiment_artifacts(filename=f"../../assets/results/interpretability/frequency_artifacts_{DIMENSION_REDUCTION_METHOD}.npz", 
#     embeddings=embeddings, y=y, c=c, embeddings_manifold=embeddings_manifold)

In [None]:
plt.title(f"$y = \sin(2c \pi x) + \epsilon$", fontsize=20)
plt.scatter(embeddings_manifold[:, 0], 
            embeddings_manifold[:, 1], c=c[:, 0].squeeze().numpy(), cmap='magma')
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.colorbar(boundaries=np.arange(
    synthetic_dataset.freq_range[0], synthetic_dataset.freq_range[1]+1, 1))
plt.savefig(f"../../assets/figures/interpretability/frequency_artifacts_{DIMENSION_REDUCTION_METHOD}.pdf", 
    bbox_inches='tight') 
plt.show()

### Amplitude

In [None]:
synthetic_dataset = SyntheticDataset(
    n_samples=2048, seq_len=512, freq=16, amplitude_range=(1/4, 4), 
    noise_mean=0., noise_std=0.1, random_seed=13)

y, c = synthetic_dataset.gen_sinusoids_with_varying_amplitude()
n_samples = synthetic_dataset.n_samples
seq_len = synthetic_dataset.seq_len

In [None]:
# Visualize the data
fig, axs = plt.subplots(1, 5, figsize=(30, 6), sharey=True)
axs.flatten()
for i, idx in enumerate(np.arange(0, n_samples+1, n_samples//4-1)):
    axs[i].plot(y[idx].squeeze().numpy())
    axs[i].set_xticks(
        ticks=np.arange(0, seq_len+1, 128), 
        labels=np.arange(0, seq_len+1, 128), 
        fontdict={"fontsize" : 16})
    axs[i].set_title("Amplitude: {:.2f}".format(c[:, 0][idx].squeeze().numpy(), ), fontsize=16)
axs[0].set_yticks(
        ticks=np.arange(-16, 16, 10), 
        labels=np.arange(-16, 16, 10),
        fontdict={"fontsize" : 16})
plt.savefig("../../assets/figures/interpretability/amplitude_timeseries.pdf", bbox_inches='tight') 
plt.show()

In [None]:
embeddings, embeddings_manifold = embed_timeseries_in_manifold(
    model=model, y=y, device=args.device, dimension_reduction_method=DIMENSION_REDUCTION_METHOD)

# save_experiment_artifacts(filename=f"../../assets/results/interpretability/amplitude_artifacts_{DIMENSION_REDUCTION_METHOD}.npz", 
#     embeddings=embeddings, y=y, c=c, embeddings_manifold=embeddings_manifold)

In [None]:
plt.title(f"$y = c*\sin(32\pi x) + \epsilon$", fontsize=20)
plt.scatter(embeddings_manifold[:, 0], 
            embeddings_manifold[:, 1], c=c[:, 0].squeeze().numpy(), cmap='magma')
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.colorbar(boundaries=np.arange(synthetic_dataset.amplitude_range[0], synthetic_dataset.amplitude_range[1]+1, 1))
plt.savefig(f"../../assets/figures/interpretability/amplitude_artifacts_{DIMENSION_REDUCTION_METHOD}.pdf", 
    bbox_inches='tight') 
plt.show()

### Trend

In [None]:
synthetic_dataset = SyntheticDataset(n_samples=2048, freq=16, trend_range=(1/8, 8), 
                                     noise_mean=0., noise_std=0.1, random_seed=13)

y, c = synthetic_dataset.gen_sinusoids_with_varying_trend()
_, t = synthetic_dataset._generate_x()
trend = t**c
n_samples = synthetic_dataset.n_samples
seq_len = synthetic_dataset.seq_len

In [None]:
# Visualize the trends
for i in range(0, len(trend), 32):
    plt.plot(trend[i].squeeze().numpy())

In [None]:
# Visualize the data
fig, axs = plt.subplots(1, 5, figsize=(30, 6), sharey=True)
axs.flatten()
for i, idx in enumerate(np.arange(0, n_samples+1, n_samples//4-1)):
    axs[i].plot(y[idx].squeeze().numpy())
    axs[i].plot(trend[idx].squeeze().numpy())
    axs[i].set_xticks(
        ticks=np.arange(0, seq_len+1, 128), 
        labels=np.arange(0, seq_len+1, 128), 
        fontdict={"fontsize" : 16})
    axs[i].set_title("Trend: {:.2f}".format(c[:, 0][idx].squeeze().numpy(), ), fontsize=16)
axs[0].set_yticks(
        ticks=np.arange(-2, 2, 10), 
        labels=np.arange(-2, 2, 10),
        fontdict={"fontsize" : 16})
plt.savefig("../../assets/figures/interpretability/trend_timeseries.pdf", bbox_inches='tight') 
plt.show()

In [None]:
embeddings, embeddings_manifold = embed_timeseries_in_manifold(
    model=model, y=y, device=args.device, dimension_reduction_method=DIMENSION_REDUCTION_METHOD)

# save_experiment_artifacts(filename=f"../../assets/results/interpretability/trend_artifacts_{DIMENSION_REDUCTION_METHOD}.npz", 
#     embeddings=embeddings, y=y, c=c, embeddings_manifold=embeddings_manifold)

In [None]:
plt.title(f"$y = x^c + \sin(32\pi x) + \epsilon$", fontsize=20)
plt.scatter(embeddings_manifold[:, 0], 
            embeddings_manifold[:, 1], c=c[:, 0].squeeze().numpy(), cmap='magma')
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.colorbar(boundaries=np.arange(synthetic_dataset.trend_range[0], synthetic_dataset.trend_range[1]+1, 1))
plt.savefig(f"../../assets/figures/interpretability/trend_artifacts_{DIMENSION_REDUCTION_METHOD}.pdf", bbox_inches='tight') 
plt.show()

### Baseline shift

In [None]:
synthetic_dataset = SyntheticDataset(n_samples=2048, freq=16, baseline_range=(-2, 2), 
                                     noise_mean=0., noise_std=0.1, random_seed=13)

y, c = synthetic_dataset.gen_sinusoids_with_varying_baseline()
n_samples = synthetic_dataset.n_samples
seq_len = synthetic_dataset.seq_len

In [None]:
# Visualize the data
fig, axs = plt.subplots(1, 5, figsize=(30, 6), sharey=True)
axs.flatten()
for i, idx in enumerate(np.arange(0, n_samples+1, n_samples//4-1)):
    axs[i].plot(y[idx].squeeze().numpy())
    axs[i].plot(c[idx].squeeze().numpy())
    axs[i].set_xticks(
        ticks=np.arange(0, seq_len+1, 128), 
        labels=np.arange(0, seq_len+1, 128), 
        fontdict={"fontsize" : 16})
    axs[i].set_title("Baseline: {:.2f}".format(c[:, 0][idx].squeeze().numpy(), ), fontsize=16)
axs[0].set_yticks(
        ticks=np.arange(-2, 2, 10), 
        labels=np.arange(-2, 2, 10),
        fontdict={"fontsize" : 16})
plt.savefig("../../assets/figures/interpretability/baseline_timeseries.pdf", bbox_inches='tight') 
plt.show()

In [None]:
embeddings, embeddings_manifold = embed_timeseries_in_manifold(
    model=model, y=y, device=args.device, dimension_reduction_method=DIMENSION_REDUCTION_METHOD)

# save_experiment_artifacts(filename=f"../../assets/results/interpretability/trend_artifacts_{DIMENSION_REDUCTION_METHOD}.npz", 
#     embeddings=embeddings, y=y, c=c, embeddings_manifold=embeddings_manifold)

In [None]:
plt.title(f"$y = c + \sin(32\pi x) + \epsilon$", fontsize=20)
plt.scatter(embeddings_manifold[:, 0], 
            embeddings_manifold[:, 1], c=c[:, 0].squeeze().numpy(), cmap='magma')
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.colorbar(boundaries=np.arange(synthetic_dataset.baseline_range[0], synthetic_dataset.baseline_range[1]+1, 1))
plt.savefig(f"../../assets/figures/interpretability/baseline_artifacts_{DIMENSION_REDUCTION_METHOD}.pdf", bbox_inches='tight') 
plt.show()

### Auto-correlation

In [None]:
synthetic_dataset = SyntheticDataset(n_samples=512, freq=1, baseline_range=(-2, 2), 
                                     noise_mean=0., noise_std=0.1, random_seed=13)
y_1, c_1 = synthetic_dataset.gen_sinusoids_with_varying_correlation()

synthetic_dataset = SyntheticDataset(n_samples=512, freq=2, baseline_range=(-2, 2), 
                                     noise_mean=0., noise_std=0.1, random_seed=13)
y_2, c_2 = synthetic_dataset.gen_sinusoids_with_varying_correlation()

synthetic_dataset = SyntheticDataset(n_samples=512, freq=3, baseline_range=(-2, 2), 
                                     noise_mean=0., noise_std=0.1, random_seed=13)
y_3, c_3 = synthetic_dataset.gen_sinusoids_with_varying_correlation()

synthetic_dataset = SyntheticDataset(n_samples=512, freq=5, baseline_range=(-2, 2), 
                                     noise_mean=0., noise_std=0.1, random_seed=13)
y_4, c_4 = synthetic_dataset.gen_sinusoids_with_varying_correlation()

n_samples = 4*synthetic_dataset.n_samples
seq_len = synthetic_dataset.seq_len

y = torch.cat([y_1, y_2, y_3, y_4], dim=0)
c = torch.cat([c_1, c_2, c_3, c_4], dim=0)

In [None]:
# Visualize the data
fig, axs = plt.subplots(1, 5, figsize=(30, 6), sharey=True)
axs.flatten()
for i, idx in enumerate(np.arange(0, n_samples+1, n_samples//4-1)):
    axs[i].plot(y[idx].squeeze().numpy())
    axs[i].set_xticks(
        ticks=np.arange(0, seq_len+1, 128), 
        labels=np.arange(0, seq_len+1, 128), 
        fontdict={"fontsize" : 16})
    axs[i].set_title("Offset: {:.2f}".format(c[:, 0][idx].squeeze().numpy(), ), fontsize=16)
axs[0].set_yticks(
        ticks=np.arange(-2, 2, 10), 
        labels=np.arange(-2, 2, 10),
        fontdict={"fontsize" : 16})
plt.savefig(f"../../assets/figures/interpretability/correlation_timeseries.pdf", bbox_inches='tight') 
plt.show()

In [None]:
embeddings, embeddings_manifold = embed_timeseries_in_manifold(
    model=model, y=y, device=args.device, dimension_reduction_method=DIMENSION_REDUCTION_METHOD)

# save_experiment_artifacts(filename=f"../../assets/results/interpretability/correlation_artifacts_{DIMENSION_REDUCTION_METHOD}.npz", 
#     embeddings=embeddings, y=y, c=c, embeddings_manifold=embeddings_manifold)

In [None]:
plt.title(f"$y = \sin(2\pi f x + c) + \epsilon, \ c \in [0, 2\pi]$", fontsize=20)

wave_groups = {
    "1" : slice(0, 512),
    "2" : slice(512, 1024),
    "3" : slice(1024, 1536),
    "5" : slice(1536, 2048)
}

plt.scatter(embeddings_manifold[wave_groups["1"], 0], 
            embeddings_manifold[wave_groups["1"], 1], 
            c=c[wave_groups["1"], 0].squeeze().numpy(), cmap='magma', marker='o')
plt.scatter(embeddings_manifold[wave_groups["2"], 0], 
            embeddings_manifold[wave_groups["2"], 1], 
            c=c[wave_groups["2"], 0].squeeze().numpy(), cmap='magma', marker='x')
plt.scatter(embeddings_manifold[wave_groups["3"], 0], 
            embeddings_manifold[wave_groups["3"], 1], 
            c=c[wave_groups["3"], 0].squeeze().numpy(), cmap='magma', marker='^')
plt.scatter(embeddings_manifold[wave_groups["5"], 0], 
            embeddings_manifold[wave_groups["5"], 1], 
            c=c[wave_groups["5"], 0].squeeze().numpy(), cmap='magma', marker='*')

# TSNE
plt.text(18, -30, "$f=1$", fontsize=16, color='darkred')
plt.text(18, 20, "$f=2$", fontsize=16, color='darkred')
plt.text(-30, 25, "$f=3$", fontsize=16, color='darkred')
plt.text(-30, -22, "$f=5$", fontsize=16, color='darkred')

# PCA
# plt.text(0.2, -0.2, "$f=1$", fontsize=16, color='darkred')
# plt.text(0.3, 0.2, "$f=2$", fontsize=16, color='darkred')
# plt.text(-0.5, 0.3, "$f=3$", fontsize=16, color='darkred')
# plt.text(-0.5, -0.2, "$f=5$", fontsize=16, color='darkred')

plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.colorbar(boundaries=np.arange(0, 2*np.pi+1, 1))
# plt.savefig(f"../../assets/figures/interpretability/autocorrelation_artifacts_{DIMENSION_REDUCTION_METHOD}.pdf", 
#     bbox_inches='tight') 
plt.show()

In [None]:
from scipy import stats


mask_embedding = model.patch_embedding.mask_embedding.data.detach().cpu().numpy()
_, (slope, intercept, r) = stats.probplot(
    x=mask_embedding, sparams=(0, 1), dist=stats.norm, fit=True, rvalue=False, plot=plt)
# dist = [stats.logistic, stats.norm, stats.t, stats.cauchy]

plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.xlabel("Theoretical quantiles", fontsize=16)
plt.ylabel("Observed Values", fontsize=16)
plt.title("Probability Plot of Mask Embeddings", fontsize=20)
plt.text(r**2, -0.1, f"$R^2 = {r**2:.4f}$", fontsize=16)
plt.savefig("../../assets/figures/interpretability/mask_embeddings.pdf", bbox_inches='tight')
plt.show()

### Kolmogorov-Smirnov Test

In [None]:
test_result = stats.kstest(rvs=mask_embedding, cdf="norm", alternative='two-sided')
print(test_result)

### Input embeddings

### Output embeddings

In [None]:
from tqdm import tqdm, trange
from moment.utils.short_univariate_classification_datasets import short_univariate_classification_datasets
from moment.data.dataloader import get_timeseries_dataloader

In [None]:
def get_test_dataloader(args):
    args.dataset_names = args.full_file_path_and_name
    args.data_split = 'test'
    test_dataloader = get_timeseries_dataloader(args=args)
    return test_dataloader

In [None]:
DEFAULT_CONFIG_PATH = "../../configs/default.yaml"
GPU_ID = 0
dataset_names = ['Crop', 'ElectricDevices', 'Wafer', 'ECG5000', 'ChlorineConcentration']
config = Config(config_file_path="../../configs/classification/unsupervised_representation_learning.yaml", 
                default_config_file_path=DEFAULT_CONFIG_PATH).parse()
config['device'] = GPU_ID if torch.cuda.is_available() else 'cpu'
args = parse_config(config)

In [None]:
for dataset_name in dataset_names:   
    args.full_file_path_and_name = f'/XXXX-14/project/public/XXXX-9/TimeseriesDatasets/classification/UCR/{dataset_name}/{dataset_name}_TEST.ts'
    test_dataloader = get_test_dataloader(args)

    embeddings = []
    labels = []
    for batch_x in tqdm(test_dataloader):
        timeseries = batch_x.timeseries.float().to(args.device)
        input_mask = batch_x.input_mask.long().to(args.device)
        
        _embeddings, _ = embed_timeseries_in_manifold(
            model=model, y=timeseries, device=args.device, input_mask=input_mask, 
            dimension_reduction_method='none')

        embeddings.append(_embeddings)
        labels.append(batch_x.labels)

    embeddings = np.concatenate(embeddings, axis=0)
    labels = np.concatenate(labels, axis=0).squeeze()

    for dimension_reduction_method in ['tsne', 'pca']:
        if dimension_reduction_method == 'tsne':
                embeddings_manifold = TSNE(n_components=2, n_jobs=10).fit_transform(embeddings)
        elif dimension_reduction_method == 'pca':
            embeddings_manifold = PCA(n_components=2).fit_transform(embeddings)

        plt.title(f"{dataset_name}", fontsize=20)
        plt.scatter(embeddings_manifold[:, 0], 
                    embeddings_manifold[:, 1], c=labels)
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)
        # plt.colorbar(boundaries=np.arange(labels.min(), labels.max()+1, 1))
        plt.tick_params(axis='both', which='both', bottom=False, top=False, 
                        labelbottom=False, right=False, left=False, labelleft=False)
        # Remove the box 
        ax = plt.gca()
        ax.set_frame_on(False)
        plt.savefig(f"../../assets/figures/interpretability/{dataset_name}_{dimension_reduction_method}.pdf", bbox_inches='tight') 
        plt.show()

## Frequency analysis

In [None]:
def reconstruct_timeseries(model: torch.nn.Module, 
                           y: npt.NDArray, 
                           device: torch.device, 
                           input_mask: npt.NDArray = None):
    y = y.to(device)
    n_samples, _, seq_len = y.shape
    model = model.to(device)
    
    if input_mask is None:
        input_mask = torch.ones((n_samples, seq_len)).to(device)
   
    model.eval()
    embeddings_manifold = []
    
    with torch.no_grad():
        outputs = model.reconstruct(x_enc=y, input_mask=input_mask)
    reconstruction = outputs.reconstruction.detach().cpu().numpy()

    # Move tensors and models back to CPU
    y = y.detach().cpu().numpy()
    model = model.cpu()
    input_mask = input_mask.detach().cpu().numpy()

    return y, reconstruction

In [None]:
synthetic_dataset = SyntheticDataset(n_samples=1024, freq=1, freq_range=(1, 96), 
                                     noise_mean=0., noise_std=0.1, random_seed=13)

y, c = synthetic_dataset.gen_sinusoids_with_varying_freq()
n_samples = synthetic_dataset.n_samples
seq_len = synthetic_dataset.seq_len

In [None]:
timeseries, reconstruction = reconstruct_timeseries(model, y, args.device)

In [None]:
# idx = np.random.randint(0, n_samples)
idx = 512
plt.title(f"$y = \sin(2*{c[idx][0]:.1f} \pi x) + \epsilon$", fontsize=20)
plt.plot(timeseries[idx].squeeze(), color='darkblue', label='True')
plt.plot(reconstruction[idx].squeeze(), color='red', linestyle='dashed', label='Reconstruction')
plt.legend()
plt.show()

In [None]:
error = np.mean((timeseries - reconstruction)**2, axis=-1).squeeze()
plt.plot(c[:, 0].squeeze().numpy(), error, c='darkblue')
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.xlabel(f"$y = \sin(2*c \pi x) + \epsilon$", fontsize=16)
plt.ylabel("MSE", fontsize=16)
plt.show()

In [None]:
import pandas as pd
pd.plotting.autocorrelation_plot(error)