In [87]:
import os
import torch

In [88]:
def compute_gsnr(epoch: int, layer_name: str, base_path: str, epsilon=1e-8):
    """
    Computes GSNR for a given layer at a given epoch.
    """
    layer_dir = os.path.join(
        base_path,
        f"grad_info_per_epoch",
        f"epoch{epoch}",
        layer_name
    )

    grad_list = []
    filenames = sorted([
        f for f in os.listdir(layer_dir) if f.endswith(".pt")
    ])

    for fname in filenames:
        grad = torch.load(os.path.join(layer_dir, fname))  # [D]
        grad_list.append(grad)

    grads = torch.stack(grad_list)  # [T, D]
    grad_mean = grads.mean(dim=0)   # [D]

    signal = grad_mean.mean().item() ** 2
    noise = grad_mean.var().item() + epsilon
    gsnr = signal / noise

    print(f"{layer_name} (epoch {epoch}): GSNR = {gsnr:.6f}")
    return gsnr

In [89]:
def compute_gsnr_over_epochs(layer_name: str, base_path: str, epoch_list: list):
    """
    Computes and prints GSNR for a specific layer across multiple epochs.
    """
    results = {}
    for epoch in epoch_list:
        try:
            gsnr = compute_gsnr(epoch, layer_name, base_path)
            results[epoch] = gsnr
        except Exception as e:
            print(f"[Warning] Epoch {epoch} failed: {e}")
    return results


# MAML

In [90]:
base_path = "../MAML_5way_5shot_filter64_miniImagenet"
layer_name = "layer_layer_dict_conv0_conv_weight"
epochs = list(range(0, 3))  # 예: 0~9 epoch

In [91]:
compute_gsnr(
    epoch=0,
    layer_name="layer_layer_dict_conv0_conv_weight",
    base_path=our_base_path
)

layer_layer_dict_conv0_conv_weight (epoch 0): GSNR = 0.044766


0.04476636962336588

In [92]:
gsnr_results = compute_gsnr_over_epochs(layer_name, base_path, epochs)

layer_layer_dict_conv0_conv_weight (epoch 0): GSNR = 0.163490
layer_layer_dict_conv0_conv_weight (epoch 1): GSNR = 0.060648
layer_layer_dict_conv0_conv_weight (epoch 2): GSNR = 0.368776


In [93]:
our_base_path = "../MAML_Prompt_padding_5way_5shot_filter128_miniImagenet"
layer_name = "layer_layer_dict_conv0_conv_weight"

In [94]:
compute_gsnr(
    epoch=0,
    layer_name="layer_layer_dict_conv0_conv_weight",
    base_path=our_base_path
)

layer_layer_dict_conv0_conv_weight (epoch 0): GSNR = 0.044766


0.04476636962336588

In [95]:
gsnr_results = compute_gsnr_over_epochs(layer_name, our_base_path, epochs)

layer_layer_dict_conv0_conv_weight (epoch 0): GSNR = 0.044766
layer_layer_dict_conv0_conv_weight (epoch 1): GSNR = 0.224419
layer_layer_dict_conv0_conv_weight (epoch 2): GSNR = 0.453997
