In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sentiment_data import read_sentiment_examples, read_word_embeddings
from utils import *
from transformer import EncoderModel
import json
import nltk
import random
from tqdm import tqdm
import os

from torch.nn.utils.rnn import pad_sequence

In [17]:
seed = 42

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 16 
block_size = 32
learning_rate = 0.005 
n_embd =64  
n_head = 2
n_layer = 4


eval_interval = 100
max_iters = 500 
eval_iters = 200 

n_input = 64  
n_hidden = 50  
n_output = 3  
epochs_CLS = 50 

In [4]:
def collate_batch(batch):
    """ Collate a batch of data into a single tensor with padding."""
    data, labels = zip(*batch) 
    padded_sequences = pad_sequence(data, batch_first=True, padding_value=0)
    padded_sequences = padded_sequences[:, :block_size] 
    padded_sequences = torch.nn.functional.pad(padded_sequences, (0, max(0, block_size - padded_sequences.shape[1])), "constant", 0)
    labels = torch.stack(labels)  
    return padded_sequences, labels

In [5]:
with open('dev.json') as fp:
    stereoSet = json.load(fp)["data"]

all_example = []
all_label = []


inter_set = stereoSet["intersentence"]
for set_exmaple in inter_set:
    context = set_exmaple["context"]
    sentences = set_exmaple["sentences"]
    for sen in sentences:
        all_example.append(context + " " + sen["sentence"])
        if sen["gold_label"] == "unrelated":
            all_label.append(0)
        if sen["gold_label"] == "stereotype":
            all_label.append(1)
        if sen["gold_label"] == "anti-stereotype":
            all_label.append(2)


intra_set = stereoSet["intrasentence"]

for set_exmaple in intra_set:
    context = set_exmaple["context"]
    sentences = set_exmaple["sentences"]
    for sen in sentences:
        all_example.append(context + " " + sen["sentence"])
        if sen["gold_label"] == "unrelated":
            all_label.append(0)
        if sen["gold_label"] == "stereotype":
            all_label.append(1)
        if sen["gold_label"] == "anti-stereotype":
            all_label.append(2)


print(len(all_example))
print(len(all_label))

12687
12687


In [6]:
import nltk

sentence = "Our new boss is Russian. He has a terrible temper and drinks vodka on the job."
words = nltk.word_tokenize(sentence)
print(words)

['Our', 'new', 'boss', 'is', 'Russian', '.', 'He', 'has', 'a', 'terrible', 'temper', 'and', 'drinks', 'vodka', 'on', 'the', 'job', '.']


In [7]:
class StereoSet(Dataset):
    def __init__(self, sentences, labels, indexer):

        self.sentences = sentences
        self.labels = labels
        self.indexer = indexer
        
        self.indices = []
        for sent in self.sentences:
            sent_indices = []
            word_sent = nltk.word_tokenize(sent)
            for word in word_sent:
                if self.indexer.index_of(word) != -1:
                    sent_indices.append(self.indexer.index_of(word))
                else:
                    sent_indices.append(self.indexer.index_of("UNK"))
            if len(sent_indices) < 64:
                for i in range(len(sent_indices), 64):
                    sent_indices.append(self.indexer.index_of("PAD"))
            else:
                sent_indices = sent_indices[:64]
            self.indices.append(sent_indices)

        self.indices = torch.tensor(self.indices, dtype=torch.long)
        self.labels = torch.tensor(self.labels, dtype=torch.long)       

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        return self.indices[idx], self.labels[idx]

In [8]:
indices = list(range(len(all_example)))
random.shuffle(indices)

train_split = int(len(indices) * 0.7)
val_split = int(len(indices) * 0.15)

train_idx = indices[:train_split]
val_idx = indices[train_split : train_split+val_split]
test_idx = indices[train_split+val_split : ]


train_sentences = [all_example[i] for i in train_idx]
train_labels = [all_label[i] for i in train_idx]
val_sentences = [all_example[i] for i in val_idx]
val_labels = [all_label[i] for i in val_idx]
test_sentences = [all_example[i] for i in test_idx]
test_labels = [all_label[i] for i in test_idx]

In [9]:
print(train_sentences[117])
print(train_labels[117])

My sister is a brat. She wants everything, and gets everything.
1


In [12]:
word_indexer = Indexer()
word_indexer.add_and_get_index("PAD")
word_indexer.add_and_get_index("UNK")

for sentence in train_sentences:
    word_sentence = nltk.word_tokenize(sentence)
    for word in word_sentence:
        word_indexer.add_and_get_index(word)

f = open("glove.6B.300d-relativized.txt")
for line in f:
    if line.strip() != "":
        space_idx = line.find(' ')
        word = line[:space_idx]
        word_indexer.add_and_get_index(word)
f.close()

vocab_size = len(word_indexer)
print(vocab_size)

21941


In [11]:
train_dataset = StereoSet(train_sentences, train_labels, word_indexer)
val_dataset = StereoSet(val_sentences, val_labels, word_indexer)
test_dataset = StereoSet(test_sentences, test_labels, word_indexer)

