### Sơ lược về dataset và model
* Dataset:
  * **soict-dataset-2024**: Bộ dữ liêu gốc.
  * **soict-dataset-2024-segmented**: Bộ dữ liêu gốc được segment dùng underthesea
  * **cross-encoder-dataset**: Dataset này được tạo ra bằng cách lấy ground truth + 3 hard negative trả về bởi bi-encoder + 1 random negative,tập train gồm hơn 500000 sample và tập val có hơn 100000 sample
  * **cross-encoder-dataset-segmented**: Giống cross-encoder-dataset nhưng được tạo ra bằng bi-encoder mới và có segmentation
* Mô hình:
    * **bi_encoder**: Đây là mô hình BiEncoder với base model là  https://huggingface.co/bkai-foundation-models/vietnamese-bi-encoder được train trên tập dữ liệu **soict-dataset-2024-segmented** với 2 epoch có mine hard negative
    * **bi_encoder_embedding** : Đây là FAISS vector database chứa các embedding vector của toàn bộ document trong corpus_segmented.csv được tạo ra bằng cách dùng document_encoder của của bi_encoder
    * **cross_encoder_new**: Mô hình cross-encoder được tạo ra bằng thư viện sentence_transformer và được train trên tập dữ liệu **soict-dataset-2024-segmented**, dùng để reranking các kết quả trả về từ bi-encoder 


