In [None]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=0
import os, sys, time
sys.path.insert(0, '..')
import lib

import math
import numpy as np
from copy import deepcopy
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
%matplotlib inline

# For reproducibility
import random
seed = random.randint(0, 2**32 - 1)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(seed)

In [None]:
from matplotlib.backends.backend_pdf import PdfPages
plt.style.use('seaborn-whitegrid')
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

# Setting

In [None]:
model_type = 'lstm'

# Language Model
emb_size = 128
sequence_length = 100
hidden_size = 256

# Dataset 
data_dir = './data'
train_batch_size = 128
valid_batch_size = 128
test_batch_size = 128

device = 'cuda' if torch.cuda.is_available() else 'cpu'

loss_function = F.nll_loss

# MAML
max_epochs = 3000
maml_batches_in_epoch = 200
maml_epochs = 1
maml_steps = maml_epochs * maml_batches_in_epoch

max_meta_parameters_grad_norm = 10.

loss_interval = 40
first_valid_step = 40

assert (maml_steps - first_valid_step) % loss_interval == 0
validation_steps = int((maml_steps - first_valid_step) / loss_interval + 1)

# Optimizer
optimizer_type='adam'
learning_rate = 0.001
momentum = 0.9

#Adam
beta1 = momentum
beta2 = 0.999
epsilon = 1e-8

# Meta optimizer
meta_learning_rate = 0.0003
meta_betas = (0.9, 0.997)
meta_decay_interval = max_epochs

checkpoint_steps = 6
recovery_step = None


kwargs = dict(
    first_valid_step=first_valid_step,
    valid_loss_interval=loss_interval,
)

In [None]:
exp_name = f"LSTM-2_LM_{model_type}_PTb_{optimizer_type}_lr{learning_rate}" + \
           f"_meta_lr{meta_learning_rate}_steps{maml_steps}_interval{loss_interval}" + \
           f"_tr_bs{train_batch_size}_val_bs{valid_batch_size}_seed_{seed}"

print("Experiment name: ", exp_name)

logs_path = "./logs/{}".format(exp_name)
assert recovery_step is not None or not os.path.exists(logs_path)
# !rm -rf {logs_path}

### Define tokenizer

In [None]:
import torchtext

data_dir = 'data/'

PTb_TEXT = torchtext.data.Field(lower=False, tokenize=list)
PTb_train, PTb_valid, PTb_test = torchtext.datasets.PennTreebank.splits(PTb_TEXT, root=data_dir)
PTb_TEXT.build_vocab(PTb_train)

PTb_voc_size = len(PTb_TEXT.vocab)

PTb_train_loader = torchtext.data.BPTTIterator(PTb_train, train_batch_size, sequence_length, 
                                           train=True, device=device, repeat=True, shuffle=True)
PTb_valid_loader = torchtext.data.BPTTIterator(PTb_valid, valid_batch_size, sequence_length, 
                                           train=True, device=device, repeat=True, shuffle=True)

PTb_test_ids = list(map(PTb_TEXT.vocab.stoi.get, PTb_test.examples[0].text))
PTb_test_ids = list(map(lambda x: x if x is not None else 0, PTb_test_ids))
full_PTb_test_ids = torch.tensor(PTb_test_ids)
part_PTb_test_ids = torch.tensor(PTb_test_ids)[:len(PTb_test_ids) // 10]

### Utils

In [None]:
import torch.nn.parallel
from concurrent.futures import Future, ThreadPoolExecutor, as_completed, TimeoutError
GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=16)

def run_in_background(func: callable, *args, **kwargs) -> Future:
    """ run func(*args, **kwargs) in background and return Future for its outputs """
    return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)


def samples_batches(dataloader, num_batches):
    x_batches, y_batches = [], []
    for batch_i, batch in enumerate(dataloader):
        if batch_i >= num_batches: break
        x_batches.append(batch.text.t())
        y_batches.append(batch.target.t()) 
    return x_batches, y_batches


def compute_maml_loss(_maml, x_batches, y_batches, x_val_batches, y_val_batches, device):
    with lib.training_mode(maml, is_train=True):
        maml.resample_parameters()
        updated_model, train_loss_history, valid_loss_history, *etc = \
            maml.forward(x_batches, y_batches, x_val_batches, y_val_batches, 
                         device=device, **kwargs)  
        train_loss = torch.cat(train_loss_history[first_valid_step:]).mean()
        valid_loss = torch.cat(valid_loss_history).mean() if len(valid_loss_history) > 0 else torch.zeros(1)
    return train_loss, valid_loss

