# Read SST2 Data

In [271]:
train_text, train_label = [], []
with open('train.tsv', 'r', encoding='utf-8') as f:
    i = 0
    for line in f:
        if i == 0:
            i += 1
            continue
        train_text.append(line.strip().split('\t')[0].strip())
        train_label.append(int(line.strip().split('\t')[1]))
dev_text, dev_label = [], []
with open('dev.tsv', 'r', encoding='utf-8') as f:
    i = 0
    for line in f:
        if i == 0:
            i += 1
            continue
        dev_text.append(line.strip().split('\t')[0].strip())
        dev_label.append(int(line.strip().split('\t')[1].strip()))

# Model

BERT-distil

In [21]:
import torch
import numpy as np
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
config = DistilBertConfig.from_pretrained('bert-distil')
tokenizer = DistilBertTokenizer.from_pretrained('bert-distil')
bert_model = DistilBertModel.from_pretrained('bert-distil', config=config).cuda()
for param in bert_model.base_model.parameters():
    param.requires_grad = False

Bert-distil + Attention

In [273]:
import torch.nn as nn

def bert_embedding(x, bert_model, bert_tokenizer, device):
    encode = bert_tokenizer(x, return_tensors='pt', padding=True, add_special_tokens=False)
    input_ids, attention_mask = encode['input_ids'].to(device), encode['attention_mask'].to(device)
    embed_x = bert_model(input_ids, attention_mask=attention_mask) # [bs, seq_len, 768]
    return embed_x[0]

class Att(nn.Module):
    def __init__(self, device, input_dim=768, out_dim=2):
        super(Att, self).__init__()
        self.proj1 = nn.Linear(input_dim, input_dim)
        self.tanh = nn.Tanh()
        self.u = nn.Parameter(torch.Tensor(input_dim, 1))
        self.proj2 = nn.Linear(input_dim, out_dim)
        self.device = device
        self.init_params()

    def init_params(self):
        nn.init.xavier_uniform_(self.proj1.weight.data)
        nn.init.xavier_uniform_(self.proj2.weight.data)
        nn.init.constant_(self.proj1.bias.data, 0.1)
        nn.init.constant_(self.proj2.bias.data, 0.1)
        nn.init.uniform_(self.u, -0.1, 0.1)
        
    def forward(self, x, bert_model, bert_tokenizer):
        embed_x = bert_embedding(x, bert_model, bert_tokenizer, self.device) # [bs, seq_len, 768]
        ut = self.tanh(self.proj1(embed_x)) # ut: [bs, seq_len, 768]
        alpha = torch.softmax(torch.matmul(ut, self.u), dim=1) # alpha: [bs, seq_len, 1]
        s = torch.sum(alpha * embed_x, dim=1) # s: [bs, 768]
        return self.proj2(s), alpha

# Training and test function

training function

In [108]:
def train(model, train_text, train_label, bs, num_epoch, optimizer, criterion,
          dev_text, dev_label, bert_model, bert_tokenizer, PATH, device):
    bert_model.eval()
    num_batch = len(train_text) // bs
    max_dev_acc = 0.
    for epoch in range(num_epoch):
        tot_loss = 0.
        model.train()
        for i in range(num_batch):
            optimizer.zero_grad()
            x, y = train_text[i*bs:(i+1)*bs], torch.LongTensor(train_label[i*bs:(i+1)*bs]).to(device)
            pred, _ = model(x, bert_model, bert_tokenizer) # [bs, out_dim]
            loss = criterion(pred, y)
            tot_loss += loss.item()
            loss.backward()
            optimizer.step()
        dev_acc = eval(model, dev_text, dev_label, bs, bert_model, bert_tokenizer)
        if dev_acc > max_dev_acc:
            max_dev_acc = dev_acc
            torch.save(model.state_dict(), PATH)
        print(f"Epoch {epoch+1}/{num_epoch}, Total loss: {tot_loss:.4f}, Dev Acc.: {dev_acc:.2%}.")

evaluation function

In [104]:
def eval(model, dev_text, dev_label, bs, bert_model, bert_tokenizer):
    correct, total = 0, 0
    num_batch = len(dev_text) // bs
    model.eval()
    bert_model.eval()
    with torch.no_grad():
        for i in range(num_batch):
            x, y = dev_text[i*bs:(i+1)*bs], dev_label[i*bs:(i+1)*bs]
            pred, _ = model(x, bert_model, bert_tokenizer)
            pred = pred.argmax(dim=-1)
            for j in range(len(y)):
                if pred[j].item() == y[j]:
                    correct += 1
                total += 1
    return correct / total

# Training

