In [1]:
import importlib
import nlp_train_utils
importlib.reload(nlp_train_utils)
from nlp_train_utils import *
import os

[nltk_data] Downloading package punkt to /home/alexisross/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt to /home/alexisross/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [5]:
import torch
from transformers import BertTokenizer, BertModel
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, AdamW

from nltk.tree import Tree
from nltk.tokenize.treebank import TreebankWordDetokenizer

from transformers import get_linear_schedule_with_warmup
from tqdm.notebook import tqdm
import numpy as np

import requests
from bs4 import BeautifulSoup
import nltk
nltk.download('punkt')
from nltk.tokenize import word_tokenize
from nltk.corpus import wordnet as wn
from sklearn.metrics import f1_score, precision_score, recall_score


def load_model(device, model_name = 'bert-base-uncased'):
    print("loading model...")
    model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
    model.to(device)
    model.train()

    tokenizer = BertTokenizer.from_pretrained(model_name)

    print("done.")
    return model, tokenizer

def get_sst_data(file_path):
    with open(file_path, "r") as data_file:
        lines = list(data_file.readlines())
        texts = []
        labels = []
        for i, line in enumerate(lines):
            line = line.strip("\n")
            if not line:
                continue
            parsed_line = Tree.fromstring(line)
            text = (TreebankWordDetokenizer().detokenize(parsed_line.leaves()))
            sentiment = int(parsed_line.label())
            if sentiment < 2:
                sentiment = 0.0
            elif sentiment == 2:
                continue
            else:
                sentiment = 1.0
            texts.append(text)
            labels.append(sentiment)
    return texts, labels

def antonyms(term):
    response = requests.get('https://www.thesaurus.com/browse/{}'.format(term))
    soup = BeautifulSoup(response.text, 'lxml')
    return [span.text for span in soup.findAll('a', {'class': 'css-4elvh4'})] # class = .css-7854fb for less relevant


def get_candidates(model, text, max_candidates = 1):
    words = word_tokenize(text)
    candidates = [None] * max_candidates
    counter = 0
    for word in words:
        if wn.synsets(word) == []:
            continue
        tmp = wn.synsets(word)[0].pos()
        # if not adjective or noun, continue
        if tmp != "a" and tmp != "n":
            continue
        for a in antonyms(word):
            candidates[counter] = (TreebankWordDetokenizer().detokenize([a if x == word else x for x in words]))
            counter += 1
            if counter >= max_candidates:
                return list(filter(None.__ne__, candidates))
    return list(filter(None.__ne__, candidates))

def get_delta_opt(model, tokenizer, device, text):
    cands = get_candidates(model, text)
    max_prob = 0
    found_cand = False
    for c in cands:
        cand_logits, cand_labels, cand_prob = get_pred(model, tokenizer, device, c, 1.0)
        print(cand_prob)
        if cand_prob > max_prob:
            max_cand = c
            max_prob = cand_prob
            max_logits = cand_logits
            max_prob = cand_prob
            found_cand = True
        else:
            del cand_logits
            del cand_labels
            del cand_prob
            torch.cuda.empty_cache()
    if not found_cand:
        max_cand, max_logits, max_prob = get_pred(model, tokenizer, device, text, 1.0)
    return max_cand, max_logits, max_prob

def get_pred(model, tokenizer, device, text, label):
    encoding = tokenizer(text, return_tensors='pt', padding=True, truncation=True)['input_ids']
    input_ids = encoding.to(device)
    labels = torch.LongTensor([label]).to(device)
    outputs = model(input_ids)
    logits = outputs[0]
    pos_prob = torch.nn.Softmax(dim=-1)(logits)[:, -1]
    return logits, labels, pos_prob