def compute_loss(logp_next, batch_targets, **kwargs): # logp_next -> batch_inputs
    #logp_next = model(batch_inputs)  # [batch_size, seq_length, voc_size]
    xent = F.nll_loss(logp_next.reshape(-1, logp_next.shape[-1]), 
                      batch_targets.reshape(-1), reduction='none')
    loss = xent.mean()
    return loss


@torch.no_grad()
def compute_test_loss(model, loss_function, test_ids, **kwargs):
    model = model.cpu()
    logp_next = model(test_ids[:-1][None])
    loss = loss_function(logp_next, test_ids[1:][None])
    model = model.to(device)
    return loss

In [None]:
if optimizer_type == 'sgd':
    optimizer = lib.IngraphGradientDescent(learning_rate=learning_rate)
elif optimizer_type == 'momentum':
    optimizer = lib.IngraphMomentum(learning_rate=learning_rate, momentum=momentum)
elif optimizer_type == 'rmsprop':
    optimizer = lib.IngraphRMSProp(learning_rate=learning_rate, momentum=momentum, alpha=0.99, epsilon=epsilon)
elif optimizer_type == 'adam':
    optimizer = lib.IngraphAdam(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon)
else: 
    raise NotImplemetedError("This optimizer is not implemeted")

model = lib.models.language_model.LanguageModel(PTb_voc_size, emb_size, hid_size=hidden_size)
maml = lib.MAML(model, model_type, optimizer=optimizer, 
    checkpoint_steps=checkpoint_steps,
    loss_function=compute_loss
).to(device)

