In [86]:
import torch
from transformers import BertTokenizer, BertModel
import sys
sys.path.append('..')
from probing.utils import get_sentence_repr, get_model_and_tokenizer, get_pos_data

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print("device:", device)

device: cuda


# Get data for part-of-speech tagging

In [87]:
train_sentences, train_labels, test_sentences, test_labels, _, _, label2index = get_pos_data("../probing", frac=0.01)
num_labels = len(label2index)
print("Training sentences:", len(train_sentences), "Test sentences:", len(test_sentences))
print("Unique labels:", num_labels)


Training sentences: 125 Test sentences: 21
Unique labels: 15


# Set up model

In [97]:
class Classifier(torch.nn.Module):
    
    def __init__(self, input_dim, output_dim):
        super(Classifier, self).__init__()
        
        self.linear = torch.nn.Linear(input_dim, output_dim)
        
    def forward(self, input):
        output = self.linear(input)
        return output
    

class NonlinearClassifier(torch.nn.Module):
    
    def __init__(self, input_dim, output_dim):
        super(NonlinearClassifier, self).__init__()
        
        self.input2hidden = torch.nn.Linear(input_dim, input_dim)
        self.hidden2output = torch.nn.Linear(input_dim, output_dim)
        
    def forward(self, input):
        hidden = self.input2hidden(input)
        output = self.hidden2output(hidden)
        return output
    
    
def build_classifier(emb_dim, num_labels, nonlinear=False):

    if nonlinear:
        classifier = NonlinearClassifier(emb_dim, num_labels)
    else:
        classifier = Classifier(emb_dim, num_labels)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(classifier.parameters())

    return classifier, criterion, optimizer


model_name = 'bert-base-cased'
# get model and tokenizer from Transformers
model, tokenizer, sep, emb_dim = get_model_and_tokenizer(model_name, device)
# build classifier
classifier, criterion, optimizer = build_classifier(emb_dim, num_labels)

In [89]:
print(model)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [90]:
print(classifier)

Classifier(
  (linear): Linear(in_features=768, out_features=15, bias=True)
)


# Train 

In [91]:
def train(num_epochs, train_sentences, train_labels, 
          model, tokenizer, sep, model_name, device, 
          classifier, criterion, optimizer, layer=-1):
    
    num_total = sum([len(l) for l in train_labels])
    for i in range(num_epochs):
        total_loss = 0.
        num_correct = 0.
        for sentence, labels in zip(train_sentences, train_labels):
            optimizer.zero_grad()

            sentence_repr = get_sentence_repr(sentence, model, tokenizer, sep, model_name, device)
            # take layer representations
            sentence_repr = sentence_repr[layer]
            loss = 0
            for word_repr, label in zip(sentence_repr, labels):
                out = classifier(word_repr)
                # we'll just just a batch of size 1 for simplicity 
                out = torch.unsqueeze(out, 0)
                pred = out.max(1)[1]
                if pred == label.item():
                    num_correct += 1
                loss += criterion(out, label.unsqueeze(0))
            total_loss += loss.item()

            loss.backward()
            optimizer.step()
#         print('Training epoch: {}, loss: {}, accuracy: {}'.format(i, total_loss/num_total, num_correct/num_total))
    return total_loss/num_total, num_correct/num_total


# Evaluate

In [92]:
def evaluate(test_sentences, test_labels, 
             model, tokenizer, sep, model_name, device, 
             classifier, criterion, layer=-1):
    
    num_correct = 0.
    num_total = sum([len(l) for l in test_labels])
    total_loss = 0.
    with torch.no_grad():
        for sentence, labels in zip(test_sentences, test_labels):
            sentence_repr = get_sentence_repr(sentence, model, tokenizer, sep, model_name, device)
            sentence_repr = sentence_repr[layer]
            for word_repr, label in zip(sentence_repr, labels):
                out = classifier(word_repr)
                out = torch.unsqueeze(out, 0)
                pred = out.max(1)[1]
                if pred == label:
                    num_correct += 1
                total_loss += criterion(out, label.unsqueeze(0))

#     print('Testing loss: {}, accuracy: {}'.format(total_loss/num_total, num_correct/num_total))
    return total_loss/num_total, num_correct/num_total

# Experiment 1: Evaluate representation for POS quality

In [93]:
train_loss, train_accuracy = train(2, train_sentences, train_labels, 
          model, tokenizer, sep, model_name, device, 
          classifier, criterion, optimizer)
test_loss, test_accuracy = evaluate(test_sentences, test_labels, 
         model, tokenizer, sep, model_name, device, 
         classifier, criterion)
print("Train accuracy: {}, Test accuracy: {}".format(train_accuracy, test_accuracy))

Train accuracy: 0.738976377952756, Test accuracy: 0.8425925925925926


# Experiment 2: Compare representation quality across layers

In [94]:
num_layers = 12
for l in range(num_layers):
    classifier, criterion, optimizer = build_classifier(emb_dim, num_labels)
    train_loss, train_accuracy = train(2, train_sentences, train_labels, 
          model, tokenizer, sep, model_name, device, 
          classifier, criterion, optimizer, layer=l)
    test_loss, test_accuracy = evaluate(test_sentences, test_labels, 
         model, tokenizer, sep, model_name, device, 
         classifier, criterion, layer=l)
    print("layer: {}, test accuracy: {}".format(l, test_accuracy))

layer: 0, test accuracy: 0.8726851851851852
layer: 1, test accuracy: 0.9027777777777778
layer: 2, test accuracy: 0.9583333333333334
layer: 3, test accuracy: 0.9467592592592593
layer: 4, test accuracy: 0.9351851851851852
layer: 5, test accuracy: 0.9375
layer: 6, test accuracy: 0.9398148148148148
layer: 7, test accuracy: 0.9375
layer: 8, test accuracy: 0.9282407407407407
layer: 9, test accuracy: 0.9143518518518519
layer: 10, test accuracy: 0.9074074074074074
layer: 11, test accuracy: 0.9050925925925926


# Experiment 3: Non-linear classifier

In [None]:
num_layers = 12
for l in range(num_layers):
    classifier, criterion, optimizer = build_classifier(emb_dim, num_labels, nonlinear=True)
    train_loss, train_accuracy = train(2, train_sentences, train_labels, 
          model, tokenizer, sep, model_name, device, 
          classifier, criterion, optimizer, layer=l)
    test_loss, test_accuracy = evaluate(test_sentences, test_labels, 
         model, tokenizer, sep, model_name, device, 
         classifier, criterion, layer=l)
    print("layer: {}, test accuracy: {}".format(l, test_accuracy))

layer: 0, test accuracy: 0.9004629629629629
layer: 1, test accuracy: 0.8518518518518519
layer: 2, test accuracy: 0.8935185185185185
layer: 3, test accuracy: 0.9282407407407407
layer: 4, test accuracy: 0.9513888888888888
layer: 5, test accuracy: 0.9490740740740741
layer: 6, test accuracy: 0.9513888888888888
layer: 7, test accuracy: 0.9699074074074074
layer: 8, test accuracy: 0.9699074074074074
layer: 9, test accuracy: 0.9722222222222222
layer: 10, test accuracy: 0.9328703703703703
