<a href="https://colab.research.google.com/github/aixiuxiuxiu/long-context-classification/blob/main/longcontext_blog.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m20.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import f1_score, precision_score, recall_score
import numpy as np
import math
from typing import List, Dict, Optional, Tuple
import logging
from tqdm import tqdm
from dataclasses import dataclass
from datasets import load_dataset

In [3]:
class MultiEurLexSegmentedDataset(Dataset):
    def __init__(self, data, tokenizer, num_labels, max_segments=64, max_segment_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_segments = max_segments
        self.max_segment_length = max_segment_length
        self._num_labels = num_labels


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item['text']
        labels = torch.zeros(self._num_labels)  # Assuming 50 possible labels
        labels[item['labels']] = 1. # creating one-hot vector for labels

        # Split text into words
        words = text.split()
        segments = []
        current_segment = []
        current_length = 0

        # Create segments
        for word in words:
            word_tokens = self.tokenizer.tokenize(word)
            if current_length + len(word_tokens) > self.max_segment_length - 2:  # Account for [CLS] and [SEP]
                if current_segment:
                    segments.append(' '.join(current_segment))
                current_segment = [word]
                current_length = len(word_tokens)
            else:
                current_segment.append(word)
                current_length += len(word_tokens)

        if current_segment:
            segments.append(' '.join(current_segment))

        # Pad or truncate segments
        segments = segments[:self.max_segments]
        segments = segments + [''] * (self.max_segments - len(segments))

        # Tokenize all segments
        tokenized_segments = []
        attention_masks = []

        for segment in segments:
            if segment:
                tokens = self.tokenizer(
                    segment,
                    max_length=self.max_segment_length,
                    padding='max_length',
                    truncation=True,
                    return_tensors='pt'
                )
            else:
                tokens = self.tokenizer(
                    '[PAD]',
                    max_length=self.max_segment_length,
                    padding='max_length',
                    truncation=True,
                    return_tensors='pt'
                )

            tokenized_segments.append(tokens['input_ids'].squeeze(0))
            attention_masks.append(tokens['attention_mask'].squeeze(0))

        return {
            'input_ids': torch.stack(tokenized_segments),
            'attention_mask': torch.stack(attention_masks),
            'labels': labels
        }

In [4]:

@dataclass
class SimpleOutput:
    last_hidden_state: torch.Tensor
    hidden_states: torch.Tensor

def sinusoidal_init(num_embeddings: int, embedding_dim: int) -> torch.Tensor:
    """Initialize sinusoidal positional embeddings."""
    position = torch.arange(0, num_embeddings).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, embedding_dim, 2) * (-math.log(10000.0) / embedding_dim))
    pos_embedding = torch.zeros(num_embeddings, embedding_dim)
    pos_embedding[:, 0::2] = torch.sin(position * div_term)
    pos_embedding[:, 1::2] = torch.cos(position * div_term)
    return pos_embedding

class HierarchicalBert(nn.Module):
    def __init__(self, encoder, max_segments=64, max_segment_length=128):
        super(HierarchicalBert, self).__init__()

        # Pre-trained segment (token-wise) encoder, e.g., BERT
        self.encoder = encoder
        # Specs for the segment-wise encoder
        self.hidden_size = encoder.config.hidden_size
        self.max_segments = max_segments
        self.max_segment_length = max_segment_length

        # Init sinusoidal positional embeddings
        self.seg_pos_embeddings = nn.Embedding(
            max_segments + 1,
            encoder.config.hidden_size,
            padding_idx=0,
            _weight=sinusoidal_init(max_segments + 1, encoder.config.hidden_size)
        )

        # Init segment-wise transformer-based encoder
        self.seg_encoder = nn.Transformer(
            d_model=encoder.config.hidden_size,
            nhead=encoder.config.num_attention_heads,
            batch_first=True,
            dim_feedforward=encoder.config.intermediate_size,
            activation=encoder.config.hidden_act,
            dropout=encoder.config.hidden_dropout_prob,
            layer_norm_eps=encoder.config.layer_norm_eps,
            num_encoder_layers=2,
            num_decoder_layers=0
        ).encoder

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        # Hypothetical Example
        # Batch of 4 documents: (batch_size, n_segments, max_segment_length) --> (4, 64, 128)
        # BERT-BASE encoder: 768 hidden units

        # Squash samples and segments into a single axis (batch_size * n_segments, max_segment_length) --> (256, 128)
        input_ids_reshape = input_ids.contiguous().view(-1, input_ids.size(-1))
        attention_mask_reshape = attention_mask.contiguous().view(-1, attention_mask.size(-1))
        if token_type_ids is not None:
            token_type_ids_reshape = token_type_ids.contiguous().view(-1, token_type_ids.size(-1))
        else:
            token_type_ids_reshape = None

        # Encode segments with BERT --> (256, 128, 768)
        encoder_outputs = self.encoder(
            input_ids=input_ids_reshape,
            attention_mask=attention_mask_reshape,
            token_type_ids=token_type_ids_reshape
        )[0]

        # Reshape back to (batch_size, n_segments, max_segment_length, output_size) --> (4, 64, 128, 768)
        encoder_outputs = encoder_outputs.contiguous().view(
            input_ids.size(0),
            self.max_segments,
            self.max_segment_length,
            self.hidden_size
        )

        # Gather CLS outputs per segment --> (4, 64, 768)
        encoder_outputs = encoder_outputs[:, :, 0]

        # Infer real segments, i.e., mask paddings
        seg_mask = (torch.sum(input_ids, 2) != 0).to(input_ids.dtype)
        # Infer and collect segment positional embeddings
        seg_positions = torch.arange(1, self.max_segments + 1).to(input_ids.device) * seg_mask
        # Add segment positional embeddings to segment inputs
        encoder_outputs += self.seg_pos_embeddings(seg_positions)

        # Encode segments with segment-wise transformer
        seg_encoder_outputs = self.seg_encoder(encoder_outputs)

        # Collect document representation
        outputs, _ = torch.max(seg_encoder_outputs, 1)

        return SimpleOutput(last_hidden_state=outputs, hidden_states=outputs)



