In [None]:
import torch
import glob

In [None]:
analysis_dir = 'analysis'
n = len(glob.glob(f'{analysis_dir}/artificial_mask_*.pt'))
n_steps = len(glob.glob(f'{analysis_dir}/step_*.pt')) / n
n_steps = int(n_steps)
T = 50
steps = torch.linspace(0, T - 1, n_steps).flip(0).round().long().tolist()
steps

In [None]:
def print_denoising_process(i: int, n_th_stock: int, n_th_period: int, steps: list[int], print_n_values: int = 10):
    # Load tensors
    mask = torch.load(f"{analysis_dir}/artificial_mask_{i}.pt", weights_only=False, map_location="cpu")
    fully_noise = torch.load(f"{analysis_dir}/noised_{i}.pt", weights_only=False, map_location="cpu")
    step = [
        torch.load(f"{analysis_dir}/step_{s}_{i}.pt", weights_only=False, map_location="cpu")
        for s in steps
    ]
    original_data = torch.load(f"{analysis_dir}/original_{i}.pt", map_location="cpu", weights_only=False)

    # Extract relevant slices
    mask_1 = mask[n_th_stock, n_th_period, :].bool()

    def format_list(data):
        return [float(f"{x:.4f}") for x in data[:print_n_values]]

    # Print fully noised data
    fully_noise_1 = fully_noise[n_th_stock, n_th_period, :][mask_1]
    full_noise_diff = ((fully_noise - original_data) * mask)[n_th_stock, n_th_period, :][mask_1].abs().sum().item()
    full_noise_mean_diff = ((fully_noise - original_data) * mask)[n_th_stock, n_th_period, :][mask_1].abs().mean().item()
    print(f"{format_list(fully_noise_1)}\tFully noised (diff: {full_noise_diff:.2f}) (mean diff: {full_noise_mean_diff:.2f})")

    # Print each step's data
    for idx, s in enumerate(steps):
        step_diff = step[idx] - original_data
        step_diff = (step_diff * mask).abs()
        step_diff_1 = step_diff[n_th_stock, n_th_period, :][mask_1].sum().item()
        mean_diff = step_diff[n_th_stock, n_th_period, :][mask_1].mean().item()
        
        print(f"{format_list(step[idx][n_th_stock, n_th_period, :][mask_1])}\tStep {s} (diff: {step_diff_1:.2f}) (mean diff: {mean_diff:.2f})")

    # Print original data
    original_1 = original_data[n_th_stock, n_th_period, :][mask_1]
    print(f"{format_list(original_1)}\tOriginal")

print_denoising_process(i=0, n_th_stock=0, n_th_period=0, steps=steps, print_n_values=5)