### Sơ lược về dataset và model
* Dataset:
  * **soict-dataset-2024**: Bộ dữ liêu gốc.
  * **soict-dataset-2024-segmented**: Bộ dữ liêu gốc được segment dùng underthesea
  * **cross-encoder-dataset**: Dataset này được tạo ra bằng cách lấy ground truth + 3 hard negative trả về bởi bi-encoder + 1 random negative,tập train gồm hơn 500000 sample và tập val có hơn 100000 sample
  * **cross-encoder-dataset-segmented**: Giống cross-encoder-dataset nhưng được tạo ra bằng bi-encoder mới và có segmentation
* Mô hình:
    * **bi_encoder**: Đây là mô hình BiEncoder với base model là  https://huggingface.co/bkai-foundation-models/vietnamese-bi-encoder được train trên tập dữ liệu **soict-dataset-2024-segmented** với 2 epoch có mine hard negative
    * **bi_encoder_embedding** : Đây là FAISS vector database chứa các embedding vector của toàn bộ document trong corpus_segmented.csv được tạo ra bằng cách dùng document_encoder của của bi_encoder
    * **cross_encoder_new**: Mô hình cross-encoder được tạo ra bằng thư viện sentence_transformer và được train trên tập dữ liệu **soict-dataset-2024-segmented**, dùng để reranking các kết quả trả về từ bi-encoder 


In [None]:
import pandas as pd
df = pd.read_csv("../input/cross-encoder-dataset/train_data.csv")
df[:10]

In [None]:
# Visualize the dataset for cross-encoder
for i in range(len(df[:10])):
    
    if df.iloc[i]['label'] == 1:
        print("------------------------------------------------------------")
        print(f" QUERY: {df.iloc[i]['question']}")
        print(" - This is relevant document")
        print(df.iloc[i]['document'])
    elif (df.iloc[i]['label'] == 0 and df.iloc[i+1]['label'] == 1):
        print(" - This is random negative")
        print(df.iloc[i]['document'])
    else:
        print(" - This is hard negative")
        print(df.iloc[i]['document'])
    


### Những thứ đã được thực hiện trong notebook này:
1. Train một mô hình bi-encoder với XML-RoBERTA với 1 epoch **bi_encoder_xlmRoBERTA** , chọn mô hình nàu làm baseline vì nó đã được pretrain bằng nhiều ngôn ngữ gồm tiếng Việt
2. Sử dụng mô hình bi-encoder vừa train để thực hiện retrieval trên toàn bộ dataset, với mỗi sample trong dataset lọc lấy các ground truth, và chọn 3 hard negative từ kết quả trả về (có thể hiểu ứng với mỗi query chọn 3 kết quả sai được rank cao nhất khi thực hiện retrieval bằng bi-encoder) cùng với 1 random negative, tổng hợp lại ta được một dataset mới là **cross-encoder-dataset** dùng để train cross-encoder (lí do phải tạo dataset mới để train cross-encoder là vì cross-encoder sẽ được dùng để reranking các candidate trả về bởi bi-encoder, mà các candidate là top@k trả về từ bi-encoder do đó rất tương đồng với câu query vì thế phải train cross-encoder trên các hard negative trả về từ bi-encoder.
3. Sử dụng mô hình **bi_encoder_xlmRoBERTA** trích xuất đặc trưng trên toàn bộ corpus và lưu vào FAISS vector database để lưu trữ và fast silimarity search. Vector embedding được chứa trong **bi_encoder_corpus_embedding**
4. Định nghĩa một mô hình cross-encoder cơ bản dùng thư viện sentence_transformer với base model là XML-RoBERTA và huấn luyện nó trên bộ dữ liệu **cross-encoder-dataset**, khi huấn luyện nhận thấy kết quả không thay đổi nhiều nên đã dừng train sau khoảng 200000 sample, kết quả thu được mô hình **cross_encoder_ckp** dùng để reranking.
5. Định nghĩa một pipeline để load các mô hình đã train được để sử dụng và đánh giá hiệu quả mô hình.

### Một số ý tưởng cách cải thiện mô hình
1. Cải thiện Bi-encoder:
   * Train lại bi-encoder dùng base model được train riêng biệt cho tiếng việt như PhoBERT https://huggingface.co/vinai/phobert-base. dùng các thư viện như underthesea để word segmentation dataset trước khi train
   * Thay vì code chay bi-encoder ta có thể dùng một bi-encoder được pretrain sẵn cho tiếng Việt như https://huggingface.co/bkai-foundation-models/vietnamese-bi-encoder, nếu dùng bi-encoder được pretrain cần phải viết lại một hàm mới để load và train mô hình và word segmentation nếu cần thiết
   * Train thêm nhiều epoch và mine hard negative, mô hình hiện tại dùng XML-RoBERTA và chỉ train trên 1 epoch chưa mine hard negative vì bị tràn ram
2. Cải thiện Cross-encoder:
   * Tạo lại dataset mới để train cho cross-encoder từ mô hình bi-encoder đã được cải thiện
   * Với dataset **cross-encoder-dataset** hiện tại chỉ có 2 class là 1 cho ground truth và 0 cho cả hard và random negative. Ta có thể tạo thêm 1 class nữa với label là 2 cho random negative để giảm sự mất cân bằng dữ liệu.
   * Tìm thêm các cách tốt hơn để tạo dataset để train cross-encoder.
   * Thay vì dùng base model là XLM-RoBERTA nên chuyển sang dùng thử PhoBERT.

### Các thay đổi đã được thực hiện
1. Sử dụng underthesea để thực hiện word segmentation cho train.csv và corpus.csv, thu được tập dữ liệu **soict-dataset-2024-segmented**
2. Dùng https://huggingface.co/bkai-foundation-models/vietnamese-bi-encoder để làm base model và train một mô hình bi-encoder mới trên tập dữ **soict-dataset-2024-segmented**, mô hình bi-encoder mới được train trên 2 epochs có mining hard negative và layers freezing, (xem method train của BiEncoderTrainer)
3. Dùng bi-encoder mới để tạo bộ dữ liệu mới là **cross-encoder-segmented** để train một mô hình cross-encoder mới
4. Train mô hình cross-encoder mới với base model tương tự như bi-encoder

**NOTE**: Hiện tại kaggle đang bị lỗi gì đấy nên thi thoảng sẽ không show progress bar hoặc hiển thị javascript error, lỗi này không ảnh hưởng đến việc chạy code nhưng sẽ khá phiền vì đôi lúc không biết mô hình train đến đâu.

## STAGE 1 - Bi-encoder (First Retrieval Stage)

In [None]:
!pip install faiss-gpu

In [None]:
!pip install sentence-transformers

In [None]:
!pip install underthesea

In [None]:
import torch
from torch import nn
from transformers import XLMRobertaModel, XLMRobertaTokenizer, AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from tqdm import tqdm
from typing import List, Dict, Tuple, Union
import faiss
import numpy as np
from sklearn.model_selection import train_test_split

In [None]:
class BiEncoderConfig:
    def __init__(
        self,
        max_length: int = 256,
        batch_size: int = 16,
        learning_rate: float = 1e-5,
        num_epochs: int = 2,
        temperature: float = 0.05,
        embedding_dim: int = 768,

    ):
        self.max_length = max_length
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        self.temperature = temperature
        self.embedding_dim = embedding_dim

In [None]:
class LegalDataset(Dataset):
    def __init__(self, questions: List[str], contexts: List[str], tokenizer, max_length: int):
        self.questions = questions
        self.contexts = contexts
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.questions)
    
    def __getitem__(self, idx):
        question = self.questions[idx]
        context = self.contexts[idx]
        
        return {
            'question': question,
            'context': context
        }