In [None]:
class TrainerLM(lib.Trainer):
    def train_on_batch(self, train_loader, valid_loader, prefix='train/', **kwargs):
        """ Performs a single gradient update and reports metrics """
        # Prepare train data
        x_batches, y_batches = samples_batches(train_loader, maml_batches_in_epoch)
        
        # Due to a little amount of validation data, 
        # validation batches are sampled from both remained train and valid sets
        x_val_batches, y_val_batches = samples_batches(train_loader, validation_steps - 1)
        x_tmp_batches, y_tmp_batches = samples_batches(valid_loader, 1)
        x_val_batches.extend(x_tmp_batches)
        y_val_batches.extend(y_tmp_batches)
        
        # Perform step
        self.meta_optimizer.zero_grad()
        train_loss, valid_loss = compute_maml_loss(self.maml, x_batches, y_batches, 
                                                   x_val_batches, y_val_batches, self.device)
        valid_loss.backward()
        
        global_grad_norm = nn.utils.clip_grad_norm_(list(self.maml.initializers.parameters()), 
                                                    max_meta_parameters_grad_norm)
        self.writer.add_scalar(prefix + "global_grad_norm", global_grad_norm, self.total_steps)
        
        grad_norms = {}
        for name, param in self.maml.initializers.named_parameters():
            if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                print("Nan or inf in meta grads")
            grad_norms[name] = param.grad.norm(2).item()
            param.grad = torch.where(torch.isnan(param.grad), torch.zeros_like(param.grad), param.grad)
            param.grad = torch.where(torch.isinf(param.grad), torch.zeros_like(param.grad), param.grad)
            
        self.meta_optimizer.step()
        self.logs.append((self.total_steps, train_loss.item(), valid_loss.item(), global_grad_norm, grad_norms))
        return self.record(train_loss=train_loss.item(),
                           valid_loss=valid_loss.item(), prefix=prefix)
    
    
    def parallel_train_on_batch(self, train_loader, valid_loader, prefix='train/', **kwargs):
        # generate training/validation batches for each device
        replica_devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
        replicas = torch.nn.parallel.replicate(self.maml, devices=replica_devices, detach=False)

        replica_inputs = []
        for i, (replica, replica_device) in enumerate(zip(replicas, replica_devices)):
            replica.untrained_initializers = lib.switch_initializers_device(replica.untrained_initializers, 
                                                                        replica_device)

            x_batches, y_batches = [], []
            for _ in range(maml_epochs):
                x, y = samples_batches(train_loader, batches_in_epoch)
                x_batches.extend(x)
                y_batches.extend(y)
            
            x_val_batches, y_val_batches = samples_batches(valid_loader, validation_steps)

            replica_inputs.append((replica, x_batches, x_batches,
                                   x_val_batches, x_val_batches, replica_device))
        
        replica_losses_futures = [run_in_background(compute_maml_loss, *replica_input)
                                  for replica_input in replica_inputs]
        
        replica_losses = [future.result() for future in replica_losses_futures]
        train_loss = sum(train_loss.item() for train_loss, _ in replica_losses) / len(replica_losses)
        valid_loss = sum(valid_loss.to(self.device) for _, valid_loss in replica_losses) / len(replica_losses)
        
        self.meta_optimizer.zero_grad()
        valid_loss.backward()
        
        global_grad_norm = nn.utils.clip_grad_norm_(list(self.maml.initializers.parameters()), 
                                                    max_meta_parameters_grad_norm)
        self.writer.add_scalar(prefix + "global_grad_norm", global_grad_norm, self.total_steps)
        
        grad_norms = {}
        for name, param in self.maml.initializers.named_parameters():
            if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                print("Nan or inf in meta grads")
            grad_norms[name] = param.grad.norm(2).item()
            param.grad = torch.where(torch.isnan(param.grad), torch.zeros_like(param.grad), param.grad)
            param.grad = torch.where(torch.isinf(param.grad), torch.zeros_like(param.grad), param.grad)
        
        self.meta_optimizer.step()
        self.logs.append((self.total_steps, train_loss.item(), valid_loss.item(), global_grad_norm, grad_norms))
        
        return self.record(train_loss=train_loss,
                           valid_loss=valid_loss.item(), prefix=prefix)

    @torch.no_grad()
    def plot_pdf(self):
        indices = [-1, 1, 2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 14, 15, 6, 
                   17, 18, 19, 20, 22, 23, 24, 25, 27, 28, 29, 30]
        plt.figure(figsize=[22, 26])
        i = 0
        for name, (weight_maml_init, bias_maml_init) in self.maml.initializers.items():
            weight_base_init, bias_base_init = self.maml.untrained_initializers[name]
            if not isinstance(weight_maml_init, nn.ModuleDict):
                weight_base_mean = weight_base_init.mean.item()
                weight_base_std = weight_base_init.std.item()
                weight_base_init = torch.distributions.Normal(weight_base_mean, weight_base_std)

                weight_maml_mean = weight_maml_init.mean.item()
                weight_maml_std = weight_maml_init.std.item()
                weight_maml_init = torch.distributions.Normal(weight_maml_mean, weight_maml_std)

                xx = np.linspace(min([weight_base_mean - 3.*weight_base_std, 
                                      weight_maml_mean - 3.*weight_maml_std]), 
                                 max([weight_base_mean + 3.*weight_base_std, 
                                      weight_maml_mean + 3.*weight_maml_std]), 1000)
                i += 1
                plt.subplot(6, 5, indices[i])
                yy = weight_base_init.log_prob(torch.tensor(xx)).exp().numpy()
                plt.plot(xx, yy, '--')
                yy = weight_maml_init.log_prob(torch.tensor(xx)).exp().numpy()
                plt.plot(xx, yy, c='g')
                plt.xticks(fontsize=12)
                plt.yticks(fontsize=12)
                plt.title(name + '_weight', fontsize=14)

                if bias_maml_init is not None:
                    bias_base_mean = bias_base_init.mean.item()
                    bias_base_std = bias_base_init.std.item()
                    bias_base_init = torch.distributions.Normal(bias_base_mean, bias_base_std)

                    bias_maml_mean = bias_maml_init.mean.item()
                    bias_maml_std = bias_maml_init.std.item()
                    bias_maml_init = torch.distributions.Normal(bias_maml_mean, bias_maml_std)

                    i += 1
                    plt.subplot(6, 5, indices[i])
                    if i == 12:
                        yy = bias_base_init.log_prob(torch.tensor(xx)).exp().numpy()
                        plt.plot(xx, yy, '--', label='Kaiming')
                        yy = bias_maml_init.log_prob(torch.tensor(xx)).exp().numpy()
                        plt.plot(xx, yy, c='g', label='DIMAML')
                        leg = plt.legend(loc=4, fontsize=15, frameon=False)
                        for line in leg.get_lines():
                            line.set_linewidth(1.6)
                    else:
                        yy = bias_base_init.log_prob(torch.tensor(xx)).exp().numpy()
                        plt.plot(xx, yy, '--')
                        yy = bias_maml_init.log_prob(torch.tensor(xx)).exp().numpy()
                        plt.plot(xx, yy, c='g')

                    plt.title(name + '_bias', fontsize=14)
                    plt.yticks(fontsize=12)
                    plt.xticks(fontsize=12)
            else:
                for weight_name, maml_init in weight_maml_init.items():
                    base_init = weight_base_init[weight_name]
                    
                    weight_name_split = weight_name.split('_')
                    if weight_name_split[-1] == '0':
                        gate_name = 'input_gate'
                    elif weight_name_split[-1] == '1':
                        gate_name = 'forget_gate'
                    elif weight_name_split[-1] == '2':
                        gate_name = 'update'
                    elif weight_name_split[-1] == '3':
                        gate_name = 'output_gate'
                        
                    weight_base_mean = base_init.mean.item()
                    weight_base_std = base_init.std.item()
                    base_init = torch.distributions.Normal(weight_base_mean, weight_base_std)

                    weight_maml_mean = maml_init.mean.item()
                    weight_maml_std = maml_init.std.item()
                    maml_init = torch.distributions.Normal(weight_maml_mean, weight_maml_std)

                    xx = np.linspace(min([weight_base_mean - 3.*weight_base_std, 
                                      weight_maml_mean - 3.*weight_maml_std]), 
                             max([weight_base_mean + 3.*weight_base_std, 
                                  weight_maml_mean + 3.*weight_maml_std]), 1000)
                    i += 1
                    plt.subplot(6, 5, indices[i])
                    yy = base_init.log_prob(torch.tensor(xx)).exp().numpy()
                    plt.plot(xx, yy, '--')
                    yy = maml_init.log_prob(torch.tensor(xx)).exp().numpy()
                    plt.plot(xx, yy, c='g')
                    plt.xticks(fontsize=12)
                    plt.yticks(fontsize=12)

                    if weight_name_split[0] == 'weight':
                        weight_name = '_'.join(['lstm', gate_name] + weight_name_split[:2])
                    else:
                        weight_name = '_'.join(['lstm', gate_name, weight_name_split[0]])
                    plt.title(weight_name, fontsize=14)               
        plt.show()
        
    def evaluate_metrics(self, train_loader, test_loader, prefix='val/', **kwargs):
        """ Predicts and evaluates metrics over the entire dataset """
        torch.cuda.empty_cache()
        
        print('Baseline')
        self.maml.resample_parameters(initializers=self.maml.untrained_initializers, is_final=True)
        base_model = deepcopy(self.maml.model)    
        base_train_loss_history, base_test_loss_history = eval_model(base_model, train_loader, test_loader,
                                                                     device=self.device, **kwargs)
        print('Ours')
        self.maml.resample_parameters(is_final=True)
        maml_model = deepcopy(self.maml.model)
        maml_train_loss_history, maml_test_loss_history = eval_model(maml_model, train_loader, test_loader, 
                                                                     device=self.device,  **kwargs)
        draw_plots(base_train_loss_history, base_test_loss_history, 
                   maml_train_loss_history, maml_test_loss_history)
        
        self.writer.add_scalar(prefix + "train_AUC", sum(maml_train_loss_history), self.total_steps)
        self.writer.add_scalar(prefix + "test_AUC", sum(maml_test_loss_history), self.total_steps)
        self.writer.add_scalar(prefix + "test_loss", maml_test_loss_history[-1], self.total_steps)