### Những thứ đã được thực hiện trong notebook này:
1. Train một mô hình bi-encoder với XML-RoBERTA với 1 epoch **bi_encoder_xlmRoBERTA** , chọn mô hình nàu làm baseline vì nó đã được pretrain bằng nhiều ngôn ngữ gồm tiếng Việt
2. Sử dụng mô hình bi-encoder vừa train để thực hiện retrieval trên toàn bộ dataset, với mỗi sample trong dataset lọc lấy các ground truth, và chọn 3 hard negative từ kết quả trả về (có thể hiểu ứng với mỗi query chọn 3 kết quả sai được rank cao nhất khi thực hiện retrieval bằng bi-encoder) cùng với 1 random negative, tổng hợp lại ta được một dataset mới là **cross-encoder-dataset** dùng để train cross-encoder (lí do phải tạo dataset mới để train cross-encoder là vì cross-encoder sẽ được dùng để reranking các candidate trả về bởi bi-encoder, mà các candidate là top@k trả về từ bi-encoder do đó rất tương đồng với câu query vì thế phải train cross-encoder trên các hard negative trả về từ bi-encoder.
3. Sử dụng mô hình **bi_encoder_xlmRoBERTA** trích xuất đặc trưng trên toàn bộ corpus và lưu vào FAISS vector database để lưu trữ và fast silimarity search. Vector embedding được chứa trong **bi_encoder_corpus_embedding**
4. Định nghĩa một mô hình cross-encoder cơ bản dùng thư viện sentence_transformer với base model là XML-RoBERTA và huấn luyện nó trên bộ dữ liệu **cross-encoder-dataset**, khi huấn luyện nhận thấy kết quả không thay đổi nhiều nên đã dừng train sau khoảng 200000 sample, kết quả thu được mô hình **cross_encoder_ckp** dùng để reranking.
5. Định nghĩa một pipeline để load các mô hình đã train được để sử dụng và đánh giá hiệu quả mô hình.

### Một số ý tưởng cách cải thiện mô hình
1. Cải thiện Bi-encoder:
   * Train lại bi-encoder dùng base model được train riêng biệt cho tiếng việt như PhoBERT https://huggingface.co/vinai/phobert-base. dùng các thư viện như underthesea để word segmentation dataset trước khi train
   * Thay vì code chay bi-encoder ta có thể dùng một bi-encoder được pretrain sẵn cho tiếng Việt như https://huggingface.co/bkai-foundation-models/vietnamese-bi-encoder, nếu dùng bi-encoder được pretrain cần phải viết lại một hàm mới để load và train mô hình và word segmentation nếu cần thiết
   * Train thêm nhiều epoch và mine hard negative, mô hình hiện tại dùng XML-RoBERTA và chỉ train trên 1 epoch chưa mine hard negative vì bị tràn ram
2. Cải thiện Cross-encoder:
   * Tạo lại dataset mới để train cho cross-encoder từ mô hình bi-encoder đã được cải thiện
   * Với dataset **cross-encoder-dataset** hiện tại chỉ có 2 class là 1 cho ground truth và 0 cho cả hard và random negative. Ta có thể tạo thêm 1 class nữa với label là 2 cho random negative để giảm sự mất cân bằng dữ liệu.
   * Tìm thêm các cách tốt hơn để tạo dataset để train cross-encoder.
   * Thay vì dùng base model là XLM-RoBERTA nên chuyển sang dùng thử PhoBERT.

### Các thay đổi đã được thực hiện
1. Sử dụng underthesea để thực hiện word segmentation cho train.csv và corpus.csv, thu được tập dữ liệu **soict-dataset-2024-segmented**
2. Dùng https://huggingface.co/bkai-foundation-models/vietnamese-bi-encoder để làm base model và train một mô hình bi-encoder mới trên tập dữ **soict-dataset-2024-segmented**, mô hình bi-encoder mới được train trên 2 epochs có mining hard negative và layers freezing, (xem method train của BiEncoderTrainer)
3. Dùng bi-encoder mới để tạo bộ dữ liệu mới là **cross-encoder-segmented** để train một mô hình cross-encoder mới
4. Train mô hình cross-encoder mới với base model tương tự như bi-encoder

## STAGE 1 - Bi-encoder (First Retrieval Stage)

In [1]:
!pip install --quiet faiss-gpu sentence-transformers underthesea

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m20.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m18.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.9/20.9 MB[0m [31m68.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m657.8/657.8 kB[0m [31m33.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m47.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import torch
from torch import nn
from transformers import XLMRobertaModel, XLMRobertaTokenizer, AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from tqdm import tqdm
from typing import List, Dict, Tuple, Union
import faiss
import numpy as np
from sklearn.model_selection import train_test_split

class BiEncoderConfig:
    def __init__(
        self,
        max_length: int = 256,
        batch_size: int = 16,
        learning_rate: float = 1e-5,
        num_epochs: int = 2,
        temperature: float = 0.05,
        embedding_dim: int = 768,

    ):
        self.max_length = max_length
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        self.temperature = temperature
        self.embedding_dim = embedding_dim

class LegalDataset(Dataset):
    def __init__(self, questions: List[str], contexts: List[str], tokenizer, max_length: int):
        self.questions = questions
        self.contexts = contexts
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.questions)
    
    def __getitem__(self, idx):
        question = self.questions[idx]
        context = self.contexts[idx]
        
        return {
            'question': question,
            'context': context
        }
class BiEncoder(nn.Module):
    def __init__(self, config: BiEncoderConfig):
        super().__init__()
        self.config = config
        
        # Load the pre-trained Vietnamese bi-encoder for both encoders
        self.question_encoder = AutoModel.from_pretrained("bkai-foundation-models/vietnamese-bi-encoder")
        self.document_encoder = AutoModel.from_pretrained("bkai-foundation-models/vietnamese-bi-encoder")

        # Use the model's tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained("bkai-foundation-models/vietnamese-bi-encoder")
        self.max_length = config.max_length

    def get_device(self):
        # Helper method to get the current device
        if isinstance(self.question_encoder, nn.DataParallel):
            return self.question_encoder.module.device
        return next(self.question_encoder.parameters()).device

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(
            -1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def encode_question(self, questions: List[str], batch_size: int = 32) -> torch.Tensor:
        all_embeddings = []
        device = self.get_device()

        for i in range(0, len(questions), batch_size):
            batch_texts = questions[i:i + batch_size]

            encoded_input = self.tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=self.max_length,
                return_tensors='pt'
            ).to(device)

            with torch.no_grad():
                model_output = self.question_encoder(**encoded_input)

            batch_embeddings = self.mean_pooling(
                model_output, encoded_input['attention_mask'])
            all_embeddings.append(batch_embeddings)

        return torch.cat(all_embeddings, dim=0)

    def encode_document(self, documents: List[str], batch_size: int = 32, disable_progress_bar: bool = False) -> torch.Tensor:
        all_embeddings = []
        device = self.get_device()

        # Calculate total number of batches for progress bar
        num_batches = (len(documents) + batch_size - 1) // batch_size

        for i in tqdm(range(0, len(documents), batch_size), total=num_batches, desc="Encoding documents", disable=disable_progress_bar):
            batch_texts = documents[i:i + batch_size]

            encoded_input = self.tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=self.max_length,
                return_tensors='pt'
            ).to(device)

            with torch.no_grad():
                model_output = self.document_encoder(**encoded_input)

            batch_embeddings = self.mean_pooling(
                model_output, encoded_input['attention_mask'])
            all_embeddings.append(batch_embeddings)

        return torch.cat(all_embeddings, dim=0)


class LegalDataset(Dataset):
    def __init__(self, questions: List[str], contexts: List[str], tokenizer, max_length: int):
        self.questions = questions
        self.contexts = contexts
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.questions)
    
    def __getitem__(self, idx):
        question = self.questions[idx]
        context = self.contexts[idx]
        
        return {
            'question': question,
            'context': context
        }

In [3]:
%%writefile cross_encoder.py
import torch
from torch import nn
from transformers import XLMRobertaModel, XLMRobertaTokenizer, AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from tqdm import tqdm
from typing import List, Dict, Tuple, Union
import faiss
import numpy as np
from sklearn.model_selection import train_test_split

class BiEncoderConfig:
    def __init__(
        self,
        max_length: int = 256,
        batch_size: int = 16,
        learning_rate: float = 1e-5,
        num_epochs: int = 2,
        temperature: float = 0.05,
        embedding_dim: int = 768,

    ):
        self.max_length = max_length
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        self.temperature = temperature
        self.embedding_dim = embedding_dim

class LegalDataset(Dataset):
    def __init__(self, questions: List[str], contexts: List[str], tokenizer, max_length: int):
        self.questions = questions
        self.contexts = contexts
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.questions)
    
    def __getitem__(self, idx):
        question = self.questions[idx]
        context = self.contexts[idx]
        
        return {
            'question': question,
            'context': context
        }
    
class BiEncoder(nn.Module):
    def __init__(self, config: BiEncoderConfig):
        super().__init__()
        self.config = config
        
        # Load the pre-trained Vietnamese bi-encoder for both encoders
        self.question_encoder = AutoModel.from_pretrained("bkai-foundation-models/vietnamese-bi-encoder")
        self.document_encoder = AutoModel.from_pretrained("bkai-foundation-models/vietnamese-bi-encoder")

        # Use the model's tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained("bkai-foundation-models/vietnamese-bi-encoder")
        self.max_length = config.max_length

    def get_device(self):
        # Helper method to get the current device
        if isinstance(self.question_encoder, nn.DataParallel):
            return self.question_encoder.module.device
        return next(self.question_encoder.parameters()).device

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(
            -1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def encode_question(self, questions: List[str], batch_size: int = 32) -> torch.Tensor:
        all_embeddings = []
        device = self.get_device()

        for i in range(0, len(questions), batch_size):
            batch_texts = questions[i:i + batch_size]

            encoded_input = self.tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=self.max_length,
                return_tensors='pt'
            ).to(device)

            with torch.no_grad():
                model_output = self.question_encoder(**encoded_input)

            batch_embeddings = self.mean_pooling(
                model_output, encoded_input['attention_mask'])
            all_embeddings.append(batch_embeddings)

        return torch.cat(all_embeddings, dim=0)

    def encode_document(self, documents: List[str], batch_size: int = 32, disable_progress_bar: bool = False) -> torch.Tensor:
        all_embeddings = []
        device = self.get_device()

        # Calculate total number of batches for progress bar
        num_batches = (len(documents) + batch_size - 1) // batch_size

        for i in tqdm(range(0, len(documents), batch_size), total=num_batches, desc="Encoding documents", disable=disable_progress_bar):
            batch_texts = documents[i:i + batch_size]

            encoded_input = self.tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=self.max_length,
                return_tensors='pt'
            ).to(device)

            with torch.no_grad():
                model_output = self.document_encoder(**encoded_input)

            batch_embeddings = self.mean_pooling(
                model_output, encoded_input['attention_mask'])
            all_embeddings.append(batch_embeddings)

        return torch.cat(all_embeddings, dim=0)


class LegalDataset(Dataset):
    def __init__(self, questions: List[str], contexts: List[str], tokenizer, max_length: int):
        self.questions = questions
        self.contexts = contexts
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.questions)
    
    def __getitem__(self, idx):
        question = self.questions[idx]
        context = self.contexts[idx]
        
        return {
            'question': question,
            'context': context
        }
    