In [None]:
class BiEncoder(nn.Module):
    def __init__(self, config: BiEncoderConfig):
        super().__init__()
        self.config = config
        
        # Load the pre-trained Vietnamese bi-encoder for both encoders
        self.question_encoder = AutoModel.from_pretrained("bkai-foundation-models/vietnamese-bi-encoder")
        self.document_encoder = AutoModel.from_pretrained("bkai-foundation-models/vietnamese-bi-encoder")

        # Use the model's tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained("bkai-foundation-models/vietnamese-bi-encoder")
        self.max_length = config.max_length

    def get_device(self):
        # Helper method to get the current device
        if isinstance(self.question_encoder, nn.DataParallel):
            return self.question_encoder.module.device
        return next(self.question_encoder.parameters()).device

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(
            -1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def encode_question(self, questions: List[str], batch_size: int = 32) -> torch.Tensor:
        all_embeddings = []
        device = self.get_device()

        for i in range(0, len(questions), batch_size):
            batch_texts = questions[i:i + batch_size]

            encoded_input = self.tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=self.max_length,
                return_tensors='pt'
            ).to(device)

            with torch.no_grad():
                model_output = self.question_encoder(**encoded_input)

            batch_embeddings = self.mean_pooling(
                model_output, encoded_input['attention_mask'])
            all_embeddings.append(batch_embeddings)

        return torch.cat(all_embeddings, dim=0)

    def encode_document(self, documents: List[str], batch_size: int = 32, disable_progress_bar: bool = False) -> torch.Tensor:
        all_embeddings = []
        device = self.get_device()

        # Calculate total number of batches for progress bar
        num_batches = (len(documents) + batch_size - 1) // batch_size

        for i in tqdm(range(0, len(documents), batch_size), total=num_batches, desc="Encoding documents", disable=disable_progress_bar):
            batch_texts = documents[i:i + batch_size]

            encoded_input = self.tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=self.max_length,
                return_tensors='pt'
            ).to(device)

            with torch.no_grad():
                model_output = self.document_encoder(**encoded_input)

            batch_embeddings = self.mean_pooling(
                model_output, encoded_input['attention_mask'])
            all_embeddings.append(batch_embeddings)

        return torch.cat(all_embeddings, dim=0)


class LegalDataset(Dataset):
    def __init__(self, questions: List[str], contexts: List[str], tokenizer, max_length: int):
        self.questions = questions
        self.contexts = contexts
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.questions)
    
    def __getitem__(self, idx):
        question = self.questions[idx]
        context = self.contexts[idx]
        
        return {
            'question': question,
            'context': context
        }