In [None]:
##################
# Eval functions #
##################

def adjust_learning_rate(optimizer, epoch, milestone=80, **kwargs):
    """decrease the learning rate at 80 epoch"""
    if milestone <= epoch:
        lr = learning_rate / 10.
    else:
        lr = learning_rate
        
    for param_group in optimizer.param_groups:        
        param_group['lr'] = lr
    return lr

def eval_model(model, train_loader, test_ids, epochs=3, 
               test_loss_interval=20, mode='train', device='cuda', **kwargs):
    if optimizer_type == 'sgd':
        optimizer = torch.optim.SGD(maml.get_parameters(model), lr=learning_rate)
    elif optimizer_type == 'momentum':
        optimizer = torch.optim.SGD(maml.get_parameters(model), lr=learning_rate, momentum=momentum)
    elif optimizer_type == 'rmsprop':
        optimizer = torch.optim.RMSprop(maml.get_parameters(model), lr=learning_rate, beta=beta, eps=epsilon)
    elif optimizer_type == 'adam':
        optimizer = torch.optim.Adam(maml.get_parameters(model), lr=learning_rate, 
                                     betas=(beta1, beta2), eps=epsilon)
    else: 
        raise NotImplemetedError("{} optimizer is not implemeted".format(optimizer_type))
        
    # Train loop
    train_loss_history = []
    test_loss_history = []

    training_mode = model.training
    total_iters = 0
    epoch = 0
    model.train()
            
    for i, batch in enumerate(train_loader):
        adjust_learning_rate(optimizer, epoch, **kwargs)
        epoch = (total_iters + 1) // len(train_loader)
        
        optimizer.zero_grad()
        preds = model(batch.text.t())
        loss = compute_loss(preds, batch.target.t())
        loss.backward()
        optimizer.step()
        train_loss_history.append(loss.item())
        
        if (total_iters == 0) or (total_iters + 1) % test_loss_interval == 0:
            model.eval()
            test_loss = compute_test_loss(model, compute_loss, test_ids, device=device, **kwargs)
            bpc = test_loss * math.log2(math.e)
            print(f"Epoch {epoch} | Iteration {total_iters + 1} | Loss {test_loss} | bpc {bpc}")
            test_loss_history.append(test_loss)
            model.train()
            
        if epoch >= epochs: break
        if mode == 'train' and total_iters >= maml_steps: break
        total_iters += 1
        
    model.train(training_mode)
    return train_loss_history, test_loss_history

    
