In [None]:
!pip install --upgrade transformers torch transformers[torch] tokenizers huggingface_hub pytorch-crf
!pip install protobuf==3.20.3

In [26]:
import torch
torch.cuda.empty_cache()

assert torch.cuda.is_available()

In [27]:
device_name = torch.cuda.get_device_name()
n_gpu = torch.cuda.device_count()
print(f"Found device: {device_name}, n_gpu: {n_gpu}")
device = torch.device("cuda")

Found device: NVIDIA A100-SXM4-80GB MIG 3g.40gb, n_gpu: 1


In [28]:
import random
import numpy as np

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

seed_everything()

In [29]:
import pandas as pd

def read_conll(file_path):
    sentences = []
    with open(file_path, 'r') as file:
        for line in file:
            line = line.strip()
            if line:
                columns = line.split()
                sentences.append([columns[0], columns[-1]])

    df = pd.DataFrame(sentences, columns=['Tokens', 'Labels'])
    return df

In [30]:
def tokenize_and_format(words, tokenizer):
    """
    Parameters:
    words: words to be tokenized and formatted
    tokenizer: tokenizer with which the words will be tokenized
    """
    input_ids = []
    attention_masks = []

    # encode each word by padding and truncating it
    for word in words:
        encoded_dict = tokenizer.encode_plus(
            word,
            add_special_tokens=False,
            max_length=11,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        # store the input id and the attention mask of this word
        input_ids.append(encoded_dict['input_ids'])
        attention_masks.append(encoded_dict['attention_mask'])

    # convert input ids and attention_masks to tensors
    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)

    return input_ids, attention_masks

In [31]:
from transformers import AutoTokenizer

model_name = 'worldbank/econberta'
tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=True)



In [32]:
label_dict = {
    'O': 0,
    'B-intervention': 1,
    'I-intervention': 2,
    'B-outcome': 3,
    'I-outcome': 4,
    'B-population': 5,
    'I-population': 6,
    'B-effect_size': 7,
    'I-effect_size': 8,
    'B-coreference': 9,
    'I-coreference': 10
}

In [33]:
seed_everything()

def get_dataset(df):

    # get words and labels form the df
    words = df.Tokens.values
    labels = df.Labels.values

    # get input ids and attention masks from words by tokenizing
    input_ids, attention_masks = tokenize_and_format(words, tokenizer)

    # Create an empty label list to store one-hot encoded labels
    label_list = []

    # Iterate through each label and one-hot encode it
    for label in labels:
        # create a list of zeros the length of unique labels
        label_array = np.zeros(len(label_dict))
        # set 1 only at the position where the label matches, encoding the label
        label_array[label_dict[label]] = 1
        # store this one-hot encoded label array to the list where all encoded labels are getting stored
        label_list.append(label_array)

    # convert label_list to a tensor
    labels = torch.tensor(np.array(label_list))

    dataset = [(input_ids[i], attention_masks[i], labels[i]) for i in range(len(df))]

    return dataset, words

In [34]:
seed_everything()

train_df = read_conll('data/econ_ie/train.conll')
val_df = read_conll('data/econ_ie/dev.conll')
test_df = read_conll('data/econ_ie/test.conll')

train_set, train_words = get_dataset(train_df)
val_set, val_words = get_dataset(val_df)
test_set, test_words = get_dataset(test_df)

In [35]:
# Set the hyperparameters according to Table 8
dropout = 0.2
learning_rates = [5e-5, 6e-5, 7e-5]  # Perform hyperparameter search
batch_size = 12
gradient_accumulation_steps = 4
weight_decay = 0
max_epochs = 10
lr_decay = "slanted_triangular"
fraction_of_steps = 0.06
adam_epsilon = 1e-8
adam_beta1 = 0.9
adam_beta2 = 0.999

seed_everything()

In [36]:
def preprocess_entities(labels, words):
    entities = []
    current_entity = None
    for i, (word, label) in enumerate(zip(words, labels)):
        if label != "O":
            prefix, entity_type = label.split("-")
            if prefix == "B":
                if current_entity:
                    entities.append(current_entity)
                current_entity = (entity_type, i, i, words[i])
            elif prefix == "I" and current_entity:
                current_entity = (current_entity[0], current_entity[1], i, current_entity[3] + " " + words[i])
            else:
                # Handle invalid sequences (e.g., I- without B-)
                current_entity = None
        else:
            if current_entity:
                entities.append(current_entity)
            current_entity = None
    if current_entity:
        entities.append(current_entity)
    return entities

In [37]:
# Function to compute entity-level metrics
def compute_entity_level_metrics(true_entities, pred_entities):
    metrics = {
        "EM": 0,  # Exact Match
        "EB": 0,  # Exact Boundary
        "PM": 0,  # Partial Match
        "PB": 0,  # Partial Boundaries
        "ML": 0,  # Missed Label
        "FA": 0   # False Alarm 
    }

    for true_entity, pred_entity in zip(true_entities, pred_entities):
        if true_entity == pred_entity:
            metrics["EM"] += 1
        elif true_entity[0] == pred_entity[0] and is_overlapping(true_entity[1:], pred_entity[1:]):
            if true_entity[1] == pred_entity[1] and true_entity[2] == pred_entity[2]:
                metrics["EB"] += 1
            else:
                metrics["PM"] += 1
        elif is_overlapping(true_entity[1:], pred_entity[1:]):
            metrics["PB"] += 1
        elif pred_entity[0] == "O": 
            metrics["ML"] += 1

    for pred_entity in pred_entities:
        if pred_entity != "O" and pred_entity not in true_entities:  # Exclude "O" for False Alarm
            metrics["FA"] += 1

    return metrics