In [None]:
class BiEncoderTrainer:

    def get_config(self) -> BiEncoderConfig:
        """Get the trainer's configuration"""
        return self.config
        
    def get_model(self) -> BiEncoder:
        """Get the trainer's model"""
        return self.model
        
    def __init__(self, config: BiEncoderConfig):
        self.config = config
        self.model = BiEncoder(self.config)
        
        # Initialize with smaller learning rate for fine-tuning
        self.config.learning_rate = 1e-5  # Reduced from 2e-5
        
        # Setup multi-GPU
        if torch.cuda.device_count() > 1:
            print(f"Using {torch.cuda.device_count()} GPUs!")
            self.model.question_encoder = nn.DataParallel(self.model.question_encoder)
            self.model.document_encoder = nn.DataParallel(self.model.document_encoder)
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        
        # Implement gradual unfreezing
        self.unfreeze_layers = 0  # Start with all layers frozen
        self._freeze_layers()
        
        # metrics tracking
        self.best_mrr = 0.0
        self.best_recall = 0.0
        
        # Add new attributes for step-based unfreezing
        self.total_steps = 0
        self.unfreeze_schedule = None  # Will be set in train()

    def _freeze_layers(self):
        """Freeze/unfreeze layers gradually during training"""
        # First freeze all layers
        for param in self.model.question_encoder.parameters():
            param.requires_grad = False
        for param in self.model.document_encoder.parameters():
            param.requires_grad = False
            
        def unfreeze_model_layers(model, num_layers):
            # Always unfreeze the pooler and final layer
            if isinstance(model, nn.DataParallel):
                model = model.module
            
            # Unfreeze pooler
            for param in model.pooler.parameters():
                param.requires_grad = True
            
            # Always keep the final layer unfrozen
            if num_layers == 0:
                for param in model.encoder.layer[-1].parameters():
                    param.requires_grad = True
                return
                
            # Unfreeze specified number of layers from the top
            for layer in list(model.encoder.layer)[-num_layers:]:
                for param in layer.parameters():
                    param.requires_grad = True
                
        # Apply unfreezing to both encoders
        unfreeze_model_layers(self.model.question_encoder, self.unfreeze_layers)
        unfreeze_model_layers(self.model.document_encoder, self.unfreeze_layers)

    def set_scores(self, scores: Tuple):
        self.best_mrr = scores[0]
        self.best_recall = scores[1]
        
    def prepare_batch(self, batch: Dict[str, List[str]]) -> Dict[str, torch.Tensor]:
        # tokenize questions
        questions_tokenized = self.model.tokenizer(
            batch['question'],
            padding=True,
            truncation=True,
            max_length=self.config.max_length,
            return_tensors='pt'
        ).to(self.device)
        
        #Tokenize contexts
        contexts_tokenized = self.model.tokenizer(
            batch['context'],
            padding=True,
            truncation=True,
            max_length=self.config.max_length,
            return_tensors='pt'
        ).to(self.device)
        
        return {
            'question_data': questions_tokenized,
            'context_data': contexts_tokenized
        }
    
    def compute_loss(self, q_embeddings: torch.Tensor, d_embeddings: torch.Tensor,
                    hard_negative_embeddings: torch.Tensor = None) -> torch.Tensor:
        """Compute loss with both in-batch and hard negatives"""
        #regular in-batch negative loss
        similarity = torch.matmul(q_embeddings, d_embeddings.t())
        
        if hard_negative_embeddings is not None:
            # hard negative similarities
            hard_similarity = torch.matmul(q_embeddings, hard_negative_embeddings.t())
            similarity = torch.cat([similarity, hard_similarity], dim=1)
        
        # scale by temperature
        similarity = similarity / self.config.temperature
        
        # Create labels for diagonal (positive pairs)
        labels = torch.arange(q_embeddings.size(0)).to(self.device)
        
        # Compute loss
        loss = nn.CrossEntropyLoss()(similarity, labels)
        
        return loss
    
    def evaluate(self, val_dataset: LegalDataset) -> Dict[str, float]:
        """
        Evaluate the model on validation dataset
        Returns MRR@k and Recall@k metrics
        """
        self.model.eval()
        val_dataloader = DataLoader(
            val_dataset,
            batch_size=self.config.batch_size,
            shuffle=False
        )
        
        all_q_embeddings = []
        all_d_embeddings = []
        
        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc="Evaluating"):
                processed_batch = self.prepare_batch(batch)
                
                # Get embeddings
                q_embeddings = self.model.mean_pooling(
                    self.model.question_encoder(**processed_batch['question_data']),
                    processed_batch['question_data']['attention_mask']
                )
                d_embeddings = self.model.mean_pooling(
                    self.model.document_encoder(**processed_batch['context_data']),
                    processed_batch['context_data']['attention_mask']
                )
                
                # Normalize
                q_embeddings = nn.functional.normalize(q_embeddings, p=2, dim=1)
                d_embeddings = nn.functional.normalize(d_embeddings, p=2, dim=1)
                
                all_q_embeddings.append(q_embeddings)
                all_d_embeddings.append(d_embeddings)
        
        # Concatenate all embeddings
        all_q_embeddings = torch.cat(all_q_embeddings, dim=0)
        all_d_embeddings = torch.cat(all_d_embeddings, dim=0)
        
        # Compute similarity matrix
        similarity = torch.matmul(all_q_embeddings, all_d_embeddings.t())
        
        # Calculate metrics
        k_values = [1, 5, 10, 50,100, 200, 500, 1000]
        metrics = {}
        
        for k in k_values:
            # Get top-k indices
            _, indices = similarity.topk(k, dim=1)
            
            # Calculate Recall@k
            correct = torch.arange(similarity.size(0)).unsqueeze(1).expand_as(indices).to(self.device)
            recall_at_k = (indices == correct).float().sum(dim=1).mean().item()
            metrics[f'recall@{k}'] = recall_at_k
            
            # Calculate MRR@k
            rank = (indices == correct).nonzero()[:, 1] + 1
            mrr = (1.0 / rank).mean().item()
            metrics[f'mrr@{k}'] = mrr
            
        return metrics

    def train(self, train_dataset: LegalDataset, val_dataset: LegalDataset = None):
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True
        )
        
        # Calculate total steps and set unfreeze schedule
        total_steps = len(train_dataloader) * self.config.num_epochs
        steps_per_epoch = len(train_dataloader)
        
        # Adjust unfreeze schedule to be more frequent within the epoch
        self.unfreeze_schedule = {
            steps_per_epoch // 4: 2,     # Unfreeze top 2 layers after 25% steps
            steps_per_epoch // 2: 4,     # Unfreeze top 4 layers after 50% steps
            3 * steps_per_epoch // 4: 6,  # Unfreeze top 6 layers after 75% steps
            9 * steps_per_epoch // 10: 8  # Unfreeze top 8 layers after 90% steps
        }
        
        # Use different optimizers for frozen/unfrozen parameters
        def get_optimizer():
            params = []
            for model in [self.model.question_encoder, self.model.document_encoder]:
                params.extend([p for p in model.parameters() if p.requires_grad])
            return torch.optim.AdamW(
                params,
                lr=self.config.learning_rate,
                weight_decay=0.01
            )
        
        optimizer = get_optimizer()
        
        # Add learning rate scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 
            T_max=total_steps
        )
        
        # Adjust mining frequency to occur multiple times within epoch
        mine_every_n_steps = steps_per_epoch // 2  # Mine 4 times per epoch
        
        for epoch in range(self.config.num_epochs):
            self.model.train()
            total_loss = 0
            progress_bar = tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{self.config.num_epochs}')
            
            hard_negatives = None
            current_epoch_step = 0
            
            for batch_idx, batch in enumerate(progress_bar):
                current_epoch_step = batch_idx
                
                # Check if it's time to mine hard negatives
                if current_epoch_step % mine_every_n_steps == 0 and current_epoch_step != 0:
                    print(f"\nMining hard negatives at step {current_epoch_step}...")
                    questions = train_dataset.questions
                    contexts = train_dataset.contexts
                    hard_negatives = self.mine_hard_negatives(questions, contexts)
                
                # Check unfreeze schedule based on current epoch step
                if current_epoch_step in self.unfreeze_schedule:
                    print(f"\nUnfreezing layers at step {current_epoch_step}...")
                    self.unfreeze_layers = self.unfreeze_schedule[current_epoch_step]
                    self._freeze_layers()
                    optimizer = get_optimizer()  # Reinitialize optimizer with new trainable params
                
                optimizer.zero_grad()
                
                #Prepare batch data
                processed_batch = self.prepare_batch(batch)
                
                #Get embeddings from separate encoders
                q_embeddings = self.model.mean_pooling(
                    self.model.question_encoder(**processed_batch['question_data']),
                    processed_batch['question_data']['attention_mask']
                )
                d_embeddings = self.model.mean_pooling(
                    self.model.document_encoder(**processed_batch['context_data']),
                    processed_batch['context_data']['attention_mask']
                )
                
                #Process hard negatives if available
                hard_negative_embeddings = None
                if hard_negatives is not None:
                    hard_negative_batch = self.model.tokenizer(
                        hard_negatives[batch_idx:batch_idx + len(batch)],
                        padding=True,
                        truncation=True,
                        max_length=self.config.max_length,
                        return_tensors='pt'
                    ).to(self.device)
                    
                    hard_negative_embeddings = self.model.mean_pooling(
                        self.model.document_encoder(**hard_negative_batch),
                        hard_negative_batch['attention_mask']
                    )
                    hard_negative_embeddings = nn.functional.normalize(hard_negative_embeddings, p=2, dim=1)
                
                #Normalize embeddings
                q_embeddings = nn.functional.normalize(q_embeddings, p=2, dim=1)
                d_embeddings = nn.functional.normalize(d_embeddings, p=2, dim=1)
                
                #Compute loss with hard negatives
                loss = self.compute_loss(q_embeddings, d_embeddings, hard_negative_embeddings)
                
                #Backward pass
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                progress_bar.set_postfix({'loss': total_loss / (progress_bar.n + 1)})
            
            avg_loss = total_loss / len(train_dataloader)
            print(f'Epoch {epoch + 1}/{self.config.num_epochs}, Average Loss: {avg_loss:.4f}')
            
            # Validation step
            if val_dataset is not None:
                metrics = self.evaluate(val_dataset)
                print("Validation metrics:")
                for metric_name, value in metrics.items():
                    print(f"{metric_name}: {value:.4f}")
                
                # Save best model
                if metrics['mrr@10'] > self.best_mrr:
                    self.best_mrr = metrics['mrr@10']
                    self.save_model('best_model.pt')
    
    def save_model(self, path: str):
        # Save both encoders and config
        torch.save({
            'question_encoder': self.model.question_encoder.state_dict(),
            'document_encoder': self.model.document_encoder.state_dict(),
            'config': self.config
        }, path)
    
    def load_model(self, path: str):
        checkpoint = torch.load(path)
        print
        question_state_dict = checkpoint['question_encoder']
        document_state_dict = checkpoint['document_encoder']
        
        # Remove 'module.' prefix if it exists and model is not using DataParallel
        if not isinstance(self.model.question_encoder, nn.DataParallel):
            print("remove module.")
            question_state_dict = {k.replace('module.', ''): v for k, v in question_state_dict}
            document_state_dict = {k.replace('module.', ''): v for k, v in document_state_dict}
        # Add 'module.' prefix if model is using DataParallel but saved model wasn't
        elif not any(k.startswith('module.') for k in question_state_dict):
            print("add module.")
            question_state_dict = {'module.' + k: v for k, v in question_state_dict}
            document_state_dict = {'module.' + k: v for k, v in document_state_dict}
        
        # Load the state dictionaries
        try:
            self.model.question_encoder.load_state_dict(question_state_dict)
            self.model.document_encoder.load_state_dict(document_state_dict)
        except RuntimeError as e:
            print(f"Error loading state dict: {e}")
            print("Attempting alternative loading method...")
            
            # If the first attempt fails, try the opposite approach
            if isinstance(self.model.question_encoder, nn.DataParallel):
                question_state_dict = {k.replace('module.', ''): v for k in question_state_dict}
                document_state_dict = {k.replace('module.', ''): v for k in document_state_dict}
            else:
                question_state_dict = {'module.' + k: v for k in question_state_dict}
                document_state_dict = {'module.' + k: v for k in document_state_dict}
            
            self.model.question_encoder.load_state_dict(question_state_dict)
            self.model.document_encoder.load_state_dict(document_state_dict)

    def mine_hard_negatives(self, questions: List[str], documents: List[str], 
                       batch_size: int = 256) -> List[str]:  # Reduced batch size
        """Mine hard negatives using embeddings similarity"""
        self.model.eval()
        device = self.device
    
        # Significantly reduced chunk sizes to avoid OOM
        chunk_size = 2000  # Reduced from 5000
        similarity_chunk_size = 200  # Reduced from 500
    
        with torch.no_grad():
            # Process questions in smaller chunks
            all_q_embeddings = []
            for i in tqdm(range(0, len(questions), chunk_size), desc="Encoding questions"):
                q_chunk = questions[i:i + chunk_size]
                q_emb = self.model.encode_question(q_chunk, batch_size)
                all_q_embeddings.append(q_emb.cpu())  # Move to CPU after processing
                torch.cuda.empty_cache()
            q_embeddings = torch.cat(all_q_embeddings, dim=0)
        
            # Process documents in smaller chunks
            all_d_embeddings = []
            for i in tqdm(range(0, len(documents), chunk_size), desc="Encoding documents"):
                d_chunk = documents[i:i + chunk_size]
                d_emb = self.model.encode_document(d_chunk, batch_size)
                all_d_embeddings.append(d_emb.cpu())  # Move to CPU after processing
                torch.cuda.empty_cache()
            d_embeddings = torch.cat(all_d_embeddings, dim=0)
            
            # Normalize embeddings (on CPU to save GPU memory)
            q_embeddings = nn.functional.normalize(q_embeddings, p=2, dim=1)
            d_embeddings = nn.functional.normalize(d_embeddings, p=2, dim=1)
            
            # Compute similarity in smaller chunks
            hard_negative_indices = []
            k = 2  # Reduced number of hard negatives per question
            
            for i in tqdm(range(0, len(q_embeddings), similarity_chunk_size), desc="Mining negatives"):
                # Move only the current chunks to GPU
                q_chunk = q_embeddings[i:i + similarity_chunk_size].to(device)
                d_chunk = d_embeddings.to(device)
                
                similarity = torch.matmul(q_chunk, d_chunk.t())
                
                # Get top-k most similar but incorrect documents
                values, indices = similarity.topk(k + 1, dim=1)
                
                # Filter out positive pairs
                mask = torch.arange(i, min(i + similarity_chunk_size, len(q_embeddings))).unsqueeze(1).expand_as(indices).to(device)
                chunk_negative_indices = indices[indices != mask].view(-1)
                hard_negative_indices.extend(chunk_negative_indices.cpu().numpy())
                
                # Free memory
                del similarity, values, indices, q_chunk, d_chunk
                torch.cuda.empty_cache()
        
        return [documents[idx] for idx in hard_negative_indices]

