<a href="https://colab.research.google.com/github/dbamman/nlp21/blob/main/HW6/HW_6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **SETUP**

In [None]:
!pip install transformers

In [5]:
import json
import torch
import torch.nn as nn
import transformers
import random
import numpy as np
from tqdm import tqdm
from transformers import BertTokenizer, BertModel


#Sets random seeds for reproducibility
seed=0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [None]:
print(torch.__version__)
print(transformers.__version__)

In [None]:
!wget https://raw.githubusercontent.com/dbamman/nlp21/main/HW6/train.txt
!wget https://raw.githubusercontent.com/dbamman/nlp21/main/HW6/dev.txt

# **IMPORTANT**: GPU is not enabled by default

You must switch runtime environments if your output of the next block of code has an error saying "ValueError: Expected a cuda device, but got: cpu"

Go to Runtime > Change runtime type > Hardware accelerator > GPU

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on {}".format(device))

# Tip: Indexing into `torch.Tensor`

In this section, we will briefly guide you through some examples of indexing into a 2D and 3D tensor that will be useful for the homework that follows.


A [`torch.Tensor`](https://pytorch.org/docs/stable/tensors.html) object is a multi-dimensional matrix that can be indexed in more than 2 dimensions. For example, you can create a 3D  tensor (of size 2 x 2 x 2) like this:  

In [None]:

T_data = [[[1., 2.], [3., 4.]],
          [[5., 6.], [7., 8.]]]
T = torch.tensor(T_data)
print(T.size())
print(T)

torch.Size([2, 2, 2])
tensor([[[1., 2.],
         [3., 4.]],

        [[5., 6.],
         [7., 8.]]])


You can index into this tensor and get a 2D matrix: 


In [None]:
print(T[0].size())
print(T[0]) 

torch.Size([2, 2])
tensor([[1., 2.],
        [3., 4.]])


Here's an example of a 4 x 3 matrix (2D tensor)

In [None]:
 mat = torch.tensor([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])
 print(mat.size())
 mat

torch.Size([4, 3])


tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])

Here are a number of different ways that you can index into that tensor.

In [None]:
#Case1: Select 1,5,9,10
print("Case1: " + str(mat[(0,1,2,3),(0,1,2,0)])) #() or [] both work

#Case2: Select first row
print("Case2: " + str(mat[0,]))
#also mat[0,]
#also mat[0,:]

#Case3: Select all entries of the second column
print("Case3: " + str(mat[:,1]))
#also mat[torch.arange((mat.size(0))), 1]
#also mat[(0,1,2,3), 1]

#Case4: Select the first three rows of the third column
print("Case4: " + str(mat[(0,1,2), 2]))

Case1: tensor([ 1,  5,  9, 10])
Case2: tensor([1, 2, 3])
Case3: tensor([ 2,  5,  8, 11])
Case4: tensor([3, 6, 9])


Now (most relevant to the homework), let's say we have a three-dimensional tensor (e.g., batch size x number of words in the sentence x 768 (the BERT dimension); each row in the first dimension is a sentence, the second dimension corresponds to a WordPiece token within a sentence, and the third dimension the BERT embedding.  Let's say that we want to index into different words for each sentence (for example, the predicate might be at WordPiece position #3 in the first sentence and position #1 in the second sentence). What we want to end up with is a 2 x 3 selection from that matrix (just pulling out those respective vectors that correspond to the BERT embeddings for the predicate position). Here's how we can do that for a sample tensor of size 2 x 4 x 3:

In [None]:
 mat = torch.tensor([ 
                     [
                      [1,2,3],[4,5,6],[7,8,9],[10,11,12] 
                     ],
                     [
                      [13,14,15],[16,17,18],[19,20,21],[22,23,24]
                     ]
                     ])
 
print(mat.size())
print(mat)

torch.Size([2, 4, 3])
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9],
         [10, 11, 12]],

        [[13, 14, 15],
         [16, 17, 18],
         [19, 20, 21],
         [22, 23, 24]]])


In [None]:
# e.g., the predicate WP index is #3 in the first sentence in the batch, and #1 in the second sentence
indexes=[3,1]
selected=mat[torch.arange(mat.size(0)),indexes]
print(selected)
print(selected.size())

tensor([[10, 11, 12],
        [16, 17, 18]])
torch.Size([2, 3])


# Deliverable 1: Predict the semantic role of arguments for prose-argument pairs in a sentence. 

In this section, we will train a BERT-based classifier to assign proto-role labels (`ARG0`, `ARG1`, or `O`) to arguments in a given predicate-argument pair. 

The `BERTSRLClassifier` class is provided for you below, along with code to read in the data and train the model. This BERT-based classifier takes in the words of a sentence along with information about the beginning and end of an argument span and predicate to classify these predicate-specific arguments into `ARG0`, `ARG1`, or `O`(neither). 

See the writeup for a full description of the parts of the model you should implement.  To summarize, the `forward` function in the `BERTSRLClassifier` class concatenates the 786-dimensional BERT vectors that are indexed by:
1. the start WordPiece token position of the argument span,
2. the end WordPiece token position of the argument span,
3. and the start WordPiece token position of the predicate