In [38]:
def is_overlapping(span1, span2):
    entity_type1, start1, entity_text1 = span1
    entity_type2, start2, entity_text2 = span2
    end1 = start1 + len(entity_text1) - 1  # Calculate end position
    end2 = start2 + len(entity_text2) - 1
    return start1 <= end2 and start2 <= end1

In [39]:
from collections import defaultdict
reverse_label_dict = {v: k for k, v in label_dict.items()}

def analyze_generalization(data, words, train_words):
    grouped_entities = defaultdict(lambda: ([], []))  # {group_name: (true_entities, pred_entities)}

    for i, (input_ids, attention_mask, label_tensor) in enumerate(data):
        true_labels = [reverse_label_dict[int(l.item())] for l in label_tensor]  # Use reverse_label_dict
        
        input_ids = input_ids.unsqueeze(0).to(device)
        attention_mask = attention_mask.unsqueeze(0).to(device)
        
        pred_labels = model(input_ids, attention_mask)[0] # No need for [0][0] as labels are already decoded
        pred_labels = [reverse_label_dict[l] for l in pred_labels]  # Use reverse_label_dict

        true_entities = preprocess_entities(true_labels, words[i])
        pred_entities = preprocess_entities(pred_labels, words[i])

        for true_entity, pred_entity in zip(true_entities, pred_entities):
            length = true_entity[2] - true_entity[1] + 1
            seen = true_entity[3] in train_words  # Check if entity text was seen in training

            group_name = f"Length {length} - {'Seen' if seen else 'Unseen'}"
            grouped_entities[group_name][0].append(true_entity)
            grouped_entities[group_name][1].append(pred_entity)

    for group_name, group_data in grouped_entities.items():
        group_true_entities, group_pred_entities = group_data
        metrics = compute_entity_level_metrics(group_true_entities, group_pred_entities)
        print(f"Group: {group_name}, Metrics: {metrics}")

In [40]:
from sklearn.metrics import classification_report

def get_validation_performance(val_set):
    # Put the model in evaluation mode
    model.eval()

    # Tracking variables
    total_eval_loss = 0
    all_pred_labels = []
    all_true_labels = []

    num_batches = int(len(val_set)/batch_size) + 1

    for i in range(num_batches):
        end_index = min(batch_size * (i+1), len(val_set))
        batch = val_set[i*batch_size:end_index]

        if len(batch) == 0: continue

        input_id_tensors = torch.stack([data[0] for data in batch])
        input_mask_tensors = torch.stack([data[1] for data in batch])
        label_tensors = torch.stack([data[2] for data in batch])

        # Move tensors to the GPU
        b_input_ids = input_id_tensors.to(device)
        b_input_mask = input_mask_tensors.to(device)
        b_labels = label_tensors.to(device)
        b_labels = b_labels.long()

        # Tell pytorch not to bother with constructing the compute graph during
        # the forward pass, since this is only needed for backprop (training).
        with torch.no_grad():
            # Forward pass, calculate logit predictions.
            outputs = model(b_input_ids,
                            attention_mask=b_input_mask,
                            labels=b_labels)
            loss = outputs['loss']
            logits = outputs['logits']

            # Accumulate the validation loss.
            total_eval_loss += loss.item()

            # Move logits and labels to CPU
            logits = logits.detach().cpu().numpy()
            label_ids = b_labels.to('cpu').numpy()

            # Get the predicted labels
            # Get the predicted labels
            pred_labels = np.argmax(logits, axis=2).flatten()
            true_labels = label_ids.flatten()

            all_pred_labels.extend(pred_labels)
            all_true_labels.extend(true_labels)
    
    # Convert labels to their original names
    label_names = list(label_dict.keys())
    all_pred_labels = [label_names[label] for label in all_pred_labels]
    all_true_labels = [label_names[label] for label in all_true_labels]

    # Calculate precision, recall, and F1 score
    report = classification_report(all_true_labels, all_pred_labels, digits=4)

    return report

In [41]:
from torchcrf import CRF
from transformers import AutoModel