In [None]:
# import pandas as pd

# original_df = pd.read_csv("../input/soict-dataset-2024-segmented/train_segmented.csv")
# original_df.head(), original_df.shape

In [None]:
# from sklearn.model_selection import train_test_split

# train_set, test_set = train_test_split(original_df, test_size=0.1, random_state=42)
# train_set.shape, test_set.shape

In [None]:
# # Create datasets
# tokenizer = AutoTokenizer.from_pretrained("bkai-foundation-models/vietnamese-bi-encoder")

# train_dataset = LegalDataset(
#         questions=train_set['question'].tolist(),
#         contexts=train_set['context'].tolist(),
#         tokenizer=tokenizer,
#         max_length=432
# )
    
# val_dataset = LegalDataset(
#         questions=test_set['question'].tolist(),
#         contexts=test_set['context'].tolist(),
#         tokenizer=tokenizer,
#         max_length=432
# )

In [None]:
# #Initialize config
# config = BiEncoderConfig(
#     max_length=256,
#     batch_size=16,
#     learning_rate=1e-5,
#     num_epochs=2,
#     temperature=0.05
# )

In [None]:
# Initialize trainer
# trainer = BiEncoderTrainer(config)
    
# # Train model with validation
# trainer.train(train_dataset, val_dataset)

In [None]:
# # Initialize trainer
#base_model = BiEncoderTrainer(config)
#base_model.load_model('best_model.pt')
#base_model.evaluate(val_dataset)

In [None]:
# # Initialize model path
# model_path = "../input/bi-encoder-xlm-roberta/pytorch/default/1/best_model.pt"


# trained_model = BiEncoderTrainer(config)
# trained_model.load_model(model_path)
# trained_model.evaluate(val_dataset)

## STAGE 2 - Cross-encoder (Reranking Stage):

