In [1]:
import torch
import numpy as np
import random
from torch.autograd import Variable
import torch.nn as nn
import torch.optim
import json
import torch.utils.data.sampler
import os
import glob
import time

import configs
import backbone
import data.feature_loader as feat_loader
from data.datamgr import SetDataManager
from methods.maml import MAML
from methods.differentialDKTIXnogpytorch import differentialDKTIXnogpy
from io_utils import model_dict, get_resume_file, parse_args, get_best_file , get_assigned_file

Conv4S
Conv6


In [2]:
def _set_seed(seed, verbose=True):
    if(seed!=0):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False 
        if(verbose): print("[INFO] Setting SEED: " + str(seed))   
    else:
        if(verbose): print("[INFO] Setting SEED: None")

In [3]:
def remap_state_dict(old_state_dict):
    new_state_dict = {}
    for old_key, value in old_state_dict.items():
        # Remap keys to match the new model structure
        if old_key.startswith("feature.trunk"):
            # Convert keys like "feature.trunk.0.trunk.0.weight" to "feature_extractor.0.trunk.0.C.weight"
            parts = old_key.split(".")
            layer_idx = parts[2]  # Extract layer index
            layer_part = parts[3:]  # Rest of the parts
            if "trunk" in layer_part:
                layer_part.remove("trunk")  # Remove extra "trunk" if present
            new_key = f"feature_extractor.{layer_idx}.trunk.0." + ".".join(layer_part)
            new_state_dict[new_key] = value
        elif old_key.startswith("feature.trunk") and "num_batches_tracked" not in old_key:
            # Handle BatchNorm layers without extra "trunk"
            parts = old_key.split(".")
            layer_idx = parts[2]
            layer_part = parts[3:]
            new_key = f"feature_extractor.{layer_idx}.trunk.0." + ".".join(layer_part)
            new_state_dict[new_key] = value
        else:
            # Keep other keys unchanged
            new_state_dict[old_key] = value
    return new_state_dict

In [4]:
device = 'cuda:0'
    
seed = 1
_set_seed(seed)


# First define loaders
n_query = max(1, int(16 * 5 / 5))  # if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
print(f"n_query : {n_query}")

# Dataloader
image_size = 84

base_file = configs.data_dir['CUB'] + 'base.json'
val_file = configs.data_dir['CUB'] + 'val.json'

train_few_shot_params = dict(n_way=5, n_support=1)
base_datamgr = SetDataManager(image_size, n_query=n_query, **train_few_shot_params) #n_eposide=100
base_loader = base_datamgr.get_data_loader(base_file, aug=True)

test_few_shot_params = dict(n_way=5, n_support=1)
val_datamgr = SetDataManager(image_size, n_query=n_query, **test_few_shot_params)
val_loader = val_datamgr.get_data_loader(val_file, aug=False)
# a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor

# WHERE WE NEED TO ADD BIAS
# backbone.ConvBlock.maml = True
# backbone.SimpleBlock.maml = True
# backbone.BottleneckBlock.maml = True
# backbone.ResNet.maml = True            
model = differentialDKTIXnogpy(model_dict['Conv4_maml_diffDKTIX'], **train_few_shot_params)

# Load state_dict
model = model.cuda()


[INFO] Setting SEED: 1
n_query : 16
Conv4_maml_diffDKTIX
Normalization : False


In [5]:
print('Conv4 state dict :\n', model_dict['Conv4_maml_diffDKTIX']().state_dict().keys())
print('')
print('Conv4 backbone state dict:\n', model_dict['Conv4']().state_dict().keys())

Conv4_maml_diffDKTIX
Conv4 state dict :
 odict_keys(['trunk.0.C.weight', 'trunk.0.C.bias', 'trunk.0.BN.weight', 'trunk.0.BN.bias', 'trunk.0.BN.running_mean', 'trunk.0.BN.running_var', 'trunk.0.BN.num_batches_tracked', 'trunk.1.C.weight', 'trunk.1.C.bias', 'trunk.1.BN.weight', 'trunk.1.BN.bias', 'trunk.1.BN.running_mean', 'trunk.1.BN.running_var', 'trunk.1.BN.num_batches_tracked', 'trunk.2.C.weight', 'trunk.2.C.bias', 'trunk.2.BN.weight', 'trunk.2.BN.bias', 'trunk.2.BN.running_mean', 'trunk.2.BN.running_var', 'trunk.2.BN.num_batches_tracked', 'trunk.3.C.weight', 'trunk.3.C.bias', 'trunk.3.BN.weight', 'trunk.3.BN.bias', 'trunk.3.BN.running_mean', 'trunk.3.BN.running_var', 'trunk.3.BN.num_batches_tracked'])

