In [None]:
from pathlib import Path
from operator import itemgetter

import matplotlib.pyplot as plt
import mlflow
import numpy as np
import torch
from rich.progress import track

import data_loader.data_loaders as module_data
import model.model as module_arch
from configs import configs
from utils.plot_utils import plot_relavant_features, plot_relevances_amplitudes
from utils.utils import read_json
from notebooks.helpers import import_artifacts_from_runID

In [None]:
run_id = "fe7c023d3ad94aeabb2f6fd23a1d031f"
artifacts = import_artifacts_from_runID(run_id)
model, data_loader, device, config = itemgetter("model", "data_loader", "device", "config")(
	artifacts
)

relevances_list = []
signatures_list = []
with torch.no_grad():
	for i, (data, target, _) in enumerate(track(data_loader, description="Loading data...")):
		data = data.to(device)
		signature = torch.mean(data, dim=(2, 3))
		output = model.spectral.fc1(signature)
		output = model.spectral.act1(output)
		output = model.spectral.fc2(output)
		output = model.spectral.act2(output)
		relevances_list.append(output.detach().cpu().numpy().flatten())
		signatures_list.append(signature.detach().cpu().numpy().flatten())

relevances = np.mean(relevances_list, axis=0)
signature = np.mean(signatures_list, axis=0)
plot_relevances_amplitudes(relevances)
# plt.plot(configs.BANDS, signature, color="black", linewidth=2)
plot_relavant_features(relevances)
plt.show()