class BiEncoderTrainer:

    def get_config(self) -> BiEncoderConfig:
        """Get the trainer's configuration"""
        return self.config
        
    def get_model(self) -> BiEncoder:
        """Get the trainer's model"""
        return self.model
        
    def __init__(self, config: BiEncoderConfig):
        self.config = config
        self.model = BiEncoder(self.config)
        
        # Initialize with smaller learning rate for fine-tuning
        self.config.learning_rate = 1e-5  # Reduced from 2e-5
        
        # Setup multi-GPU
        if torch.cuda.device_count() > 1:
            print(f"Using {torch.cuda.device_count()} GPUs!")
            self.model.question_encoder = nn.DataParallel(self.model.question_encoder)
            self.model.document_encoder = nn.DataParallel(self.model.document_encoder)
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        
        # Implement gradual unfreezing
        self.unfreeze_layers = 0  # Start with all layers frozen
        self._freeze_layers()
        
        # metrics tracking
        self.best_mrr = 0.0
        self.best_recall = 0.0
        
        # Add new attributes for step-based unfreezing
        self.total_steps = 0
        self.unfreeze_schedule = None  # Will be set in train()

    def _freeze_layers(self):
        """Freeze/unfreeze layers gradually during training"""
        # First freeze all layers
        for param in self.model.question_encoder.parameters():
            param.requires_grad = False
        for param in self.model.document_encoder.parameters():
            param.requires_grad = False
            
        def unfreeze_model_layers(model, num_layers):
            # Always unfreeze the pooler and final layer
            if isinstance(model, nn.DataParallel):
                model = model.module
            
            # Unfreeze pooler
            for param in model.pooler.parameters():
                param.requires_grad = True
            
            # Always keep the final layer unfrozen
            if num_layers == 0:
                for param in model.encoder.layer[-1].parameters():
                    param.requires_grad = True
                return
                
            # Unfreeze specified number of layers from the top
            for layer in list(model.encoder.layer)[-num_layers:]:
                for param in layer.parameters():
                    param.requires_grad = True
                
        # Apply unfreezing to both encoders
        unfreeze_model_layers(self.model.question_encoder, self.unfreeze_layers)
        unfreeze_model_layers(self.model.document_encoder, self.unfreeze_layers)

    def set_scores(self, scores: Tuple):
        self.best_mrr = scores[0]
        self.best_recall = scores[1]
        
    def prepare_batch(self, batch: Dict[str, List[str]]) -> Dict[str, torch.Tensor]:
        # tokenize questions
        questions_tokenized = self.model.tokenizer(
            batch['question'],
            padding=True,
            truncation=True,
            max_length=self.config.max_length,
            return_tensors='pt'
        ).to(self.device)
        
        #Tokenize contexts
        contexts_tokenized = self.model.tokenizer(
            batch['context'],
            padding=True,
            truncation=True,
            max_length=self.config.max_length,
            return_tensors='pt'
        ).to(self.device)
        
        return {
            'question_data': questions_tokenized,
            'context_data': contexts_tokenized
        }
    
    def compute_loss(self, q_embeddings: torch.Tensor, d_embeddings: torch.Tensor,
                    hard_negative_embeddings: torch.Tensor = None) -> torch.Tensor:
        """Compute loss with both in-batch and hard negatives"""
        #regular in-batch negative loss
        similarity = torch.matmul(q_embeddings, d_embeddings.t())
        
        if hard_negative_embeddings is not None:
            # hard negative similarities
            hard_similarity = torch.matmul(q_embeddings, hard_negative_embeddings.t())
            similarity = torch.cat([similarity, hard_similarity], dim=1)
        
        # scale by temperature
        similarity = similarity / self.config.temperature
        
        # Create labels for diagonal (positive pairs)
        labels = torch.arange(q_embeddings.size(0)).to(self.device)
        
        # Compute loss
        loss = nn.CrossEntropyLoss()(similarity, labels)
        
        return loss
    
    def evaluate(self, val_dataset: LegalDataset) -> Dict[str, float]:
        """
        Evaluate the model on validation dataset
        Returns MRR@k and Recall@k metrics
        """
        self.model.eval()
        val_dataloader = DataLoader(
            val_dataset,
            batch_size=self.config.batch_size,
            shuffle=False
        )
        
        all_q_embeddings = []
        all_d_embeddings = []
        
        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc="Evaluating"):
                processed_batch = self.prepare_batch(batch)
                
                # Get embeddings
                q_embeddings = self.model.mean_pooling(
                    self.model.question_encoder(**processed_batch['question_data']),
                    processed_batch['question_data']['attention_mask']
                )
                d_embeddings = self.model.mean_pooling(
                    self.model.document_encoder(**processed_batch['context_data']),
                    processed_batch['context_data']['attention_mask']
                )
                
                # Normalize
                q_embeddings = nn.functional.normalize(q_embeddings, p=2, dim=1)
                d_embeddings = nn.functional.normalize(d_embeddings, p=2, dim=1)
                
                all_q_embeddings.append(q_embeddings)
                all_d_embeddings.append(d_embeddings)
        
        # Concatenate all embeddings
        all_q_embeddings = torch.cat(all_q_embeddings, dim=0)
        all_d_embeddings = torch.cat(all_d_embeddings, dim=0)
        
        # Compute similarity matrix
        similarity = torch.matmul(all_q_embeddings, all_d_embeddings.t())
        
        # Calculate metrics
        k_values = [1, 5, 10, 50,100, 200, 500, 1000]
        metrics = {}
        
        for k in k_values:
            # Get top-k indices
            _, indices = similarity.topk(k, dim=1)
            
            # Calculate Recall@k
            correct = torch.arange(similarity.size(0)).unsqueeze(1).expand_as(indices).to(self.device)
            recall_at_k = (indices == correct).float().sum(dim=1).mean().item()
            metrics[f'recall@{k}'] = recall_at_k
            
            # Calculate MRR@k
            rank = (indices == correct).nonzero()[:, 1] + 1
            mrr = (1.0 / rank).mean().item()
            metrics[f'mrr@{k}'] = mrr
            
        return metrics

    def train(self, train_dataset: LegalDataset, val_dataset: LegalDataset = None):
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True
        )
        
        # Calculate total steps and set unfreeze schedule
        total_steps = len(train_dataloader) * self.config.num_epochs
        steps_per_epoch = len(train_dataloader)
        
        # Adjust unfreeze schedule to be more frequent within the epoch
        self.unfreeze_schedule = {
            steps_per_epoch // 4: 2,     # Unfreeze top 2 layers after 25% steps
            steps_per_epoch // 2: 4,     # Unfreeze top 4 layers after 50% steps
            3 * steps_per_epoch // 4: 6,  # Unfreeze top 6 layers after 75% steps
            9 * steps_per_epoch // 10: 8  # Unfreeze top 8 layers after 90% steps
        }
        
        # Use different optimizers for frozen/unfrozen parameters
        def get_optimizer():
            params = []
            for model in [self.model.question_encoder, self.model.document_encoder]:
                params.extend([p for p in model.parameters() if p.requires_grad])
            return torch.optim.AdamW(
                params,
                lr=self.config.learning_rate,
                weight_decay=0.01
            )
        
        optimizer = get_optimizer()
        
        # Add learning rate scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 
            T_max=total_steps
        )
        
        # Adjust mining frequency to occur multiple times within epoch
        mine_every_n_steps = steps_per_epoch // 2  # Mine 4 times per epoch
        
        for epoch in range(self.config.num_epochs):
            self.model.train()
            total_loss = 0
            progress_bar = tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{self.config.num_epochs}')
            
            hard_negatives = None
            current_epoch_step = 0
            
            for batch_idx, batch in enumerate(progress_bar):
                current_epoch_step = batch_idx
                
                # Check if it's time to mine hard negatives
                if current_epoch_step % mine_every_n_steps == 0 and current_epoch_step != 0:
                    print(f"\nMining hard negatives at step {current_epoch_step}...")
                    questions = train_dataset.questions
                    contexts = train_dataset.contexts
                    hard_negatives = self.mine_hard_negatives(questions, contexts)
                
                # Check unfreeze schedule based on current epoch step
                if current_epoch_step in self.unfreeze_schedule:
                    print(f"\nUnfreezing layers at step {current_epoch_step}...")
                    self.unfreeze_layers = self.unfreeze_schedule[current_epoch_step]
                    self._freeze_layers()
                    optimizer = get_optimizer()  # Reinitialize optimizer with new trainable params
                
                optimizer.zero_grad()
                
                #Prepare batch data
                processed_batch = self.prepare_batch(batch)
                
                #Get embeddings from separate encoders
                q_embeddings = self.model.mean_pooling(
                    self.model.question_encoder(**processed_batch['question_data']),
                    processed_batch['question_data']['attention_mask']
                )
                d_embeddings = self.model.mean_pooling(
                    self.model.document_encoder(**processed_batch['context_data']),
                    processed_batch['context_data']['attention_mask']
                )
                
                #Process hard negatives if available
                hard_negative_embeddings = None
                if hard_negatives is not None:
                    hard_negative_batch = self.model.tokenizer(
                        hard_negatives[batch_idx:batch_idx + len(batch)],
                        padding=True,
                        truncation=True,
                        max_length=self.config.max_length,
                        return_tensors='pt'
                    ).to(self.device)
                    
                    hard_negative_embeddings = self.model.mean_pooling(
                        self.model.document_encoder(**hard_negative_batch),
                        hard_negative_batch['attention_mask']
                    )
                    hard_negative_embeddings = nn.functional.normalize(hard_negative_embeddings, p=2, dim=1)
                
                #Normalize embeddings
                q_embeddings = nn.functional.normalize(q_embeddings, p=2, dim=1)
                d_embeddings = nn.functional.normalize(d_embeddings, p=2, dim=1)
                
                #Compute loss with hard negatives
                loss = self.compute_loss(q_embeddings, d_embeddings, hard_negative_embeddings)
                
                #Backward pass
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                progress_bar.set_postfix({'loss': total_loss / (progress_bar.n + 1)})
            
            avg_loss = total_loss / len(train_dataloader)
            print(f'Epoch {epoch + 1}/{self.config.num_epochs}, Average Loss: {avg_loss:.4f}')
            
            # Validation step
            if val_dataset is not None:
                metrics = self.evaluate(val_dataset)
                print("Validation metrics:")
                for metric_name, value in metrics.items():
                    print(f"{metric_name}: {value:.4f}")
                
                # Save best model
                if metrics['mrr@10'] > self.best_mrr:
                    self.best_mrr = metrics['mrr@10']
                    self.save_model('best_model.pt')
    
    def save_model(self, path: str):
        # Save both encoders and config
        torch.save({
            'question_encoder': self.model.question_encoder.state_dict(),
            'document_encoder': self.model.document_encoder.state_dict(),
            'config': self.config
        }, path)
    
    def load_model(self, path: str):
        checkpoint = torch.load(path)
        print
        question_state_dict = checkpoint['question_encoder']
        document_state_dict = checkpoint['document_encoder']
        
        # Remove 'module.' prefix if it exists and model is not using DataParallel
        if not isinstance(self.model.question_encoder, nn.DataParallel):
            print("remove module.")
            question_state_dict = {k.replace('module.', ''): v for k, v in question_state_dict}
            document_state_dict = {k.replace('module.', ''): v for k, v in document_state_dict}
        # Add 'module.' prefix if model is using DataParallel but saved model wasn't
        elif not any(k.startswith('module.') for k in question_state_dict):
            print("add module.")
            question_state_dict = {'module.' + k: v for k, v in question_state_dict}
            document_state_dict = {'module.' + k: v for k, v in document_state_dict}
        
        # Load the state dictionaries
        try:
            self.model.question_encoder.load_state_dict(question_state_dict)
            self.model.document_encoder.load_state_dict(document_state_dict)
        except RuntimeError as e:
            print(f"Error loading state dict: {e}")
            print("Attempting alternative loading method...")
            
            # If the first attempt fails, try the opposite approach
            if isinstance(self.model.question_encoder, nn.DataParallel):
                question_state_dict = {k.replace('module.', ''): v for k, v in question_state_dict.items()}
                document_state_dict = {k.replace('module.', ''): v for k, v in document_state_dict.items()}
            else:
                question_state_dict = {'module.' + k: v for k, v in question_state_dict.items()}
                document_state_dict = {'module.' + k: v for k, v in document_state_dict.items()}
            
            self.model.question_encoder.load_state_dict(question_state_dict)
            self.model.document_encoder.load_state_dict(document_state_dict)

    def mine_hard_negatives(self, questions: List[str], documents: List[str], 
                       batch_size: int = 256) -> List[str]:  # Reduced batch size
        """Mine hard negatives using embeddings similarity"""
        self.model.eval()
        device = self.device
    
        # Significantly reduced chunk sizes to avoid OOM
        chunk_size = 2000  # Reduced from 5000
        similarity_chunk_size = 200  # Reduced from 500
    
        with torch.no_grad():
            # Process questions in smaller chunks
            all_q_embeddings = []
            for i in tqdm(range(0, len(questions), chunk_size), desc="Encoding questions"):
                q_chunk = questions[i:i + chunk_size]
                q_emb = self.model.encode_question(q_chunk, batch_size)
                all_q_embeddings.append(q_emb.cpu())  # Move to CPU after processing
                torch.cuda.empty_cache()
            q_embeddings = torch.cat(all_q_embeddings, dim=0)
        
            # Process documents in smaller chunks
            all_d_embeddings = []
            for i in tqdm(range(0, len(documents), chunk_size), desc="Encoding documents"):
                d_chunk = documents[i:i + chunk_size]
                d_emb = self.model.encode_document(d_chunk, batch_size)
                all_d_embeddings.append(d_emb.cpu())  # Move to CPU after processing
                torch.cuda.empty_cache()
            d_embeddings = torch.cat(all_d_embeddings, dim=0)
            
            # Normalize embeddings (on CPU to save GPU memory)
            q_embeddings = nn.functional.normalize(q_embeddings, p=2, dim=1)
            d_embeddings = nn.functional.normalize(d_embeddings, p=2, dim=1)
            
            # Compute similarity in smaller chunks
            hard_negative_indices = []
            k = 2  # Reduced number of hard negatives per question
            
            for i in tqdm(range(0, len(q_embeddings), similarity_chunk_size), desc="Mining negatives"):
                # Move only the current chunks to GPU
                q_chunk = q_embeddings[i:i + similarity_chunk_size].to(device)
                d_chunk = d_embeddings.to(device)
                
                similarity = torch.matmul(q_chunk, d_chunk.t())
                
                # Get top-k most similar but incorrect documents
                values, indices = similarity.topk(k + 1, dim=1)
                
                # Filter out positive pairs
                mask = torch.arange(i, min(i + similarity_chunk_size, len(q_embeddings))).unsqueeze(1).expand_as(indices).to(device)
                chunk_negative_indices = indices[indices != mask].view(-1)
                hard_negative_indices.extend(chunk_negative_indices.cpu().numpy())
                
                # Free memory
                del similarity, values, indices, q_chunk, d_chunk
                torch.cuda.empty_cache()
        
        return [documents[idx] for idx in hard_negative_indices]