and passes them through a linear transformation into the size of the 3-dimensional output space (for the three labels `ARG0`, `ARG1`, `O`).  Your deliverable is to complete the indicated section of the `forward` function in the `BERTSRLClassifier` class, where you will extract the BERT vectors corresponding to the positions (described above) from the final BERT layer before concatenating them and passing them through the linear transformation.

In [None]:
max_toks=56
tokenizer=BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False, do_basic_tokenize=False)

labels={"ARG0":0, "ARG1":1, "O":2}


#return start and end positions of WordPiece tokens
def get_wp_position_for_token(words, index):

	cur=1
	targetWP=None

	for idx, word in enumerate(words):
		target=tokenizer.tokenize(word)
		if idx == index:
			targetWP=target
			return cur, cur+len(target)
		cur+=len(target)
	

def read_data(filename, max_toks=max_toks):

	x=[]
	y=[]
	m=[]

	max_num=5000
	idx=0

	with open(filename) as file:
		for line in file:
			cols=line.rstrip().split("\t")
			words=cols[0].split(" ")
			predicate=int(cols[1])
			arguments=json.loads(cols[2])
			candidates=json.loads(cols[3])
			
			nonargs=[]

			for cat, start, end in arguments:
				if cat == "ARG0" or cat == "ARG1":
					x.append(words)
					y.append(labels[cat])
             #store position of WordPiece tokens for start of span, end of span, and start of predicate
					m.append((get_wp_position_for_token(words, start)[0], get_wp_position_for_token(words, end)[1], get_wp_position_for_token(words, predicate)[0]))

					idx+=1

				else:
					nonargs.append((cat, start, end))


			# select random non-ARG0 or ARG1
			cat, start, end=random.choice(nonargs)
			x.append(words)
			y.append(labels["O"])
             #store position of WordPiece tokens for start of span, end of span, and start of predicate) 
			m.append((get_wp_position_for_token(words, start)[0], get_wp_position_for_token(words, end)[1], get_wp_position_for_token(words, predicate)[0]))

			idx+=1

			if idx >= max_num:
				break

	return x, y, m

In [None]:
class BERTSRLClassifier(nn.Module):

    def __init__(self):
        super().__init__()
            
        self.tokenizer = tokenizer
        self.bert = BertModel.from_pretrained("bert-base-cased")
        self.num_labels = 3    
        self.fc = nn.Linear(3*768, self.num_labels)

    def get_batches(self, all_x, all_y, all_m, batch_size=32, max_toks=max_toks):
        """
        Get batches for input x, y, and m, with data tokenized according to the
        #BERT tokenizer (and limited to a maximum number of WordPiece tokens)
        """

        batches_x=[]
        batches_y=[]
        batches_m=[]

        #The input sentence starts with a [CLS] tag and is followed by a [SEP] tag
        for i in range(0, len(all_x), batch_size):

            current_batch=[]

            xb=all_x[i:i+batch_size]

            current_batch_input_ids=[]
            current_batch_attention_mask=[]

            for s, sent in enumerate(xb):

                sent_wp_tokens=[self.tokenizer.convert_tokens_to_ids("[CLS]")]
                attention_mask=[1]

                for word in sent:
                    toks = self.tokenizer.tokenize(word)
                    toks = self.tokenizer.convert_tokens_to_ids(toks)
                    sent_wp_tokens.extend(toks)
                    attention_mask.extend([1]*len(toks))
                
                sent_wp_tokens.append(self.tokenizer.convert_tokens_to_ids("[SEP]"))
                attention_mask.append(1)

                current_batch_input_ids.append(sent_wp_tokens)
                current_batch_attention_mask.append(attention_mask)

            max_len = max([len(s) for s in current_batch_input_ids])

            for j, sent in enumerate(current_batch_input_ids):
                for k in range(len(current_batch_input_ids[j]), max_len):
                    current_batch_input_ids[j].append(0)
                    current_batch_attention_mask[j].append(0)

            batch_y=all_y[i:i+batch_size]
            batch_m=all_m[i:i+batch_size]

            batches_x.append((torch.LongTensor(current_batch_input_ids).to(device), torch.LongTensor(current_batch_attention_mask).to(device)))
            batches_y.append(torch.LongTensor(batch_y).to(device))
            batches_m.append(torch.LongTensor(batch_m).to(device))
                
        return batches_x, batches_y, batches_m
      
	
	

    def forward(self, batch_x, batch_m): 

            bert_output = self.bert(input_ids=batch_x[0],
                                                attention_mask=batch_x[1],
                                                output_hidden_states=True, return_dict=True)

            bert_hidden_states = bert_output['hidden_states']

            out = bert_hidden_states[-1]

            start_span_indexes=batch_m[:,0]
            end_span_wp_indexes=batch_m[:,1]
            predicate_wp_indexes=batch_m[:,2]

            '''
            Extract the representation of the WP token in the start and end WP position of each argument 
            and the start WP position of the predicate from the last layer output

            Then, concatenate the vectors and pass them through a linear transformation
            '''
            # YOUR CODE STARTS HERE

            ## YOUR CODE ENDS HERE

            return out.squeeze()

    def evaluate(self, batch_x, batch_y, batch_m):
			
        self.eval()
        corr = 0.
        total = 0.

        with torch.no_grad():

                for x, y, m in zip(batch_x, batch_y, batch_m):
                    y_preds = self.forward(x, m)
                    for idx, y_pred in enumerate(y_preds):
                        prediction=torch.argmax(y_pred)
                        if prediction == y[idx]:
                            corr += 1.
                        total+=1                          
        return corr/total


