In [19]:
import torch
import torch.nn as nn

import backbone
from methods.hypernets.binary_maml_utils import Binarizer

In [20]:
def update_weight(weight, update_value, subnetwork=False, hm_update_operator="minus"):
    if hm_update_operator == 'multiply' or subnetwork:
        if weight.fast is None:
            weight.fast = weight * update_value
        else:
            weight.fast = weight.fast * update_value
    elif hm_update_operator == 'minus':
        if weight.fast is None:
            weight.fast = weight - update_value
        else:
            weight.fast = weight.fast - update_value
    elif hm_update_operator == 'plus':
        if weight.fast is None:
            weight.fast = weight + update_value
        else:
            weight.fast = weight.fast + update_value

In [21]:
def print_parameters(model: nn.Module):
    print("Printing current params")
    for param in model.parameters():
        if param.fast is None:
            print(param)
        else:
            print(param.fast)

In [22]:
support_embeddings_query = torch.rand((30, 10))
support_data_labels = torch.rand(30, 5)

In [23]:
layer = backbone.Linear_fw(10, 5)
classifier = nn.Sequential(layer)

In [24]:
shapes = []

for weights in classifier.parameters():
    shapes.append(weights.shape)

In [25]:
delta_params_list = []

for shape in shapes:
    delta_params = torch.rand(shape)
    k_val = torch.quantile(delta_params, q=0.3)
    delta_params = Binarizer.apply(delta_params, k_val)
    delta_params_list.append(delta_params)

In [26]:
fast_parameters = []

clf_fast_parameters = list(classifier.parameters())
for weight in classifier.parameters():
    weight.fast = None

classifier.zero_grad()
fast_parameters = clf_fast_parameters

In [28]:
task_update_num = 5
train_lr = 0.1
loss_fn = nn.CrossEntropyLoss()

print_parameters(classifier)

for k, weight in enumerate(classifier.parameters()):
    update_value = delta_params_list[k]
    update_weight(weight, update_value, subnetwork=True)

print_parameters(classifier)

print("###########################")
print("# STARTING MAML UPDATES ###")
print("###########################")

for task_step in range(task_update_num):
    scores = classifier(support_embeddings_query)

    set_loss = loss_fn(scores, support_data_labels)

    grad = torch.autograd.grad(set_loss, fast_parameters, create_graph=True, allow_unused=True)

    print("###### BEFORE UPDATE ###########")

    print_parameters(classifier)

    for k, weight in enumerate(classifier.parameters()):
        update_value = (train_lr * grad[k])
        update_weight(weight, update_value)

    print("###### AFTER UPDATE ###########")
    print_parameters(classifier)


Printing current params
tensor([[-0.0000, -0.0000,  0.0000,  0.0000,  0.2580, -0.0000,  0.0000, -0.0000,
          0.0000, -0.0193],
        [ 0.1603,  0.0617, -0.0298,  0.0000,  0.0000, -0.0657,  0.0000,  0.0000,
         -0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.1802, -0.1468,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.1033, -0.2335,  0.0000, -0.3149,  0.0000,  0.2077,  0.0000,
          0.0000, -0.0158],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0906,
          0.0000,  0.1201]], grad_fn=<SubBackward0>)
tensor([0.0027, 0.0000, 0.0000, 0.2525, 0.0000], grad_fn=<SubBackward0>)
Printing current params
tensor([[-0.0000, -0.0000,  0.0000,  0.0000,  0.2580, -0.0000,  0.0000, -0.0000,
          0.0000, -0.0193],
        [ 0.1603,  0.0617, -0.0298,  0.0000,  0.0000, -0.0657,  0.0000,  0.0000,
         -0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.1802, -0.1468,  0.0000,
          