In [None]:
from sentence_transformers import CrossEncoder
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sentence_transformers.readers import InputExample
from sentence_transformers.cross_encoder.evaluation import CEBinaryAccuracyEvaluator
from sentence_transformers import LoggingHandler
import logging
import pandas as pd


class LoggingCallback:
    def __init__(self):
        self.current_step = 0

    def __call__(self, score: float, epoch: int, steps: int):
        self.current_step += 1
        print(f'Step: {self.current_step}, Epoch: {epoch}, Loss: {score:.4f}')


class CrossEncoderWrapper:
    def __init__(
        self,
        model_name: str = "bkai-foundation-models/vietnamese-bi-encoder",
        max_length: int = 256,
        batch_size: int = 16,
    ):
        self.model_name = model_name
        self.batch_size = batch_size
        self.max_length = max_length
        
        # Explicitly set device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Initialize model with explicit device
        self.model = CrossEncoder(
            model_name,
            max_length=max_length,
            device=self.device
        )

    def train(self, train_dataset: Dataset, val_dataset: Dataset = None,
              num_epochs: int = 1, warmup_ratio: float = 0.02):
        """Train the cross-encoder model"""

       # Setup logging
        logging.basicConfig(format='%(asctime)s - %(message)s',
                            datefmt='%Y-%m-%d %H:%M:%S',
                            level=logging.INFO,
                            handlers=[LoggingHandler()])

        # Prepare training examples
        train_samples = [
            InputExample(texts=[sample['question'], sample['document']],
                         label=sample['label'])
            for sample in train_dataset
        ]

        # Create data loader
        train_dataloader = DataLoader(
            train_samples,
            shuffle=True,
            batch_size=self.batch_size
        )

        # Create evaluator for validation
        if val_dataset is not None:
            # Prepare validation data in the required format
            val_samples = [
                InputExample(texts=[sample['question'], sample['document']],
                             label=sample['label'])
                for sample in val_dataset
            ]

            evaluator = CEBinaryAccuracyEvaluator.from_input_examples(
                val_samples, name="evaluation")
        else:
            evaluator = None
        # Create callback
        callback = LoggingCallback()

        # Train the model
        self.model.fit(
            train_dataloader=train_dataloader,
            evaluator=evaluator,
            epochs=num_epochs,
            evaluation_steps=2000,  # Evaluate every 2000 steps
            warmup_steps=int(len(train_dataset) * warmup_ratio),
            show_progress_bar=True,
            output_path='checkpoints',  # Save checkpoints
            # save_best_model=True,
            callback=callback  # Add the callback
        )

    def evaluate(self, dataset: Dataset) -> Dict[str, float]:
        """Evaluate the model on a dataset"""
        pairs = [
            [sample['question'], sample['document']]
            for sample in dataset
        ]
        labels = [sample['label'] for sample in dataset]

        # Get predictions
        scores = self.model.predict(pairs)
        predictions = (scores > 0.5).astype(int)

        # Calculate metrics
        metrics = {
            'accuracy': accuracy_score(labels, predictions),
            'precision': precision_score(labels, predictions),
            'recall': recall_score(labels, predictions),
            'f1': f1_score(labels, predictions),
            'auc_roc': roc_auc_score(labels, scores)
        }

        return metrics

    def predict(self, questions: List[str], documents: List[str], show_progress_bar=True) -> np.ndarray:
        """Get relevance scores for question-document pairs"""
        try:
            # Ensure model is in eval mode
            self.model.model.eval()
            
            # Create pairs
            pairs = [[q, d] for q, d in zip(questions, documents)]

            # Make prediction with smaller batch size
            return self.model.predict(
                pairs,
                batch_size=16,  # Smaller batch size
                show_progress_bar=show_progress_bar
            )
            
        except Exception as e:
            print(f"Error during prediction: {str(e)}")
            raise

    def save_model(self, path: str):
        self.model.save(path)

    def load_model(self, path: str):
        self.model = CrossEncoder(
            path,
            max_length=self.max_length,
            device=self.device
        )


class CrossEncoderDataset(Dataset):
    def __init__(self, questions: List[str], documents: List[str], labels: List[int]):
        self.questions = questions
        self.documents = documents
        self.labels = labels
        assert len(questions) == len(documents) == len(
            labels), "All inputs must have the same length"

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        return {
            'question': self.questions[idx],
            'document': self.documents[idx],
            'label': self.labels[idx]
        }


In [None]:
#This is to load saved dataset (the result of bi-encoder)
df = pd.read_csv("../input/cross-encoder-dataset-segmented/train_data.csv")
df = df.drop_duplicates()
df = df.sample(frac=1)
df.drop(columns=['cid', 'Unnamed: 0'], inplace=True)

df.head(), df.shape

In [None]:
# We using most of the dataset for train because 10000 val sample is enough
train_data = df[:230000]
val_data = df[230000:240000]
test_data = df[240000:]

In [None]:
val_data # Show the data sample

In [None]:
# Create train, val and test dataset
train_dataset = CrossEncoderDataset(
        questions=train_data['question'].tolist(),
        documents=train_data['document'].tolist(),
        labels=train_data['label'].tolist(),
    )

val_dataset = CrossEncoderDataset(
        questions=val_data['question'].tolist(),
        documents=val_data['document'].tolist(),
        labels=val_data['label'].tolist(),
    )

test_dataset = CrossEncoderDataset(
        questions=test_data['question'].tolist(),
        documents=test_data['document'].tolist(),
        labels=test_data['label'].tolist(),
    )



In [None]:
# Let's see what a data sample look like
val_dataset[4]

In [None]:
# Initialize the encoder (this might show javascript error but ignore it)
model = CrossEncoderWrapper(
    model_name="bkai-foundation-models/vietnamese-bi-encoder",
    max_length=256,
    batch_size=16
)

In [None]:
# Let's take a small section of the val dataset and evaluate them to see how an untrained
# cross-enncoder perform
val_query = [val_dataset[i]['question'] for i in range(100)]
val_documents = [val_dataset[i]['document'] for i in range(100)]
val_labels = [val_dataset[i]['label']for i in range(100)]
len(val_query), len(val_documents), len(val_labels)

In [None]:
import logging
logging.disable(logging.WARNING)

# Evaluate untrained cross-encoder
scores = model.predict(val_query, val_documents)
scores

In [None]:
# The prediction will usually be all 0 or all 1 
predictions = (scores > 0.5).astype(int)
predictions, np.array(val_labels)

In [None]:
# Let't evaluate the model to see how they perform, here the accuracy is 0.82 because
# 80% of the dataset is labled 0 
metrics = {
            'accuracy': accuracy_score(val_labels, predictions),
            'precision': precision_score(val_labels, predictions),
            'recall': recall_score(val_labels, predictions),
            'f1': f1_score(val_labels, predictions),
            'auc_roc': roc_auc_score(val_labels, scores)
        }
metrics

In [None]:
# logging.disable(logging.WARNING)