In [None]:
classifier=BERTSRLClassifier()
classifier.to(device)

trainData='train.txt'
devData='dev.txt'

x,y,m=read_data(trainData)
train_x, train_y, train_m=classifier.get_batches(x,y,m)

In [None]:
x,y,m=read_data(devData)
dev_x, dev_y, dev_m=classifier.get_batches(x,y,m)


optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-5)
cross_entropy=nn.CrossEntropyLoss()

num_epochs=5

#accuracy before training
accuracy=classifier.evaluate(dev_x, dev_y, dev_m)
print("Accuracy: %.3f" % accuracy)
	
for epoch in range(num_epochs):
	classifier.train()

	# Train
	for x, y, m in zip(train_x, train_y, train_m):
		y_pred = classifier.forward(x, m)
		loss = cross_entropy(y_pred.view(-1, classifier.num_labels), y.view(-1))
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

	# Evaluate
	accuracy=classifier.evaluate(dev_x, dev_y, dev_m)
	print("Accuracy: %.3f" % accuracy)




# Deliverable 2: Find `ARG0` from a list of all non-terminal phrases in a parse tree 

Now, you will be given a sentence with all of its non-terminal phrases (candidates), where each candidate is given by `[NP category, start token position, end token position]`. For simplicity, the number of candidates for the given sentence doesn't exceed the system's batch size (32).


In this section, identify the most likely `ARG0` among all the candidates, using our trained classifier, `BERTSRLClassifer`. In order to do this, you would generate an `ARG0` score for all the candidates using our classifier and select the candidate with the highest score (highest likelihood of being `ARG0`) as `ARG0`. The basic setup has been provided for you, using functions from the previous section**\***. Your task is to finish the `predict_arg0` function so that it returns the **start and end position** of the `ARG0` predicted by your model. For example, if the sentence is "I ate the cake" (as given in the HW writeup), your function should return [0,0]. 

\* *Note that we don't have y-labels for the list of candidates for this test case. In the provided code, we create a list of labels `test_y` and arbitrarily fill it with `O` because we need this list to run `get_batches`. The value in this list doesn't matter, as you won't be needing the output(`t_y`) to evaluate your result.*



In [12]:
#This is the sentence you will be working with
sent = "Also , can animals remember images on TV like us , humans ?"
verb = 4 #predicate position
candidate_list = '''[
                    ["RB", 0, 0], 
                    ["ADVP", 0, 0],
                    [",", 1, 1],
                    ["MD", 2, 2], 
                    ["NNS", 3, 3], 
                    ["NP", 3, 3], 
                    ["VB", 4, 4], 
                    ["NNS", 5, 5], 
                    ["NP", 5, 5], 
                    ["IN", 6, 6], 
                    ["NN", 7, 7], 
                    ["NP", 7, 7], 
                    ["PP", 6, 7], 
                    ["NP", 5, 7], 
                    ["IN", 8, 8], 
                    ["PRP", 9, 9], 
                    ["NP", 9, 9], 
                    [",", 10, 10], 
                    ["NNS", 11, 11],
                    ["NP", 11, 11],
                    ["NP", 9, 11], 
                    ["PP", 8, 11],
                    ["VP", 4, 11],
                    [".", 12, 12], 
                    ["SQ", 0, 12], 
                    ["TOP", 0, 12]
                    ] '''

In [None]:
#read in the sentence information
#we don't have labels (y) for this task, but will create an arbitrary list in order to run get_batches()
test_x = []
test_m = []
test_y = [] 

words = sent.split(" ")
predicate = int(verb)
candidates = json.loads(candidate_list)


for synt, start, end in candidates: 
  test_x.append(words)
  test_y.append(labels["O"]) #fill in with "O" (can be ARG0 or ARG1, doesn't matter)
  #store position of WordPiece tokens for start of span, end of span, and start of predicate) 
  test_m.append((get_wp_position_for_token(words, start)[0], get_wp_position_for_token(words, end)[1], get_wp_position_for_token(words, predicate)[0]))

In [None]:
t_x, t_y, t_m = classifier.get_batches(test_x,test_y,test_m) 

In [None]:
def predict_arg0(batch_x, batch_m, cand_list):
    """
    This function returns the start and end position of the predicted ARG0
    """
    with torch.no_grad():    
    #YOUR CODE STARTS HERE  


    ##YOUR CODE ENDS HERE
    
    return positions

In [None]:
#Run this cell to print your prediction
#The printed output should look like [a, b]
predict_arg0(t_x, t_m, candidates)