train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_batch, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_batch, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_batch, shuffle=False)

In [12]:
encoder = EncoderModel(vocab_size, n_embd, block_size, n_head, n_layer, n_hidden, n_output, device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=learning_rate)

encoder.to(device)
loss_fn.to(device)

CrossEntropyLoss()

In [None]:
for epoch in range(epochs_CLS):
    size = len(train_loader.dataset)
    num_batches = len(train_loader)
    train_loss, correct = 0, 0

    encoder.train()
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)

        pred = encoder(xb)
        # print(pred.argmax(dim=1), yb)
        loss = loss_fn(pred, yb)
        train_loss += loss.item()
        correct += (pred.argmax(dim=1) == yb).type(torch.float).sum().item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    average_train_loss = train_loss / num_batches
    accuracy = correct / size

    with torch.no_grad():
        size = len(val_loader.dataset)
        num_batches = len(val_loader)
        encoder.eval()
        eval_loss, test_correct = 0, 0
        for xt, yt in val_loader:
            xt, yt = xt.to(device), yt.to(device)
            pred = encoder(xt)
            loss = loss_fn(pred, yt)
            eval_loss += loss.item()
            test_correct += (pred.argmax(dim=1) == yt).type(torch.float).sum().item()
        
        average_eval_loss = eval_loss / num_batches
        test_accuracy = test_correct / size

    if epoch == 0 or (epoch + 1) % 5 == 0: 
        print(f'Epoch #{epoch + 1}: train loss {average_train_loss:.3f}, train accuracy {accuracy:.3f}, val loss {average_eval_loss:.3f}, val accuracy {test_accuracy:.3f}')

In [None]:
with torch.no_grad():
    size = len(test_loader.dataset)
    num_batches = len(test_loader)
    encoder.eval()
    eval_loss, test_correct = 0, 0
    for xt, yt in test_loader:
        xt, yt = xt.to(device), yt.to(device)
        pred = encoder(xt)
        loss = loss_fn(pred, yt)
        eval_loss += loss.item()
        test_correct += (pred.argmax(dim=1) == yt).type(torch.float).sum().item()
    
    test_accuracy = test_correct / size
    print(f'test accuracy {test_accuracy:.3f}')

In [5]:
annotations_path = []
dir = './BASIL/annotations/'
for folder in os.listdir(dir):
    folder_path = os.path.join(dir, folder)
    for filename in os.listdir(folder_path):
      annotations_path.append(os.path.join(folder_path, filename))


articles_path = []
dir = './BASIL/articles/'
for folder in os.listdir(dir):
    folder_path = os.path.join(dir, folder)
    for filename in os.listdir(folder_path):
      articles_path.append(os.path.join(folder_path, filename))

In [6]:
annotations_path = sorted(annotations_path)
articles_path = sorted(articles_path)

print(annotations_path)
print(articles_path)
annotations = []
articles = []
for path in sorted(annotations_path):
  with open(path, 'r') as fp:
    annotations.append(json.load(fp))

for path in sorted(articles_path):
  with open(path, 'r') as fp:
    articles.append(json.load(fp))