Conv4
Conv4 backbone state dict:
 odict_keys(['trunk.0.C.weight', 'trunk.0.C.bias', 'trunk.0.BN.weight', 'trunk.0.BN.bias', 'trunk.0.BN.running_mean', 'trunk.0.BN.running_var', 'trunk.0.BN.num_batches_tracked', 'trunk.0.trunk.0.weight', 'trunk.0.trunk.0.bias', 'trunk.0.trunk.1.weight'

In [6]:
checkpoint_dir = "./save/checkpoints/CUB/Conv4_maml_aug_5way_1shot"

resume_file = get_resume_file(checkpoint_dir)
# /!\ CAUTION : get_resume_file does not give the same results in testing that get_best_file, that is used in the test.py

tmp = torch.load(resume_file)

old_state_dict = tmp['state']
new_state_dict = remap_state_dict(old_state_dict)

print("Old keys: \n", old_state_dict.keys())
print('')
print('')
print("New keys:  \n", new_state_dict.keys())
print('')
print('')
print('')
print('')


start_epoch = tmp['epoch'] + 1
missing, unexpected = model.load_state_dict(old_state_dict, strict=False)  # /!\ cAUTION, VERY DANGEROUS STRICT = FALSE
print('missing keys ;  \n', missing)
print('')
print('')
print('Unexpected keys ;  \n', unexpected)

Old keys: 
 odict_keys(['feature.trunk.0.C.weight', 'feature.trunk.0.C.bias', 'feature.trunk.0.BN.weight', 'feature.trunk.0.BN.bias', 'feature.trunk.0.BN.running_mean', 'feature.trunk.0.BN.running_var', 'feature.trunk.0.BN.num_batches_tracked', 'feature.trunk.0.trunk.0.weight', 'feature.trunk.0.trunk.0.bias', 'feature.trunk.0.trunk.1.weight', 'feature.trunk.0.trunk.1.bias', 'feature.trunk.0.trunk.1.running_mean', 'feature.trunk.0.trunk.1.running_var', 'feature.trunk.0.trunk.1.num_batches_tracked', 'feature.trunk.1.C.weight', 'feature.trunk.1.C.bias', 'feature.trunk.1.BN.weight', 'feature.trunk.1.BN.bias', 'feature.trunk.1.BN.running_mean', 'feature.trunk.1.BN.running_var', 'feature.trunk.1.BN.num_batches_tracked', 'feature.trunk.1.trunk.0.weight', 'feature.trunk.1.trunk.0.bias', 'feature.trunk.1.trunk.1.weight', 'feature.trunk.1.trunk.1.bias', 'feature.trunk.1.trunk.1.running_mean', 'feature.trunk.1.trunk.1.running_var', 'feature.trunk.1.trunk.1.num_batches_tracked', 'feature.trunk.2.C

In [9]:
print(model)

differentialDKTIXnogpy(
  (feature): Conv4_maml_diffDKTIX_C(
    (trunk): Sequential(
      (0): ConvBlock_MAML_TO_DIFF(
        (C): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (BN): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (1): ConvBlock_MAML_TO_DIFF(
        (C): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (BN): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (2): ConvBlock_MAML_TO_DIFF(
        (C): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (BN): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(in

In [29]:
model.eval()
model.test_loop( val_loader, optim_based=True, n_ft=10, lr=0.01, temp=1, return_std = True)
# model.test_loop(val_loader)

Test | Batch 0/100 | Loss 0.000000 | Acc 36.250000
100 Test Acc = 46.64% +- 2.26%


(46.6375, 11.525372824772306)

In [17]:
# Define novel loader :
few_shot_params = dict(n_way = 5, n_support = 1)
datamgr         = SetDataManager(image_size, n_eposide = 600, n_query = 15 , **few_shot_params)
loadfile    = configs.data_dir['CUB'] + 'novel.json'
novel_loader     = datamgr.get_data_loader( loadfile, aug = False)

model.test_loop( val_loader, optim_based=True, n_ft=100, lr=0.01, temp=1, return_std = True)

Test | Batch 0/100 | Loss 0.000000 | Acc 71.250000
100 Test Acc = 57.39% +- 2.48%


(57.3875, 12.633159491987742)