In [1]:
import os
import json
import torch
import numpy as np
import glob

from collections import Counter
from collections import defaultdict
from sklearn.model_selection import train_test_split, KFold
from scipy.special import softmax
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

from torch.nn import CrossEntropyLoss, LSTM
from transformers import DataCollatorForTokenClassification, RobertaPreTrainedModel, RobertaModel
from transformers.modeling_outputs import TokenClassifierOutput


from transformers import (
    AutoTokenizer,
    TrainingArguments,
    Trainer,
)

from torch.utils.data import Dataset
import nltk


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_dir = "vua_dataset"
model_name = "roberta-base"

In [3]:
# Download NLTK punkt tokenizer data if you haven't already
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    print("NLTK 'punkt' tokenizer data not found. Downloading...")
    nltk.download('punkt', quiet=True)
    print("NLTK 'punkt' tokenizer data downloaded.")
except Exception as e:
    print(f"An unexpected error occurred during NLTK data check/download: {e}")


In [4]:
# if model_name is something like "roberta-base"
if "roberta" in model_name.lower():
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        use_fast=True,
        add_prefix_space=True,  # required for pre-tokenized input with RoBERTa
    )
else:
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        use_fast=True
    )


In [5]:
def load_and_process_data_with_all_features(json_path):
    """
    Loads raw data from a JSONL file, groups it by sentence,
    and processes it to include both POS and FGPOS tags.

    Args:
        json_path (str): The path to the JSONL data file.

    Returns:
        list: A list of dictionaries, each containing "sentence_words", "labels", 
              "pos_tags", and "fgpos_tags".
        set: A set of all unique POS tags.
        set: A set of all unique FGPOS tags.
    """
    data_raw = []
    with open(json_path, "r", encoding="utf-8") as f:
        for line in f:
            data_raw.append(json.loads(line))

    sentence_groups = defaultdict(list)
    for entry in data_raw:
        sentence_groups[entry["sentence"]].append(entry)

    processed_data = []
    all_pos_tags = set()
    all_fgpos_tags = set()
    for sentence, entries in sentence_groups.items():
        entries = sorted(entries, key=lambda x: x["w_index"])
        
        original_words = sentence.split(' ')
        words_for_model = [original_words[e['w_index']] for e in entries]
        
        current_labels = [entry["label"] for entry in entries]
        pos_tags_for_sentence = [entry["POS"] for entry in entries]
        fgpos_tags_for_sentence = [entry["FGPOS"] for entry in entries]
        
        all_pos_tags.update(pos_tags_for_sentence)
        all_fgpos_tags.update(fgpos_tags_for_sentence)

        processed_data.append({
            "sentence_words": words_for_model, 
            "labels": current_labels,
            "pos_tags": pos_tags_for_sentence,
            "fgpos_tags": fgpos_tags_for_sentence
        })

    return processed_data, all_pos_tags, all_fgpos_tags

In [6]:
# --- Load and process TRAIN data ---
train_json_path = os.path.join("vua_dataset", "vua20_metaphor_train.json")
processed_train_data, train_pos_tags, train_fgpos_tags = load_and_process_data_with_all_features(train_json_path)

# --- Load and process TEST data ---
test_json_path = os.path.join("vua_dataset", "vua20_metaphor_test.json")
processed_test_data, test_pos_tags, test_fgpos_tags = load_and_process_data_with_all_features(test_json_path)

# --- Create POS tag vocabulary ---
all_pos_tags = sorted(list(train_pos_tags.union(test_pos_tags)))
pos2id = {tag: i for i, tag in enumerate(all_pos_tags)}
pos_vocab_size = len(pos2id)

# --- Create FGPOS tag vocabulary ---
all_fgpos_tags = sorted(list(train_fgpos_tags.union(test_fgpos_tags)))
fgpos2id = {tag: i for i, tag in enumerate(all_fgpos_tags)}
fgpos_vocab_size = len(fgpos2id)

print(f"POS vocabulary size: {pos_vocab_size}")
print(f"FGPOS vocabulary size: {fgpos_vocab_size}")
print(f"Number of training samples: {len(processed_train_data)}")
print(f"Number of test samples: {len(processed_test_data)}")

