In [22]:
from maml_functorch.model import VinyalsConv
from maml_functorch.trainer import HyperParameters
from maml_functorch.dataset import load_data, DataConfig
from maml_functorch.utils import get_accuracy_from_logits
from maml_functorch.testing_utils import generate_random_batch

from functorch import make_functional_with_buffers, vmap, grad
from torch.nn import functional as F
import torch
from tqdm import tqdm
import wandb
from dataclasses import asdict

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = VinyalsConv(5, embedding_feats=288, track_running_stats=False).to(device)

In [23]:
support_set, support_labels, query_set, query_labels = generate_random_batch(5, device)

In [24]:
from functorch import make_functional_with_buffers, vmap, grad


fmodel, params, buffers = make_functional_with_buffers(model, disable_autograd_tracking=True)

In [25]:
fmodel(params, buffers, support_set[0]).shape

torch.Size([25, 5])

In [26]:
def compute_meta_loss(params, buffers, support_sample, support_labels):
    logits = fmodel(params, buffers, support_sample)
    support_loss = F.cross_entropy(logits, support_labels)

    return support_loss

compute_meta_grads = grad(compute_meta_loss)

def calculate_next_params(params, buffers, support_sample, support_labels):
    grads = compute_meta_grads(params, buffers, support_sample, support_labels)
    
    new_params = [p - 0.01 * g for p, g in zip(params, grads)]
    
    return new_params

def compute_logits(params, buffers, support_sample, support_labels, query_sample):
    last_params = params
    for i in range(5):
        last_params = calculate_next_params(last_params, buffers, support_sample, support_labels)
    
    return fmodel(last_params, buffers, query_sample)

compute_logits(params, buffers, support_set[0], support_labels[0], query_set[0]).shape

torch.Size([50, 5])

In [27]:

batch_compute_logits = vmap(compute_logits, in_dims=(None, None, 0, 0, 0))


In [28]:
# Query Shape: torch.Size([4, 10, 3, 84, 84])
print(query_set.shape)
batch_compute_logits(params, buffers, support_set, support_labels, query_set).shape

torch.Size([5, 50, 3, 84, 84])


torch.Size([5, 50, 5])

In [29]:
def compute_fo_grads(params, buffers, support_sample, support_labels, query_sample, query_labels):
    last_params = params
    for i in range(5):
        grads = compute_meta_grads(params, buffers, support_sample, support_labels)
    
        last_params = [p - 0.01 * g for p, g in zip(params, grads)]
    
    
    return compute_meta_grads(last_params, buffers, query_sample, query_labels)

In [30]:
def compute_loss(params, buffers, support_sample, support_labels, query_sample, query_labels):
    last_params = params
    for i in range(5):
        grads = compute_meta_grads(params, buffers, support_sample, support_labels)
    
        last_params = [p - 0.01 * g for p, g in zip(params, grads)]
    
    
    return compute_meta_loss(last_params, buffers, query_sample, query_labels)

In [31]:
compute_so_grads = grad(compute_loss)
batch_so_grads = vmap(compute_so_grads, in_dims=(None, None, 0, 0, 0, 0))

In [32]:
# compute_grads_maml = grad(compute_loss_stateless_model)

# example_support, supprot_labels, example_query, query_labels = generate_random_task()


# compute_fo_grads(params, buffers, example_support, supprot_labels, example_query, query_labels)

In [33]:
batch_fo_grads = vmap(compute_fo_grads, in_dims=(None, None, 0, 0, 0, 0))

In [34]:
# support_samples, support_labels, query_samples, query_labels = generate_random_batch(4)


# print(support_samples.shape)
# print(support_labels.shape)

# print(query_samples.shape)
# print(query_labels.shape)

In [35]:
# grads = batch_fo_grads(params, buffers, support_samples, support_labels, query_samples, query_labels)

# print(support_samples.shape)
# print(support_labels.shape)

# print(len(params))
# print(len(grads))

# for p, g in zip(params, grads):
#     print(p.shape)
#     print(g.sum(dim=0).shape)

In [36]:
def param_updater_factory(lr: float):
    def update_params(params, grads):
        with torch.no_grad():
            return [p.add_( -lr * g.sum(dim=0) / 4) for p, g in zip(params, grads)]
    
    return update_params

In [37]:
def cross_entropy(logits, labels):
    return F.cross_entropy(logits, labels)

batch_cross_entropy = vmap(cross_entropy, in_dims=(0, 0))

In [38]:

def train(args: HyperParameters):
    model = VinyalsConv(args.ways, embedding_feats=288, track_running_stats=False).to(device)
    fmodel, params, buffers = make_functional_with_buffers(model, disable_autograd_tracking=True)
    
    train_iter, test_iter = load_data(DataConfig(True, args.ways, args.shots, args.batch_size, args.query_size))
    
    update_params = param_updater_factory(args.beta)
    
    wandb.init(config=asdict(args))
    
    for e in tqdm(range(args.epochs)):
        train_batch = next(train_iter)
        
        support_samples = train_batch['support_set'].cuda(non_blocking=True)
        support_labels = train_batch['support_labels'].cuda(non_blocking=True)
        
        query_samples = train_batch['query_set'].cuda(non_blocking=True)
        query_labels = train_batch['query_labels'].cuda(non_blocking=True)
        
        if args.second_order == True:
            grads = batch_so_grads(params, buffers, support_samples, support_labels, query_samples, query_labels)
        else:
            grads = batch_fo_grads(params, buffers, support_samples, support_labels, query_samples, query_labels)
        
        params = update_params(params, grads)
        
        if e % 200 == 0 and e != 0:
            with torch.no_grad():
                test_batch = next(test_iter)
                
                support_samples = test_batch['support_set'].cuda(non_blocking=True)
                support_labels = test_batch['support_labels'].cuda(non_blocking=True)
                
                query_samples = test_batch['query_set'].cuda(non_blocking=True)
                query_labels = test_batch['query_labels'].cuda(non_blocking=True)
                
                logits = batch_compute_logits(
                    params,
                    buffers,
                    support_samples,
                    support_labels,
                    query_samples,
                )
                
                valid_loss = batch_cross_entropy(logits, query_labels)
                
                valid_acc = get_accuracy_from_logits(logits, query_labels)
                
                wandb.log({'valid_task_acc': valid_acc, 'valid_task_loss': valid_loss.sum()})
            

train(
    HyperParameters(
        epochs=20_000,
        alpha=0.01,
        beta=0.01,
        batch_size=4,
        num_meta_learn_loop=5,
        second_order=True,
        shots=1,
        ways=5,
        embedding_feats=288,
        query_size=10
    )
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mathecoder[0m ([33mdest[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666813326666367, max=1.0)…

  1%|          | 171/20000 [00:09<18:24, 17.95it/s]


KeyboardInterrupt: 