In [1]:
%load_ext autoreload
%autoreload 2

In [13]:
import os
import json

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import numpy as np
from tqdm.auto import tqdm
from py_vncorenlp import VnCoreNLP

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, DataCollatorWithPadding, get_scheduler

In [3]:
rdrsegmenter = VnCoreNLP(annotators=['wseg'], save_dir='/home/jovyan/caches/vncorenlp')

2025-01-30 13:41:34 INFO  WordSegmenter:24 - Loading Word Segmentation model


In [4]:
MODEL_PATH = '/home/jovyan/data/models/phobert-base-v2'
MAX_SEQ_LENGTH = 256
N_CLASSES = 11

BATCH_SIZE = 16
N_EPOCHS = 5

with open('/home/jovyan/data/datasets/PLVB/topics-classification/20250123-label2id.json') as f:
    label2id = json.load(f)
    
with open('/home/jovyan/data/datasets/PLVB/topics-classification/20250123-id2label.json') as f:
    id2label = json.load(f)

In [5]:
data_files = {
    'train': '/home/jovyan/data/datasets/PLVB/topics-classification/20250123-train.parquet',
    'val': '/home/jovyan/data/datasets/PLVB/topics-classification/20250123-val.parquet'
}

raw_dataset = load_dataset('parquet', data_files=data_files)

In [6]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=True)

In [36]:
# def tokenize_fn(examples):
#     # lst to store results
#     batch_input_ids, batch_attention_mask, batch_labels = [], [], []

#     # hyperparameters' configurations
#     chunk_size = MAX_SEQ_LENGTH - 2
#     overlap = 50
#     stride = chunk_size - overlap

#     for sentence, label in zip(examples['sentence'], examples['label']):
#         chunks = []
#         # tokenize the whole document to obtain tokens
#         tokens = tokenizer.tokenize(sentence)

#         # split chunks by token and convert back to string
#         for i in range(0, len(tokens), stride):
#             chunk = tokens[i:i+chunk_size]
#             chunks.append(tokenizer.convert_tokens_to_string(chunk))

#         # tokenizing all chunks
#         encodings = tokenizer(
#             chunks,
#             padding='max_length',
#             max_length=MAX_SEQ_LENGTH,
#             truncation=True,
#             return_tensors='pt'
#         )

#         batch_input_ids.append(encodings['input_ids'])
#         batch_attention_mask.append(encodings['attention_mask'])

#         # convert raw label to one-hot encoding then store
#         label = F.one_hot(torch.tensor(label), num_classes=N_CLASSES)
#         batch_labels.append(label)
    
#     # padding for varying number of chunks
#     max_num_chunks = max(x.size(0) for x in batch_input_ids)
#     padded_input_ids, padded_attention_mask = [], []
#     for input_ids, attention_mask in zip(batch_input_ids, batch_attention_mask):
#         pad_len = max_num_chunks - input_ids.size(0)
#         if pad_len > 0:
#             input_ids = torch.cat([input_ids, torch.zeros([pad_len, MAX_SEQ_LENGTH], dtype=torch.long)], dim=0)
#             attention_mask = torch.cat([attention_mask, torch.zeros([pad_len, MAX_SEQ_LENGTH], dtype=torch.long)], dim=0)
#         padded_input_ids.append(input_ids)
#         padded_attention_mask.append(attention_mask)

#     # stack into final tensor
#     batch_input_ids = torch.stack(padded_input_ids)
#     batch_attention_mask = torch.stack(padded_attention_mask)

#     return {
#         'input_ids': batch_input_ids,
#         'attention_mask': batch_attention_mask,
#         'labels': torch.stack(batch_labels)
#     }

# encoded_dataset = raw_dataset.map(tokenize_fn, batched=True, remove_columns=['_id', 'sentence', 'label'], batch_size=100)
# encoded_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

def tokenize_fn(example):
    """Process a single example without padding"""
    chunk_size = MAX_SEQ_LENGTH - 2
    overlap = 50
    stride = chunk_size - overlap
    
    # Tokenize the sentence
    tokens = tokenizer.tokenize(example['sentence'])
    
    # Create chunks
    chunks = []
    for i in range(0, len(tokens), stride):
        chunk = tokens[i:i+chunk_size]
        if chunk:  # Only add non-empty chunks
            chunks.append(tokenizer.convert_tokens_to_string(chunk))
    
    # Tokenize all chunks
    if chunks:
        encodings = tokenizer(
            chunks,
            padding='max_length',
            max_length=MAX_SEQ_LENGTH,
            truncation=True,
            return_tensors='pt'
        )
        input_ids = encodings['input_ids']
        attention_mask = encodings['attention_mask']
    else:
        # Handle empty documents with a single zero chunk
        input_ids = torch.zeros((1, MAX_SEQ_LENGTH), dtype=torch.long)
        attention_mask = torch.zeros((1, MAX_SEQ_LENGTH), dtype=torch.long)
    
    # Create label tensor
    label = F.one_hot(torch.tensor(example['label']), num_classes=11)
    
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': label
    }

# Apply the tokenization function to the dataset
encoded_dataset = raw_dataset.map(
    tokenize_fn,
    batched=False,
    remove_columns=['_id', 'sentence', 'label'],
)

# Set the format to PyTorch tensors
encoded_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

Map:   0%|          | 0/17068 [00:00<?, ? examples/s]

Map:   0%|          | 0/1421 [00:00<?, ? examples/s]