POS vocabulary size: 17
FGPOS vocabulary size: 41
Number of training samples: 10909
Number of test samples: 3601


In [7]:
class MetaphorDatasetWithAllFeatures(Dataset):
    def __init__(self, data, pos2id, fgpos2id):
        self.data = data
        self.pos2id = pos2id
        self.fgpos2id = fgpos2id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        sentence_words = entry["sentence_words"]
        word_labels = entry["labels"]
        word_pos_tags = entry["pos_tags"]
        word_fgpos_tags = entry["fgpos_tags"]

        raw_encoding = tokenizer(
            sentence_words,
            truncation=True,
            padding="max_length",
            max_length=128,
            is_split_into_words=True,
        )

        word_ids = raw_encoding.word_ids(batch_index=0)

        labels = []
        pos_ids = []
        fgpos_ids = []
        previous_word_idx = None
        for word_idx in word_ids:
            if word_idx is None:
                labels.append(-100)
                pos_ids.append(-100)
                fgpos_ids.append(-100)
            elif word_idx != previous_word_idx:
                labels.append(word_labels[word_idx])
                pos_ids.append(self.pos2id[word_pos_tags[word_idx]])
                fgpos_ids.append(self.fgpos2id[word_fgpos_tags[word_idx]])
            else:
                labels.append(-100)
                pos_ids.append(-100)
                fgpos_ids.append(-100)
            previous_word_idx = word_idx

        encoding = {k: torch.tensor(v).squeeze(0) for k, v in raw_encoding.items()}
        encoding["labels"] = torch.tensor(labels, dtype=torch.long)
        encoding["pos_tag_ids"] = torch.tensor(pos_ids, dtype=torch.long)
        encoding["fgpos_tag_ids"] = torch.tensor(fgpos_ids, dtype=torch.long)
        
        return encoding

In [8]:
train_dataset = MetaphorDatasetWithAllFeatures(processed_train_data, pos2id, fgpos2id)
test_dataset = MetaphorDatasetWithAllFeatures(processed_test_data, pos2id, fgpos2id)

In [9]:
class RobertaForTokenClassificationWithLSTM(RobertaPreTrainedModel):
    def __init__(self, config, pos_vocab_size, fgpos_vocab_size, pos_embedding_dim=50, fgpos_embedding_dim=50, lstm_hidden_size=128):
        super().__init__(config)
        self.num_labels = config.num_labels
        
        self.roberta = RobertaModel(config, add_pooling_layer=False)
        
        self.pos_embedding = torch.nn.Embedding(pos_vocab_size, pos_embedding_dim)
        self.fgpos_embedding = torch.nn.Embedding(fgpos_vocab_size, fgpos_embedding_dim)
        
        # The input to the LSTM is the concatenation of RoBERTa's output and the feature embeddings
        lstm_input_size = config.hidden_size + pos_embedding_dim + fgpos_embedding_dim
        
        self.lstm = LSTM(
            input_size=lstm_input_size,
            hidden_size=lstm_hidden_size,
            num_layers=1, # A single layer is often sufficient on top of RoBERTa
            batch_first=True,
            bidirectional=True
        )
        
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        # The classifier input is the output of the Bi-LSTM (hidden_size * 2 for bidirectional)
        self.classifier = torch.nn.Linear(lstm_hidden_size * 2, config.num_labels)

        # Initialize weights
        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        pos_tag_ids=None,
        fgpos_tag_ids=None,
        labels=None,
        **kwargs
    ):
        roberta_output = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            **kwargs,
        )
        sequence_output = roberta_output[0]

        # --- Get POS embeddings ---
        pos_mask = pos_tag_ids != -100
        cloned_pos_tag_ids = pos_tag_ids.clone()
        cloned_pos_tag_ids[~pos_mask] = 0
        pos_embeddings = self.pos_embedding(cloned_pos_tag_ids)
        pos_embeddings[~pos_mask] = torch.zeros_like(pos_embeddings[~pos_mask])

        # --- Get FGPOS embeddings ---
        fgpos_mask = fgpos_tag_ids != -100
        cloned_fgpos_tag_ids = fgpos_tag_ids.clone()
        cloned_fgpos_tag_ids[~fgpos_mask] = 0
        fgpos_embeddings = self.fgpos_embedding(cloned_fgpos_tag_ids)
        fgpos_embeddings[~fgpos_mask] = torch.zeros_like(fgpos_embeddings[~fgpos_mask])

        # --- Combine embeddings ---
        combined_output = torch.cat([sequence_output, pos_embeddings, fgpos_embeddings], dim=-1)
        
        # --- Pass through Bi-LSTM ---
        lstm_output, _ = self.lstm(combined_output)
        
        # --- Final Classification ---
        dropped_output = self.dropout(lstm_output)
        logits = self.classifier(dropped_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=roberta_output.hidden_states,
            attentions=roberta_output.attentions,
        )