# # Initialize
# model = CrossEncoderWrapper(
#     model_name="bkai-foundation-models/vietnamese-bi-encoder",
#     max_length=256,
#     batch_size=16
# )

# # Train
# model.train(train_dataset, val_dataset, num_epochs=1)

# # Evaluate
# metrics = model.evaluate(val_dataset)

In [None]:
# Let's load test dataset
test_query = [test_dataset[i]['question'] for i in range(len(test_dataset))]
test_documents = [test_dataset[i]['document'] for i in range(len(test_dataset))]
test_labels = [test_dataset[i]['label']for i in range(len(test_dataset))]
len(test_query), len(test_documents), len(test_labels)

In [None]:
model = CrossEncoderWrapper(
    model_name="bkai-foundation-models/vietnamese-bi-encoder",
    max_length=256,
    batch_size=16
)

# Load trained model and predict on 2000 sample of test dataset
model.load_model('/kaggle/input/cross_encoder_new/pytorch/default/1')
scores = model.predict(test_query[:2000], test_documents[:2000])

# show 100 first prediction
scores[:100]

In [None]:
# Some example prediction
predictions = (scores > 0.5).astype(int)
predictions[0:20], np.array(test_labels)[0:20]

In [None]:
# Here is the metrics after training
metrics = {
            'accuracy': accuracy_score(test_labels[:2000], predictions),
            'precision': precision_score(test_labels[:2000], predictions),
            'recall': recall_score(test_labels[:2000], predictions),
            'f1': f1_score(test_labels[:2000], predictions),
            'auc_roc': roc_auc_score(test_labels[:2000], scores)
        }
metrics

### Retrieval Pipeline

In [None]:
class RetrievalPipeline:
    def __init__(self, 
                 bi_encoder_path: str,
                 cross_encoder_path: str,
                 faiss_index_path: str,
                 #embeddings_path: str,
                 corpus_df_path: str,
                 top_k: int = 50,
                 rerank_k: int = 10):
        """
        Initialize retrieval pipeline
        Args:
            bi_encoder_path: Path to bi encoder checkpoint
            cross_encoder_path: Path to cross encoder checkpoint
            faiss_index_path: Path to saved FAISS index
            embeddings_path: Path to saved document embeddings
            corpus_df_path: Path to corpus CSV file
            top_k: Number of candidates to retrieve from FAISS
            rerank_k: Number of final results after reranking
        """
        # Load models
        self.bi_encoder = BiEncoder(BiEncoderConfig())
        self.cross_encoder = CrossEncoderWrapper()

        #Load model
        self.load_models(bi_encoder_path, cross_encoder_path)
        
        # Load FAISS index and embeddings
        self.index = faiss.read_index(faiss_index_path)
        #self.document_embeddings = np.load(embeddings_path)
        
        # Load corpus for retrieving text
        self.corpus_df = pd.read_csv(corpus_df_path)
        
        # Configuration
        self.top_k = top_k
        self.rerank_k = rerank_k
        
        # Move models to GPU if available
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.bi_encoder.to(self.device)
        #self.cross_encoder.model.to(torch.device("cuda"))

    def load_models(self, bi_encoder_path: str, cross_encoder_path: str):
        bi_encoder_checkpoint = torch.load(bi_encoder_path)
        question_state_dict = bi_encoder_checkpoint['question_encoder']
        document_state_dict = bi_encoder_checkpoint['document_encoder']

        question_state_dict = {k.replace('module.', ''): v for k, v in question_state_dict.items()}
        document_state_dict = {k.replace('module.', ''): v for k, v in document_state_dict.items()}
                     
        
        self.bi_encoder.document_encoder.load_state_dict(
        document_state_dict)
        self.bi_encoder.question_encoder.load_state_dict(
        question_state_dict)
        #self.bi_encoder.config = bi_encoder_checkpoint['config']
        
        self.cross_encoder.load_model(cross_encoder_path)


    def retrieve(self, query: str, rerank: bool = True) -> List[Dict[str, Union[str, float, int]]]:
        """
        Retrieve and rerank documents for a query
        Args:
            query: Question string
        Returns:
            List of dicts containing retrieved documents with scores and metadata
        """
        # Clear cache
        torch.cuda.empty_cache()

        # Stage 1: Bi-encoder retrieval
        query_embedding = self.bi_encoder.encode_question([query])
        
        query_embedding = query_embedding.cpu().numpy()
        faiss.normalize_L2(query_embedding)
        
        # Search FAISS index
        scores, doc_indices = self.index.search(query_embedding, self.top_k)
        
        # Get candidate documents
        candidates = []
        for score, doc_idx in zip(scores[0], doc_indices[0]):
            candidates.append({
                'text': self.corpus_df.iloc[doc_idx]['text'],
                'cid': self.corpus_df.iloc[doc_idx]['cid'],
                'bi_encoder_score': float(score)
            })
        if rerank:
            # Stage 2: Cross-encoder reranking
            candidate_texts = [c['text'] for c in candidates]
            #self.bi_encoder.to(torch.device("cuda"))
            batch_size = 16  # Adjust this based on your GPU memory capacity
            cross_encoder_scores = []

            for i in range(0, len(candidate_texts), batch_size):
                batch_texts = candidate_texts[i:i + batch_size]
                batch_scores = self.cross_encoder.predict(
                        [query] * len(batch_texts),
                        batch_texts,
                        show_progress_bar=False
                    )
                cross_encoder_scores.extend(batch_scores)
            #self.cross_encoder.to(torch.device("cpu"))
            #Add cross-encoder scores
            for idx, score in enumerate(cross_encoder_scores):
                candidates[idx]['cross_encoder_score'] = float(score)
                
            candidates.sort(key=lambda x: x['cross_encoder_score'], reverse=True)
            results = candidates[:self.rerank_k]
        
            return results

        else:
            candidates.sort(key=lambda x: x['bi_encoder_score'], reverse=True)
            results = candidates[:self.rerank_k]

            return results

    def batch_retrieve(self, 
                      queries: List[str],
                      batch_size: int = 32) -> List[List[Dict[str, Union[str, float, int]]]]:
        """
        Batch retrieval for multiple queries
        """
        all_results = []
        for i in range(0, len(queries), batch_size):
            batch = queries[i:i + batch_size]
            batch_results = [self.retrieve(q) for q in batch]
            all_results.extend(batch_results)
        return all_results
    
    