def train_nlp(model, tokenizer, weight_dir, thresholds_to_eval, recourse_loss_weight):

    # get data
    device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')

    # load model and tokenizer    
    train_texts, train_labels = get_sst_data('data/nlp_data/train.txt')
    dev_texts, dev_labels = get_sst_data('data/nlp_data/dev.txt')

    batch_size = 32
    threshold = 0.5

    lr = 2e-5
    num_warmup_steps = 0
    num_epochs = 3
    num_train_steps = len(train_texts)/batch_size * num_epochs

    optim = AdamW(model.parameters(), lr=lr)
    scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps, num_train_steps)
    loss_fn = torch.nn.CrossEntropyLoss()

    def combined_loss(model, device, logits, labels, delta_logits, loss_fn, recourse_loss_weight):
        normal_loss = loss_fn(logits, labels)
        print("delta logits: ", delta_logits)
        recourse_loss = loss_fn(delta_logits, torch.LongTensor([1.0]).to(device))
        return recourse_loss * recourse_loss_weight + normal_loss

    best_val_loss = 100000000

    for epoch in range(num_epochs):
        batch_loss, train_epoch_loss, train_correct = 0, 0, 0
        
        print("EPOCH: ", epoch)
        model.train()

        for i, (text, label) in tqdm(enumerate(zip(train_texts, train_labels)), total = len(train_texts)):
            logits, labels, pos_prob = get_pred(model, tokenizer, device, text, label)
            _, delta_logits, _ = get_delta_opt(model, tokenizer, device, text)
            batch_loss += combined_loss(model, device, logits, labels, delta_logits, loss_fn, recourse_loss_weight)
                        
            if i % 1000 == 0:
                print(i, " out of ", len(train_texts))
            
            if i % batch_size == 0:    
                model.zero_grad() 
                batch_loss.backward()
                train_epoch_loss += batch_loss.item()
                optim.step()
                scheduler.step()
                del batch_loss
                torch.cuda.empty_cache()
                batch_loss = 0
                
        print("Train acc: ", train_correct/len(train_texts))
        print("Train epoch loss: ", train_epoch_loss/len(train_texts))
        
        model.eval()
            
        val_correct = 0
        epoch_val_loss = 0
        pos_probs = []

        flipped_by_thresh = {thresh: 0 for thresh in thresholds_to_eval}
        negative_by_thresh = {thresh :0 for thresh in thresholds_to_eval}

        for i, (text, label) in tqdm(enumerate(zip(dev_texts, dev_labels)), total = len(dev_texts)):
            logits, labels, pos_prob = get_pred(model, tokenizer, device, text, label)
            _, delta_logits, delta_prob = get_delta_opt(model, tokenizer, device, text)
            epoch_val_loss += combined_loss(model, device, logits, labels, delta_logits, loss_fn, recourse_loss_weight)


            del input_ids
            
            pos_probs.append(pos_prob.item())

            for t in thresholds_to_eval:
                if pos_prob.item() < t:
                    negative_by_thresh[t] += 1
                    if delta_prob >= t:
                        flipped_by_thresh[t] += 1
            
        if epoch_val_los < best_val_loss:
            best_model_name = weight_dir + str(recourse_loss_weight) + 'best_model.pt'
            torch.save(model, best_model_name)
            best_epoch = True

        else:
            best_epoch = False

        # if best epoch, eval
        if best_epoch:
            np_probs = np.array(pos_probs)
            np_labels = np.array(dev_labels)

            f1_by_thresh, recall_by_thresh, precision_by_thresh, acc_by_thresh, flipped_proportion_by_thresh, recourse_proportion_by_thresh = [], [], [], [], [], []

            for t_idx, t in enumerate(thresholds_to_eval):
                label_preds = np.array([0.0 if a < t else 1.0 for a in np_probs])

                f1 = round(f1_score(label_preds, np_labels), 3)
                f1_by_thresh.append(f1) 

                recall = round(recall_score(label_preds, np_labels), 3)
                recall_by_thresh.append(recall)

                prec = round(precision_score(label_preds, np_labels), 3)
                precision_by_thresh.append(prec)

                acc = round(np.sum(label_preds == np_labels)/np_labels.shape[0], 3)
                acc_by_thresh.append(acc) 

                num_neg = negative_by_thresh[t]
                num_pos = len(dev_labels) - num_neg
                assert (num_neg + num_pos) == len(dev_labels)
                flipped = flipped_by_thresh[t]

                if num_neg != 0:
                    flipped_proportion = round(flipped/num_neg, 3)
                else:
                    flipped_proportion = 0

                recourse_proportion = round((flipped + num_pos)/len(dev_labels), 3)

                flipped_proportion_by_thresh.append(flipped_proportion)
                recourse_proportion_by_thresh.append(recourse_proportion)

            
            thresholds_data = {}

            thresholds_data['thresholds'] = thresholds_to_eval
            thresholds_data['precisions'] = precision_by_thresh
            thresholds_data['flipped_proportion'] = flipped_proportion_by_thresh
            thresholds_data['recourse_proportion'] = recourse_proportion_by_thresh
            thresholds_data['f1s'] = f1_by_thresh
            thresholds_data['accs'] = acc_by_thresh
            thresholds_data['recalls'] = recall_by_thresh
            thresholds_data['precisions'] = precision_by_thresh

            thresholds_df = pd.DataFrame(data=thresholds_data)
            best_model_thresholds_file_name = weight_dir + str(recourse_loss_weight) + '_val_thresholds_info.csv'
            thresholds_df.to_csv(best_model_thresholds_file_name, index_label='index')
            
        print("VAL ACC: ", val_correct/len(dev_texts))
        print("+ ", val_preds.count(1.0))
        print("-", val_preds.count(0.0))