In [10]:
# --- Define compute_metrics function ---
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_labels = []
    predicted_labels = []
    for prediction, label in zip(predictions, labels):
        for p_val, l_val in zip(prediction, label):
            if l_val != -100:
                true_labels.append(l_val)
                predicted_labels.append(p_val)

    true_labels = np.array(true_labels)
    predicted_labels = np.array(predicted_labels)

    precision, recall, f1, _ = precision_recall_fscore_support(
        true_labels, predicted_labels, average='binary', pos_label=1, zero_division=0
    )
    accuracy = accuracy_score(true_labels, predicted_labels)

    return {
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall,
    }

In [11]:
def get_class_weights(train_dataset):
    labels_list = [x['labels'].numpy() for x in train_dataset]
    labels_flat = np.concatenate(labels_list)
    labels_filtered = labels_flat[labels_flat != -100]
    counts = Counter(labels_filtered)
    
    if len(counts) < 2:
        return torch.tensor([1.0, 1.0], dtype=torch.float), counts, 0

    total = sum(counts.values())
    weight_0 = total / counts.get(0, 1)
    weight_1 = total / counts.get(1, 1)
    
    return torch.tensor(
        [weight_0, weight_1], dtype=torch.float
    ), counts, total

In [12]:
class WeightedLossTrainer(Trainer):
    def __init__(self, *args, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        outputs = model(**inputs)
        
        if self.class_weights is not None:
            logits = outputs.logits
            labels = inputs.get("labels")
            
            active_loss = labels.view(-1) != -100
            active_logits = logits.view(-1, model.config.num_labels)[active_loss]
            active_labels = labels.view(-1)[active_loss]

            weights = self.class_weights.to(logits.device)
            loss_fct = CrossEntropyLoss(weight=weights)
            loss = loss_fct(active_logits, active_labels)
            
            return (loss, outputs) if return_outputs else loss
        
        return outputs.loss

In [13]:
K = 5  # number of folds
kf = KFold(n_splits=K, shuffle=True, random_state=42)

fold_f1s = []
fold_precisions = []
fold_recalls = []
fold_losses = []

In [14]:
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

for fold_idx, (train_idx, val_idx) in enumerate(kf.split(processed_train_data)):
    print(f"\n=== Fold {fold_idx + 1}/{K} ===")
    
    train_split = [processed_train_data[i] for i in train_idx]
    val_split = [processed_train_data[i] for i in val_idx]

    train_dataset_fold = MetaphorDatasetWithAllFeatures(train_split, pos2id, fgpos2id)
    val_dataset_fold = MetaphorDatasetWithAllFeatures(val_split, pos2id, fgpos2id)

    class_weights, _, _ = get_class_weights(train_dataset_fold)

    # Instantiate the custom model for each fold
    model = RobertaForTokenClassificationWithLSTM.from_pretrained(
        model_name,
        num_labels=2,
        pos_vocab_size=pos_vocab_size,
        fgpos_vocab_size=fgpos_vocab_size,
        pos_embedding_dim=50,
        fgpos_embedding_dim=50,
        lstm_hidden_size=128
    )

    idx_folder = os.path.join('results_with_lstm', f'fold_{fold_idx + 1}')
    os.makedirs(idx_folder, exist_ok=True)
    
    training_args = TrainingArguments(
        output_dir=idx_folder,
        num_train_epochs=3,
        eval_strategy="epoch",
        save_strategy="no",
        learning_rate=2e-5,
        per_device_train_batch_size=16, # May need to reduce if memory is an issue
        per_device_eval_batch_size=4,
        weight_decay=0.01,
        warmup_ratio=0.1,
        logging_steps=50,
        seed=42 + fold_idx,
    )

    trainer = WeightedLossTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset_fold,
        eval_dataset=val_dataset_fold,
        compute_metrics=compute_metrics,
        data_collator=data_collator,
        class_weights=class_weights,
    )

    trainer.train()
    metrics = trainer.evaluate()

    fold_f1s.append(metrics["eval_f1"])
    fold_precisions.append(metrics["eval_precision"])
    fold_recalls.append(metrics["eval_recall"])
    fold_losses.append(metrics["eval_loss"])

    trainer.save_model(idx_folder)
    print(f"Saved model for fold {fold_idx + 1} to {idx_folder}")

