### Imports

In [52]:
# !pip install transformers

In [53]:
from transformers import AutoTokenizer, BertModel, GPT2LMHeadModel, GPT2Tokenizer
import torch.optim as optim

import torch
import torch.nn.functional as F
import math
import time
import sys
import json
import numpy as np


### Train Loop

In [54]:
def train_loop(model, linear, optimizer, tokenizer, train, num_choices, epochs):
    for epoch in range(epochs):
        print(f"Epoch {epoch}")

        train_len = len(train)
        total_loss = 0.0

        for train_i in range(train_len):
            observation = train[train_i]
            contexts = []
            labels = []
            mask = torch.zeros((4, 2), dtype=float).to(device) ##### if code doesnt work, change to numpy

            for choice_i in range(num_choices):
                context = observation[choice_i][0]
                label = observation[choice_i][1]
                contexts.append(context)
                labels.append(label)
                mask[choice_i][label] = 1
            
            inputs = tokenizer(contexts, max_length=256, padding="max_length", truncation=True, return_tensors="pt")
            # inputs = inputs.to(device)
            inputs['input_ids'] = inputs['input_ids'].to(device)
            inputs['token_type_ids'] = inputs['token_type_ids'].to(device)
            inputs['attention_mask'] = inputs['attention_mask'].to(device)

            optimizer.zero_grad()
            hidden = model(**inputs)

            logits = torch.matmul(hidden.last_hidden_state[:, 0, :], linear)
            probs = F.softmax(logits, dim=1)
            correct_probs = probs * mask
            log_probs = torch.log(torch.sum(correct_probs, dim=1)).squeeze()
            loss = -torch.sum(log_probs)

            total_loss += loss.item()

            loss.backward()
            optimizer.step()
            
            if train_i % 1000 == 0:
                print(train_i, "/", train_len)
        
        average_loss = total_loss / train_len
        print(f"Average Loss: {average_loss}")
            


### Test Loop

In [55]:
def test_loop(model, linear, tokenizer, test, num_choices):
    test_len = len(test)
    running_accuracy = 0

    for test_i in range(test_len):
        observation = test[test_i]
        contexts = []
        labels = []
        mask = torch.zeros((4, 2), dtype=float).to(device) ##### if code doesnt work, change to numpy

        for choice_i in range(num_choices):
            context = observation[choice_i][0]
            label = observation[choice_i][1]
            contexts.append(context)
            labels.append(label)
            mask[choice_i][label] = 1
        
        inputs = tokenizer(contexts, max_length=256, padding="max_length", truncation=True, return_tensors="pt")
        inputs = inputs.to(device)

        hidden = model(**inputs)

        logits = torch.matmul(hidden.last_hidden_state[:, 0, :], linear)
        probs = F.softmax(logits, dim=1)[:, 1]

        print(probs)

        pred = torch.argmax(probs)
        real = torch.argmax(labels)
        if pred == real:
            running_accuracy += 1
        
    average_accuracy = running_accuracy / test_len
    print(f"Average Accuracy: {average_accuracy}")


### Find Device

In [56]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Main

In [57]:
def main():  
    torch.manual_seed(0)
    answers = ['A','B','C','D']

    train = []
    test = []
    valid = []
    
    file_name = 'train_complete.jsonl'        
    with open(file_name) as json_file:
        json_list = list(json_file)
    for i in range(len(json_list)):
        json_str = json_list[i]
        result = json.loads(json_str)
        
        base = result['fact1'] + ' [SEP] ' + result['question']['stem']
        ans = answers.index(result['answerKey'])
        
        obs = []
        for j in range(4):
            text = base + result['question']['choices'][j]['text'] + ' [SEP]'
            if j == ans:
                label = 1
            else:
                label = 0
            obs.append([text,label])
        train.append(obs)
        
        # print(obs)
        # print(' ')
        
        # print(result['question']['stem'])
        # print(' ',result['question']['choices'][0]['label'],result['question']['choices'][0]['text'])
        # print(' ',result['question']['choices'][1]['label'],result['question']['choices'][1]['text'])
        # print(' ',result['question']['choices'][2]['label'],result['question']['choices'][2]['text'])
        # print(' ',result['question']['choices'][3]['label'],result['question']['choices'][3]['text'])
        # print('  Fact: ',result['fact1'])
        # print('  Answer: ',result['answerKey'])
        # print('  ')
                
    file_name = 'dev_complete.jsonl'        
    with open(file_name) as json_file:
        json_list = list(json_file)
    for i in range(len(json_list)):
        json_str = json_list[i]
        result = json.loads(json_str)
        
        base = result['fact1'] + ' [SEP] ' + result['question']['stem']
        ans = answers.index(result['answerKey'])
        
        obs = []
        for j in range(4):
            text = base + result['question']['choices'][j]['text'] + ' [SEP]'
            if j == ans:
                label = 1
            else:
                label = 0
            obs.append([text,label])
        valid.append(obs)
        
    file_name = 'test_complete.jsonl'        
    with open(file_name) as json_file:
        json_list = list(json_file)
    for i in range(len(json_list)):
        json_str = json_list[i]
        result = json.loads(json_str)
        
        base = result['fact1'] + ' [SEP] ' + result['question']['stem']
        ans = answers.index(result['answerKey'])
        
        obs = []
        for j in range(4):
            text = base + result['question']['choices'][j]['text'] + ' [SEP]'
            if j == ans:
                label = 1
            else:
                label = 0
            obs.append([text,label])
        test.append(obs)

    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    model = BertModel.from_pretrained("bert-base-uncased")
    optimizer = optim.Adam(model.parameters(), lr=3e-5)
    linear = torch.rand(768,2)
    
    model = model.to(device)
    linear = linear.to(device)
#    Add code to fine-tune and test your MCQA classifier.
           
    train_loop(model, linear, optimizer, tokenizer, train, 4, 5)

    model.eval()

    test_loop(model, linear, tokenizer, valid, 4)
    test_loop(model, linear, tokenizer, test, 4)

In [58]:
main()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


0  4957
1  4957
2  4957
3  4957
4  4957
5  4957
6  4957
7  4957
8  4957
9  4957
10  4957
11  4957
12  4957
13  4957
14  4957
15  4957
16  4957
17  4957
18  4957
19  4957
20  4957
21  4957
22  4957
23  4957
24  4957
25  4957
26  4957
27  4957
28  4957
29  4957
30  4957
31  4957
32  4957
33  4957
34  4957
35  4957
36  4957
37  4957
38  4957
39  4957
40  4957
41  4957
42  4957
43  4957
44  4957
45  4957
46  4957
47  4957
48  4957
49  4957
50  4957
51  4957
52  4957
53  4957
54  4957
55  4957
56  4957
57  4957
58  4957
59  4957
60  4957
61  4957
62  4957
63  4957
64  4957
65  4957
66  4957
67  4957
68  4957
69  4957
70  4957
71  4957
72  4957
73  4957
74  4957
75  4957
76  4957
77  4957
78  4957
79  4957
80  4957
81  4957
82  4957
83  4957
84  4957
85  4957
86  4957
87  4957
88  4957
89  4957
90  4957
91  4957
92  4957
93  4957
94  4957
95  4957
96  4957
97  4957
98  4957
99  4957
100  4957
101  4957
102  4957
103  4957
104  4957
105  4957
106  4957
107  4957
108  4957
109  4957
110  4957


KeyboardInterrupt: ignored