In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import configs

# from data.qmul_loader import get_batch, train_people, test_people
from data.regression_data_loader import data_provider

from io_utils import parse_args_regression, get_resume_file
from methods.UnLiMiTDI_regression import UnLiMiTDI
from methods.UnLiMiTDR_regression import UnLiMiTDR
from methods.UnLiMiTDproj_regression import UnLiMiTDproj
from methods.UnLiMiTDIX_regression import UnLiMiTDIX
from methods.UnLiMiTDprojX_regression import UnLiMiTDprojX
from methods.maml import MAML
from projection import create_random_projection_matrix, proj_sketch
import backbone
import os
import numpy as np

import matplotlib.pyplot as plt

In [2]:
def select_model(dataset, method, model="Conv3"):
    if dataset == "QMUL":
        if method == "DKT":
            backbone.Conv3.maml = False
            backbone.simple_net_multi_output.maml = False
            bb               = backbone.Conv3().cuda()
            if model=="Conv3_net":
                simple_net_multi = backbone.simple_net_multi_output()        
                bb = backbone.CombinedNetwork(bb, simple_net_multi).cuda()
        elif model == "Conv3" and method == "MAML":
            backbone.Conv3.maml = True
            backbone.simple_net.maml = True
            bb               = backbone.Conv3().cuda()
            bb               = backbone.CombinedNetwork(bb, backbone.simple_net()).cuda()  # nn.Linear(2916, 1)
        elif model == "Conv3" and "UnLiMiTD" in method:
            backbone.Conv3.maml = False
            backbone.simple_net.maml = False
            bb               = backbone.Conv3().cuda()
            bb               = backbone.CombinedNetwork(bb, backbone.simple_net()).cuda()  # nn.Linear(2916, 1)
        else:
            raise ValueError("Model not recognized")

    elif dataset in ("berkeley", "argus"):

        if dataset == "berkeley":
            input_dim=11
        else:
            input_dim=3

        if model == "ThreeLayerMLP" and method in ("DKT"):
            bb = backbone.ThreeLayerMLP(input_dim=input_dim, output_dim=32)
        elif model == "ThreeLayerMLP" and method in ("MAML"):
            backbone.ThreeLayerMLP.maml = True
            bb = backbone.ThreeLayerMLP(input_dim=input_dim, output_dim=32)
        elif model == "ThreeLayerMLP":
            bb = backbone.ThreeLayerMLP(input_dim=input_dim, output_dim=1)
        elif model == "SteinwartMLP" and method in ("DKT"):
            bb = backbone.SteinwartMLP(input_dim=input_dim, output_dim=32)
        elif model == "SteinwartMLP" and method in ("MAML"):
            backbone.SteinwartMLP.maml = True
            bb = backbone.SteinwartMLP(input_dim=input_dim, output_dim=32)
        elif model == "SteinwartMLP":
            bb = backbone.SteinwartMLP(input_dim=input_dim, output_dim=1)
        else:
            raise ValueError("Model not recognized")

    else:
        raise ValueError("Dataset not recognized")
        
    return bb

In [3]:
seed=1
dataset="QMUL"
param_model="Conv3"
n_test_epochs=15
n_support=-1

print_every = 50  # Collect data every 10 steps of adaptation
total_adapt_steps = 1000  # Total meta-testing-steps

mse_statistics_per_step=dict()

provider = data_provider(dataset)

In [4]:
shrinking_factor = 0.01

for method in ["UnLiMiTDI", "UnLiMiTDIX"]:
    checkpoint_dir = '%scheckpoints/%s/%s_%s' % (configs.save_dir, dataset, param_model, method)
    print(checkpoint_dir)
    
    bb = select_model(dataset=dataset, method=method, model=param_model)
    print(f"Meta-testing {method}")
    if method=="UnLiMiTDI":
        model = UnLiMiTDIX(None, bb, has_scaling_params=False).cuda()   # n_support is -1 because it's directly implemented in loader
    else:
        model = UnLiMiTDIX(None, bb, has_scaling_params=True).cuda()
    model.load_checkpoint(checkpoint_dir)
    optimizer = None

    mses_per_task=[]
    
    # Resetting the seed for comparable results
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    for epoch in range(n_test_epochs):
        mse_per_step = model.test_loop_ft(n_support, total_adapt_steps, provider, optimizer, shrinking_factor=shrinking_factor, print_every=print_every)
        mse_per_step = [float(mse.cpu().detach().numpy()) for mse in mse_per_step]
        mses_per_task.append(mse_per_step)
    mses_per_step = np.array(mses_per_task).T
    mean_mses_per_step = np.mean(mses_per_step, axis=1)
    std_mses_per_step = np.std(mses_per_step, axis=1)

    mse_statistics_per_step[method]=(mean_mses_per_step, std_mses_per_step)