# Aggregate results
mean_f1 = np.mean(fold_f1s)
std_f1 = np.std(fold_f1s)
mean_precision = np.mean(fold_precisions)
mean_recall = np.mean(fold_recalls)
mean_loss = np.mean(fold_losses)

print(f"\nCross-validated results over {K} folds (with all features):")
print(f"F1: {mean_f1:.4f} ± {std_f1:.4f}")
print(f"Precision: {mean_precision:.4f}")
print(f"Recall: {mean_recall:.4f}")
print(f"Validation loss (mean): {mean_loss:.4f}")


=== Fold 1/5 ===


Some weights of RobertaForTokenClassificationWithLSTM were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'fgpos_embedding.weight', 'lstm.bias_hh_l0', 'lstm.bias_hh_l0_reverse', 'lstm.bias_ih_l0', 'lstm.bias_ih_l0_reverse', 'lstm.weight_hh_l0', 'lstm.weight_hh_l0_reverse', 'lstm.weight_ih_l0', 'lstm.weight_ih_l0_reverse', 'pos_embedding.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.3023,0.278145,0.858704,0.612979,0.459194,0.921638
2,0.2243,0.235512,0.907838,0.69973,0.578825,0.884479
3,0.1848,0.253156,0.92214,0.72694,0.63299,0.85364


  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)


Saved model for fold 1 to results_with_lstm\fold_1

=== Fold 2/5 ===


Some weights of RobertaForTokenClassificationWithLSTM were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'fgpos_embedding.weight', 'lstm.bias_hh_l0', 'lstm.bias_hh_l0_reverse', 'lstm.bias_ih_l0', 'lstm.bias_ih_l0_reverse', 'lstm.weight_hh_l0', 'lstm.weight_hh_l0_reverse', 'lstm.weight_ih_l0', 'lstm.weight_ih_l0_reverse', 'pos_embedding.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.297,0.259239,0.901767,0.68097,0.559711,0.869299
2,0.2122,0.231667,0.917683,0.719447,0.610775,0.875159
3,0.1791,0.23263,0.921002,0.727331,0.623001,0.873631


  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)


Saved model for fold 2 to results_with_lstm\fold_2

=== Fold 3/5 ===


Some weights of RobertaForTokenClassificationWithLSTM were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'fgpos_embedding.weight', 'lstm.bias_hh_l0', 'lstm.bias_hh_l0_reverse', 'lstm.bias_ih_l0', 'lstm.bias_ih_l0_reverse', 'lstm.weight_hh_l0', 'lstm.weight_hh_l0_reverse', 'lstm.weight_ih_l0', 'lstm.weight_ih_l0_reverse', 'pos_embedding.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.3014,0.272702,0.894925,0.670167,0.541612,0.87874
2,0.2356,0.255982,0.909977,0.702202,0.58698,0.873709
3,0.1742,0.266278,0.918661,0.720583,0.618316,0.863384


  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)


Saved model for fold 3 to results_with_lstm\fold_3

=== Fold 4/5 ===


Some weights of RobertaForTokenClassificationWithLSTM were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'fgpos_embedding.weight', 'lstm.bias_hh_l0', 'lstm.bias_hh_l0_reverse', 'lstm.bias_ih_l0', 'lstm.bias_ih_l0_reverse', 'lstm.weight_hh_l0', 'lstm.weight_hh_l0_reverse', 'lstm.weight_ih_l0', 'lstm.weight_ih_l0_reverse', 'pos_embedding.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.3149,0.26809,0.897267,0.661495,0.535109,0.866044
2,0.2371,0.233916,0.911998,0.702068,0.577734,0.894592
3,0.1774,0.236368,0.921097,0.722905,0.609572,0.888004


  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)