Writing cross_encoder.py


In [4]:
%%writefile dual_encoder.py
from sentence_transformers import CrossEncoder
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sentence_transformers.readers import InputExample
from sentence_transformers.cross_encoder.evaluation import CEBinaryAccuracyEvaluator
from sentence_transformers import LoggingHandler
import logging
import pandas as pd
import torch


class LoggingCallback:
    def __init__(self):
        self.current_step = 0

    def __call__(self, score: float, epoch: int, steps: int):
        self.current_step += 1
        print(f'Step: {self.current_step}, Epoch: {epoch}, Loss: {score:.4f}')


class CrossEncoderWrapper:
    def __init__(
        self,
        model_name: str = "bkai-foundation-models/vietnamese-bi-encoder",
        max_length: int = 256,
        batch_size: int = 16,
    ):
        self.model_name = model_name
        self.batch_size = batch_size
        self.max_length = max_length
        
        # Explicitly set device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Initialize model with explicit device
        self.model = CrossEncoder(
            model_name,
            max_length=max_length,
            device=self.device
        )

    def train(self, train_dataset: Dataset, val_dataset: Dataset = None,
              num_epochs: int = 1, warmup_ratio: float = 0.02):
        """Train the cross-encoder model"""

       # Setup logging
        logging.basicConfig(format='%(asctime)s - %(message)s',
                            datefmt='%Y-%m-%d %H:%M:%S',
                            level=logging.INFO,
                            handlers=[LoggingHandler()])

        # Prepare training examples
        train_samples = [
            InputExample(texts=[sample['question'], sample['document']],
                         label=sample['label'])
            for sample in train_dataset
        ]

        # Create data loader
        train_dataloader = DataLoader(
            train_samples,
            shuffle=True,
            batch_size=self.batch_size
        )

        # Create evaluator for validation
        if val_dataset is not None:
            # Prepare validation data in the required format
            val_samples = [
                InputExample(texts=[sample['question'], sample['document']],
                             label=sample['label'])
                for sample in val_dataset
            ]

            evaluator = CEBinaryAccuracyEvaluator.from_input_examples(
                val_samples, name="evaluation")
        else:
            evaluator = None
        # Create callback
        callback = LoggingCallback()

        # Train the model
        self.model.fit(
            train_dataloader=train_dataloader,
            evaluator=evaluator,
            epochs=num_epochs,
            evaluation_steps=2000,  # Evaluate every 2000 steps
            warmup_steps=int(len(train_dataset) * warmup_ratio),
            show_progress_bar=True,
            output_path='checkpoints',  # Save checkpoints
            # save_best_model=True,
            callback=callback  # Add the callback
        )

    def evaluate(self, dataset: Dataset) -> Dict[str, float]:
        """Evaluate the model on a dataset"""
        pairs = [
            [sample['question'], sample['document']]
            for sample in dataset
        ]
        labels = [sample['label'] for sample in dataset]

        # Get predictions
        scores = self.model.predict(pairs)
        predictions = (scores > 0.5).astype(int)

        # Calculate metrics
        metrics = {
            'accuracy': accuracy_score(labels, predictions),
            'precision': precision_score(labels, predictions),
            'recall': recall_score(labels, predictions),
            'f1': f1_score(labels, predictions),
            'auc_roc': roc_auc_score(labels, scores)
        }

        return metrics

    def predict(self, questions: List[str], documents: List[str], show_progress_bar=True) -> np.ndarray:
        """Get relevance scores for question-document pairs"""
        try:
            # Ensure model is in eval mode
            self.model.model.eval()
            
            # Create pairs
            pairs = [[q, d] for q, d in zip(questions, documents)]

            # Make prediction with smaller batch size
            return self.model.predict(
                pairs,
                batch_size=16,  # Smaller batch size
                show_progress_bar=show_progress_bar
            )
            
        except Exception as e:
            print(f"Error during prediction: {str(e)}")
            raise

    def save_model(self, path: str):
        self.model.save(path)

    def load_model(self, path: str):
        self.model = CrossEncoder(
            path,
            max_length=self.max_length,
            device=self.device
        )


