In [97]:
import os
import torch

In [98]:
def compute_gsnr(epoch: int, layer_name: str, base_path: str, epsilon=1e-8):
    """
    Computes GSNR for a given layer at a given epoch.
    """
    layer_dir = os.path.join(
        base_path,
        f"grad_info_per_epoch",
        f"epoch{epoch}",
        layer_name
    )

    grad_list = []
    filenames = sorted([
        f for f in os.listdir(layer_dir) if f.endswith(".pt")
    ])

    for fname in filenames:
        grad = torch.load(os.path.join(layer_dir, fname))  # [D]
        grad_list.append(grad)

    grads = torch.stack(grad_list)  # [T, D]
    grad_mean = grads.mean(dim=0)   # [D]

    signal = grad_mean.mean().item() ** 2
    noise = grad_mean.var().item() + epsilon
    gsnr = signal / noise

    print(f"{layer_name} (epoch {epoch}): GSNR = {gsnr:.6f}")
    return gsnr

In [99]:
def compute_gsnr_over_epochs(layer_name: str, base_path: str, epoch_list: list):
    """
    Computes and prints GSNR for a specific layer across multiple epochs.
    """
    results = {}
    for epoch in epoch_list:
        try:
            gsnr = compute_gsnr(epoch, layer_name, base_path)
            results[epoch] = gsnr
        except Exception as e:
            print(f"[Warning] Epoch {epoch} failed: {e}")
    return results


# MAML

In [100]:
base_path = "../MAML_5way_5shot_filter64_miniImagenet"
epochs = list(range(0, 3))  # 예: 0~9 epoch

In [101]:
compute_gsnr(
    epoch=0,
    layer_name="layer_layer_dict_conv0_conv_weight",
    base_path=our_base_path
)

layer_layer_dict_conv0_conv_weight (epoch 0): GSNR = 0.044766


0.04476636962336588

In [102]:
layer_name = "layer_layer_dict_conv0_conv_weight"
gsnr_results = compute_gsnr_over_epochs(layer_name, base_path, epochs)

layer_layer_dict_conv0_conv_weight (epoch 0): GSNR = 0.163490
layer_layer_dict_conv0_conv_weight (epoch 1): GSNR = 0.060648
layer_layer_dict_conv0_conv_weight (epoch 2): GSNR = 0.368776


In [103]:
layer_name = "layer_layer_dict_conv1_conv_weight"
gsnr_results = compute_gsnr_over_epochs(layer_name, base_path, epochs)

layer_layer_dict_conv1_conv_weight (epoch 0): GSNR = 0.015686
layer_layer_dict_conv1_conv_weight (epoch 1): GSNR = 0.044903
layer_layer_dict_conv1_conv_weight (epoch 2): GSNR = 0.083064


In [104]:
layer_name = "layer_layer_dict_conv2_conv_weight"
gsnr_results = compute_gsnr_over_epochs(layer_name, base_path, epochs)

layer_layer_dict_conv2_conv_weight (epoch 0): GSNR = 0.016997
layer_layer_dict_conv2_conv_weight (epoch 1): GSNR = 0.007760
layer_layer_dict_conv2_conv_weight (epoch 2): GSNR = 0.013834


In [105]:
layer_name = "layer_layer_dict_conv3_conv_weight"
gsnr_results = compute_gsnr_over_epochs(layer_name, base_path, epochs)

layer_layer_dict_conv3_conv_weight (epoch 0): GSNR = 0.001702
layer_layer_dict_conv3_conv_weight (epoch 1): GSNR = 0.000643
layer_layer_dict_conv3_conv_weight (epoch 2): GSNR = 0.024073


In [106]:
layer_name = "layer_layer_dict_linear_weights"
gsnr_results = compute_gsnr_over_epochs(layer_name, base_path, epochs)

layer_layer_dict_linear_weights (epoch 0): GSNR = 0.000000
layer_layer_dict_linear_weights (epoch 1): GSNR = 0.000000
layer_layer_dict_linear_weights (epoch 2): GSNR = 0.000000


# Our

In [107]:
our_base_path = "../MAML_Prompt_padding_5way_5shot_filter128_miniImagenet"

In [108]:
compute_gsnr(
    epoch=0,
    layer_name="layer_layer_dict_conv0_conv_weight",
    base_path=our_base_path
)

layer_layer_dict_conv0_conv_weight (epoch 0): GSNR = 0.044766


0.04476636962336588

In [109]:
layer_name = "layer_layer_dict_conv0_conv_weight"
gsnr_results = compute_gsnr_over_epochs(layer_name, our_base_path, epochs)

layer_layer_dict_conv0_conv_weight (epoch 0): GSNR = 0.044766
layer_layer_dict_conv0_conv_weight (epoch 1): GSNR = 0.224419
layer_layer_dict_conv0_conv_weight (epoch 2): GSNR = 0.453997


In [110]:
layer_name = "layer_layer_dict_conv1_conv_weight"
gsnr_results = compute_gsnr_over_epochs(layer_name, our_base_path, epochs)

layer_layer_dict_conv1_conv_weight (epoch 0): GSNR = 0.005074
layer_layer_dict_conv1_conv_weight (epoch 1): GSNR = 0.028651
layer_layer_dict_conv1_conv_weight (epoch 2): GSNR = 0.081176


In [111]:
layer_name = "layer_layer_dict_conv2_conv_weight"
gsnr_results = compute_gsnr_over_epochs(layer_name, our_base_path, epochs)

layer_layer_dict_conv2_conv_weight (epoch 0): GSNR = 0.001356
layer_layer_dict_conv2_conv_weight (epoch 1): GSNR = 0.009594
layer_layer_dict_conv2_conv_weight (epoch 2): GSNR = 0.009650


In [112]:
layer_name = "layer_layer_dict_conv3_conv_weight"
gsnr_results = compute_gsnr_over_epochs(layer_name, our_base_path, epochs)

layer_layer_dict_conv3_conv_weight (epoch 0): GSNR = 0.007721
layer_layer_dict_conv3_conv_weight (epoch 1): GSNR = 0.010351
layer_layer_dict_conv3_conv_weight (epoch 2): GSNR = 0.018344


In [113]:
layer_name = "layer_layer_dict_linear_weights"
gsnr_results = compute_gsnr_over_epochs(layer_name, our_base_path, epochs)

layer_layer_dict_linear_weights (epoch 0): GSNR = 0.000000
layer_layer_dict_linear_weights (epoch 1): GSNR = 0.000000
layer_layer_dict_linear_weights (epoch 2): GSNR = 0.000000