def draw_plots(base_train_loss, base_test_loss, maml_train_loss, maml_test_loss):
    plt.figure(figsize=(20,6))
    plt.subplot(1,2,1)
    plt.plot(moving_average(base_train_loss, span=10), label='Baseline')
    plt.plot(moving_average(maml_train_loss, span=10), c='g', label='Ours')
    plt.legend(fontsize=14)
    plt.title("Train loss", fontsize=14)
    plt.subplot(1,2,2)
    plt.plot(base_test_loss, label='Baseline')
    plt.plot(maml_test_loss, c='g', label='Ours')
    plt.legend(fontsize=14)
    plt.title("Test loss", fontsize=14)
    plt.show()

In [None]:
from IPython.display import clear_output
from pandas import DataFrame

moving_average = lambda x, **kw: DataFrame({'x':np.asarray(x)}).x.ewm(**kw).mean().values
train_loss_history = []
valid_loss_history = []

trainer = TrainerLM(maml, decay_interval=meta_decay_interval, meta_lr=meta_learning_rate, 
                    meta_betas=meta_betas, exp_name=exp_name, recovery_step=recovery_step)

In [None]:
lib.free_memory()
t0 = time.time()

while trainer.total_steps <= max_epochs:
    metrics = trainer.train_on_batch(
        PTb_train_loader, PTb_valid_loader, 
        first_valid_step=first_valid_step, 
        valid_loss_interval=loss_interval,
    )
    
    train_loss = metrics['train_loss']
    train_loss_history.append(train_loss)
    
    valid_loss = metrics['valid_loss']
    valid_loss_history.append(valid_loss)
    
    if trainer.total_steps % 20 == 0:
        clear_output(True)
        print("Step %d | Time: %f | Train Loss %.5f | Valid loss %.5f" % 
              (trainer.total_steps, time.time()-t0, train_loss, valid_loss))
        plt.figure(figsize=[16, 5])
        plt.subplot(1,2,1)
        plt.title('Train Loss over time')
        plt.plot(moving_average(train_loss_history, span=50))
        plt.scatter(range(len(train_loss_history)), train_loss_history, alpha=0.1)
        plt.subplot(1,2,2)
        plt.title('Valid Loss over time')
        plt.plot(moving_average(valid_loss_history, span=50))
        plt.scatter(range(len(valid_loss_history)), valid_loss_history, alpha=0.1)
        plt.show()
        trainer.evaluate_metrics(PTb_train_loader, part_PTb_test_ids, 
                                 batches_in_epoch=maml_batches_in_epoch,
                                 epochs=maml_epochs, test_loss_interval=loss_interval)
        trainer.plot_pdf()
        t0 = time.time()
    
    if trainer.total_steps % 100 == 0:
        trainer.save_model()
    trainer.total_steps += 1