class CrossEncoderDataset(Dataset):
    def __init__(self, questions: List[str], documents: List[str], labels: List[int]):
        self.questions = questions
        self.documents = documents
        self.labels = labels
        assert len(questions) == len(documents) == len(
            labels), "All inputs must have the same length"

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        return {
            'question': self.questions[idx],
            'document': self.documents[idx],
            'label': self.labels[idx]
        }


Writing dual_encoder.py


In [5]:
%%writefile pipeline.py
from underthesea import word_tokenize
from cross_encoder import *
from dual_encoder import *


class RetrievalPipeline:
    def __init__(self,
                 bi_encoder_path: str,
                 cross_encoder_path: str,
                 faiss_index_path: str,
                 # embeddings_path: str,
                 corpus_df_path: str,
                 top_k: int = 50,
                 rerank_k: int = 10):
        """
        Initialize retrieval pipeline
        Args:
            bi_encoder_path: Path to bi encoder checkpoint
            cross_encoder_path: Path to cross encoder checkpoint
            faiss_index_path: Path to saved FAISS index
            embeddings_path: Path to saved document embeddings
            corpus_df_path: Path to corpus CSV file
            top_k: Number of candidates to retrieve from FAISS
            rerank_k: Number of final results after reranking
        """
        # Load models
        self.bi_encoder = BiEncoder(BiEncoderConfig())
        self.cross_encoder = CrossEncoderWrapper()

        # Load model
        self.load_models(bi_encoder_path, cross_encoder_path)

        # Load FAISS index and embeddings
        self.index = faiss.read_index(faiss_index_path)
        # self.document_embeddings = np.load(embeddings_path)

        # Load corpus for retrieving text
        self.corpus_df = pd.read_csv(corpus_df_path)

        # Configuration
        self.top_k = top_k
        self.rerank_k = rerank_k

        # Move models to GPU if available
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.bi_encoder.to(self.device)
        # self.cross_encoder.model.to(torch.device("cuda"))

    def load_models(self, bi_encoder_path: str, cross_encoder_path: str):
        bi_encoder_checkpoint = torch.load(bi_encoder_path)
        question_state_dict = bi_encoder_checkpoint['question_encoder']
        document_state_dict = bi_encoder_checkpoint['document_encoder']

        question_state_dict = {
            k.replace('module.', ''): v for k, v in question_state_dict.items()}
        document_state_dict = {
            k.replace('module.', ''): v for k, v in document_state_dict.items()}

        self.bi_encoder.document_encoder.load_state_dict(
            document_state_dict)
        self.bi_encoder.question_encoder.load_state_dict(
            question_state_dict)
        # self.bi_encoder.config = bi_encoder_checkpoint['config']

        self.cross_encoder.load_model(cross_encoder_path)

    def retrieve(self, query: str, rerank: bool = True) -> List[Dict[str, Union[str, float, int]]]:
        """
        Retrieve and rerank documents for a query
        Args:
            query: Question string
        Returns:
            List of dicts containing retrieved documents with scores and metadata
        """
        # Clear cache
        torch.cuda.empty_cache()

        # Stage 1: Bi-encoder retrieval
        query_embedding = self.bi_encoder.encode_question([query])

        query_embedding = query_embedding.cpu().numpy()
        faiss.normalize_L2(query_embedding)

        # Search FAISS index
        scores, doc_indices = self.index.search(query_embedding, self.top_k)

        # Get candidate documents
        candidates = []
        for score, doc_idx in zip(scores[0], doc_indices[0]):
            candidates.append({
                'text': self.corpus_df.iloc[doc_idx]['text'],
                'cid': self.corpus_df.iloc[doc_idx]['cid'],
                'bi_encoder_score': float(score)
            })
        if rerank:
            # Stage 2: Cross-encoder reranking
            candidate_texts = [c['text'] for c in candidates]
            # self.bi_encoder.to(torch.device("cuda"))
            batch_size = 16  # Adjust this based on your GPU memory capacity
            cross_encoder_scores = []

            for i in range(0, len(candidate_texts), batch_size):
                batch_texts = candidate_texts[i:i + batch_size]
                batch_scores = self.cross_encoder.predict(
                    [query] * len(batch_texts),
                    batch_texts,
                    show_progress_bar=False
                )
                cross_encoder_scores.extend(batch_scores)
            # self.cross_encoder.to(torch.device("cpu"))
            # Add cross-encoder scores
            for idx, score in enumerate(cross_encoder_scores):
                candidates[idx]['cross_encoder_score'] = float(score)

            candidates.sort(
                key=lambda x: x['cross_encoder_score'], reverse=True)
            results = candidates[:self.rerank_k]

            return results

        else:
            candidates.sort(key=lambda x: x['bi_encoder_score'], reverse=True)
            results = candidates[:self.rerank_k]

            return results

    def batch_retrieve(self,
                       queries: List[str],
                       batch_size: int = 32) -> List[List[Dict[str, Union[str, float, int]]]]:
        """
        Batch retrieval for multiple queries
        """
        all_results = []
        for i in range(0, len(queries), batch_size):
            batch = queries[i:i + batch_size]
            batch_results = [self.retrieve(q) for q in batch]
            all_results.extend(batch_results)
        return all_results


