In [9]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import csv
import transformers
import torch
import random

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
## Credits
## This cell includes code from StylePTB (https://github.com/lvyiwei1/StylePTB/tree/master) by Yiwei Lyu et al., 
## available under the Creative Commons Attribution 4.0 International License ([CC BY 4.0](https://creativecommons.org/licenses/by/4.0/)).

def lowering(pairs):
    for pair in pairs:
        for i in range(0, 2):
            pair[i] = pair[i].lower()

def numpreprocess(pairs):
    for pair in pairs:
        for i in range(0, 2):
            rep = []
            for word in pair[i].split(' '):
                if len(word) > 0 and word[0] in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
                    rep.append("NUM")
                else:
                    rep.append(word)
            pair[i] = ' '.join(rep)

def padinput(inputlist, totalpad=80):
    pads = [0] * (totalpad - len(inputlist))
    input = inputlist + pads
    mask = [1] * len(inputlist) + pads
    return input, mask

# create label for training
def labels(inlen, outputlist, totalpad=80):
    pads1 = [-100] * inlen
    pads2 = [-100] * (totalpad - inlen - len(outputlist))
    # print(outputlist)
    return pads1 + outputlist + pads2

def batchvalid(src, trg, batchsize):
    validloss = 0.0
    for i in range(0, len(src) // batchsize):
        asrc = []
        atrg = []
        for pair in src[i * batchsize:(i + 1) * batchsize]:
            asrc.append(pair)
        for pair in trg[i * batchsize:(i + 1) * batchsize]:
            atrg.append(pair)
        validloss += valid(asrc, atrg)
    return validloss / (len(src) // batchsize)

def valid(src, trg):
    padin = [padinput(l) for l in src]
    padedin = torch.LongTensor([padin[i][0] for i in range(0, len(trg))]).to(device)
    masks = torch.LongTensor([padin[i][1] for i in range(0, len(trg))]).to(device)
    label = torch.LongTensor([labels(len(src[i]), trg[i]) for i in range(0, len(trg))]).to(device)
    with torch.no_grad():
        ret = gpt_model.forward(padedin, attention_mask=masks, labels=label)
        loss = ret[0]
    return loss

In [11]:
class MAML_GPT():
    def __init__(self, gpt_model, tasks, gpt_tokenizer=None, inner_lr=2e-4, meta_lr=2e-5, K=10, 
                 multi_batch_iter=1, inner_steps=1, epochs=10000, early_stop=50, model_save_name="maml_gpt"):
        self.tasks = tasks
        self.model = gpt_model
        self.gpt_tokenizer = gpt_tokenizer
        #self.criterion = nn.MSELoss()
        self.meta_optimiser = optim.Adam(model.parameters(), meta_lr)
        self.inner_lr = inner_lr
        self.meta_lr = meta_lr
        self.K = K
        self.inner_steps = inner_steps
        self.plot_every = 1
        self.print_every = 10
        self.meta_losses = []
        self.epochs = epochs
        self.early_stop = early_stop
        self.meta_batch_size = len(tasks)
        self.model_save_name = model_save_name
        self.multi_batch_iter = multi_batch_iter
        
    def inner_loop(self, task):
        with higher.innerloop_ctx(self.model, self.meta_optimiser, copy_initial_weights=True) as (fmodel, diffopt):
            #X, y = task.sample_data(self.K) #TODO
            random_selected_samples = random.sample(task, 2*self.K)
            for step in range(self.inner_steps):
                pred = fmodel(X)
                loss = pred[0]
                diffopt.step(loss)
            
            X, y = task.sample_data(self.K) #TODO
            pred = fmodel(X)
            loss = pred[0]
            return loss
    
    def main_loop(self, num_iterations):
        min_loss = 999
        early_stop_count = 0
        print_loss = 0
        for iteration in range(1, num_iterations + 1):
            meta_loss = 0
            for _ in range(multi_batch_iter):
                for task in tasks:
                    meta_loss += self.inner_loop(task)
            if meta_loss < min_loss:
                min_loss = meta_loss
                early_stop_count = 0
                print(f"New lowest loss found!")
                torch.save(self.model, f".\\MAML_GPT_models\\{self.model_save_name}_epoch{iteration}.pt")
            else:
                early_stop_count += 1
                if early_stop_count > self.early_stop:
                    print(f"Early stop at epoch {iteration} because no lower loss is found in {early_stop} epochs"}
                    return

            self.meta_optimiser.zero_grad()
            meta_loss.backward()
            self.meta_optimiser.step()
            print_loss += meta_loss.item() / self.meta_batch_size
            if iteration % self.print_every == 0:
                print(f"Epoch {iteration}/{num_iterations}. loss: {print_loss / print_every}")
                print_loss = 0
            if iteration % self.plot_every == 0:
                self.meta_losses.append(meta_loss.item() / self.meta_batch_size)

    def train(self):
        self.main_loop(self.epochs)

In [13]:
# Load pre-trained model and tokenizer
# Pre-process the data

if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')
print(f"Device: {device}")
gpt_tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
gpt_model = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(device)
meta_train_task_name_list = ["ARR", "TFU", "ATP", "PPR", "TPA"]
meta_test_task_name_list = ["PTA", "SBR", "TPR"]

train_task_pair_list = [] # Nested list with pairs from different tasks
for meta_train_task_name in meta_train_task_name_list:
    f = open(f'.\\Data\\Meta_training\\{meta_train_task_name}\\train.tsv', 'r')
    ff = csv.reader(f, delimiter='\t')
    pairs = []
    for row in ff:
        pairs.append(row)
    lowering(pairs)
    numpreprocess(pairs)
    pairsEncode = []
    for i in pairs:
        pairsEncode.append((gpt_tokenizer.encode(i[0] + " <|endoftext|>"), gpt_tokenizer.encode(i[1] + " <|endoftext|>")))
    train_task_pair_list.append(pairsEncode)

In [None]:
maml_gpt = MAML_GPT(gpt_model=gpt_model, tasks=train_task_pair_list, gpt_tokenizer=gpt_tokenizer, 
                    multi_batch_iter=1, early_stop=50, model_save_name="maml_gpt")