./save/checkpoints/QMUL/Conv3_UnLiMiTDI
Meta-testing UnLiMiTDI
Beggining adaptation with n_support 5
Final MSE : 4.492578955250792e-05
Beggining adaptation with n_support 5
Final MSE : 0.1006358414888382
Beggining adaptation with n_support 5
Final MSE : 0.03569720312952995
Beggining adaptation with n_support 5
Final MSE : 0.01994704082608223
Beggining adaptation with n_support 5
Final MSE : 0.04445667192339897
Beggining adaptation with n_support 5
Final MSE : 0.017833024263381958
Beggining adaptation with n_support 5
Final MSE : 0.01662031188607216
Beggining adaptation with n_support 5
Final MSE : 0.013488920405507088
Beggining adaptation with n_support 5
Final MSE : 0.029217975214123726
Beggining adaptation with n_support 5
Final MSE : 0.012969711795449257
Beggining adaptation with n_support 5
Final MSE : 0.044420160353183746
Beggining adaptation with n_support 5
Final MSE : 0.020086204633116722
Beggining adaptation with n_support 5
Final MSE : 0.0007093284511938691
Beggining adaptati

In [5]:
method="MAML"
for inn_steps in (1, 3):
    checkpoint_dir = '%scheckpoints/%s/%s_%s_%s_inn_steps' % (configs.save_dir, dataset, param_model, method, inn_steps)
    print(checkpoint_dir)
    
    bb = select_model(dataset=dataset, method=method, model=param_model)
    model = MAML(bb, -1, problem="regression").cuda()   # n_support is -1 because it's directly implemented in loader
    model.load_checkpoint(checkpoint_dir)
    model.task_update_num = total_adapt_steps
    optimizer = None
    
    print(f"Meta-testing MAML {inn_steps} inner steps")
    mses_per_task=[]
    
    # Resetting the seed for comparable results
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    for epoch in range(n_test_epochs):
        mse_per_step = model.test_loop(n_support, provider, optimizer = optimizer, print_every=print_every)
        mse_per_step = [float(mse.cpu().detach().numpy()) for mse in mse_per_step]
        mses_per_task.append(mse_per_step)
    mses_per_step = np.array(mses_per_task).T
    mean_mses_per_step = np.mean(mses_per_step, axis=1)
    std_mses_per_step = np.std(mses_per_step, axis=1)
    
    mse_statistics_per_step[f"maml_{inn_steps}"]=(mean_mses_per_step, std_mses_per_step)

./save/checkpoints/QMUL/Conv3_MAML_1_inn_steps


RuntimeError: Error(s) in loading state_dict for CombinedNetwork:
	Missing key(s) in state_dict: "networks.0.conv1.weight", "networks.0.conv1.bias", "networks.0.conv2.weight", "networks.0.conv2.bias", "networks.0.conv3.weight", "networks.0.conv3.bias". 
	Unexpected key(s) in state_dict: "networks.0.layer1.weight", "networks.0.layer1.bias", "networks.0.layer2.weight", "networks.0.layer2.bias", "networks.0.layer3.weight", "networks.0.layer3.bias". 

In [None]:
# Prepare plot
plt.figure(figsize=(10, 6))

# Plot for MAML with 1 inner step
plt.errorbar(
    x=print_every*np.arange(len(mse_statistics_per_step[f"maml_1"][0])), 
    y=mse_statistics_per_step[f"maml_1"][0], 
    yerr=mse_statistics_per_step[f"maml_1"][1], 
    label="MAML 1 inner step", 
    fmt='-o', 
    color='blue', 
    capsize=5
)

# Plot for MAML with 3 inner steps
plt.errorbar(
    x=print_every*np.arange(len(mse_statistics_per_step[f"maml_3"][0])), 
    y=mse_statistics_per_step[f"maml_3"][0], 
    yerr=mse_statistics_per_step[f"maml_3"][1], 
    label="MAML 3 inner steps", 
    fmt='-o', 
    color='red', 
    capsize=5
)

# Plot for oursI
if dataset=="QMUL":
    plt.errorbar(
        x=print_every*np.arange(len(mse_statistics_per_step["UnLiMiTDI"][0])), 
        y=mse_statistics_per_step["UnLiMiTDI"][0], 
        yerr=mse_statistics_per_step["UnLiMiTDI"][1], 
        label="OursI", 
        fmt='-o', 
        color='green', 
        capsize=5
    )

plt.errorbar(
    x=print_every*np.arange(len(mse_statistics_per_step["UnLiMiTDIX"][0])), 
    y=mse_statistics_per_step["UnLiMiTDIX"][0], 
    yerr=mse_statistics_per_step["UnLiMiTDIX"][1], 
    label="OursIX", 
    fmt='-o', 
    color='orange', 
    capsize=5
)

# Customize the plot
plt.title("MSE Per Adaptation Step for MAML (1 vs. 3 Inner Steps)")
plt.xlabel("Adaptation Step")
plt.ylabel("Mean MSE")
plt.xticks(print_every*np.arange(len(mse_statistics_per_step[f"maml_1"][0])))
plt.legend()

# Show the plot
plt.ylim(0, .04)
plt.grid(True)
plt.tight_layout()
plt.savefig(f'meta_test_adaptation_maml_ours_{dataset}.png', dpi=300)
plt.show()