class CRFTagger(torch.nn.Module):
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = torch.nn.Dropout(dropout)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, num_labels)
        self.crf = CRF(num_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        if labels is not None:
            loss = -self.crf(logits, labels, mask=attention_mask.byte(), reduction='mean')
            return {'loss': loss, 'logits': logits}  # Return a dictionary with loss and logits
        else:
            decoded_labels = self.crf.decode(logits, mask=attention_mask.byte())
            return decoded_labels

In [42]:
seed_everything()

# Load the pre-trained model
model = CRFTagger(model_name, len(label_dict))
model.dropout = torch.nn.Dropout(dropout)
model.to(device)

CRFTagger(
  (bert): DebertaV2Model(
    (embeddings): DebertaV2Embeddings(
      (word_embeddings): Embedding(128100, 768, padding_idx=0)
      (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
      (dropout): StableDropout()
    )
    (encoder): DebertaV2Encoder(
      (layer): ModuleList(
        (0-11): 12 x DebertaV2Layer(
          (attention): DebertaV2Attention(
            (self): DisentangledSelfAttention(
              (query_proj): Linear(in_features=768, out_features=768, bias=True)
              (key_proj): Linear(in_features=768, out_features=768, bias=True)
              (value_proj): Linear(in_features=768, out_features=768, bias=True)
              (pos_dropout): StableDropout()
              (dropout): StableDropout()
            )
            (output): DebertaV2SelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
              (dropou

In [None]:
from transformers import get_linear_schedule_with_warmup, AdamW

# Calculate the total number of training steps
total_steps = len(train_set) // batch_size * max_epochs

lr = learning_rates[0]
print(f"Current learning rate: {lr}")

# Create the optimizer with the specified hyperparameters
optimizer = AdamW(model.parameters(), lr=lr, eps=adam_epsilon, betas=(adam_beta1, adam_beta2), weight_decay=weight_decay)

# Create the learning rate scheduler
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(total_steps * fraction_of_steps), num_training_steps=total_steps)

# training loop
for epoch_i in range(max_epochs):
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, max_epochs))
    print('Training...')

    total_train_loss = 0
    model.train()

    num_batches = int(len(train_set) / batch_size) + 1

    for i in range(num_batches):
        end_index = min(batch_size * (i + 1), len(train_set))
        batch = train_set[i * batch_size:end_index]

        if len(batch) == 0:
            continue

        input_id_tensors = torch.stack([data[0] for data in batch])
        input_mask_tensors = torch.stack([data[1] for data in batch])
        label_tensors = torch.stack([data[2] for data in batch])

        b_input_ids = input_id_tensors.to(device)
        b_input_mask = input_mask_tensors.to(device)
        b_labels = label_tensors.to(device)
        b_labels = b_labels.long()

        outputs = model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
        loss = outputs['loss']
        logits = outputs['logits']

        total_train_loss += loss.item()

        # Accumulate gradients
        loss = loss / gradient_accumulation_steps
        loss.backward()

        # Perform optimizer step after accumulating gradients for gradient_accumulation_steps
        if (i + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

    print(f"Total loss: {total_train_loss}")
    report = get_validation_performance(val_set)
    print(report)
    analyze_generalization(val_set, val_words, train_words)

print("")
print(f"Training complete at learning rate: {lr}!")

torch.save(model.state_dict(), f'{model_name}-model_lr-{lr}_2.pth')

Current learning rate: 5e-05

Training...




Total loss: 8722.56406621635


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                precision    recall  f1-score   support

B-intervention     0.8869    0.5944    0.7118     23616
     B-outcome     0.0000    0.0000    0.0000         0
             O     0.7082    0.0285    0.0548    236160

      accuracy                         0.0800    259776
     macro avg     0.5317    0.2076    0.2555    259776
  weighted avg     0.7245    0.0800    0.1145    259776

Group: Length 1 - Seen, Metrics: {'EM': 11216, 'EB': 0, 'PM': 0, 'PB': 0, 'ML': 0, 'FA': 113}
Group: Length 1 - Unseen, Metrics: {'EM': 2860, 'EB': 0, 'PM': 0, 'PB': 0, 'ML': 0, 'FA': 115}

Training...
Total loss: 5074.142737869173


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                precision    recall  f1-score   support

B-intervention     0.8767    0.6323    0.7347     23616
     B-outcome     0.0000    0.0000    0.0000         0
             O     0.7737    0.0272    0.0525    236160

      accuracy                         0.0822    259776
     macro avg     0.5501    0.2198    0.2624    259776
  weighted avg     0.7830    0.0822    0.1146    259776

Group: Length 1 - Seen, Metrics: {'EM': 12063, 'EB': 0, 'PM': 0, 'PB': 0, 'ML': 0, 'FA': 136}
Group: Length 1 - Unseen, Metrics: {'EM': 3020, 'EB': 0, 'PM': 0, 'PB': 0, 'ML': 0, 'FA': 141}

Training...
Total loss: 4693.233993748203
                precision    recall  f1-score   support

B-intervention     0.0852    0.9043    0.1557     23616
             O     0.7487    0.0285    0.0549    236160

      accuracy                         0.1081    259776
     macro avg     0.4169    0.4664    0.1053    259776
  weighted avg     0.6884    0.1081    0.0641    259776

Group: Length 1 - Seen, Metrics: {

In [None]:
model_name = model_name
lr = learning_rates[0]

# Load state_dict of the model
model.load_state_dict(torch.load(f'{model_name}-model_lr-{lr}_1.pth'))

get_validation_performance(test_set)
analyze_generalization(test_set, test_words, train_words)