In [2]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification
from sklearn.preprocessing import MultiLabelBinarizer
import torch
import torch.nn as nn
from transformers import BertModel

# Initialize the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')


# Custom multi-task BERT model
class MultiTaskBERT(nn.Module):
    def __init__(self, pretrained_model_name, num_main_roles=3, num_fine_roles=22):
        super(MultiTaskBERT, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)

        # Main role classification head (single-label classification)
        self.main_role_classifier = nn.Linear(self.bert.config.hidden_size, num_main_roles)

        # Fine-grained role classification head (multi-label classification)
        self.fine_role_classifier = nn.Linear(self.bert.config.hidden_size, num_fine_roles)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output  # Use the pooled output from BERT

        # Predict main role and fine-grained roles
        main_role_logits = self.main_role_classifier(pooled_output)
        fine_role_logits = self.fine_role_classifier(pooled_output)

        return main_role_logits, fine_role_logits

# Instantiate the modified model
model = MultiTaskBERT('bert-base-multilingual-cased', num_main_roles=3, num_fine_roles=22)


# Custom function to handle multiple fine-grained roles
def load_annotations(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            fields = line.strip().split('\t')
            # Handle cases where there are multiple fine-grained roles
            article_id, entity, start, end, main_role = fields[:5]
            fine_roles = fields[5:]  # Collect remaining fields as fine-grained roles
            data.append([article_id, entity, int(start), int(end), main_role, fine_roles])

    # Convert to DataFrame
    annotations_df = pd.DataFrame(data, columns=["article_id", "entity", "start", "end", "main_role", "fine_roles"])
    return annotations_df

# Load the annotations with the custom loader
annotations = load_annotations("EN/annotations/subtask-1-annotations.txt")

# Display a few rows to confirm it loaded correctly
print(annotations.head())

# Step 2: Read all articles dynamically from the folder
def load_all_articles(raw_documents_folder):
    articles = {}
    for filename in os.listdir(raw_documents_folder):
        if filename.endswith(".txt"):
            article_id = filename.split('.')[0]
            with open(os.path.join(raw_documents_folder, filename), 'r', encoding='utf-8') as f:
                articles[article_id] = f.read()
    return articles

raw_documents_folder = "EN/raw-documents"  # Adjust path as needed
articles = load_all_articles(raw_documents_folder)

# Step 3: Handle multiple fine-grained roles and tokenize data
def preprocess_data(annotations, articles):
    data = []
    mlb = MultiLabelBinarizer()  # Multi-label encoder for fine-grained roles

    # Fit the MultiLabelBinarizer on all fine roles
    all_roles = [role for roles in annotations['fine_roles'] for role in roles]
    mlb.fit([all_roles])

    for _, row in annotations.iterrows():
        article_id = row['article_id'].split('.')[0]
        if article_id in articles:
            text = articles[article_id]
            entity_text = text[int(row['start']):int(row['end'])]
            marked_text = (
                text[:int(row['start'])] + "[ENTITY]" + entity_text + "[/ENTITY]" + text[int(row['end']):]
            )
            inputs = tokenizer(marked_text, padding='max_length', max_length=512, truncation=True, return_tensors="pt")

            # Encode main role and fine-grained roles
            main_role_label = {"Protagonist": 0, "Antagonist": 1, "Innocent": 2}[row['main_role']]
            fine_role_labels = mlb.transform([row['fine_roles']])[0]  # Use the list directly

            data.append((inputs, main_role_label, fine_role_labels))

    return data

train_data = preprocess_data(annotations, articles)


# Step 4: Create a Dataset class for PyTorch
class EntityFramingDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        inputs, main_label, fine_labels = self.data[idx]
        return (
            inputs['input_ids'].squeeze(0),
            inputs['attention_mask'].squeeze(0),
            torch.tensor(main_label),
            torch.tensor(fine_labels, dtype=torch.float)
        )


# Step 5: DataLoader for batching
from sklearn.model_selection import train_test_split

# Split data into training and validation sets
train_data, val_data = train_test_split(train_data, test_size=0.25)

# Create DataLoaders for training and validation
train_dataset = EntityFramingDataset(train_data)
val_dataset = EntityFramingDataset(val_data)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)

# Optimizer and loss functions
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
loss_fn_main = nn.CrossEntropyLoss()
loss_fn_fine = nn.BCEWithLogitsLoss()

# Training loop with validation and gradient monitoring
for epoch in range(1):
    model.train()
    total_loss = 0

    print(f"\nStarting Epoch {epoch + 1}")

    # Training step
    for step, (input_ids, attention_mask, main_label, fine_labels) in enumerate(train_loader):

        optimizer.zero_grad()
        main_role_logits, fine_role_logits = model(input_ids, attention_mask)

        # Compute losses
        loss_main = loss_fn_main(main_role_logits, main_label)
        loss_fine = loss_fn_fine(fine_role_logits, fine_labels)

        loss = loss_main + loss_fine
        loss.backward()

        # # Check for vanishing or exploding gradients
        # for name, param in model.named_parameters():
        #     if param.grad is not None:
        #         print(f"{name} - Grad Norm: {param.grad.norm().item()}")

        optimizer.step()
        total_loss += loss.item()

        # Log every 10 steps
        if step % 10 == 0:
            print(f"Epoch {epoch + 1}, Step {step}, Loss: {loss.item()}")

    print(f"Epoch {epoch + 1} completed. Average Training Loss: {total_loss / len(train_loader)}")

    # Validation step
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for input_ids, attention_mask, main_label, fine_labels in val_loader:

            main_role_logits, fine_role_logits = model(input_ids, attention_mask)

            # Compute validation losses
            loss_main = loss_fn_main(main_role_logits, main_label)
            loss_fine = loss_fn_fine(fine_role_logits, fine_labels)

            val_loss += (loss_main + loss_fine).item()

    print(f"Validation Loss after Epoch {epoch + 1}: {val_loss / len(val_loader)}")



# Save the fine-tuned model
model.save_pretrained("mbert_entity_framing")
tokenizer.save_pretrained("mbert_entity_framing")

         article_id        entity  start   end    main_role  \
0  EN_UA_103861.txt       Chinese    791   797   Antagonist   
1  EN_UA_103861.txt         China   1516  1520   Antagonist   
2  EN_UA_103861.txt         Hamas   2121  2125   Antagonist   
3  EN_UA_103861.txt  Donald Trump   4909  4920  Protagonist   
4  EN_UA_021270.txt        Yermak    667   672   Antagonist   

               fine_roles  
0                   [Spy]  
1            [Instigator]  
2             [Terrorist]  
3  [Peacemaker, Guardian]  
4           [Incompetent]  

Starting Epoch 1
Epoch 1, Step 0, Loss: 1.9463112354278564
Epoch 1, Step 10, Loss: 2.162278890609741
Epoch 1, Step 20, Loss: 1.2604279518127441
Epoch 1, Step 30, Loss: 1.070824146270752
Epoch 1 completed. Average Training Loss: 1.4355713740373268
Validation Loss after Epoch 1: 1.0696790694044187

Starting Epoch 2
Epoch 2, Step 0, Loss: 0.7335984110832214
Epoch 2, Step 10, Loss: 1.1520800590515137
Epoch 2, Step 20, Loss: 0.8655332326889038
Epoch 2, 

KeyboardInterrupt: 