In [37]:
def collate_fn(batch):
    """
    Custom collate function to pad sequences in a batch to the same length
    """
    # Find max number of chunks in this batch
    max_chunks = max(x['input_ids'].size(0) for x in batch)
    
    # Lists to store padded tensors
    batch_input_ids = []
    batch_attention_masks = []
    batch_labels = []
    
    # Pad each example to max_chunks
    for example in batch:
        input_ids = example['input_ids']
        attention_mask = example['attention_mask']
        current_chunks = input_ids.size(0)
        
        # Calculate padding needed
        pad_chunks = max_chunks - current_chunks
        
        if pad_chunks > 0:
            # Create padding tensors
            input_pad = torch.zeros(pad_chunks, MAX_SEQ_LENGTH, dtype=torch.long)
            mask_pad = torch.zeros(pad_chunks, MAX_SEQ_LENGTH, dtype=torch.long)
            
            # Add padding
            input_ids = torch.cat([input_ids, input_pad], dim=0)
            attention_mask = torch.cat([attention_mask, mask_pad], dim=0)
        
        batch_input_ids.append(input_ids)
        batch_attention_masks.append(attention_mask)
        batch_labels.append(example['labels'])
    
    # Stack all tensors
    return {
        'input_ids': torch.stack(batch_input_ids),
        'attention_mask': torch.stack(batch_attention_masks),
        'labels': torch.stack(batch_labels)
    }

train_loader = DataLoader(encoded_dataset['train'], batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(encoded_dataset['val'], batch_size=16, collate_fn=collate_fn)

In [9]:
class Classifier(nn.Module):
    def __init__(self, num_labels):
        super().__init__()

        self.encoder = AutoModel.from_pretrained(MODEL_PATH)
        self.aggregator = nn.MultiheadAttention(embed_dim=768, num_heads=8, dropout=0.1, batch_first=True)
        self.classifier = nn.Sequential(
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(768, num_labels)
        )

    def forward(self, input_ids, attention_mask):
        batch_size, num_chunks, seq_length = input_ids.size()
        
        # Reshape for processing all chunks
        flat_input_ids = input_ids.view(-1, seq_length)
        flat_attention_mask = attention_mask.view(-1, seq_length)
        
        # Get PhoBERT encodings for all chunks
        outputs = self.encoder(
            input_ids=flat_input_ids,
            attention_mask=flat_attention_mask
        )
        
        # Get CLS token representations for each chunk
        chunk_encodings = outputs.last_hidden_state[:, 0]  # Use CLS token
        chunk_encodings = chunk_encodings.view(batch_size, num_chunks, -1)
        
        # Create chunk-level attention mask (1 if chunk exists, 0 if padding)
        chunk_mask = (attention_mask.sum(dim=-1) > 0)
        
        # Convert boolean mask to float mask where False = -inf, True = 0.0
        attn_mask = torch.zeros_like(chunk_mask, dtype=torch.float)
        attn_mask.masked_fill_(~chunk_mask, float('-inf'))
        
        # Aggregate chunk encodings using multi-head attention
        # Using self-attention: query, key, and value are all chunk_encodings
        doc_encoding, _ = self.chunk_attention(
            query=chunk_encodings,
            key=chunk_encodings,
            value=chunk_encodings,
            key_padding_mask=~chunk_mask  # PyTorch expects mask to be False for valid positions
        )
        
        # Pool the attention outputs to get document representation
        # Use the mask to get mean of non-padded chunks
        mask_expanded = chunk_mask.unsqueeze(-1).float()
        doc_encoding = (doc_encoding * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1).clamp(min=1e-9)
        
        # Classification
        logits = self.classifier(doc_encoding)
        
        return logits

In [11]:
model = Classifier(11)
optimizer = AdamW(model.parameters(), lr=3e-5)
num_training_steps = N_EPOCHS * len(train_loader)
lr_scheduler = get_scheduler(
    name='cosine_with_restarts', optimizer=optimizer, num_warmup_steps=50, num_training_steps=num_training_steps
)

Some weights of RobertaModel were not initialized from the model checkpoint at /home/jovyan/data/models/phobert-base-v2 and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [39]:
pbar = tqdm(train_loader, total=num_training_steps)

for index, batch in enumerate(pbar):
    if index == 0:
        print(batch)

  0%|          | 0/5335 [00:00<?, ?it/s]

{'input_ids': tensor([[[    0,    20,  8859,  ...,  2325,  6692,     2],
         [    0,   292,     4,  ...,  8917,  2665,     2],
         [    0,    12,  1391,  ...,     1,     1,     1],
         ...,
         [    0,     0,     0,  ...,     0,     0,     0],
         [    0,     0,     0,  ...,     0,     0,     0],
         [    0,     0,     0,  ...,     0,     0,     0]],

        [[    0,   404,  7068,  ...,    11,   369,     2],
         [    0,    11,  2488,  ..., 12123,  1986,     2],
         [    0,   556,  4253,  ...,    39,  7068,     2],
         ...,
         [    0,     0,     0,  ...,     0,     0,     0],
         [    0,     0,     0,  ...,     0,     0,     0],
         [    0,     0,     0,  ...,     0,     0,     0]],

        [[    0, 42680,   605,  ...,     1,     1,     1],
         [    0,     0,     0,  ...,     0,     0,     0],
         [    0,     0,     0,  ...,     0,     0,     0],
         ...,
         [    0,     0,     0,  ...,     0,     0,     