def evaluate_retrieval(pipeline, test_df, top_k=10, re_ranking=True):
    """
    Evaluate the retrieval pipeline on a test set using MAP, MRR, NDCG, and Recall@k.
    
    Args:
        pipeline: The retrieval pipeline instance.
        test_df: DataFrame containing 'question' and 'cid' columns.
        top_k: Number of top results to consider for evaluation.
        
    Returns:
        A dictionary with MAP, MRR, NDCG, and Recall@k scores.
    """
    def parse_cid_string(cid_str):
        """Parse a space-separated string of CIDs into a set of integers."""
        return set(map(int, cid_str.strip('[]').split()))

    true_cids = test_df['cid'].apply(parse_cid_string)
    questions = test_df['question'].tolist()
    
    average_precisions = []
    reciprocal_ranks = []
    ndcg_scores = []
    recall_at_k = []
    
    for question, true_cid_set in tqdm(zip(questions, true_cids), total=len(questions), desc="Evaluating"):
        # Retrieve documents
        results = pipeline.retrieve(question, re_ranking)
        
        # Get retrieved cids
        retrieved_cids = [result['cid'] for result in results[:top_k]]
        
        # Calculate Average Precision
        num_relevant = 0
        precision_sum = 0.0
        for i, cid in enumerate(retrieved_cids):
            if cid in true_cid_set:
                num_relevant += 1
                precision_sum += num_relevant / (i + 1)
        average_precision = precision_sum / len(true_cid_set) if true_cid_set else 0
        average_precisions.append(average_precision)
        
        # Calculate Reciprocal Rank
        reciprocal_rank = 0.0
        for i, cid in enumerate(retrieved_cids):
            if cid in true_cid_set:
                reciprocal_rank = 1.0 / (i + 1)
                break
        reciprocal_ranks.append(reciprocal_rank)
        
        # Calculate NDCG
        dcg = 0.0
        idcg = sum(1.0 / (i + 1) for i in range(min(len(true_cid_set), top_k)))
        for i, cid in enumerate(retrieved_cids):
            if cid in true_cid_set:
                dcg += 1.0 / (i + 1)
        ndcg = dcg / idcg if idcg > 0 else 0
        ndcg_scores.append(ndcg)
        
        # Calculate Recall@k
        relevant_retrieved = len(set(retrieved_cids) & true_cid_set)
        recall = relevant_retrieved / len(true_cid_set) if true_cid_set else 0
        recall_at_k.append(recall)
    
    # Calculate average metrics
    map_score = sum(average_precisions) / len(average_precisions)
    mrr_score = sum(reciprocal_ranks) / len(reciprocal_ranks)
    avg_ndcg = sum(ndcg_scores) / len(ndcg_scores)
    avg_recall_at_k = sum(recall_at_k) / len(recall_at_k)
    
    return {
        'MAP': map_score,
        'MRR': mrr_score,
        'NDCG': avg_ndcg,
        'Recall@k': avg_recall_at_k
    }

In [None]:
bi_encoder_path = "/kaggle/input/bi_encoder/pytorch/default/1/bi_encoder.pt"
cross_encoder_path = "/kaggle/input/cross_encoder_new/pytorch/default/1"
faiss_index_path = "/kaggle/input/bi_encoder_embedding/pytorch/default/1/document_index.faiss"
corpus_df_path = "/kaggle/input/soict-dataset-2024-segmented/corpus_segmented.csv"

# Initialize pipeline
pipeline = RetrievalPipeline(
    bi_encoder_path=bi_encoder_path,
    cross_encoder_path=cross_encoder_path,
    faiss_index_path=faiss_index_path,
    #embeddings_path="path/to/embeddings.npy",
    corpus_df_path=corpus_df_path,
    #rerank_k=50
)

In [None]:
from underthesea import word_tokenize

def segment(text: str) -> str:
    return word_tokenize(text, format="text")

In [None]:
# Single query
import logging

logging.disable(logging.WARNING)

query = "Cơ sở cho thuê nhà trọ có phải đăng ký kinh doanh không?"
segment_query = segment(query)

results = pipeline.retrieve(segment_query)

In [None]:
# Here is the retrieved document, this is what we'll feed into the LLM (after some 
#processign of course)
results

In [None]:
# Let's evaluate the whole pipeline
from sklearn.model_selection import train_test_split

train_df = pd.read_csv("../input/soict-dataset-2024-segmented/train_segmented.csv")
train_data, val_data = train_test_split(train_df, test_size=0.1, random_state=42)

In [None]:
val_data

In [None]:
%%time
# This is the evaluation of our pipeline without reranking, take about 8 minutes to run
metrics = evaluate_retrieval(pipeline, val_data, re_ranking=False)
print(metrics)

In [None]:
%%time

# Compare the metric on train and val data
metrics_on_train_data = evaluate_retrieval(pipeline, train_data[:1000], re_ranking=False, top_k=50)
print("metrics on train data ", metrics_on_train_data )
metrics_on_val_data = evaluate_retrieval(pipeline, val_data[:1000], re_ranking=False, top_k=50)
print("metrics on val data ", metrics_on_val_data )

In [None]:
%%time

# Evaluate on val data with cross-encoder reranking, this take 26 minutes to run!
# This is why we only use cross-encoder for reranking not for directly retrieval
# The improvement is impressive:
# MAP: 0.470 -> 0.559
# MRR: 0.489 -> 0.576
# Recall@k: 0.722 -> 0.771
metrics = evaluate_retrieval(pipeline, val_data[:1000], top_k=50, re_ranking=True)
print(metrics)

## Helper functions

### Create embedding for corpus

In [None]:
# import pandas as pd
# import torch
# import faiss
# import numpy as np
# from tqdm import tqdm
# from typing import Tuple

# def build_faiss_index(
#     corpus_path: str,
#     model: BiEncoder,
#     batch_size: int = 32,
#     index_path: str = "document_index.faiss",
#     embeddings_path: str = "document_embeddings.npy",
#     nlist: int = 1000,  # Number of clusters/cells
#     nprobe: int = 100,  # Number of cells to visit during search
# ) -> faiss.IndexIVFFlat:
    # """
    # Build optimized FAISS index from corpus documents
    
    # Args:
    #     corpus_path: Path to corpus CSV file
    #     model: Trained BiEncoder model
    #     batch_size: Batch size for encoding
    #     index_path: Where to save the FAISS index
    #     nlist: Number of Voronoi cells (clusters)
    #     nprobe: Number of nearest cells to search
    # """
    # print("Reading corpus...")
    # df = pd.read_csv(corpus_path)
    # documents = df['text'].tolist()
    
    # # Prepare model
    # model.eval()
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # model.to(device)

    # # Initialize lists to store batched embeddings
    # all_embeddings = []
    
    # # Process in batches to handle memory efficiently
    # print("Encoding documents...")
    
    # progress_bar = tqdm(total=int(len(documents)/batch_size),
    #                         desc="Encoding documents",
    #                         ncols=80,
    #                         position=0,  # Force position to 0
    #                         leave=True)  # Keep final result visible
    
    # with torch.no_grad():
    #     for i in range(0, len(documents), batch_size):
    #         batch = documents[i:i + batch_size]
    #         embeddings = model.encode_document(batch, disable_progress_bar= True )
    #         all_embeddings.append(embeddings.cpu())
    #         progress_bar.update(1)
    
    # # Concatenate all embeddings
    # document_embeddings = torch.cat(all_embeddings, dim=0).numpy()
    # dimension = document_embeddings.shape[1]

    # # Normalize embeddings
    # faiss.normalize_L2(document_embeddings)

    # # Create GPU resource
    # res = faiss.StandardGpuResources()
    
    # # Configure index
    # quantizer = faiss.IndexFlatIP(dimension)
    # index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_INNER_PRODUCT)
    
    # # Transfer to GPU for training and adding vectors
    # gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
    
    # # Train index
    # print("Training index...")
    # gpu_index.train(document_embeddings)
    
    # # Add vectors to index
    # print("Adding vectors to index...")
    # gpu_index.add(document_embeddings)
    
    # # Set number of cells to probe during search
    # gpu_index.nprobe = nprobe
    
    # # Transfer back to CPU for saving
    # index = faiss.index_gpu_to_cpu(gpu_index)
    
    # # Save index
    # print("Saving index...")
    # faiss.write_index(index, index_path)
    
    # #  Save Embeddings
    # print("Saving embeddings...")
    # np.save(embeddings_path, document_embeddings)
    
    # print(f"Index built with {index.ntotal} vectors of dimension {dimension}")
    # print(f"Number of clusters: {nlist}, nprobe: {nprobe}")
    # return index, document_embeddings



