In [1]:
import json
import os
import numpy as np
import torch
import random
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

In [2]:
with open('Subtask_1_train.json') as f:
    data = json.load(f)

In [3]:
data[0]

{'conversation_ID': 1,
 'conversation': [{'utterance_ID': 1,
   'text': 'Alright , so I am back in high school , I am standing in the middle of the cafeteria , and I realize I am totally naked .',
   'speaker': 'Chandler',
   'emotion': 'neutral'},
  {'utterance_ID': 2,
   'text': 'Oh , yeah . Had that dream .',
   'speaker': 'All',
   'emotion': 'neutral'},
  {'utterance_ID': 3,
   'text': 'Then I look down , and I realize there is a phone ... there .',
   'speaker': 'Chandler',
   'emotion': 'surprise'},
  {'utterance_ID': 4,
   'text': 'Instead of ... ?',
   'speaker': 'Joey',
   'emotion': 'surprise'},
  {'utterance_ID': 5,
   'text': 'That is right .',
   'speaker': 'Chandler',
   'emotion': 'anger'},
  {'utterance_ID': 6,
   'text': 'Never had that dream .',
   'speaker': 'Joey',
   'emotion': 'neutral'},
  {'utterance_ID': 7,
   'text': 'No .',
   'speaker': 'Phoebe',
   'emotion': 'neutral'},
  {'utterance_ID': 8,
   'text': 'All of a sudden , the phone starts to ring .',
   's

In [4]:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline

# Load the pretrained model and tokenizer
model_checkpoint = "Nakul24/Spanbert-emotion-extraction"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
# model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

# Add new tokens to the tokenizer
new_tokens = {'additional_special_tokens': ["[SP1]", "[SP2]", "[SP3]", "[SP4]", "[SP5]", "[SP6]", "[SP7]", "[SP8]", "[SEP]"]}
tokenizer.add_special_tokens(new_tokens)

# Resize the model's embedding layer to accommodate the new tokens
# model.resize_token_embeddings(len(tokenizer))

# Use the modified tokenizer and model
# nlp = pipeline('question-answering', model=model, tokenizer=tokenizer)

# Example usage
context = ["[SP1] This is an example sentence.", "[SP2] This is another example sentence."]
question = ["What is an example?", "This is an example sentence."]

# Tokenize with the modified tokenizer
inputs = tokenizer(context, question, return_tensors="pt", padding=True)

print (inputs)

# Forward pass with the modified model
# outputs = model(**inputs)

# Extract answers
# start_logits, end_logits = outputs.start_logits, outputs.end_logits
# start_index = torch.argmax(start_logits)
# end_index = torch.argmax(end_logits) + 1
# answer = tokenizer.decode(inputs["input_ids"][0][start_index:end_index])

# print("Answer:", answer)


{'input_ids': tensor([[  101, 29001,  1142,  1110,  1126,  1859,  5650,   119,   102,  1184,
          1110,  1126,  1859,   136,   102,     0],
        [  101, 29003,  1142,  1110,  1330,  1859,  5650,   119,   102,  1142,
          1110,  1126,  1859,  5650,   119,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


In [5]:
tokenizer.decode(inputs['input_ids'][1])

'[CLS] [SP2] this is another example sentence. [SEP] this is an example sentence. [SEP]'

In [7]:
tokenizer.vocab['[SP1]']

29001

In [5]:
def decode1(conversation):
    emotions = [i['emotion'] for i in conversation['conversation']]
    texts = [i['text'] for i in conversation['conversation']]
    speakers = [i['speaker'] for i in conversation['conversation']]
    speaker_dict = {}
    ind = 1
    for speaker in speakers:
        if speaker not in speaker_dict.keys():
            if ind <= 7:
                speaker_dict[speaker] = '[SP'+str(ind)+']'
                ind = ind+1
            else:
                speaker_dict[speaker] = '[SP8]'
    speakers = [speaker_dict[speaker] for speaker in speakers]

    history = [speakers[i]+conversation['conversation'][i]['text'] for i in range(0,len(texts))]
    history = "[CLS]"+ "[SEP]".join(history)
    conversation_history = [history]*len(texts)
    batch = {'emotion':emotions,"target_utterance":texts,
             "conversation_history":conversation_history}
    return batch

def decode2(conversation):
    ids = [i['utterance_ID'] for i in conversation['conversation']]
    emotions = [i['emotion'] for i in conversation['conversation']]
    texts = [i['text'] for i in conversation['conversation']]
    speakers = [i['speaker'] for i in conversation['conversation']]
    speaker_dict = {}
    ind = 1
    for speaker in speakers:
        if speaker not in speaker_dict.keys():
            if ind <= 7:
                speaker_dict[speaker] = '[SP'+str(ind)+']'
                ind = ind+1
            else:
                speaker_dict[speaker] = '[SP8]'
    speakers = [speaker_dict[speaker] for speaker in speakers]
    #speaker_no = [int(speaker[4]) for speaker in speakers]

    history = [speakers[i]+conversation['conversation'][i]['text'] for i in range(0,len(texts))]
    history = "[SEP]".join(history)
    conversation_history = [history]*len(texts)

    ec_pairs = conversation['emotion-cause_pairs']
    cause_spans = {}
    for index in ids:
        cause_spans[index-1] = []
    for ec in ec_pairs:
        utt_id = int(ec[1][:ec[1].find('_')]) - 1
        utt_text = texts[utt_id]
        text_int = utt_text.find(ec[1][ec[1].find('_')+1:])
        span_start = history.find(utt_text)+text_int
        span_end = history.find(utt_text)+text_int+len(ec[1][ec[1].find('_')+1:])
        cause_spans[int(ec[0][:ec[0].find('_')]) - 1].append((int(ec[1][:ec[1].find('_')]) - 1, span_start, span_end))
    batch = {'emotion':emotions,"target_utterance":texts,
             "conversation_history":conversation_history,"cause_spans":cause_spans}
    return batch


def decode_span(conversation):
  ids = [i['utterance_ID'] for i in conversation['conversation']]
  emotions = [i['emotion'] for i in conversation['conversation']]

  non_neutral_emotions = [i+1 for i in range(len(emotions)) if emotions[i] != 'neutral']

  texts = [i['text'] for i in conversation['conversation']]
  speakers = [i['speaker'] for i in conversation['conversation']]
  speaker_dict = {}
  ind = 1
  for speaker in speakers:
      if speaker not in speaker_dict.keys():
          if ind <= 7:
              speaker_dict[speaker] = '[SP'+str(ind)+']'
              ind = ind+1
          else:
              speaker_dict[speaker] = '[SP8]'
  speakers = [speaker_dict[speaker] for speaker in speakers]

  history = [speakers[i]+conversation['conversation'][i]['text'] for i in range(0,len(texts))]
  history = "[SEP]".join(history)

  ec_pairs = conversation['emotion-cause_pairs']

  ec_dict = {}
  for ind in non_neutral_emotions:
    ec_dict[ind] = []

  for ec in ec_pairs:
    ec_dict[int(ec[0].split('_')[0])].append(ec[1])

  emotion = []
  target_utterence = []
  evidence_utterence = []
  conversation_history = []
  cause_spans = []

  for i in range(0, len(texts)):
    if emotions[i] != 'neutral':
      emotion = emotion + ([emotions[i]]*len(texts))
      target_utterence = target_utterence + ([texts[i]]*len(texts))
      #evidence_utterence = evidence_utterence + texts

      conversation_history_single = [text + ' [SEP] '+ history for text in texts]

      conversation_history = conversation_history + (conversation_history_single)
      relevant_contexts = ec_dict[i+1]
      relevant_contexts_dict = {}
      for context in relevant_contexts:
        relevant_contexts_dict[int(context.split('_')[0])-1] = context.split('_')[1]
      cause_span = []
      for j in range(0,len(texts)):
        if j not in relevant_contexts_dict.keys():
          cause_span.append([0,0,0])
        else:
          span_text = relevant_contexts_dict[j]
          actual_text = texts[j]

          span_start = actual_text.find(span_text)
          span_end = span_start + len(span_text)

          cause_span.append([1,span_start,span_end])
      cause_spans = cause_spans + cause_span

  batch = {"emotion":emotion,"target_utterance":target_utterence, "context":conversation_history,
             "cause_spans":cause_spans}

  return batch

In [6]:
# train_decoded = [decode_span(conv) for conv in train_data]

In [7]:
em_dict = {'anger':0, 'disgust':1, 'fear':2, 'joy':3, 'neutral':6, 'sadness':5, 'surprise':4}

In [8]:
target_utt_list = []
context_list = []
cause_spans_list =[]
emo_list = []
for conv in data:
    conv_decoded = decode_span(conv)
    for i, utt in enumerate(conv_decoded['target_utterance']):
        context = conv_decoded['context'][i]
        target_utt_list.append(utt)
        context_list.append(context)
        cause_spans_list.append(conv_decoded['cause_spans'][i])
        emo_list.append(em_dict[conv_decoded['emotion'][i]])

In [9]:
set(emo_list)

{0, 1, 2, 3, 4, 5}

In [10]:
context_list[10:12]

['Then I look down , and I realize there is a phone ... there . [SEP] [SP1]Alright , so I am back in high school , I am standing in the middle of the cafeteria , and I realize I am totally naked .[SEP][SP2]Oh , yeah . Had that dream .[SEP][SP1]Then I look down , and I realize there is a phone ... there .[SEP][SP3]Instead of ... ?[SEP][SP1]That is right .[SEP][SP3]Never had that dream .[SEP][SP4]No .[SEP][SP1]All of a sudden , the phone starts to ring .',
 'Instead of ... ? [SEP] [SP1]Alright , so I am back in high school , I am standing in the middle of the cafeteria , and I realize I am totally naked .[SEP][SP2]Oh , yeah . Had that dream .[SEP][SP1]Then I look down , and I realize there is a phone ... there .[SEP][SP3]Instead of ... ?[SEP][SP1]That is right .[SEP][SP3]Never had that dream .[SEP][SP4]No .[SEP][SP1]All of a sudden , the phone starts to ring .']

In [11]:
cause_spans_list[:5]

[[1, 91, 121], [0, 0, 0], [1, 0, 61], [0, 0, 0], [0, 0, 0]]

In [12]:
random.seed(1234)
all_batches = list(zip(target_utt_list, context_list, emo_list, cause_spans_list))
random.shuffle(all_batches)


In [13]:
train_target_utt, val_target_utt, train_context, val_context, train_emo, val_emo, train_cause_spans, val_cause_spans \
                      = train_test_split(*zip(*all_batches), test_size=0.2, random_state=1234)

In [14]:
# train_data, val_data = train_test_split(data, test_size = 0.2, random_state=1234)

In [15]:
len(train_target_utt), len(val_target_utt)

(80924, 20232)

In [16]:
# train_decoded[:2]

In [17]:
def char_span_to_token_span(question, context, char_start, char_end, tokenizer):
    """
    Convert character-based span to token-based span in a tokenized question and context.

    Parameters:
    - question (str): The question text.
    - context (str): The original context text.
    - char_start (int): The start index of the span in characters relative to the start of the context.
    - char_end (int): The end index of the span in characters relative to the start of the context.
    - tokenizer: The tokenizer object used to tokenize the question and context.

    Returns:
    - token_start (int): The start index of the span in tokens.
    - token_end (int): The end index of the span in tokens.
    """
    if char_start == 0 and char_end == 0:
        return 0, 0
    # Tokenize the question and context
    char_start = len(context[:char_start].replace(' ', ''))
    char_end = len(context[:char_end].replace(' ', ''))
    # print (char_start, char_end)

    tokens = tokenizer.tokenize(context, return_tensors="pt")
    # tokens = encoding["input_ids"].squeeze().tolist()

    # Get the length of the question
    question_len = len(tokenizer.tokenize(question, return_tensors="pt"))
    # print (question_len)

    # Initialize variables

    token_start = None
    token_end = None
    current_char = 0
    current_token = 0

    # Iterate through tokens and find the corresponding span
    while current_token < len(tokens):
        # print ("*********************************")
        # token = tokenizer.decode(tokens[current_token], skip_special_tokens=False)
        token = tokens[current_token]
        # print (token)
        # print (len(token))

        # Check if the current character is within the span
        if current_char >= char_start and current_char < char_end:
            if token_start is None:
                token_start = current_token + question_len + 2 # Adjust for the question tokens
                # print ('I am start')

            # Check if the next character goes beyond the span
            if current_char + len(token.replace('#', '')) >= char_end:
                token_end = current_token + question_len + 2 # Adjust for the question tokens
                # print ('I am end')
                break

        # Move to the next token and update character position
        current_char += len(token.replace('#', ''))
        # print (current_char)
        current_token += 1
        # print (current_token)

    return token_start, token_end

# Example usage:
# question = "That is one way !"
# context = '[SP1]Oh my God ![SEP][SP2]I know , I know , I am such an idiot .[SEP][SP2]I guess I should have caught on when she started going to the dentist four and five times a week . I mean , how clean can teeth get ?[SEP][SP1]My brother going through that right now , he is such a mess . How did you get through it ?[SEP][SP2]Well , you might try accidentally breaking something valuable of hers , say her ...[SEP][SP1]leg ?[SEP][SP2]That is one way ![SEP][SP2]Me , I ... I went for the watch .[SEP][SP1]You actually broke her watch ?'
question = target_utt_list[651]
context = context_list[651]
char_start = 0
char_end = 28

# tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased-distilled-squad")
token_start, token_end = char_span_to_token_span(question, context, char_start, char_end, tokenizer)

print("Token Start:", token_start)
print("Token End:", token_end)

Token Start: 12
Token End: 18


In [18]:
def create_batches(target_utt_list, context_list, emo_list, cause_spans_list):
    batch_encoded_conversation = []
    batch_cause_spans = []
    batch_start_token_indx = []
    batch_end_token_indx = []

    batch_emo_label = []

    for i in range(0, len(target_utt_list), BATCH_SIZE):
        start_end_token_indx = []
        batch_encoded_conversation.append(tokenizer(target_utt_list[i: i + BATCH_SIZE], context_list[i :i + BATCH_SIZE],
                                                    return_tensors="pt", padding=True, max_length=512, truncation=True,
                                                    return_token_type_ids=True))
        batch_emo_label.append(torch.tensor(emo_list[i: i + BATCH_SIZE], dtype=torch.long))
        batch_cause_spans.append(cause_spans_list[i:i + BATCH_SIZE])
        start_token_indx = []
        end_token_indx = []
        start_token_indx = np.zeros([len(target_utt_list[i: i + BATCH_SIZE]), int(batch_encoded_conversation[-1]['input_ids'].shape[1])])
        end_token_indx = np.zeros([len(target_utt_list[i: i + BATCH_SIZE]), int(batch_encoded_conversation[-1]['input_ids'].shape[1])])
        for j in range(len(target_utt_list[i: i + BATCH_SIZE])):
              se_indx = [char_span_to_token_span(target_utt_list[i + j], context_list[i + j], cause_span[1], cause_span[2], tokenizer) for cause_span in cause_spans_list[i + j]]
              if len(se_indx) > 0:
                  # print (i + j, se_indx)
                  s_indx = [int(x[0]) for x in se_indx if x[0] is not None and x[0] < 512]
                  e_indx = [int(x[1]) for x in se_indx if x[1] is not None and x[1] < 512]
                  # print(s_indx)
                  # print(e_indx)
                  start_token_indx[j, s_indx] = 1.
                  end_token_indx[j, e_indx] = 1.
                  # print(start_token_indx[j,:])

              # start_token_indx.append([x[0] for x in se_indx])
              # end_token_indx.append([x[1] for x in se_indx])
        batch_start_token_indx.append(torch.tensor(start_token_indx, dtype=torch.float))
        batch_end_token_indx.append(torch.tensor(end_token_indx, dtype=torch.float))
    return batch_encoded_conversation, batch_emo_label, batch_start_token_indx, batch_end_token_indx

In [19]:
def create_batches(target_utt_list, context_list, emo_list, cause_spans_list):
    batch_encoded_conversation = []
    batch_cause_spans = []
    batch_start_token_indx = []
    batch_end_token_indx = []

    batch_emo_label = []

    for i in range(0, len(target_utt_list), BATCH_SIZE):
        start_end_token_indx = []
        batch_encoded_conversation.append(tokenizer(target_utt_list[i: i + BATCH_SIZE], context_list[i :i + BATCH_SIZE],
                                                    return_tensors="pt", padding=True, max_length=512, truncation=True,
                                                    return_token_type_ids=False))
        batch_emo_label.append(torch.tensor(emo_list[i: i + BATCH_SIZE], dtype=torch.long))
        batch_cause_spans.append(cause_spans_list[i:i + BATCH_SIZE])
        start_token_indx = []
        end_token_indx = []
        start_token_indx = np.zeros([len(target_utt_list[i: i + BATCH_SIZE]), int(batch_encoded_conversation[-1]['input_ids'].shape[1])])
        end_token_indx = np.zeros([len(target_utt_list[i: i + BATCH_SIZE]), int(batch_encoded_conversation[-1]['input_ids'].shape[1])])
        for j in range(len(target_utt_list[i: i + BATCH_SIZE])):
              se_indx = char_span_to_token_span(target_utt_list[i + j], context_list[i + j], cause_spans_list[i + j][1], cause_spans_list[i + j][2], tokenizer)
              if len(se_indx) > 0:
                  # print (i + j, se_indx)
                  s_indx = se_indx[0]
                  e_indx = se_indx[1]
                  # print(s_indx)
                  # print(e_indx)
                  start_token_indx[j, s_indx] = 1.
                  end_token_indx[j, e_indx] = 1.
                  # print(start_token_indx[j,:])

              # start_token_indx.append([x[0] for x in se_indx])
              # end_token_indx.append([x[1] for x in se_indx])
        batch_start_token_indx.append(torch.tensor(start_token_indx, dtype=torch.float))
        batch_end_token_indx.append(torch.tensor(end_token_indx, dtype=torch.float))
    return batch_encoded_conversation, batch_emo_label, batch_start_token_indx, batch_end_token_indx

In [20]:
BATCH_SIZE = 16


In [21]:
train_batch_conv, train_batch_emo, train_batch_start, train_batch_end = create_batches(train_target_utt, train_context, train_emo, train_cause_spans)

In [22]:
len(train_batch_conv)

5058

In [23]:
val_batch_conv, val_batch_emo, val_batch_start, val_batch_end = create_batches(val_target_utt, val_context, val_emo, val_cause_spans)

In [26]:
for param in model.parameters():
     param.requires_grad = True

NameError: name 'model' is not defined

# New Attempt

In [24]:
# Imports for most of the notebook
import torch
from typing import Dict, List
import random
from tqdm.autonotebook import tqdm

In [25]:
print(torch.cuda.is_available())
# device = torch.device("cpu")
# TODO: Uncomment the below line if you see True in the print statement
device = torch.device("cuda:0")

True


In [26]:
model = AutoModelForQuestionAnswering.from_pretrained("Nakul24/Spanbert-emotion-extraction")
model.resize_token_embeddings(len(tokenizer))
model.to(device).train()

BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(29004, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=

In [27]:
# For making predictions at test time
def predict(model: torch.nn.Module, sents: torch.Tensor) -> List:
    logits = model(**sents)
    start_probs = logits.start_logits.softmax(dim=1)
    end_probs = logits.end_logits.softmax(dim=1)
    return start_probs.cpu().detach().numpy(), end_probs.cpu().detach().numpy()

In [28]:
import numpy as np
from numpy import sum as t_sum
from numpy import logical_and


def precision(predicted_labels, true_labels, which_label=1):
    """
    Precision is True Positives / All Positives Predictions
    """
    pred_which = np.array([pred == which_label for pred in predicted_labels])
    true_which = np.array([lab == which_label for lab in true_labels])
    denominator = t_sum(pred_which)
    if denominator:
        return t_sum(logical_and(pred_which, true_which))/denominator
    else:
        return 0.


def recall(predicted_labels, true_labels, which_label=1):
    """
    Recall is True Positives / All Positive Labels
    """
    pred_which = np.array([pred == which_label for pred in predicted_labels])
    true_which = np.array([lab == which_label for lab in true_labels])
    denominator = t_sum(true_which)
    if denominator:
        return t_sum(logical_and(pred_which, true_which))/denominator
    else:
        return 0.


def f1_score(
    predicted_labels: List[int],
    true_labels: List[int],
    which_label: int
):
    """
    F1 score is the harmonic mean of precision and recall
    """
    P = precision(predicted_labels, true_labels, which_label=which_label)
    R = recall(predicted_labels, true_labels, which_label=which_label)

    if P and R:
        return 2*P*R/(P+R)
    else:
        return 0.


def macro_f1(
    predicted_labels: List[int],
    true_labels: List[int],
    possible_labels: List[int],
    label_map=None
):
    converted_prediction = [label_map[int(x)] for x in predicted_labels] if label_map else predicted_labels
    scores = [f1_score(converted_prediction, true_labels, l) for l in possible_labels]
    # Macro, so we take the uniform avg.
    return sum(scores) / len(scores)

In [29]:
def training_loop(
    num_epochs,
    train_features,
    train_labels,
    train_span_starts,
    train_span_ends,
    dev_sents,
    dev_labels,
    dev_span_starts,
    dev_span_ends,
    optimizer,
    model,
):
    
    print("Training...")
    #weight = torch.tensor([2,3,3,2,0.5,2,2])
    # loss_func = torch.nn.NLLLoss(weight = torch.tensor([2.5,4,4,2,0.25,2.5,2]).to(device))
    loss_cross_ent = torch.nn.CrossEntropyLoss()
    batches = list(zip(train_features, train_labels, train_span_starts, train_span_ends))
    random.shuffle(batches)

    train_labels_all = list(torch.concat(list(zip(*batches))[1]).numpy())
    train_span_starts = list(zip(*batches))[2]
    train_span_ends = list(zip(*batches))[3]

    train_starts_all = torch.concat([x.view(-1) for x in train_span_starts]).numpy()
    train_ends_all = torch.concat([x.view(-1) for x in train_span_ends]).numpy()

    dev_span_starts_all = torch.concat([x.view(-1) for x in dev_span_starts]).numpy()
    dev_span_ends_all = torch.concat([x.view(-1) for x in dev_span_ends]).numpy()

    possible_labels = set(train_labels_all)

    for i in range(num_epochs):
        losses = []
        # train_preds = []
        train_start_probs = []
        train_end_probs = []
        for features, labels, span_starts, span_ends in tqdm(batches):
            # Empty the dynamic computation graph
            optimizer.zero_grad()
            preds = model(**features.to(device),
                         start_positions=torch.where(span_starts==1)[1].to(device), 
                         end_positions=torch.where(span_ends==1)[1].to(device))
            start_probs = preds.start_logits.softmax(dim=1)
            end_probs = preds.end_logits.softmax(dim=1)
            # preds, start_probs, end_probs = \
            #       preds.squeeze(1), start_probs.squeeze(2), end_probs.squeeze(2)
            # loss_emo = loss_func(preds, labels.to(device))
            # loss_start = loss_cross_ent(start_probs, span_starts.to(device))
            # loss_end = loss_cross_ent(end_probs, span_ends.to(device))

            loss = preds.loss

            # Backpropogate the loss through our model
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
            # train_preds.append(preds)
            train_start_probs.append(start_probs)
            train_end_probs.append(end_probs)

        print(f"epoch {i}, loss: {sum(losses)/len(losses)}")
        print ("Evaluating train...")
        # train_preds = torch.concat(train_preds)
        # print (train_preds.shape)
        # train_preds = list(torch.argmax(train_preds, axis=1).squeeze().cpu().numpy())
        # train_labels_all = list(torch.concat(train_labels_all).cpu().numpy())
        # train_f1 = macro_f1(train_preds, train_labels_all, possible_labels)
        # print(f"Train F1 {train_f1}")

        train_start_probs = torch.concat([x.view(-1) for x in train_start_probs]).cpu().detach().numpy()
        train_start_auc = roc_auc_score(train_starts_all, train_start_probs)
        print(f"Train START AUC: {train_start_auc}")

        train_end_probs = torch.concat([x.view(-1) for x in train_end_probs]).cpu().detach().numpy()
        train_end_auc = roc_auc_score(train_ends_all, train_end_probs)
        print(f"Train END AUC: {train_end_auc}")


        # Estimate the f1 score for the development set
        print("Evaluating dev...")
        # all_preds = []
        all_labels = []
        all_starts = []
        all_ends = []
        for sents, labels in tqdm(zip(dev_sents, dev_labels), total=len(dev_sents)):
            # pred = predict(model, sents).cpu()
            start_probs, end_probs = predict(model, sents.to(device))
            # all_preds.extend(pred)
            # all_labels.extend(list(labels.cpu().numpy()))
            all_starts.append(start_probs)
            all_ends.append(end_probs)

        # dev_f1 = macro_f1(all_preds, all_labels, set(all_labels))
        # print(f"Dev F1 {dev_f1}")

        all_starts = np.concatenate([x.reshape(-1) for x in all_starts])
        dev_start_auc = roc_auc_score(dev_span_starts_all, all_starts)
        print(f"Dev START AUC: {dev_start_auc}")

        all_ends = np.concatenate([x.reshape(-1) for x in all_ends])
        dev_end_auc = roc_auc_score(dev_span_ends_all, all_ends)
        print(f"Dev END AUC: {dev_end_auc}")

    # Return the trained model
    return model

In [33]:
# You can increase epochs if need be
epochs = 10

# TODO: Find a good learning rate and hidden size
LR = 5e-5
hidden_size = 32

possible_labels = set(emo_list)
# model = NLIClassifier(output_size=len(possible_labels), hidden_size=hidden_size)
# model.bert.resize_token_embeddings(len(tokenizer))
# model.to(device)
# optimizer = torch.optim.AdamW([{'params': model.hidden_layer.parameters()},
#                                {'params': model.classifier.parameters()},
#                                {'params': model.bert.parameters(), 'lr': 0.000005}],
#                               lr=LR)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)


# batch_tokenizer = BatchTokenizer()

# validation_input_batches = [b for b in chunk_multi(validation_premises, validation_hypotheses, batch_size)]

# Tokenize + encode
# validation_input_batches = [batch_tokenizer(*batch) for batch in validation_input_batches]
# validation_batch_labels = [b for b in chunk(validation_labels, batch_size)]
# validation_batch_labels = [encode_labels(batch) for batch in validation_batch_labels]

trained_model = training_loop(
    epochs,
    train_batch_conv,
    train_batch_emo,
    train_batch_start,
    train_batch_end,
    val_batch_conv,
    val_batch_emo,
    val_batch_start,
    val_batch_end,
    optimizer,
    model,
)

torch.save(model.state_dict(), './Emo_classifier_and_span_QA.pt')

Training...


  0%|          | 0/5058 [00:00<?, ?it/s]

epoch 0, loss: 2.9240074851570417
Evaluating train...
Train START AUC: 0.9811815998225388
Train END AUC: 0.981511118703256
Evaluating dev...


  0%|          | 0/1265 [00:00<?, ?it/s]

Dev START AUC: 0.9815464114327338
Dev END AUC: 0.9818631903066006


  0%|          | 0/5058 [00:00<?, ?it/s]

epoch 1, loss: 0.588716798964829
Evaluating train...
Train START AUC: 0.9993297027204012
Train END AUC: 0.9993728502311088
Evaluating dev...


  0%|          | 0/1265 [00:00<?, ?it/s]

Dev START AUC: 0.9998979710876108
Dev END AUC: 0.9999224097706342


  0%|          | 0/5058 [00:00<?, ?it/s]

epoch 2, loss: 0.21901927951613964
Evaluating train...
Train START AUC: 0.9999283078479304
Train END AUC: 0.999930162107753
Evaluating dev...


  0%|          | 0/1265 [00:00<?, ?it/s]

Dev START AUC: 0.9998869434649074
Dev END AUC: 0.9999353662464454


  0%|          | 0/5058 [00:00<?, ?it/s]

epoch 3, loss: 0.15255171110206664
Evaluating train...
Train START AUC: 0.9999544398446708
Train END AUC: 0.9999683646579952
Evaluating dev...


  0%|          | 0/1265 [00:00<?, ?it/s]

Dev START AUC: 0.9999266266622343
Dev END AUC: 0.9999429718031299


  0%|          | 0/5058 [00:00<?, ?it/s]

epoch 4, loss: 0.11909757941653759
Evaluating train...
Train START AUC: 0.9999787831804497
Train END AUC: 0.9999801971914074
Evaluating dev...


  0%|          | 0/1265 [00:00<?, ?it/s]

Dev START AUC: 0.999859495616807
Dev END AUC: 0.9999318738145336


  0%|          | 0/5058 [00:00<?, ?it/s]

epoch 5, loss: 0.1023054254299786
Evaluating train...
Train START AUC: 0.9999832353775413
Train END AUC: 0.9999833868637292
Evaluating dev...


  0%|          | 0/1265 [00:00<?, ?it/s]

Dev START AUC: 0.999864358626451
Dev END AUC: 0.9999159955484023


  0%|          | 0/5058 [00:00<?, ?it/s]

epoch 6, loss: 0.08424376451944056
Evaluating train...
Train START AUC: 0.9999853805199005
Train END AUC: 0.9999893807569434
Evaluating dev...


  0%|          | 0/1265 [00:00<?, ?it/s]

Dev START AUC: 0.9998851278401489
Dev END AUC: 0.9999174000709389


  0%|          | 0/5058 [00:00<?, ?it/s]

epoch 7, loss: 0.06941495563205212
Evaluating train...
Train START AUC: 0.99998914281649
Train END AUC: 0.9999931793211534
Evaluating dev...


  0%|          | 0/1265 [00:00<?, ?it/s]

Dev START AUC: 0.9999005258778267
Dev END AUC: 0.9997019432268133


  0%|          | 0/5058 [00:00<?, ?it/s]

epoch 8, loss: 0.06205525155744498
Evaluating train...
Train START AUC: 0.9999935899649774
Train END AUC: 0.9999941103777358
Evaluating dev...


  0%|          | 0/1265 [00:00<?, ?it/s]

Dev START AUC: 0.9997325649354932
Dev END AUC: 0.9996577668638736


  0%|          | 0/5058 [00:00<?, ?it/s]

epoch 9, loss: 0.05702099177585144
Evaluating train...
Train START AUC: 0.9999941997992364
Train END AUC: 0.9999951745694953
Evaluating dev...


  0%|          | 0/1265 [00:00<?, ?it/s]

Dev START AUC: 0.9998645551550854
Dev END AUC: 0.9998685951520764


In [102]:
LR = 0.001
model.train()
optimizer = torch.optim.AdamW([{'params': model.hidden_layer.parameters()},
                               {'params': model.classifier.parameters()},
                               {'params': model.bert.parameters(), 'lr': 0.00001}],
                              lr=LR)

trained_model = training_loop(
    2,
    train_batch_conv,
    train_batch_emo,
    train_batch_start,
    train_batch_end,
    val_batch_conv,
    val_batch_emo,
    val_batch_start,
    val_batch_end,
    optimizer,
    model,
)

Training...


  0%|          | 0/681 [00:00<?, ?it/s]

epoch 0, loss: 8.60454346445251
Evaluating train...
Train F1 0.6993372004638662
Train START AUC: 0.8693739188981577
Train END AUC: 0.9129548856640031
Evaluating dev...


  0%|          | 0/171 [00:00<?, ?it/s]

Dev F1 0.414273435107559
Dev START AUC: 0.8675199591520296
Dev END AUC: 0.913759059779703


  0%|          | 0/681 [00:00<?, ?it/s]

epoch 1, loss: 8.484089928314024
Evaluating train...
Train F1 0.7478131279095076
Train START AUC: 0.8730592123899641
Train END AUC: 0.9160338390223043
Evaluating dev...


  0%|          | 0/171 [00:00<?, ?it/s]

Dev F1 0.4134414905028632
Dev START AUC: 0.865236830937049
Dev END AUC: 0.9158337494361248


In [35]:
torch.save(model.state_dict(), './Emo_classifier_and_span.pt')

In [180]:
tokenizer.save_pretrained('./Emo_classifier_and_span_QA.tok')

('./Emo_classifier_and_span_QA.tok/tokenizer_config.json',
 './Emo_classifier_and_span_QA.tok/special_tokens_map.json',
 './Emo_classifier_and_span_QA.tok/vocab.txt',
 './Emo_classifier_and_span_QA.tok/added_tokens.json',
 './Emo_classifier_and_span_QA.tok/tokenizer.json')

In [28]:
model = NLIClassifier(output_size=7, hidden_size=32)
model.bert.resize_token_embeddings(len(tokenizer))
model.to(device)
model.load_state_dict(torch.load('Emo_classifier_and_span.pt'))

Some weights of BertModel were not initialized from the model checkpoint at Nakul24/Spanbert-emotion-extraction and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [31]:
model.load_state_dict(torch.load('./Emo_classifier_and_span_QA.pt'))

<All keys matched successfully>

In [33]:
model.eval()

print("Evaluating dev...")
all_preds = []
all_labels = []
all_starts = []
all_ends = []
for sents, labels in tqdm(zip(val_batch_conv, val_batch_emo), total=len(val_batch_conv)):
    # pred = predict(model, sents).cpu()
    with torch.no_grad():
        start_probs, end_probs = predict(model, sents.to(device))
    # all_preds.extend(pred)
    all_labels.extend(list(labels.cpu().numpy()))
    all_starts.append(start_probs)
    all_ends.append(end_probs)

# dev_f1 = macro_f1(all_preds, all_labels, set(all_labels))
# print(f"Dev F1 {dev_f1}")

# all_starts = np.concatenate([x.reshape(-1) for x in all_starts])
# dev_start_auc = roc_auc_score(dev_span_starts_all, all_starts)
# print(f"Dev START AUC: {train_start_auc}")

# all_ends = np.concatenate([x.reshape(-1) for x in all_ends])
# dev_end_auc = roc_auc_score(dev_span_ends_all, all_ends)
# print(f"Dev END AUC: {dev_end_auc}")


Evaluating dev...


  0%|          | 0/1265 [00:00<?, ?it/s]

In [34]:
val_start_preds = [s.argmax(axis=1) for s in all_starts]
val_end_preds = [s.argmax(axis=1) for s in all_ends]

In [35]:
check_ind = 109

In [36]:
val_start_preds[check_ind], val_end_preds[check_ind]

(array([ 0,  0,  0,  0,  0, 32,  0,  0,  0, 27,  0,  0,  0,  0,  0,  0]),
 array([ 0,  0,  0,  0,  0, 57,  0,  0,  0, 35,  0,  0,  0,  0,  0,  0]))

In [37]:
torch.where(val_batch_start[check_ind]==1)[1], torch.where(val_batch_end[check_ind]==1)[1]

(tensor([ 0,  0,  0,  0,  0, 32,  0,  0,  0, 27,  0,  0,  0,  0,  0,  0]),
 tensor([ 0,  0,  0,  0,  0, 46,  0,  0,  0, 35,  0,  0,  0,  0,  0,  0]))

In [38]:
tokenizer.decode(val_batch_conv[check_ind]['input_ids'][5][32:57])

'he is about to go hit on isabella rosselini. i am just sorry we do not got popcorn'

In [39]:
tokenizer.decode(val_batch_conv[check_ind]['input_ids'][5][32:46])

'he is about to go hit on isabella rosselini'

In [40]:
tokenizer.decode(val_batch_conv[check_ind]['input_ids'][5])

'[CLS] honey, he is about to go hit on isabella rosselini. i am just sorry we do not got popcorn. [SEP] honey, he is about to go hit on isabella rosselini. i am just sorry we do not got popcorn. [SEP] [SP1] okay sir, um... mm, let see if i got this right. [SEP] [SP1] ah, so this is a half... caf, double tall, easy hazel nut, non... fat, no foam, with whip, extra hot latte, right? [SEP] [SP1] okay, great. [SEP] [SP1] you freak. [SEP] [SP2] thank you. [SEP] [SP3] a coffee to go, please. [SEP] [SP2] isabella rosselini. [SEP] [SP4] are you serious? oh my god. [SEP] [SP2] damn! i can not believe i took her off my list. [SEP] [SP4] why? cause otherwise you would go for it? [SEP] [SP2] yeah, maybe. [SEP] [SP1] oh... oh, you lie. [SEP] [SP2] what? you do not think i would go up to her? [SEP] [SP1] ross, it took you ten years to finally admit you liked me. [SEP] [SP2] yeah, well missy, you better be glad that list is laminated. [SEP] [SP1] you know what honey, you go ahead, we will call her an 

In [171]:
validation_predictions_data = {'val_batch_conv': val_batch_conv,
                                'val_batch_emo': val_batch_emo,
                                'val_batch_start': val_batch_start,
                                'val_batch_end': val_batch_end,
                                'val_start_preds': val_start_preds,
                                'val_start_ends': val_end_preds}

In [176]:
import pickle
with open('validation_prediction_data', 'wb') as f:
    pickle.dump(validation_predictions_data, f)

In [177]:
with open('validation_prediction_data', 'rb') as f:
    test = pickle.load(f)

In [106]:
def compute_overlap_strict_f1(true_start, true_end, pred_start, pred_end, return_true_counts=False):
    total_correct = 0.
    total_pred_cnt = 0.
    total_true_cnt = 0.
    for tst, tend, pst, pend in zip(true_start, true_end, pred_start, pred_end):
        total_correct += 1. if tst > 0 and tend > 0 and tst == pst and tend == pend else 0.
        total_pred_cnt += 1. if pst > 0 and pend > 0 else 0.
        total_true_cnt += 1. if tst > 0 and tend > 0 else 0.
        
    P = total_correct/total_pred_cnt
    R = total_correct/total_true_cnt
    F = (2 * P * R) / (P + R)
    return (F, P, R, total_true_cnt) if return_true_counts else (F, P, R)

In [103]:
def compute_overlap_f1(true_start, true_end, pred_start, pred_end):
    total_overlap = 0.
    total_pred_cnt = 0.
    total_true_cnt = 0.
    for tst, tend, pst, pend in zip(true_start, true_end, pred_start, pred_end):
        true_span = list(range(tst, tend + 1))
        pred_span = list(range(pst, pend + 1))
        
        if 0 in true_span: true_span.remove(0)
        if 0 in pred_span: pred_span.remove(0)
        
        overlap_cnt = len(set(true_span).intersection(pred_span))
        pred_cnt = len(pred_span)
        true_cnt = len(true_span)
        total_overlap += overlap_cnt
        total_pred_cnt += pred_cnt
        total_true_cnt += true_cnt
    P = total_overlap/total_pred_cnt
    R = total_overlap/total_true_cnt
    F = (2 * P * R) / (P + R)
    return F, P, R

In [50]:
val_start_preds_list = np.concatenate(val_start_preds).tolist()
val_end_preds_list = np.concatenate(val_end_preds).tolist()

In [56]:
val_start_true_list = torch.concat([torch.where(s==1)[1] for s in val_batch_start]).tolist()
val_end_true_list = torch.concat([torch.where(s==1)[1] for s in val_batch_end]).tolist()


In [85]:
## Proportional Span match F1 Score

f1, precision, recall = compute_overlap_f1(val_start_true_list, val_end_true_list, val_start_preds_list, val_end_preds_list)
print (f1, precision, recall)

0.5987520692728894 0.5769891523094292 0.6222210459453738


In [84]:
## Strict Span match F1 Score

f1, precision, recall = compute_overlap_strict_f1(val_start_true_list, val_end_true_list, val_start_preds_list, val_end_preds_list)
print (f1, precision, recall)

0.46940985381700057 0.46890210924824227 0.46991869918699186


In [88]:
emo_labels_list = torch.concat(val_batch_emo).tolist()

In [119]:
def weighted_f1(true_start, true_end, pred_start, pred_end, labels_list):
    f1_prop_list = []
    f1_strict_list = []
    cnt_list = []
    for label, numbers in zip(*np.unique(np.array(labels_list), return_counts=True)):
        vts = np.array(true_start)[labels_list==label].tolist()
        vte = np.array(true_end)[labels_list==label].tolist()
        vps = np.array(pred_start)[labels_list==label].tolist()
        vpe = np.array(pred_end)[labels_list==label].tolist()
        f1_prop, _, _ = compute_overlap_f1(vts, vte, vps, vpe)
        f1_strict, _, _, true_cnt = compute_overlap_strict_f1(vts, vte, vps, vpe, return_true_counts=True)

        f1_prop_list.append(f1_prop)
        f1_strict_list.append(f1_strict)
        cnt_list.append(true_cnt)

    total_cnt = sum(cnt_list)
    wt_list = [cnt/total_cnt for cnt in cnt_list]
    wt_prop_f1 = sum([f*wt for f, wt in zip(f1_prop_list, wt_list)])
    wt_strict_f1 = sum([f*wt for f, wt in zip(f1_strict_list, wt_list)])
    return wt_prop_f1, wt_strict_f1

In [120]:
wt_prop_f1, wt_strict_f1 = weughted_f1(val_start_true_list, 
                                       val_end_true_list, 
                                       val_start_preds_list, 
                                       val_end_preds_list,
                                       emo_labels_list)

print (wt_prop_f1, wt_strict_f1)

0.6009776165975743 0.46950670489160073