def evaluate_retrieval(pipeline, test_df, top_k=10, re_ranking=True):
    """
    Evaluate the retrieval pipeline on a test set using MAP, MRR, NDCG, and Recall@k.

    Args:
        pipeline: The retrieval pipeline instance.
        test_df: DataFrame containing 'question' and 'cid' columns.
        top_k: Number of top results to consider for evaluation.

    Returns:
        A dictionary with MAP, MRR, NDCG, and Recall@k scores.
    """
    def parse_cid_string(cid_str):
        """Parse a space-separated string of CIDs into a set of integers."""
        return set(map(int, cid_str.strip('[]').split()))

    true_cids = test_df['cid'].apply(parse_cid_string)
    questions = test_df['question'].tolist()

    average_precisions = []
    reciprocal_ranks = []
    ndcg_scores = []
    recall_at_k = []

    for question, true_cid_set in tqdm(zip(questions, true_cids), total=len(questions), desc="Evaluating"):
        # Retrieve documents
        results = pipeline.retrieve(question, re_ranking)

        # Get retrieved cids
        retrieved_cids = [result['cid'] for result in results[:top_k]]

        # Calculate Average Precision
        num_relevant = 0
        precision_sum = 0.0
        for i, cid in enumerate(retrieved_cids):
            if cid in true_cid_set:
                num_relevant += 1
                precision_sum += num_relevant / (i + 1)
        average_precision = precision_sum / \
            len(true_cid_set) if true_cid_set else 0
        average_precisions.append(average_precision)

        # Calculate Reciprocal Rank
        reciprocal_rank = 0.0
        for i, cid in enumerate(retrieved_cids):
            if cid in true_cid_set:
                reciprocal_rank = 1.0 / (i + 1)
                break
        reciprocal_ranks.append(reciprocal_rank)

        # Calculate NDCG
        dcg = 0.0
        idcg = sum(1.0 / (i + 1) for i in range(min(len(true_cid_set), top_k)))
        for i, cid in enumerate(retrieved_cids):
            if cid in true_cid_set:
                dcg += 1.0 / (i + 1)
        ndcg = dcg / idcg if idcg > 0 else 0
        ndcg_scores.append(ndcg)

        # Calculate Recall@k
        relevant_retrieved = len(set(retrieved_cids) & true_cid_set)
        recall = relevant_retrieved / len(true_cid_set) if true_cid_set else 0
        recall_at_k.append(recall)

    # Calculate average metrics
    map_score = sum(average_precisions) / len(average_precisions)
    mrr_score = sum(reciprocal_ranks) / len(reciprocal_ranks)
    avg_ndcg = sum(ndcg_scores) / len(ndcg_scores)
    avg_recall_at_k = sum(recall_at_k) / len(recall_at_k)

    return {
        'MAP': map_score,
        'MRR': mrr_score,
        'NDCG': avg_ndcg,
        'Recall@k': avg_recall_at_k
    }