In [None]:
# config = BiEncoderConfig(
#     max_length=256,
#     batch_size=16,
#     learning_rate=1e-5,
#     num_epochs=2,
#     temperature=0.05
# )

# trainer = BiEncoderTrainer(config)
# trainer.load_model("/kaggle/input/bi_encoder/pytorch/default/1/bi_encoder.pt")
# model = trainer.get_model()

In [None]:
# index, document_embeddings = build_faiss_index(
#                 corpus_path= "../input/soict-dataset-2024-segmented/corpus_segmented.csv",
#                 model = model
#             )

In [None]:
# def prepare_cross_encoder_data(
#     train_df: pd.DataFrame,
#     bi_encoder_model: BiEncoder,
#     corpus_df: pd.DataFrame,
#     num_hard_negatives: int = 3,
#     num_random_negatives: int = 1,
#     batch_size: int = 32,
#     embeddings_path: str = "document_embeddings.npy"
# ) -> pd.DataFrame:
#     """
#     Prepare training data for cross-encoder with multiple types of negatives using pre-computed document embeddings
    
#     Args:
    #     train_df: Training dataframe with questions and ground truths
    #     bi_encoder_model: Trained bi-encoder model for finding hard negatives
    #     corpus_df: Full corpus dataframe
    #     num_hard_negatives: Number of hard negatives per question
    #     num_random_negatives: Number of random negatives per question
    #     batch_size: Batch size for bi-encoder inference
    #     embeddings_path: Path to saved document embeddings
    # """
    # training_pairs = []
    
    # # Create corpus lookup
    # corpus_lookup = dict(zip(corpus_df['cid'], corpus_df['text']))
    
    
    # def parse_cids(cid_str: str) -> List[int]:
    #     """Parse space-separated CIDs"""
    #     if isinstance(cid_str, str):
    #         # Remove brackets and split by whitespace
    #         return [int(cid) for cid in cid_str.strip('[]').split()]
    #     return cid_str
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # # Load pre-computed document embeddings
    # print("Loading pre-computed document embeddings...")
    # d_embeddings = torch.from_numpy(np.load(embeddings_path)).to(device)

    # # Process questions in batches with accurate progress tracking
    # training_pairs = []
    # total_questions = len(train_df)
    
    # with tqdm(total=total_questions, desc="Processing questions") as pbar:
    #     for start_idx in range(0, len(train_df), batch_size):
    #         batch_df = train_df.iloc[start_idx:start_idx + batch_size]
            
    #         # Get question embeddings for the batch
    #         questions = batch_df['question'].tolist()
    #         q_embeddings = bi_encoder_model.encode_question(questions)
            
    #         # Compute similarities
    #         similarities = torch.matmul(q_embeddings, d_embeddings.t())
            
            # # Process each question in the batch
            # for batch_idx, (_, row) in enumerate(batch_df.iterrows()):
            #     question = row['question']
                
            #     # Parse context and CIDs
            #     correct_docs = row['context']
            #     correct_cids = parse_cids(row['cid'])
                
            #     # Add positive pairs
            #     for doc, cid in zip([correct_docs] * len(correct_cids), correct_cids):
            #         training_pairs.append({
            #             'question': question,
            #             'document': doc,
            #             'label': 1,
            #             'cid': int(cid)
            # #         })
                
            #     # Get hard negatives
            #     q_sim = similarities[batch_idx]
            #     _, candidate_indices = q_sim.topk(num_hard_negatives + len(correct_cids))
                
            #     # Filter out correct documents
            #     hard_negative_indices = [
            #         idx.item() for idx in candidate_indices 
            #         if corpus_df.iloc[int(idx.item())]['cid'] not in correct_cids
            #     ][:num_hard_negatives]
                
                # # Add hard negative pairs
                # for neg_idx in hard_negative_indices:
                #     neg_doc = corpus_df.iloc[int(neg_idx)]
                #     training_pairs.append({
                #         'question': question,
                #         'document': neg_doc['text'],
                #         'label': 0,
                #         'cid': neg_doc['cid']
                #     })
                
                # # Add random negatives
                # random_negative_cids = np.random.choice(
                #     [cid for cid in corpus_df['cid'] if cid not in correct_cids],
                #     size=num_random_negatives,
                #     replace=False
                # )
                
    #             for neg_cid in random_negative_cids:
    #                 training_pairs.append({
    #                     'question': question,
    #                     'document': corpus_lookup[neg_cid],
    #                     'label': 0,
    #                     'cid': neg_cid
    #                 })
                
    #             # Update progress bar for each processed question
    #             pbar.update(1)
    # return pd.DataFrame(training_pairs)

In [None]:
# # Load data orginal dataset to create dataset for training cross encoder
# corpus_path = "../input/soict-dataset-2024-segmented/corpus_segmented.csv"
# train_path = "../input/soict-dataset-2024-segmented/train_segmented.csv"
# corpus_df = pd.read_csv(corpus_path)
# train_df = pd.read_csv(train_path)
# corpus_df.shape, train_df.shape

In [None]:
# # Split into train and validation
# val_size = 0.1
# train_data, val_data = train_test_split(train_df, test_size=val_size, random_state=42)
# train_data.shape

In [None]:
# # Prepare training data with negatives
# cross_encoder_data = prepare_cross_encoder_data(
#     train_data[:50000],
#     model,
#     corpus_df,
#     num_hard_negatives=3,
#     num_random_negatives=1,
#     embeddings_path = "/kaggle/input/bi_encoder_embedding/pytorch/default/1/document_embeddings.npy"
# )

In [None]:
#cross_encoder_data.to_csv('train.csv')