In [None]:
trainer.plot_pdf()

# Evaluation

In [None]:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

def gradient_quotient(loss, params, eps=1e-5): 
    grad = torch.autograd.grad(loss, params, retain_graph=True, create_graph=True)
    prod = torch.autograd.grad(sum([(g**2).sum() / 2 for g in grad]),
                               params, retain_graph=True, create_graph=True)
    out = sum([((g - p) / (g + eps * (2*(g >= 0).float() - 1).detach()) - 1).abs().sum() 
               for g, p in zip(grad, prod)])
    return out / sum([p.data.nelement() for p in params])

              
def metainit(model, criterion, x_size, y_size, lr=0.1, momentum=0.9, steps=150, eps=1e-5):
    model.eval()
    params = [p for p in model.parameters() 
              if p.requires_grad and len(p.size()) >= 2]
    memory = [0] * len(params)
    for i in range(steps):
        sequences = torch.randint(0, y_size, torch.Size([x_size[0], x_size[1] + 1])).cuda()
        input, target = sequences[:, :-1], sequences[:, 1:]
        loss = criterion(model(input), target)
        gq = gradient_quotient(loss, [p for p in model.parameters() 
              if p.requires_grad], eps)
        
        grad = torch.autograd.grad(gq, params)
        for j, (p, g_all) in enumerate(zip(params, grad)):
            norm = p.data.norm().item()
            g = torch.sign((p.data * g_all).sum() / norm) 
            memory[j] = momentum * memory[j] - lr * g.item() 
            new_norm = norm + memory[j]
            p.data.mul_(new_norm / norm)
        print("%d/GQ = %.2f" % (i, gq.item()))
              
              
def lstm_orthogonal(cell, gain=1):
    cell.reset_parameters()

    # orthogonal initialization of recurrent weights
    _, hh, _, _ = list(cell.parameters())
    for i in range(0, hh.size(0), cell.hidden_size):
         torch.nn.init.orthogonal_(hh[i:i + cell.hidden_size], gain=gain)

# Eval WikiText2 (D_test)

In [None]:
WikiTEXT = torchtext.data.Field(lower=False, tokenize=list)

# load corpora, each dataset only contains one long "example" with all text in that example
wikitext2_train, wikitext2_valid, wikitext2_test = torchtext.datasets.WikiText2.splits(WikiTEXT, root=data_dir, 
                                                         train='wiki.train.raw',
                                                         validation='wiki.valid.raw',
                                                         test='wiki.test.raw')
WikiTEXT.build_vocab(wikitext2_train)
wikitext2_voc_size = len(WikiTEXT.vocab)

wikitext2_train_loader = torchtext.data.BPTTIterator(wikitext2_train, train_batch_size, sequence_length, 
                                           train=True, device=device, repeat=True, shuffle=True)
wikitext2_valid_loader = torchtext.data.BPTTIterator(wikitext2_valid, valid_batch_size, sequence_length, 
                                           train=False, device=device, repeat=False, shuffle=True)

wikitext2_test_ids = list(map(WikiTEXT.vocab.stoi.get, wikitext2_test.examples[0].text))
wikitext2_test_ids = list(map(lambda x: x if x is not None else 0, wikitext2_test_ids))
full_wikitext2_test_ids = torch.tensor(wikitext2_test_ids)

# For MetaInit
batch = next(iter(wikitext2_train_loader))
text, target = batch.text.t(), batch.target.t()

# Tune voc_size
maml.model.emb_vectors = nn.Embedding(wikitext2_voc_size, emb_size).to(device)
maml.model.logits = nn.Linear(hidden_size, wikitext2_voc_size).to(device)
maml.model.init_weights()

In [None]:
num_reruns = 10
wikitext2_batches_in_epoch = len(wikitext2_train_loader)

reruns_base, reruns_orthogonal = [], []
reruns_metainit, reruns_dimaml = [], []

