In [None]:
from pathlib import Path
from operator import itemgetter

import matplotlib.pyplot as plt
import mlflow
import numpy as np
import torch
from rich.progress import track
from IPython.display import clear_output

import data_loader.data_loaders as module_data
import model.model as module_arch
from configs import configs
from utils.plot_utils import plot_relavant_features, plot_relevances_amplitudes
from utils.utils import ensure_dir
from notebooks.helpers import import_artifacts_from_runID, load_ids_from_registry, get_plot_name

In [None]:
def load_test_df(run_id):
	artifacts = import_artifacts_from_runID(run_id)
	model, data_loader, device, config = itemgetter("model", "data_loader", "device", "config")(
		artifacts
	)

	relevances_list = []
	with torch.no_grad():
		for i, (data, target, _) in enumerate(track(data_loader, description="Loading data...")):
			data = data.to(device)
			signature = torch.mean(data, dim=(2, 3))
			output = model.spectral.fc1(signature)
			output = model.spectral.act1(output)
			output = model.spectral.fc2(output)
			output = model.spectral.act2(output)
			relevances_list.append(output.detach().cpu().numpy().flatten())

	clear_output(wait=True)

	relevances = np.mean(relevances_list, axis=0)
	return relevances, config

In [None]:
run_ids = load_ids_from_registry()
test_data = [(load_test_df(run_id)) for run_id in run_ids]

save_relevances_dir = ensure_dir(configs.BASE_DIR / "saved/relevances")

for relevances, config in test_data:
	name = get_plot_name(config)
	plot_relevances_amplitudes(relevances, title=name)
	# plot_relavant_features(relevances)
	save_path = save_relevances_dir / f"{name}.png"
	plt.savefig(save_path)