['./BASIL/annotations/2010\\2b95d2cf-e979-4f9c-ae27-9a5370934f23_1_ann.json', './BASIL/annotations/2010\\2b95d2cf-e979-4f9c-ae27-9a5370934f23_2_ann.json', './BASIL/annotations/2010\\2b95d2cf-e979-4f9c-ae27-9a5370934f23_3_ann.json', './BASIL/annotations/2010\\38f7cbb7-5d6a-4c89-bcbd-8e164144172a_1_ann.json', './BASIL/annotations/2010\\38f7cbb7-5d6a-4c89-bcbd-8e164144172a_2_ann.json', './BASIL/annotations/2010\\38f7cbb7-5d6a-4c89-bcbd-8e164144172a_3_ann.json', './BASIL/annotations/2010\\45bd61bc-c356-4450-9e3a-cbfc862b09fd_1_ann.json', './BASIL/annotations/2010\\45bd61bc-c356-4450-9e3a-cbfc862b09fd_2_ann.json', './BASIL/annotations/2010\\45bd61bc-c356-4450-9e3a-cbfc862b09fd_3_ann.json', './BASIL/annotations/2010\\6b541575-99b1-40d2-8730-9bb868ee38ed_1_ann.json', './BASIL/annotations/2010\\6b541575-99b1-40d2-8730-9bb868ee38ed_2_ann.json', './BASIL/annotations/2010\\6b541575-99b1-40d2-8730-9bb868ee38ed_3_ann.json', './BASIL/annotations/2010\\6f95dcb9-e960-45ac-8c0e-91b85724c909_1_ann.json'

In [7]:
all_sentences = []
all_labels = []

for i in range(len(articles)):
    paragraphs = articles[i]["body-paragraphs"]
    sentences = [sent for para in paragraphs for sent in para]
    annotats = annotations[i]["phrase-level-annotations"]
    labels = [0 for _ in range(len(sentences))]
    for annot in annotats:
        if annot["id"][0] == 'p':
            id = int(annot["id"][1:])
            polarity = annot['polarity']
            if polarity == 'neg':
                labels[id] = 1
            elif polarity == 'pos':
                labels[id] = 2
    all_sentences.append(sentences)
    all_labels.append(labels)

sentence_data = [sent for sublist in all_sentences for sent in sublist]
label_data = [label for sublist in all_labels for label in sublist]

In [8]:
indices = list(range(len(sentence_data)))
random.shuffle(indices)

train_split = int(len(indices) * 0.7)
val_split = int(len(indices) * 0.15)

train_idx = indices[:train_split]
val_idx = indices[train_split : train_split+val_split]
test_idx = indices[train_split+val_split : ]


train_sentences = [sentence_data[i] for i in train_idx]
train_labels = [label_data[i] for i in train_idx]
val_sentences = [sentence_data[i] for i in val_idx]
val_labels = [label_data[i] for i in val_idx]
test_sentences = [sentence_data[i] for i in test_idx]
test_labels = [label_data[i] for i in test_idx]

In [9]:
class BasilDataset(Dataset):
    def __init__(self, sentences, labels, indexer):

        self.sentences = sentences
        self.labels = labels
        self.indexer = indexer
        
        self.indices = []
        for sent in self.sentences:
            sent_indices = []
            word_sent = nltk.word_tokenize(sent)
            for word in word_sent:
                if self.indexer.index_of(word) != -1:
                    sent_indices.append(self.indexer.index_of(word))
                else:
                    sent_indices.append(self.indexer.index_of("UNK"))
            if len(sent_indices) < 128:
                for i in range(len(sent_indices), 128):
                    sent_indices.append(self.indexer.index_of("PAD"))
            else:
                sent_indices = sent_indices[:128]
            self.indices.append(sent_indices)

        self.indices = torch.tensor(self.indices, dtype=torch.long)
        self.labels = torch.tensor(self.labels, dtype=torch.long)       

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        return self.indices[idx], self.labels[idx]

In [13]:
train_basil_dataset = BasilDataset(train_sentences, train_labels, word_indexer)
val_basil_dataset = BasilDataset(val_sentences, val_labels, word_indexer)
test_basil_dataset = BasilDataset(test_sentences, test_labels, word_indexer)

train_basil_loader = DataLoader(train_basil_dataset, batch_size=16, collate_fn=collate_batch, shuffle=True)
val_basil_loader = DataLoader(val_basil_dataset, batch_size=16, collate_fn=collate_batch, shuffle=False)
test_basil_loader = DataLoader(test_basil_dataset, batch_size=16, collate_fn=collate_batch,shuffle=False)

In [14]:
encoder_basil = EncoderModel(vocab_size, 128, block_size, n_head, n_layer, n_hidden, n_output, device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(encoder_basil.parameters(), lr=learning_rate)

encoder_basil.to(device)
loss_fn.to(device)

CrossEntropyLoss()

In [None]:
for epoch in range(epochs_CLS):
    size = len(train_basil_loader.dataset)
    num_batches = len(train_basil_loader)
    train_loss, correct = 0, 0

    encoder_basil.train()
    for xb, yb in train_basil_loader:
        xb, yb = xb.to(device), yb.to(device)

        pred = encoder_basil(xb)
        # print(pred.argmax(dim=1), yb)
        loss = loss_fn(pred, yb)
        train_loss += loss.item()
        correct += (pred.argmax(dim=1) == yb).type(torch.float).sum().item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    average_train_loss = train_loss / num_batches
    accuracy = correct / size

    with torch.no_grad():
        size = len(val_basil_loader.dataset)
        num_batches = len(val_basil_loader)
        encoder_basil.eval()
        eval_loss, test_correct = 0, 0
        for xt, yt in val_basil_loader:
            xt, yt = xt.to(device), yt.to(device)
            pred = encoder_basil(xt)
            loss = loss_fn(pred, yt)
            eval_loss += loss.item()
            test_correct += (pred.argmax(dim=1) == yt).type(torch.float).sum().item()
        
        average_eval_loss = eval_loss / num_batches
        test_accuracy = test_correct / size

    if epoch == 0 or (epoch + 1) % 10 == 0: 
        print(f'Epoch #{epoch + 1}: train loss {average_train_loss:.3f}, train accuracy {accuracy:.3f}, val loss {average_eval_loss:.3f}, val accuracy {test_accuracy:.3f}')

In [None]:
with torch.no_grad():
    size = len(test_basil_loader.dataset)
    num_batches = len(test_basil_loader)
    encoder_basil.eval()
    eval_loss, test_correct = 0, 0
    for xt, yt in test_basil_loader:
        xt, yt = xt.to(device), yt.to(device)
        pred = encoder_basil(xt)
        loss = loss_fn(pred, yt)
        eval_loss += loss.item()
        test_correct += (pred.argmax(dim=1) == yt).type(torch.float).sum().item()
    
    test_accuracy = test_correct / size
    print(f'test accuracy {test_accuracy:.3f}')