Saved model for fold 4 to results_with_lstm\fold_4

=== Fold 5/5 ===


Some weights of RobertaForTokenClassificationWithLSTM were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'fgpos_embedding.weight', 'lstm.bias_hh_l0', 'lstm.bias_hh_l0_reverse', 'lstm.bias_ih_l0', 'lstm.bias_ih_l0_reverse', 'lstm.weight_hh_l0', 'lstm.weight_hh_l0_reverse', 'lstm.weight_ih_l0', 'lstm.weight_ih_l0_reverse', 'pos_embedding.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.3124,0.262273,0.893728,0.666537,0.533727,0.887339
2,0.2289,0.240939,0.920048,0.724854,0.61629,0.879845
3,0.2059,0.2631,0.926976,0.738856,0.64591,0.863049


  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)


Saved model for fold 5 to results_with_lstm\fold_5

Cross-validated results over 5 folds (with all features):
F1: 0.7273 ± 0.0063
Precision: 0.6260
Recall: 0.8683
Validation loss (mean): 0.2503


In [15]:
# --- ENSEMBLE EVALUATION ---

# Load all fold models
model_dirs = sorted(glob.glob(os.path.join("results_with_lstm", "fold_*")))
models = []
for d in model_dirs:
    if os.path.exists(os.path.join(d, "pytorch_model.bin")) or os.path.exists(os.path.join(d, "model.safetensors")):
        model = RobertaForTokenClassificationWithLSTM.from_pretrained(
            d,
            pos_vocab_size=pos_vocab_size,
            fgpos_vocab_size=fgpos_vocab_size
        )
        models.append(model)
    else:
        print(f"Warning: Model not found in {d}, skipping.")

print(f"Loaded {len(models)} models for ensemble prediction.")

# Create a dummy trainer for prediction
if models:
    args = TrainingArguments(output_dir="./inference_tmp_lstm", per_device_eval_batch_size=8)
    predictor = Trainer(model=models[0], args=args, data_collator=data_collator)

# Get predictions
per_model_logits = []
for model in models:
    predictor.model = model.to(predictor.args.device)
    pred_out = predictor.predict(test_dataset)
    per_model_logits.append(pred_out.predictions)

per_model_logits = np.stack(per_model_logits, axis=0)

# --- Analysis by Adjusting Majority Vote Threshold ---
n_models = per_model_logits.shape[0]
per_model_preds = np.argmax(per_model_logits, axis=-1)

labels = np.stack([item['labels'].numpy() for item in test_dataset])
mask = labels != -100
y_true = labels[mask]

print("\n--- Evaluating Ensemble Performance by Adjusting Vote Count ---")
print(f"Required Votes | Precision | Recall    | F1-Score  | Accuracy")
print("---------------------------------------------------------------")

for required_votes in range(int(n_models / 2) + 1, n_models + 1):
    vote_sum = per_model_preds.sum(axis=0)
    y_pred_at_threshold = (vote_sum[mask] >= required_votes).astype(int)

    prec, rec, f1, _ = precision_recall_fscore_support(
        y_true, y_pred_at_threshold, average="binary", pos_label=1, zero_division=0
    )
    acc = accuracy_score(y_true, y_pred_at_threshold)

    print(f"{required_votes} of {n_models}      | {prec:<9.4f} | {rec:<9.4f} | {f1:<9.4f} | {acc:<9.4f}")

Loaded 5 models for ensemble prediction.


  return forward_call(*args, **kwargs)



--- Evaluating Ensemble Performance by Adjusting Vote Count ---
Required Votes | Precision | Recall    | F1-Score  | Accuracy
---------------------------------------------------------------
3 of 5      | 0.5352    | 0.7361    | 0.6198    | 0.8380   
4 of 5      | 0.5615    | 0.6806    | 0.6153    | 0.8474   
5 of 5      | 0.6004    | 0.6015    | 0.6009    | 0.8567   