[nltk_data] Downloading package punkt to /home/alexisross/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
recourse_loss_weight = 0.0
weight_dir = 'test_nlp/' + str(recourse_loss_weight)
if not os.path.exists(weight_dir):
    os.makedirs(weight_dir)
    
thresholds_to_eval = [0.5, 0.6]
device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')

model, tokenizer = load_model(device, model_name = 'bert-base-uncased')


loading model...


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

done.


In [6]:
train_nlp(model, tokenizer, weight_dir, thresholds_to_eval, recourse_loss_weight)


EPOCH:  0


HBox(children=(FloatProgress(value=0.0, max=6920.0), HTML(value='')))

tensor([0.3690], device='cuda:1', grad_fn=<SelectBackward>)
delta logits:  tensor([[ 0.3082, -0.2282]], device='cuda:1', grad_fn=<AddmmBackward>)
0  out of  6920



RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 1; 10.76 GiB total capacity; 774.00 MiB already allocated; 6.56 MiB free; 864.00 MiB reserved in total by PyTorch)
Exception raised from malloc at /pytorch/c10/cuda/CUDACachingAllocator.cpp:272 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f4cd777b1e2 in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x1e64b (0x7f4cd79d164b in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libc10_cuda.so)
frame #2: <unknown function> + 0x1f464 (0x7f4cd79d2464 in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libc10_cuda.so)
frame #3: <unknown function> + 0x1faa1 (0x7f4cd79d2aa1 in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libc10_cuda.so)
frame #4: at::native::empty_cuda(c10::ArrayRef<long>, c10::TensorOptions const&, c10::optional<c10::MemoryFormat>) + 0x11e (0x7f4c87d3890e in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #5: <unknown function> + 0xf33949 (0x7f4c86172949 in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #6: <unknown function> + 0xf4d777 (0x7f4c8618c777 in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0x10e9c7d (0x7f4cc0f28c7d in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0x10e9f97 (0x7f4cc0f28f97 in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #9: at::empty(c10::ArrayRef<long>, c10::TensorOptions const&, c10::optional<c10::MemoryFormat>) + 0xfa (0x7f4cc1033a1a in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #10: at::native::mm_cuda(at::Tensor const&, at::Tensor const&) + 0x6c (0x7f4c87227ffc in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #11: <unknown function> + 0xf22a20 (0x7f4c86161a20 in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #12: <unknown function> + 0xa56530 (0x7f4cc0895530 in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #13: at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&)> const&, at::Tensor const&, at::Tensor const&) const + 0xbc (0x7f4cc107d81c in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #14: at::mm(at::Tensor const&, at::Tensor const&) + 0x4b (0x7f4cc0fce6ab in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #15: <unknown function> + 0x2ed0a2f (0x7f4cc2d0fa2f in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #16: <unknown function> + 0xa56530 (0x7f4cc0895530 in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #17: at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&)> const&, at::Tensor const&, at::Tensor const&) const + 0xbc (0x7f4cc107d81c in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #18: at::Tensor::mm(at::Tensor const&) const + 0x4b (0x7f4cc1163cab in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #19: <unknown function> + 0x2d11fbb (0x7f4cc2b50fbb in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #20: torch::autograd::generated::MmBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0x25f (0x7f4cc2b6c7df in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #21: <unknown function> + 0x3375bb7 (0x7f4cc31b4bb7 in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #22: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x1400 (0x7f4cc31b0400 in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #23: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x451 (0x7f4cc31b0fa1 in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #24: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x89 (0x7f4cc31a9119 in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #25: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x4a (0x7f4ce853f4ba in /home/alexisross/.local/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #26: <unknown function> + 0xbd6df (0x7f4cf41d46df in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #27: <unknown function> + 0x76db (0x7f4cf9bf46db in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #28: clone + 0x3f (0x7f4cf9f2da3f in /lib/x86_64-linux-gnu/libc.so.6)