bi_encoder_path = "/kaggle/working/bi_encoder.pt"
cross_encoder_path = "/kaggle/input/cross_encoder_new/pytorch/default/1"
faiss_index_path = "/kaggle/input/bi_encoder_embedding/pytorch/default/1/document_index.faiss"
corpus_df_path = "/kaggle/input/soict-dataset-2024-segmented/corpus_segmented.csv"

# Initialize pipeline
pipeline = RetrievalPipeline(
    bi_encoder_path=bi_encoder_path,
    cross_encoder_path=cross_encoder_path,
    faiss_index_path=faiss_index_path,
    # embeddings_path="path/to/embeddings.npy",
    corpus_df_path=corpus_df_path,
    # rerank_k=50
)


def segment(text: str) -> str:
    return word_tokenize(text, format="text")


def format_retrieval_results(results: List[Dict[str, Union[str, float, int]]]) -> str:
    """
    Reformat and concatenate retrieved documents into a single string.

    Args:
        results: List of dictionaries containing retrieved documents with their metadata

    Returns:
        str: Formatted and concatenated document string
    """
    # Remove quotes and underscores, then join documents with newlines
    formatted_docs = []

    for doc in results:
        # Remove quotes and underscores from text
        text = doc['text'].strip('"\'')  # Remove quotes
        text = text.replace('_', ' ')    # Replace underscores with spaces

        # Add document to list with its score
        formatted_docs.append(f"Document (relevance score: {doc['cross_encoder_score']:.2f}):\n{text}")

    # Join all documents with double newlines for better readability
    return "\n\n".join(formatted_docs)


def retrieval_legal_documents(query: str) -> str:
    """
    Sử dụng công cụ này để truy vấn 10 tài liệu luật liên quan nhất đến câu query
    Ví dụ cách sử dụng:
    query: Quy định về kinh doanh như thế nào?
    """
    segment_query = segment(query)
    results = pipeline.retrieve(segment_query)
    return format_retrieval_results(results)


Writing pipeline.py


In [6]:
# Create a script to clean the checkpoint
import torch

# Load the checkpoint
checkpoint = torch.load("/kaggle/input/bi_encoder/pytorch/default/1/bi_encoder.pt")

# Create new checkpoint with only the model states
clean_checkpoint = {
    'question_encoder': checkpoint['question_encoder'],
    'document_encoder': checkpoint['document_encoder']
}

# Save the clean checkpoint
torch.save(clean_checkpoint, 'bi_encoder.pt')

  checkpoint = torch.load("/kaggle/input/bi_encoder/pytorch/default/1/bi_encoder.pt")


## STAGE 2 - Cross-encoder (Reranking Stage):

## Interface

In [7]:
!pip install pydantic==2.10