In [5]:
# Define the classifier model
class MultiLabelHierarchicalBERTClassifier(nn.Module):
    def __init__(self, model_name: str, num_labels: int, max_segments: int = 64, max_segment_length: int = 128):
        super(MultiLabelHierarchicalBERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.hierarchical_bert = HierarchicalBert(
            encoder=self.bert,
            max_segments=max_segments,
            max_segment_length=max_segment_length
        )

        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

        self.num_labels = num_labels
        self.max_segments = max_segments
        self.max_segment_length = max_segment_length

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
        outputs = self.hierarchical_bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )


        outputs = self.dropout(outputs.last_hidden_state)
        logits = self.classifier(outputs)

        loss = None
        if labels is not None:
            loss = F.binary_cross_entropy_with_logits(logits, labels.float())

        return {
            'loss': loss,
            'logits': logits
        }



In [6]:
# Load the datase
dataset = load_dataset('multi_eurlex', 'en', trust_remote_code=True)
train_dataset = dataset['train'].select(range(100))
test_dataset = dataset['test'].select(range(10))
validation_dataset = dataset['validation'].select(range(10))

README.md:   0%|          | 0.00/47.8k [00:00<?, ?B/s]

multi_eurlex.py:   0%|          | 0.00/138k [00:00<?, ?B/s]

multi_eurlex.tar.gz:   0%|          | 0.00/2.77G [00:00<?, ?B/s]

Generating train split:   0%|          | 0/55000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5000 [00:00<?, ? examples/s]

In [7]:
# Training configuration
model_name = "bert-base-uncased"
max_segments = 64
max_segment_length = 128
batch_size = 2
num_epochs = 3
num_labels = 50
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize tokenizer and model
tokenizer = BertTokenizer.from_pretrained(model_name)
model = MultiLabelHierarchicalBERTClassifier(
    model_name=model_name,
    num_labels=num_labels,
    max_segments=max_segments,
    max_segment_length=max_segment_length
).to(device)


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [8]:
# Create dataset and dataloader

dataset_train =  MultiEurLexSegmentedDataset(data=train_dataset, tokenizer=tokenizer, num_labels= num_labels, max_segments=64, max_segment_length=128)
dataset_test =  MultiEurLexSegmentedDataset(data=test_dataset, tokenizer=tokenizer, num_labels= num_labels, max_segments=64, max_segment_length=128 )
dataset_validation =  MultiEurLexSegmentedDataset(data=validation_dataset, tokenizer=tokenizer, num_labels= num_labels, max_segments=64, max_segment_length=128 )

train_dataloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(dataset_test, batch_size=batch_size)
validation_dataloader = DataLoader(dataset_test, batch_size=batch_size)

In [9]:

# Initialize optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=2e-5)
total_steps = len(train_dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

# Training loop
print("Starting training...")
best_val_f1 = 0.0

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")

    # Training
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_dataloader, desc="Training")

    for batch in progress_bar:
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        loss = outputs['loss']
        total_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()

        progress_bar.set_postfix({'loss': loss.item()})

    avg_loss = total_loss / len(train_dataloader)
    print(f"\nAverage training loss: {avg_loss:.4f}")

    # Validation
    model.eval()
    val_preds = []
    val_labels = []

    with torch.no_grad():
        for batch in tqdm(validation_dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels']

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs['logits']
            preds = (torch.sigmoid(logits) > 0.5).cpu().numpy()

            val_preds.extend(preds)
            val_labels.extend(labels.numpy())

    val_preds = np.array(val_preds)
    val_labels = np.array(val_labels)

    f1 = f1_score(val_labels, val_preds, average='micro')
    precision = precision_score(val_labels, val_preds, average='micro')
    recall = recall_score(val_labels, val_preds, average='micro')

    print(f"\nValidation Metrics:")
    print(f"F1: {f1:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")

    if f1 > best_val_f1:
        best_val_f1 = f1
        torch.save(model.state_dict(), 'best_model.pt')

print("\nTraining completed!")





Starting training...

Epoch 1/3


Training: 100%|██████████| 50/50 [02:04<00:00,  2.50s/it, loss=0.229]



Average training loss: 0.3281


Evaluating: 100%|██████████| 5/5 [00:05<00:00,  1.17s/it]



Validation Metrics:
F1: 0.2182
Precision: 0.6667
Recall: 0.1304

Epoch 2/3


Training: 100%|██████████| 50/50 [02:07<00:00,  2.54s/it, loss=0.174]



Average training loss: 0.1898


Evaluating: 100%|██████████| 5/5 [00:05<00:00,  1.17s/it]



Validation Metrics:
F1: 0.2759
Precision: 0.6667
Recall: 0.1739

Epoch 3/3


Training: 100%|██████████| 50/50 [02:07<00:00,  2.54s/it, loss=0.196]



Average training loss: 0.1765


Evaluating: 100%|██████████| 5/5 [00:05<00:00,  1.16s/it]



Validation Metrics:
F1: 0.2857
Precision: 0.8000
Recall: 0.1739

Training completed!


In [10]:
model.eval()
with torch.no_grad():
    for batch in test_dataloader:
      input_ids = batch['input_ids'].to(device)
      attention_mask = batch['attention_mask'].to(device)
      labels = batch['labels'].to(device)  # Shape: [batch_size, num_labels]

      outputs = model(input_ids=input_ids, attention_mask=attention_mask)
      predictions = torch.sigmoid(outputs['logits']) > 0.5

      # Convert to numpy for easier handling
      predictions_np = predictions.cpu().numpy()
      labels_np = labels.cpu().numpy()

      # For each sample in the batch, get the predicted and true label indices
      for sample_idx in range(len(predictions_np)):
          print(f"\nSample {sample_idx + 1}:")

          # Get indices where predictions are 1
          predicted_labels = np.where(predictions_np[sample_idx] == 1)[0]
          true_labels = np.where(labels_np[sample_idx] == 1)[0]

          print(f"Predicted label indices: {predicted_labels}")
          print(f"True label indices: {true_labels}")

          # Calculate metrics for this sample
          correct_predictions = set(predicted_labels) & set(true_labels)
          print(f"Correctly predicted labels: {list(correct_predictions)}")

          # Calculate precision, recall for this sample
          precision = len(correct_predictions) / len(predicted_labels) if len(predicted_labels) > 0 else 0
          recall = len(correct_predictions) / len(true_labels) if len(true_labels) > 0 else 0
          f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

          print(f"Precision: {precision:.2f}")
          print(f"Recall: {recall:.2f}")
          print(f"F1 Score: {f1:.2f}")





Sample 1:
Predicted label indices: [3]
True label indices: [ 5  6 15 18]
Correctly predicted labels: []
Precision: 0.00
Recall: 0.00
F1 Score: 0.00

Sample 2:
Predicted label indices: [3]
True label indices: [ 3 17 18]
Correctly predicted labels: [3]
Precision: 1.00
Recall: 0.33
F1 Score: 0.50

Sample 1:
Predicted label indices: []
True label indices: [ 0  1  6 17 18 20]
Correctly predicted labels: []
Precision: 0.00
Recall: 0.00
F1 Score: 0.00

Sample 2:
Predicted label indices: [3]
True label indices: [ 1  2  5  9 15 18 19]
Correctly predicted labels: []
Precision: 0.00
Recall: 0.00
F1 Score: 0.00

Sample 1:
Predicted label indices: []
True label indices: [ 5  6 15 18]
Correctly predicted labels: []
Precision: 0.00
Recall: 0.00
F1 Score: 0.00

Sample 2:
Predicted label indices: [ 3 17]
True label indices: [ 2  3  6 17]
Correctly predicted labels: [17, 3]
Precision: 1.00
Recall: 0.50
F1 Score: 0.67

Sample 1:
Predicted label indices: [ 3 17]
True label indices: [ 2  3  6 17]
Correctl