for rerun_id in range(num_reruns):
    print(f"Rerun #{rerun_id}")
    print('DIMAML')
    maml.resample_parameters(is_final=True)
    maml_model = deepcopy(maml.model)
    maml_train_loss_history, maml_test_loss_history = eval_model(maml_model, wikitext2_train_loader,
                                                                 part_wikitext2_test_ids, epochs=100, 
                                                                 device=device, mode='eval',
                                                                 test_loss_interval=10*wikitext2_batches_in_epoch)
    
    reruns_dimaml.append(maml_test_loss_history)
    
    print('Baseline')
    maml.resample_parameters(initializers=maml.untrained_initializers, is_final=True)
    base_model = deepcopy(maml.model)    
    base_train_loss_history, base_test_loss_history = eval_model(base_model, wikitext2_train_loader, 
                                                                 part_wikitext2_test_ids, epochs=100, 
                                                                 device=device, mode='eval',
                                                                 test_loss_interval=10*wikitext2_batches_in_epoch)
    reruns_base.append(base_test_loss_history)
    
    print("Orthogonal")
    maml.resample_parameters(initializers=maml.untrained_initializers, is_final=True)
    orthogonal_model = deepcopy(maml.model)
    lstm_orthogonal(orthogonal_model.lstm1)
    lstm_orthogonal(orthogonal_model.lstm2)
    orthogonal_train_loss_history, orthogonal_test_loss_history = eval_model(orthogonal_model, wikitext2_train_loader,
                                                                  part_wikitext2_test_ids, epochs=100, 
                                                                  device=device, mode='eval',
                                                                  test_loss_interval=10*wikitext2_batches_in_epoch)
    reruns_orthogonal.append(orthogonal_test_loss_history)
    
    print("Metainit")
    maml.resample_parameters(initializers=maml.untrained_initializers, is_final=True)
    metainit_model = deepcopy(maml.model)
    metainit(metainit_model, compute_loss, text.shape, wikitext2_voc_size)
    metainit_train_loss_history, metainit_test_loss_history = eval_model(metainit_model, wikitext2_train_loader, 
                                                                         part_wikitext2_test_ids, epochs=100,
                                                                         device=device, mode='eval', 
                                                                         test_loss_interval=10*wikitext2_batches_in_epoch)
    reruns_metainit.append(metainit_test_loss_history)
    
reruns_base = np.array(reruns_base) * math.log2(math.e)
reruns_dimaml = np.array(reruns_dimaml) * math.log2(math.e)
reruns_metainit = np.array(reruns_metainit) * math.log2(math.e)
reruns_orthogonal = np.array(reruns_orthogonal) * math.log2(math.e)

In [None]:
print("Baseline 10 epoch: ", reruns_base.mean(0)[1], reruns_base.std(0, ddof=1)[1])
print("Baseline 50 epoch: ", reruns_base.mean(0)[5], reruns_base.std(0, ddof=1)[5])
print("Baseline 100 epoch: ", reruns_base.mean(0)[10], reruns_base.std(0, ddof=1)[10])
print()
print("DIMAML 10 epoch: ", reruns_dimaml.mean(0)[1], reruns_dimaml.std(0, ddof=1)[1])
print("DIMAML 50 epoch: ", reruns_dimaml.mean(0)[5], reruns_dimaml.std(0, ddof=1)[5])
print("DIMAML 100 epoch: ", reruns_dimaml.mean(0)[10], reruns_dimaml.std(0, ddof=1)[10])
print()
print("MetaInit 10 epoch: ", fixed_reruns_metainit.mean(0)[1], fixed_reruns_metainit.std(0, ddof=1)[1])
print("MetaInit 50 epoch: ", fixed_reruns_metainit.mean(0)[5], fixed_reruns_metainit.std(0, ddof=1)[5])
print("MetaInit 100 epoch: ", fixed_reruns_metainit.mean(0)[10], fixed_reruns_metainit.std(0, ddof=1)[10])
print()
print("Orthogonal 10 epoch: ", reruns_orthogonal.mean(0)[1], reruns_orthogonal.std(0, ddof=1)[1])
print("Orthogonal 50 epoch: ", reruns_orthogonal.mean(0)[5], reruns_orthogonal.std(0, ddof=1)[5])
print("Orthogonal 100 epoch: ", reruns_orthogonal.mean(0)[10], reruns_orthogonal.std(0, ddof=1)[10])