<a href="https://colab.research.google.com/github/Yash-Kamtekar/Special-Topics-Assignment-4/blob/main/Meta_Learning_Bert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [20]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [21]:
import json
from random import shuffle
reviews = json.load(open('/content/drive/MyDrive/297/Assignment_4/dataset.json'))

reviews[:5]

[{'text': "GOOD LOOKING KICKS IF YOUR KICKIN IT OLD SCHOOL LIKE ME. AND COMFORTABLE. AND RELATIVELY CHEAP. I'LL ALWAYS KEEP A PAIR OF STAN SMITH'S AROUND FOR WEEKENDS",
  'label': 'positive',
  'domain': 'apparel'},
 {'text': 'These sunglasses are all right. They were a little crooked, but still cool..',
  'label': 'positive',
  'domain': 'apparel'},
 {'text': "I don't see the difference between these bodysuits and the more expensive ones. Fits my boy just right",
  'label': 'positive',
  'domain': 'apparel'},
 {'text': 'Very nice basic clothing. I think the size is fine. I really like being able to find these shades of green, though I have decided the lighter shade is really a feminine color. This is the only brand that I can find these muted greens',
  'label': 'positive',
  'domain': 'apparel'},
 {'text': 'I love these socks. They fit great (my 15 month old daughter has thick ankles) and she can zoom around on the kitchen floor and not take a nose dive into things. :',
  'label': 'p

In [22]:
from collections import Counter
mention_domain = [r['domain'] for r in reviews]
Counter(mention_domain)

Counter({'apparel': 1717,
         'baby': 1107,
         'beauty': 993,
         'books': 921,
         'camera_&_photo': 1086,
         'cell_phones_&_service': 698,
         'dvd': 893,
         'electronics': 1277,
         'grocery': 1100,
         'health_&_personal_care': 1429,
         'jewelry_&_watches': 1086,
         'kitchen_&_housewares': 1390,
         'magazines': 1133,
         'music': 1007,
         'outdoor_living': 980,
         'software': 1029,
         'sports_&_outdoors': 1336,
         'toys_&_games': 1363,
         'video': 1010,
         'automotive': 100,
         'computer_&_video_games': 100,
         'office_products': 100})

## Creating Meta Learning Tasks

In [23]:
import os
import torch
from torch.utils.data import Dataset
import numpy as np
import collections
import random
import json, pickle
from torch.utils.data import TensorDataset

LABEL_MAP  = {'positive':0, 'negative':1, 0:'positive', 1:'negative'}

class MetaTask(Dataset):
    
    def __init__(self, examples, num_task, k_support, k_query, tokenizer):
        """
        :param samples: list of samples
        :param num_task: number of training tasks.
        :param k_support: number of support sample per task
        :param k_query: number of query sample per task
        """
        self.examples = examples
        random.shuffle(self.examples)
        
        self.num_task = num_task
        self.k_support = k_support
        self.k_query = k_query
        self.tokenizer = tokenizer
        self.max_seq_length = 128
        self.create_batch(self.num_task)
    
    def create_batch(self, num_task):
        self.supports = []  # support set
        self.queries = []  # query set
        for b in range(num_task):  # for each task
            # 1.select domain randomly
            domain = random.choice(self.examples)['domain']
            domainExamples = [e for e in self.examples if e['domain'] == domain]
            
            # 1.select k_support + k_query examples from domain randomly
            selected_examples = random.sample(domainExamples,self.k_support + self.k_query)
            random.shuffle(selected_examples)
            exam_train = selected_examples[:self.k_support]
            exam_test  = selected_examples[self.k_support:]
            
            self.supports.append(exam_train)
            self.queries.append(exam_test)
    def create_feature_set(self,examples):
        all_input_ids      = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
        all_attention_mask = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
        all_segment_ids    = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
        all_label_ids      = torch.empty(len(examples), dtype = torch.long)

        for id_,example in enumerate(examples):
            input_ids = tokenizer.encode(example['text'])
            attention_mask = [1] * len(input_ids)
            segment_ids    = [0] * len(input_ids)

            while len(input_ids) < self.max_seq_length:
                input_ids.append(0)
                attention_mask.append(0)
                segment_ids.append(0)

            label_id = LABEL_MAP[example['label']]
            all_input_ids[id_] = torch.Tensor(input_ids).to(torch.long)
            all_attention_mask[id_] = torch.Tensor(attention_mask).to(torch.long)
            all_segment_ids[id_] = torch.Tensor(segment_ids).to(torch.long)
            all_label_ids[id_] = torch.Tensor([label_id]).to(torch.long)

        tensor_set = TensorDataset(all_input_ids, all_attention_mask, all_segment_ids, all_label_ids)  
        return tensor_set
    
    def __getitem__(self, index):
        support_set = self.create_feature_set(self.supports[index])
        query_set   = self.create_feature_set(self.queries[index])
        return support_set, query_set

    def __len__(self):
        # as we have built up to batchsz of sets, you can sample some small batch size of sets.
        return self.num_task

## Split meta training and meta testing

In [24]:
low_resource_domains = ["office_products", "automotive", "computer_&_video_games"]
train_examples = [r for r in reviews if r['domain'] not in low_resource_domains]
test_examples = [r for r in reviews if r['domain'] in low_resource_domains]
print(len(train_examples), len(test_examples))

21555 300


In [25]:
!pip install transformers==2.5.1

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [26]:
import torch
from transformers import BertModel, BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case = True)
train = MetaTask(train_examples, num_task = 20, k_support=100, k_query=30, tokenizer = tokenizer)

## Take a glance at the first two samples from support set of 1st meta-task

In [27]:
train.supports[0][:2]

[{'text': '...buy it here and save $s!Shipping cost from the UK added $46.95 to the price AND a Foreign Debit Card Transaction fee of $2.51 for a waste of $49.46 that could have been spent getting additional models.Or get the download version and save even more as like someone else pointed out, the included manual is near useless and has not improved since version 4. Practical Poser 6 is still the best book to get if you are serious about doing Poser 6 development',
  'label': 'negative',
  'domain': 'software'},
 {'text': "Looks nice... but beware!Installed Norton 360 on the 6 workstations in our small office. As soon as I did that our accounting system (Peachtree) started to lock up a few times a day. Even after turing off the firewall on all workstations the accounting system still locked up.I've now uninstalled Norton 360 on the server PC and that seems to have solved the problem.I wish now that I had just upgraded from Norton Internet Security 2005 to 2007... Never again.",
  'lab

In [28]:
# Let take a look at the first two samples from support set
train[0][0][:2]

(tensor([[  101,  1012,  1012,  1012,  4965,  2009,  2182,  1998,  3828,  1002,
           1055,   999,  7829,  3465,  2013,  1996,  2866,  2794,  1002,  4805,
           1012,  5345,  2000,  1996,  3976,  1998,  1037,  3097,  2139, 16313,
           4003, 12598,  7408,  1997,  1002,  1016,  1012,  4868,  2005,  1037,
           5949,  1997,  1002,  4749,  1012,  4805,  2008,  2071,  2031,  2042,
           2985,  2893,  3176,  4275,  1012,  2030,  2131,  1996,  8816,  2544,
           1998,  3828,  2130,  2062,  2004,  2066,  2619,  2842,  4197,  2041,
           1010,  1996,  2443,  6410,  2003,  2379, 11809,  1998,  2038,  2025,
           5301,  2144,  2544,  1018,  1012,  6742, 13382,  2099,  1020,  2003,
           2145,  1996,  2190,  2338,  2000,  2131,  2065,  2017,  2024,  3809,
           2055,  2725, 13382,  2099,  1020,  2458,   102,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,   

## Training Meta

In [29]:
import time
import logging
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

def random_seed(value):
    torch.backends.cudnn.deterministic=True
    torch.manual_seed(value)
    torch.cuda.manual_seed(value)
    np.random.seed(value)
    random.seed(value)

def create_batch_of_tasks(taskset, is_shuffle = True, batch_size = 4):
    idxs = list(range(0,len(taskset)))
    if is_shuffle:
        random.shuffle(idxs)
    for i in range(0,len(idxs), batch_size):
        yield [taskset[idxs[i]] for i in range(i, min(i + batch_size,len(taskset)))]

class TrainingArgs:
    def __init__(self):
        self.num_labels = 2
        self.meta_epoch=10
        self.k_spt=80
        self.k_qry=20
        self.outer_batch_size = 2
        self.inner_batch_size = 12
        self.outer_update_lr = 5e-5
        self.inner_update_lr = 5e-5
        self.inner_update_step = 10
        self.inner_update_step_eval = 40
        self.bert_model = 'bert-base-uncased'
        self.num_task_train = 10
        self.num_task_test = 5

args = TrainingArgs()

Creating Meta Learner

In [30]:
from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from transformers import BertForSequenceClassification
from copy import deepcopy
import gc
from sklearn.metrics import accuracy_score
import torch
import numpy as np

class Learner(nn.Module):
    """
    Meta Learner
    """
    def __init__(self, args):
        """
        :param args:
        """
        super(Learner, self).__init__()
        
        self.num_labels = args.num_labels
        self.outer_batch_size = args.outer_batch_size
        self.inner_batch_size = args.inner_batch_size
        self.outer_update_lr  = args.outer_update_lr
        self.inner_update_lr  = args.inner_update_lr
        self.inner_update_step = args.inner_update_step
        self.inner_update_step_eval = args.inner_update_step_eval
        self.bert_model = args.bert_model
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.model = BertForSequenceClassification.from_pretrained(self.bert_model, num_labels = self.num_labels)
        self.outer_optimizer = Adam(self.model.parameters(), lr=self.outer_update_lr)
        self.model.train()
    def forward(self, batch_tasks, training = True):
        """
        batch = [(support TensorDataset, query TensorDataset),
                 (support TensorDataset, query TensorDataset),
                 (support TensorDataset, query TensorDataset),
                 (support TensorDataset, query TensorDataset)]
        
        # support = TensorDataset(all_input_ids, all_attention_mask, all_segment_ids, all_label_ids)
        """
        task_accs = []
        sum_gradients = []
        num_task = len(batch_tasks)
        num_inner_update_step = self.inner_update_step if training else self.inner_update_step_eval

        for task_id, task in enumerate(batch_tasks):
            support = task[0]
            query   = task[1]
            
            fast_model = deepcopy(self.model)
            fast_model.to(self.device)
            support_dataloader = DataLoader(support, sampler=RandomSampler(support),
                                            batch_size=self.inner_batch_size)
            
            inner_optimizer = Adam(fast_model.parameters(), lr=self.inner_update_lr)
            fast_model.train()
            
            print('----Task',task_id, '----')
            for i in range(0,num_inner_update_step):
                all_loss = []
                for inner_step, batch in enumerate(support_dataloader):
                    batch = tuple(t.to(self.device) for t in batch)
                    input_ids, attention_mask, segment_ids, label_id = batch
                    outputs = fast_model(input_ids, attention_mask, segment_ids, labels = label_id)
                    
                    loss = outputs[0]              
                    loss.backward()
                    inner_optimizer.step()
                    inner_optimizer.zero_grad()
                    
                    all_loss.append(loss.item())
                
                if i % 4 == 0:
                    print("Inner Loss: ", np.mean(all_loss))
            
            fast_model.to(torch.device('cpu'))
            
            if training:
                meta_weights = list(self.model.parameters())
                fast_weights = list(fast_model.parameters())

                gradients = []
                for i, (meta_params, fast_params) in enumerate(zip(meta_weights, fast_weights)):
                    gradient = meta_params - fast_params
                    if task_id == 0:
                        sum_gradients.append(gradient)
                    else:
                        sum_gradients[i] += gradient
            fast_model.to(self.device)
            fast_model.eval()
            with torch.no_grad():
                query_dataloader = DataLoader(query, sampler=None, batch_size=len(query))
                query_batch = iter(query_dataloader).next()
                query_batch = tuple(t.to(self.device) for t in query_batch)
                q_input_ids, q_attention_mask, q_segment_ids, q_label_id = query_batch
                q_outputs = fast_model(q_input_ids, q_attention_mask, q_segment_ids, labels = q_label_id)

                q_logits = F.softmax(q_outputs[1],dim=1)
                pre_label_id = torch.argmax(q_logits,dim=1)
                pre_label_id = pre_label_id.detach().cpu().numpy().tolist()
                q_label_id = q_label_id.detach().cpu().numpy().tolist()

                acc = accuracy_score(pre_label_id,q_label_id)
                task_accs.append(acc)
            
            fast_model.to(torch.device('cpu'))
            del fast_model, inner_optimizer
            torch.cuda.empty_cache()
        if training:
            # Average gradient across tasks
            for i in range(0,len(sum_gradients)):
                sum_gradients[i] = sum_gradients[i] / float(num_task)

            #Assign gradient for original model, then using optimizer to update its weights
            for i, params in enumerate(self.model.parameters()):
                params.grad = sum_gradients[i]

            self.outer_optimizer.step()
            self.outer_optimizer.zero_grad()
            
            del sum_gradients
            gc.collect()
        
        return np.mean(task_accs)

In [31]:
learner = Learner(args)

In [32]:
random_seed(123)
test = MetaTask(test_examples, num_task = 3, k_support=80, k_query=20, tokenizer = tokenizer)
random_seed(int(time.time() % 10))

In [33]:
test.supports[2]

[{'text': 'Love love love the moleskin notebooks in every size. So cool looking and just get cooler as they get beat up. My favorite size is the small one because I can carry it around in my bag so that I can scribble down any random thoughts throughout the day. One of my unexpectedly best purchases!',
  'label': 'positive',
  'domain': 'office_products'},
 {'text': "A friend of mine had this planner and when I saw it I knew I just had to have it! It's great because I can keep track of my week on one side and then make my 'to do' lists or keep notes on the otherside of the page. I recommend this planner to EVERYONE",
  'label': 'positive',
  'domain': 'office_products'},
 {'text': 'Does it make any sense to purchase paper for 30.99 and then pay 27.99 for shipping to recieve it. Someone over there should really look at how realistic that is',
  'label': 'negative',
  'domain': 'office_products'},
 {'text': 'The nib on this pen is far too broad. Most fountain pen manufacturers would prob

## Training

In [34]:
global_step = 0

for epoch in range(args.meta_epoch):
    
    train = MetaTask(train_examples, num_task = 10, k_support=80, k_query=20, tokenizer = tokenizer)
    db = create_batch_of_tasks(train, is_shuffle = True, batch_size = args.outer_batch_size)

    for step, task_batch in enumerate(db):
        
        f = open('log.txt', 'a')
        
        acc = learner(task_batch)
        
        print('Step:', step, '\ttraining Acc:', acc)
        f.write(str(acc) + '\n')
        
        if global_step % 20 == 0:
            random_seed(123)
            print("\n-----------------Testing Mode-----------------\n")
            db_test = create_batch_of_tasks(test, is_shuffle = False, batch_size = 1)
            acc_all_test = []

            for test_batch in db_test:
                acc = learner(test_batch, training = False)
                acc_all_test.append(acc)

            print('Step:', step, 'Test F1:', np.mean(acc_all_test))
            f.write('Test' + str(np.mean(acc_all_test)) + '\n')
            
            random_seed(int(time.time() % 10))
        
        global_step += 1
        f.close()

----Task 0 ----
Inner Loss:  0.7173990692411151
Inner Loss:  0.06377800021852766
Inner Loss:  0.0037996030878275633
----Task 1 ----
Inner Loss:  0.6026502336774554
Inner Loss:  0.06383104808628559
Inner Loss:  0.004119277931749821
Step: 0 	training Acc: 0.7

-----------------Testing Mode-----------------

----Task 0 ----
Inner Loss:  0.6777281420571464
Inner Loss:  0.014927135647407599
Inner Loss:  0.001665695570409298
Inner Loss:  0.0009195762520123805
Inner Loss:  0.0006575825391337276
Inner Loss:  0.0004932242819839823
Inner Loss:  0.00039933801079834145
Inner Loss:  0.00034068627116669504
Inner Loss:  0.0002873630770149508
Inner Loss:  0.00024793007261385877
----Task 0 ----
Inner Loss:  0.5497982374259404
Inner Loss:  0.017872560104089125
Inner Loss:  0.004533191677182913
Inner Loss:  0.002481567324139178
Inner Loss:  0.0015738967174131954
Inner Loss:  0.0011426725790702871
Inner Loss:  0.0008683831131617938
Inner Loss:  0.0006934135370621723
Inner Loss:  0.000567113988966282
Inner

## Meta_Learner_MAML

In [35]:
from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from transformers import BertForSequenceClassification
from copy import deepcopy
import gc
import torch
from sklearn.metrics import accuracy_score
import numpy as np

class Learner(nn.Module):
    """
    Meta Learner
    """
    def __init__(self, args):
        """
        :param args:
        """
        super(Learner, self).__init__()
        
        self.num_labels = args.num_labels
        self.outer_batch_size = args.outer_batch_size
        self.inner_batch_size = args.inner_batch_size
        self.outer_update_lr  = args.outer_update_lr
        self.inner_update_lr  = args.inner_update_lr
        self.inner_update_step = args.inner_update_step
        self.inner_update_step_eval = args.inner_update_step_eval
        self.bert_model = args.bert_model
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.model = BertForSequenceClassification.from_pretrained(self.bert_model, num_labels = self.num_labels)
        self.outer_optimizer = Adam(self.model.parameters(), lr=self.outer_update_lr)
        self.model.train()

    def forward(self, batch_tasks, training = True):
        """
        batch = [(support TensorDataset, query TensorDataset),
                 (support TensorDataset, query TensorDataset),
                 (support TensorDataset, query TensorDataset),
                 (support TensorDataset, query TensorDataset)]
        
        # support = TensorDataset(all_input_ids, all_attention_mask, all_segment_ids, all_label_ids)
        """
        task_accs = []
        sum_gradients = []
        num_task = len(batch_tasks)
        num_inner_update_step = self.inner_update_step if training else self.inner_update_step_eval

        for task_id, task in enumerate(batch_tasks):
            support = task[0]
            query   = task[1]
            
            fast_model = deepcopy(self.model)
            fast_model.to(self.device)
            support_dataloader = DataLoader(support, sampler=RandomSampler(support),
                                            batch_size=self.inner_batch_size)
            
            inner_optimizer = Adam(fast_model.parameters(), lr=self.inner_update_lr)
            fast_model.train()
            
            print('----Task',task_id, '----')
            for i in range(0,num_inner_update_step):
                all_loss = []
                for inner_step, batch in enumerate(support_dataloader):
                    
                    batch = tuple(t.to(self.device) for t in batch)
                    input_ids, attention_mask, segment_ids, label_id = batch
                    outputs = fast_model(input_ids, attention_mask, segment_ids, labels = label_id)
                    
                    loss = outputs[0]              
                    loss.backward()
                    inner_optimizer.step()
                    inner_optimizer.zero_grad()
                    
                    all_loss.append(loss.item())
                
                if i % 4 == 0:
                    print("Inner Loss: ", np.mean(all_loss))

            query_dataloader = DataLoader(query, sampler=None, batch_size=len(query))
            query_batch = iter(query_dataloader).next()
            query_batch = tuple(t.to(self.device) for t in query_batch)
            q_input_ids, q_attention_mask, q_segment_ids, q_label_id = query_batch
            q_outputs = fast_model(q_input_ids, q_attention_mask, q_segment_ids, labels = q_label_id)
            
            if training:
                q_loss = q_outputs[0]
                q_loss.backward()
                fast_model.to(torch.device('cpu'))
                for i, params in enumerate(fast_model.parameters()):
                    if task_id == 0:
                        sum_gradients.append(deepcopy(params.grad))
                    else:
                        sum_gradients[i] += deepcopy(params.grad)

            q_logits = F.softmax(q_outputs[1],dim=1)
            pre_label_id = torch.argmax(q_logits,dim=1)
            pre_label_id = pre_label_id.detach().cpu().numpy().tolist()
            q_label_id = q_label_id.detach().cpu().numpy().tolist()
            
            acc = accuracy_score(pre_label_id,q_label_id)
            task_accs.append(acc)
            
            del fast_model, inner_optimizer
            torch.cuda.empty_cache()
        
        if training:
            # Average gradient across tasks
            for i in range(0,len(sum_gradients)):
                sum_gradients[i] = sum_gradients[i] / float(num_task)

            #Assign gradient for original model, then using optimizer to update its weights
            for i, params in enumerate(self.model.parameters()):
                params.grad = sum_gradients[i]

            self.outer_optimizer.step()
            self.outer_optimizer.zero_grad()
            
            del sum_gradients
            gc.collect()
        
        return np.mean(task_accs)

## Functional_Forward_Bert

In [36]:
from torch.nn.functional import gelu, elu
import torch.nn.functional as F
import torch.nn as nn
import math
import torch
from collections import OrderedDict
from transformers import BertModel, BertTokenizer, BertForSequenceClassification

def functional_bert(fast_weights, config, input_ids=None, attention_mask=None, token_type_ids=None,
                    position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None,
                    encoder_attention_mask=None, is_train = True):

    if input_ids is not None and inputs_embeds is not None:
        raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
    elif input_ids is not None:
        input_shape = input_ids.size()
    elif inputs_embeds is not None:
        input_shape = inputs_embeds.size()[:-1]
    else:
        raise ValueError("You have to specify either input_ids or inputs_embeds")

    device = input_ids.device if input_ids is not None else inputs_embeds.device

    if attention_mask is None:
        attention_mask = torch.ones(input_shape, device=device)
    if token_type_ids is None:
        token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

    if attention_mask.dim() == 3:
        extended_attention_mask = attention_mask[:, None, :, :]
    elif attention_mask.dim() == 2:
        if config.is_decoder:
            batch_size, seq_length = input_shape
            seq_ids = torch.arange(seq_length, device=device)
            causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
            causal_mask = causal_mask.to(torch.long)
            extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
        else:
            extended_attention_mask = attention_mask[:, None, None, :]
    else:
        raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape))

    extended_attention_mask = extended_attention_mask.to(dtype=next((p for p in fast_weights.values())).dtype)  # fp16 compatibility
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

    if config.is_decoder and encoder_hidden_states is not None:
        encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
        encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
        if encoder_attention_mask is None:
            encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)

        if encoder_attention_mask.dim() == 3:
            encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
        elif encoder_attention_mask.dim() == 2:
            encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
        else:
            raise ValueError("Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(encoder_hidden_shape,
                                                                                                                           encoder_attention_mask.shape))
        encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next((p for p in fast_weights.values())).dtype) 
        encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
    else:
        encoder_extended_attention_mask = None

    if head_mask is not None:
        if head_mask.dim() == 1:
            head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            head_mask = head_mask.expand(config.num_hidden_layers, -1, -1, -1, -1)
        elif head_mask.dim() == 2:
            head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
        head_mask = head_mask.to(dtype=next((p for p in fast_weights.values())).dtype)
    else:
        head_mask = [None] * config.num_hidden_layers
    
    embedding_output = functional_embeeding(fast_weights, config, input_ids, position_ids, 
                                            token_type_ids, inputs_embeds, is_train = is_train)
    
    encoder_outputs = functional_encoder(fast_weights, config, embedding_output,
                                   attention_mask=extended_attention_mask,
                                   head_mask=head_mask, encoder_hidden_states=encoder_hidden_states,
                                   encoder_attention_mask=encoder_extended_attention_mask, is_train = is_train)
    
    sequence_output = encoder_outputs
    outputs = (sequence_output,)
    return outputs


def functional_embeeding(fast_weights, config, input_ids, position_ids, 
                         token_type_ids, inputs_embeds = None, is_train = True):

    if input_ids is not None:
        input_shape = input_ids.size()
    else:
        input_shape = inputs_embeds.size()[:-1]

    seq_length = input_shape[1]
    device = input_ids.device if input_ids is not None else inputs_embeds.device
    if position_ids is None:
        position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
        position_ids = position_ids.unsqueeze(0).expand(input_shape)
    if token_type_ids is None:
        token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

    if inputs_embeds is None:
        inputs_embeds = F.embedding(input_ids, fast_weights['bert.embeddings.word_embeddings.weight'], padding_idx = 0)
    
    position_embeddings = F.embedding(position_ids, fast_weights['bert.embeddings.position_embeddings.weight'])
    token_type_embeddings = F.embedding(token_type_ids, fast_weights['bert.embeddings.token_type_embeddings.weight'])

    embeddings = inputs_embeds + position_embeddings + token_type_embeddings
    
    embeddings = F.layer_norm(embeddings, [config.hidden_size], 
                              weight=fast_weights['bert.embeddings.LayerNorm.weight'],
                              bias=fast_weights['bert.embeddings.LayerNorm.bias'],
                              eps=config.layer_norm_eps)

    embeddings = F.dropout(embeddings, p=config.hidden_dropout_prob, training = is_train)
    
    return embeddings

    
def transpose_for_scores(config, x):
    new_x_shape = x.size()[:-1] + (config.num_attention_heads, int(config.hidden_size / config.num_attention_heads))
    x = x.view(*new_x_shape)
    return x.permute(0, 2, 1, 3)

def functional_self_attention(fast_weights, config, layer_idx,
                              hidden_states, attention_mask, head_mask, 
                              encoder_hidden_states, encoder_attention_mask,
                              is_train = True):
    
    attention_head_size = int(config.hidden_size / config.num_attention_heads)
    all_head_size = config.num_attention_heads * attention_head_size
    
    mixed_query_layer = F.linear(hidden_states,
                                fast_weights['bert.encoder.layer.'+layer_idx+'.attention.self.query.weight'],
                                fast_weights['bert.encoder.layer.'+layer_idx+'.attention.self.query.bias'])
    
    if encoder_hidden_states is not None:
        mixed_key_layer = F.linear(encoder_hidden_states,
                                fast_weights['bert.encoder.layer.'+layer_idx+'.attention.self.key.weight'],
                                fast_weights['bert.encoder.layer.'+layer_idx+'.attention.self.key.bias'])
        mixed_value_layer = F.linear(encoder_hidden_states,
                                fast_weights['bert.encoder.layer.'+layer_idx+'.attention.self.value.weight'],
                                fast_weights['bert.encoder.layer.'+layer_idx+'.attention.self.value.bias'])
        attention_mask = encoder_attention_mask
    else:
        mixed_key_layer   = F.linear(hidden_states,
                                fast_weights['bert.encoder.layer.'+layer_idx+'.attention.self.key.weight'],
                                fast_weights['bert.encoder.layer.'+layer_idx+'.attention.self.key.bias'])
        mixed_value_layer = F.linear(hidden_states,
                                fast_weights['bert.encoder.layer.'+layer_idx+'.attention.self.value.weight'],
                                fast_weights['bert.encoder.layer.'+layer_idx+'.attention.self.value.bias'])

    query_layer = transpose_for_scores(config, mixed_query_layer)
    key_layer   = transpose_for_scores(config, mixed_key_layer)
    value_layer = transpose_for_scores(config, mixed_value_layer)

    # Take the dot product between "query" and "key" to get the raw attention scores.
    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
    attention_scores = attention_scores / math.sqrt(attention_head_size)
    if attention_mask is not None:
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        attention_scores = attention_scores + attention_mask
        
    attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)

    if is_train:
        attention_probs = F.dropout(attention_probs, p= config.attention_probs_dropout_prob)

    # Mask heads if we want to
    if head_mask is not None:
        attention_probs = attention_probs * head_mask
    
    context_layer = torch.matmul(attention_probs, value_layer)

    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (all_head_size,)
    context_layer = context_layer.view(*new_context_layer_shape)
    
    outputs = context_layer
    return outputs
    
def functional_out_attention(fast_weights, config, layer_idx,
                              hidden_states, input_tensor,
                              is_train = True):
    
    hidden_states = F.linear(hidden_states,
                            fast_weights['bert.encoder.layer.'+layer_idx+'.attention.output.dense.weight'],
                            fast_weights['bert.encoder.layer.'+layer_idx+'.attention.output.dense.bias'])

    hidden_states = F.dropout(hidden_states, p=config.hidden_dropout_prob, training = is_train)
    hidden_states = F.layer_norm(hidden_states + input_tensor, [config.hidden_size],
                              weight=fast_weights['bert.encoder.layer.'+layer_idx+'.attention.output.LayerNorm.weight'],
                              bias=fast_weights['bert.encoder.layer.'+layer_idx+'.attention.output.LayerNorm.bias'],
                              eps=config.layer_norm_eps)
    
    return hidden_states    


def functional_attention(fast_weights, config, layer_idx,
                         hidden_states, attention_mask=None, head_mask=None,
                         encoder_hidden_states=None, encoder_attention_mask=None,
                         is_train = True):
    
    self_outputs = functional_self_attention(fast_weights, config, layer_idx,
                                             hidden_states, attention_mask, head_mask, 
                                             encoder_hidden_states, encoder_attention_mask, is_train)
    
    attention_output = functional_out_attention(fast_weights, config, layer_idx,
                                                self_outputs, hidden_states, is_train)
    return attention_output

def functional_intermediate(fast_weights, config, layer_idx, hidden_states, is_train = True):
    weight_name = 'bert.encoder.layer.' + layer_idx + '.intermediate.dense.weight'
    bias_name   = 'bert.encoder.layer.' + layer_idx + '.intermediate.dense.bias'
    hidden_states = F.linear(hidden_states, fast_weights[weight_name], fast_weights[bias_name])
    hidden_states = gelu(hidden_states)
    
    return hidden_states


def functional_output(fast_weights, config, layer_idx, hidden_states, input_tensor, is_train = True):

    hidden_states = F.linear(hidden_states, 
                             fast_weights['bert.encoder.layer.'+layer_idx+'.output.dense.weight'], 
                             fast_weights['bert.encoder.layer.'+layer_idx+'.output.dense.bias'])
    
    hidden_states = F.dropout(hidden_states, p=config.hidden_dropout_prob, training = is_train)
    hidden_states = F.layer_norm(hidden_states + input_tensor, [config.hidden_size],
                              weight=fast_weights['bert.encoder.layer.'+layer_idx+'.output.LayerNorm.weight'],
                              bias=fast_weights['bert.encoder.layer.'+layer_idx+'.output.LayerNorm.bias'],
                              eps=config.layer_norm_eps)
    return hidden_states

def functional_layer(fast_weights, config, layer_idx, hidden_states, attention_mask,
                     head_mask, encoder_hidden_states, encoder_attention_mask, is_train = True):
    
    self_attention_outputs = functional_attention(fast_weights, config, layer_idx,
                                                  hidden_states, attention_mask, head_mask,
                                                  encoder_hidden_states, encoder_attention_mask,is_train)
    
    attention_output = self_attention_outputs
    intermediate_output = functional_intermediate(fast_weights, config, layer_idx, attention_output, is_train)
    layer_output = functional_output(fast_weights, config, layer_idx, 
                                     intermediate_output, attention_output, is_train)
    
    return layer_output
    

def functional_encoder(fast_weights, config , hidden_states, attention_mask,
                       head_mask, encoder_hidden_states, encoder_attention_mask, is_train = True):
    
    for i in range(0,config.num_hidden_layers):
        layer_outputs = functional_layer(fast_weights, config, str(i),
                                         hidden_states, attention_mask, head_mask[i], 
                                         encoder_hidden_states, encoder_attention_mask, is_train)
        hidden_states = layer_outputs
        
    outputs = hidden_states
    return outputs

if __name__ == '__main__':
    
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
    fast_weights = OrderedDict(model.named_parameters())
    
    input_ids = torch.Tensor([[  101,  1303,  1110,  1199,  3087,  1106,  4035, 13775,   102],
                              [  101,   178,  1274,  1204,  1176,  1115,  4170,   182,   102]]).to(torch.long)
    token_type_ids = torch.Tensor([[0,  0,  0,  0,  0,  1,  1, 1, 1],
                                   [0,  0,  0,  0,  0,  1,  1, 1, 1]]).to(torch.long)
    attention_mask = torch.Tensor([[1,  1,  1,  1,  1,  1,  1, 1, 1],
                                   [1,  1,  1,  1,  1,  1,  1, 1, 1]]).to(torch.long)
    
    print(functional_bert(fast_weights, model.config, input_ids=input_ids, attention_mask=attention_mask, 
                    token_type_ids=token_type_ids,is_train = True))

(tensor([[[-0.1576,  0.3370, -0.1221,  ..., -0.3356,  0.3942,  0.4703],
         [ 0.8703,  0.3434,  0.3536,  ..., -0.3917,  0.7031,  1.0644],
         [ 0.1744,  1.0315,  0.2739,  ..., -0.1700,  0.6216, -0.0226],
         ...,
         [ 0.3732,  0.2495,  0.4426,  ..., -0.1940,  0.3933,  0.6892],
         [-0.0438,  0.3092,  0.0463,  ...,  0.1204,  0.3048, -0.0807],
         [ 0.5353,  0.2694, -0.4089,  ...,  0.5772, -0.6564, -0.3425]],

        [[-0.4248, -0.3055,  0.3123,  ...,  0.0534,  0.3531,  0.3712],
         [ 0.1217,  0.0371,  0.1795,  ..., -0.1514,  0.6819,  0.1585],
         [ 0.9209,  0.1281,  0.2641,  ..., -0.2881,  0.5027,  0.3020],
         ...,
         [ 0.1037, -0.6846,  0.0244,  ..., -0.3091,  0.2962,  0.5372],
         [-0.3636, -0.3467, -0.0437,  ...,  0.1067,  0.0879, -0.0273],
         [ 0.4633,  0.0210, -0.3303,  ...,  0.2166, -0.4741, -0.0168]]],
       grad_fn=<NativeLayerNormBackward0>),)