Collecting pydantic==2.10
  Downloading pydantic-2.10.0-py3-none-any.whl.metadata (167 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m167.8/167.8 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m00:01[0m
Collecting pydantic-core==2.27.0 (from pydantic==2.10)
  Downloading pydantic_core-2.27.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Downloading pydantic-2.10.0-py3-none-any.whl (454 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m454.3/454.3 kB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pydantic_core-2.27.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m49.3 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: pydantic-core, pydantic
  Attempting uninstall: pydantic-core
    Found existing installation: pydantic_core 2.23.4
    Uninstalling pydantic_core-2.23.4:
      Succe

In [8]:
!pip install --quiet pyngrok  chainlit langchain_openai langchain langchain_core langchain_community

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.1/57.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.8/41.8 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.7/4.7 MB[0m [31m49.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.9/50.9 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m46.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m411.6/411.6 kB[0m [31m28.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m70.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━

In [9]:
!ngrok config add-authtoken 2nWCWyomomJ1lyKhW8JYbJpKT1S_2cBhY2ssSnyuDVzbtP7Eu

Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml                                


In [10]:
%%writefile interface.py
import os
from typing import List
from datetime import datetime
from pipeline import retrieval_legal_documents
import chainlit as cl
from langchain_openai import ChatOpenAI
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain_core.tools import tool
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.memory import ConversationBufferMemory
from langchain.callbacks.base import BaseCallbackHandler

# Custom callback handler for streaming

LEGAL_SYSTEM_PROMPT = """Bạn là một trợ lý pháp lý chuyên nghiệp, 
chuyên về luật pháp Việt Nam. Nhiệm vụ của bạn là cung cấp thông tin chính xác, 
rõ ràng và hữu ích dựa trên các văn bản pháp luật được cung cấp. Bạn cũng có công cụ để
tìm tin tức trong trường hợp người dùng yêu cầu. Bạn cũng có công cụ để lấy ngày tháng hiện tại
hãy sử dụng kết hợp các công cụ để hỗ trợ người dùng

NẾU CÂU HỎI CÓ LIÊN QUAN ĐẾN LUẬT THÌ TRẢ LỜI THEO FORMAT SAU:
Nhiệm vụ chính:
1. CHỈ trả lời dựa trên ngữ cảnh pháp lý được cung cấp. Không được dùng thông tin ngoài tài liệu được cung cấp.
2. Đảm bảo tính chính xác về mặt pháp lý.
3. Duy trì giọng điệu chuyên nghiệp, khách quan.
4. Sử dụng một cách thông minh các công cụ để hỗ trời người dùng

Cấu trúc câu trả lời:
1. Câu trả lời trực tiếp: Trả lời ngắn gọn
2. Giải thích chi tiết:
    * Giải thích rõ ràng, nên ghi chi tiết nhất có thể dựa vào tài liệu bạn có
    * Ghi ra thật cụ thể thông tin cần thiết từ tài liệu bạn được cung cấp
    * Phải ghi cụ thể chứ không được trả lời chung chung
3. Căn cứ pháp lý: Trích dẫn các điều luật liên quan
4. Lưu ý thêm: Các điểm cần lưu ý nếu có
Lưu ý: Bạn được cung cấp một công cụ để truy vấn tài liệu dựa trên câu hỏi của người dùng
sẽ có 10 tài liệu được trả về nhưng thường chỉ có 1 hoặc 2 tài liệu có thông tin chính
xác, do đó hãy xem xét cẩn thận từng tài liệu trước khi trả lời, nếu nhận thấy không
có tài liệu nào có thể trả lời được câu hỏi của người dùng hãy bảo họ là bạn không có
thông tin
"""


class StreamHandler(BaseCallbackHandler):
    def __init__(self, container):
        self.container = container

    def on_llm_new_token(self, token: str, **kwargs):
        self.container.write(token)


@tool
def get_current_day() -> str:
    """
    Dùng để lấy ngày tháng năm hiện tại dưới dạng dd/mm/yyyy
    """
    # Get the current date
    current_date = datetime.now()

    # Format the date as dd/mm/yyyy
    formatted_date = current_date.strftime("%d/%m/%Y")
    return formatted_date

@tool
def get_news(query: str, date: str) -> str:
    """
    Dùng để tìm kiếm tin tức. Ví dụ:
    query: "Thông tin thời tiết hôm nay?"
    date: "01/01/2003"
    """
    if date == "19/12/2024":
        return """
        Ngày 3.11, Ngân hàng Nhà nước giảm nhẹ tỷ giá trung tâm 1 đồng/USD, 
        xuống còn 23.687 đồng/USD. Các ngân hàng tăng giá bán USD lên mức kịch trần.
        Eximbank bán USD lên 24.870 đồng, mua vào 24.680 - 24.700 đồng. Vietcombank bán
        USD với giá 24.872 đồng, mua vào 24.562 - 24.592 đồng… Trên thị trường liên
        ngân hàng, tỷ giá chốt phiên với mức 24.850 đồng/USD, giảm 27 đồng.
        """
    else:
        return "Không có tin tức về chủ đề này"

@tool
def retrieve_documents(query: str) -> str:
    """
    Sử dụng công cụ này để truy vấn 10 tài liệu luật liên quan nhất đến câu query
    Ví dụ cách sử dụng:
    query: Quy định về kinh doanh như thế nào?
    """
    return retrieval_legal_documents(query)


@cl.on_chat_start
async def start_chat():
    """Initialize the chat session"""
    # Initialize the ChatOpenAI model
    llm = ChatOpenAI(
        base_url="https://openrouter.ai/api/v1",
        api_key= "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",
        model="openai/gpt-4o-mini",
        #streaming=True,
        #model_name="gpt-4o-mini"
    )

    # Create a conversation memory
    memory = ConversationBufferMemory(
        memory_key="chat_history",
        return_messages=True,
        output_key="output"
    )

    # Create the prompt with memory
    prompt = ChatPromptTemplate.from_messages([
        ("system", LEGAL_SYSTEM_PROMPT),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{input}"),
        MessagesPlaceholder(variable_name="agent_scratchpad"),
    ])

    # Create the agent
    agent = create_tool_calling_agent(
        llm=llm,
        tools=[get_current_day, get_news, retrieve_documents],
        prompt=prompt
    )

    # Create the agent executor with memory
    agent_executor = AgentExecutor(
        agent=agent,
        tools=[get_current_day, get_news, retrieve_documents],
        memory=memory,
        verbose=True,
        return_intermediate_steps=True,
    )

    # Store the agent executor in the user session
    cl.user_session.set("agent", agent_executor)

    await cl.Message(content="Xin chào tôi là AI tư vấn luật, tôi có thể giúp gì cho bạn").send()


@cl.on_message
async def main(message: cl.Message):
    """Handle incoming messages"""

    # Get the agent executor from user session
    agent_executor = cl.user_session.get("agent")

    # Create a message container
    msg = cl.Message(content="")
    await msg.send()

    # Run the agent with streaming
    async for chunk in agent_executor.astream(
        {"input": message.content},
    ):
        # Check if 'output' is in the chunk to handle different streaming behaviors
        if 'output' in chunk:
            await msg.stream_token(chunk["output"])
        elif isinstance(chunk, str):
            await msg.stream_token(chunk)

    await msg.update()


Writing interface.py


In [11]:
# import os
# os.environ["OPENAI_API_KEY"] = ""

In [12]:
from pyngrok import ngrok

# Set up ngrok tunnel to the FastAPI server
public_url = ngrok.connect(8000)
print(f"Public URL: {public_url}")

Public URL: NgrokTunnel: "https://4467-35-237-228-176.ngrok-free.app" -> "http://localhost:8000"


In [None]:
!chainlit run interface.py -w

2025-01-09 15:15:40.090087: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-01-09 15:15:40.338141: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-01-09 15:15:40.407573: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
config.json: 100%|█████████████████████████████| 777/777 [00:00<00:00, 5.32MB/s]
model.safetensors: 100%|██████████████████████| 540M/540M [00:02<00:00, 228MB/s]
tokenizer_config.json: 100%|███████████████| 1.17k/1.17k [00:00<00:00, 11.4MB/s]
vocab.txt: 100%|█████████████████████████████| 895k/895k [00:00<00:00, 14.4MB/s]
bpe.codes: 100%|█████████████████████████

In [None]:
ngrok.kill()