In [275]:
device = torch.device('cuda')
model = Att(device=device,
            input_dim=768,
            out_dim=2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [276]:
train(model, train_text, train_label, bs=64, num_epoch=10,
      optimizer=optimizer, criterion=criterion,
      dev_text=dev_text, dev_label=dev_label,
      bert_model=bert_model, bert_tokenizer=tokenizer,
      PATH='SST2-model-ATT.pt', device=device)

Epoch 1/10, Total loss: 455.4002, Dev Acc.: 84.01%.
Epoch 2/10, Total loss: 388.0566, Dev Acc.: 82.21%.
Epoch 3/10, Total loss: 352.2763, Dev Acc.: 82.45%.
Epoch 4/10, Total loss: 321.0624, Dev Acc.: 82.81%.
Epoch 5/10, Total loss: 293.0998, Dev Acc.: 81.37%.
Epoch 6/10, Total loss: 271.4165, Dev Acc.: 83.05%.
Epoch 7/10, Total loss: 254.6685, Dev Acc.: 83.77%.
Epoch 8/10, Total loss: 238.5233, Dev Acc.: 81.97%.
Epoch 9/10, Total loss: 224.2089, Dev Acc.: 83.05%.
Epoch 10/10, Total loss: 214.3441, Dev Acc.: 82.93%.


# Extract Sentiment Words

In [277]:
model.load_state_dict(torch.load("SST2-model-ATT.pt"))

<All keys matched successfully>

In [305]:
def analysis(sent, model, bert_model, bert_tokenizer):
    model.eval()
    tok_sent = bert_tokenizer.tokenize(sent)
    new_sent = []
    p = 0
    while p < len(tok_sent):
        if "##" not in tok_sent[p]:
            new_sent.append(tok_sent[p])
        else:
            new_sent[-1] += tok_sent[p][2:]
        p += 1
    with torch.no_grad():
        _, scores = model([sent], bert_model, bert_tokenizer)
        scores = scores.view(-1).cpu().numpy().tolist()
        new_scores = []
        p = 0
        while p < len(scores):
            if "##" not in tok_sent[p]:
                if len(new_scores) > 0:
                    new_scores[-1] /= count
                new_scores.append(scores[p])
                count, p = 1, p + 1
            else:
                new_scores[-1] += scores[p]
                count, p = count + 1, p + 1
        new_scores = np.array(new_scores)
        
        idx = []
        for i in range(min(5, len(new_sent))):
            idx.append(new_scores.argmax())
            new_scores[idx[-1]] = -float('inf')
    
    #for i in range(len(idx)):
    #    print(f"{i+1}. {new_sent[idx[i]]}")
    return new_sent[idx[0]]

Sentiment DIctionary Building

In [310]:
pos, neg = set(), set()
for i in range(len(train_text)):
    if train_label[i] == 0:
        neg.add(analysis(train_text[i], model, bert_model, tokenizer))
    else:
        pos.add(analysis(train_text[i], model, bert_model, tokenizer))
pos_neg = pos & neg
pos = pos - pos_neg
neg = neg - pos_neg
with open('pos.txt', 'w', encoding='utf-8') as f:
    for word in pos:
        f.write(word + ', ')
with open('neg.txt', 'w', encoding='utf-8') as f:
    for word in neg:
        f.write(word + ', ')

# Evaluate Sentiment Dictionary

Build a test set

In [341]:
pred, true, new_text = [], [], []
for i in range(len(dev_text)):
    word = analysis(dev_text[i], model, bert_model, tokenizer)
    f1, f2 = False, False
    if word in pos:
        f1 = True
    if word in neg:
        f2 = True
    if (f1 and f2) or (not f1 and not f2):
        continue
    elif f1:
        true.append(dev_label[i])
        pred.append(1)
        new_text.append(dev_text[i])
    else:
        pred.append(0)
        true.append(dev_label[i])
        new_text.append(dev_text[i])

Baseline

In [340]:
correct, tot = 0, 0
for i in range(len(pred)):
    if true[i] == 1:
        correct += 1
    tot += 1
correct / tot

0.5440414507772021

ATT Performance

In [343]:
eval(model, new_text, true, 32, bert_model, tokenizer)

0.8177083333333334

Dictionary Performance

In [361]:
pred = []
for sent in new_text:
    tok_sent = tokenizer.tokenize(sent)
    new_sent = []
    p = 0
    while p < len(tok_sent):
        if "##" not in tok_sent[p]:
            new_sent.append(tok_sent[p])
        else:
            new_sent[-1] += tok_sent[p][2:]
        p += 1
    res = [0, 0]
    for word in new_sent:
        if word in pos:
            res[1] += 1
        if word in neg:
            res[0] += 1
    if res[0] >= res[1]:
        pred.append(0)
    else:
        pred.append(1)
correct, tot = 0, 0
for i in range(len(pred)):
    if pred[i] == true[i]:
        correct += 1
    tot += 1
correct / tot

0.7564766839378239