In [None]:
# Install Dependencies

!pip uninstall -y numpy
!pip install --force-reinstall numpy==1.26.4

!pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 sentence-transformers==2.2.2
!pip install pandas==2.0.0
!pip install transformers==4.41.0 scikit-learn==1.2.0
!pip install huggingface-hub==0.25.2
!pip install nltk==3.8.1 rouge-score==0.1.2 bert-score==0.3.13 -q
!pip install tqdm==4.66.5 -q

Found existing installation: numpy 2.0.2
Uninstalling numpy-2.0.2:
  Successfully uninstalled numpy-2.0.2
Collecting numpy==1.26.4
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.3/18.3 MB[0m [31m96.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.26.4 which is incompatible.[0m[31m
[0mSuccessfully installed numpy-1.26.4
Collecting torch==2.2.1
  Downloading torch-2.2.1-cp311-cp311-ma

In [None]:
# Setup and Imports

import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    BartForConditionalGeneration, BartTokenizer,
    DPRContextEncoder, DPRQuestionEncoder,
    DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer
)
from sentence_transformers import SentenceTransformer
from sklearn.model_selection import train_test_split
import os
from google.colab import drive
from tqdm import tqdm
from nltk.translate.bleu_score import sentence_bleu
from rouge_score import rouge_scorer
from bert_score import score as bert_score
import string
import nltk
import json
import pickle

nltk.download('wordnet')
nltk.download('punkt')

# Mount Google Drive
drive.mount('/content/drive')

# Configuration
class Config:
    BASE_PATH = "/content/drive/MyDrive/LJMU-Datasets"
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    BART_MODEL_NAME = "facebook/bart-base"
    DPR_CTX_MODEL_NAME = "facebook/dpr-ctx_encoder-single-nq-base"
    DPR_QUESTION_MODEL_NAME = "facebook/dpr-question_encoder-single-nq-base"
    BATCH_SIZE = 8
    MAX_EPOCHS = 3
    NUM_WORKERS = 4
    MAX_LENGTH = 256
    SUBSET_SIZE = 500
    HOTPOTQA_MAX_SAMPLES = 1000
    WIKIDATA_SUBSET_SIZE = 30000

CONFIG = Config()
print(f"Using device: {CONFIG.DEVICE}")

# Clear GPU memory
torch.cuda.empty_cache()

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


Mounted at /content/drive
Using device: cuda


In [None]:
# Load Artifacts from Step 2 and Create DataLoaders

import pickle

save_path = '/content/drive/MyDrive/bert_retrieval_artifacts_v4'

# Load raw datasets
qa_train_path_v4 = os.path.join(CONFIG.BASE_PATH, "qa_train_v4.csv")
qa_val_path_v4 = os.path.join(CONFIG.BASE_PATH, "qa_val_v4.csv")
triple_train_path_v4 = os.path.join(CONFIG.BASE_PATH, "triple_train_v4.csv")

qa_train_df_v4 = pd.read_csv(qa_train_path_v4)
qa_val_df_v4 = pd.read_csv(qa_val_path_v4)
triple_train_df_v4 = pd.read_csv(triple_train_path_v4)

# Split triple_train_df into train and validation sets (80/20)
triple_train_df_v4, triple_val_df_v4 = train_test_split(triple_train_df_v4, train_size=0.8, random_state=42)
print(f"Triple Train Size (Version 4): {len(triple_train_df_v4)}, Triple Val Size (Version 4): {len(triple_val_df_v4)}")

# Load all_candidates
with open(os.path.join(save_path, 'all_candidates_v4.pkl'), 'rb') as f:
    all_candidates = pickle.load(f)

# Define RetrievalDataset class
class RetrievalDataset(Dataset):
    def __init__(self, df: pd.DataFrame, bart_tokenizer: BartTokenizer, dpr_question_tokenizer: DPRQuestionEncoderTokenizer,
                 max_length: int = 256, task: str = "qa", candidate_objects: list = None):
        self.bart_tokenizer = bart_tokenizer
        self.dpr_question_tokenizer = dpr_question_tokenizer
        self.max_length = max_length
        self.task = task
        self.data = df
        self.candidate_objects = candidate_objects

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        question = row["question"]
        context = row["context"]
        answer = row["answer"]

        if self.task == "qa":
            bart_input_text = f"question: {question} context: {context}"
        else:
            bart_input_text = f"complete the triple with the exact object: {question} context: {context}"
        bart_inputs = self.bart_tokenizer(
            bart_input_text,
            return_tensors="pt",
            max_length=self.max_length,
            truncation=True,
            padding="max_length"
        )
        bart_labels = self.bart_tokenizer(
            answer,
            return_tensors="pt",
            max_length=self.max_length,
            truncation=True,
            padding="max_length"
        )

        dpr_inputs = self.dpr_question_tokenizer(
            question,
            return_tensors="pt",
            max_length=self.max_length,
            truncation=True,
            padding="max_length"
        )

        item = {
            "task": self.task,
            "bart_input_ids": bart_inputs["input_ids"].squeeze().long(),  # Ensure LongTensor
            "bart_attention_mask": bart_inputs["attention_mask"].squeeze().long(),
            "bart_labels": bart_labels["input_ids"].squeeze().long(),
            "dpr_input_ids": dpr_inputs["input_ids"].squeeze().long(),  # Ensure LongTensor
            "dpr_attention_mask": dpr_inputs["attention_mask"].squeeze().long(),
            "question": question,
            "answer": answer
        }

        if self.task == "triple" and self.candidate_objects:
            label_idx = self.candidate_objects.index(answer) if answer in self.candidate_objects else -1
            item["label_idx"] = label_idx

        return item

# Create DataLoaders from raw data
bart_tokenizer = BartTokenizer.from_pretrained(CONFIG.BART_MODEL_NAME)
dpr_question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(CONFIG.DPR_QUESTION_MODEL_NAME)

qa_train_dataset_v4 = RetrievalDataset(qa_train_df_v4, bart_tokenizer, dpr_question_tokenizer, task="qa", candidate_objects=all_candidates)
qa_val_dataset_v4 = RetrievalDataset(qa_val_df_v4, bart_tokenizer, dpr_question_tokenizer, task="qa", candidate_objects=all_candidates)
triple_train_dataset_v4 = RetrievalDataset(triple_train_df_v4, bart_tokenizer, dpr_question_tokenizer, task="triple", candidate_objects=all_candidates)
triple_val_dataset_v4 = RetrievalDataset(triple_val_df_v4, bart_tokenizer, dpr_question_tokenizer, task="triple", candidate_objects=all_candidates)

qa_train_loader_v4 = DataLoader(qa_train_dataset_v4, batch_size=CONFIG.BATCH_SIZE, shuffle=True, num_workers=CONFIG.NUM_WORKERS)
qa_val_loader_v4 = DataLoader(qa_val_dataset_v4, batch_size=CONFIG.BATCH_SIZE, shuffle=False, num_workers=CONFIG.NUM_WORKERS)
triple_train_loader_v4 = DataLoader(triple_train_dataset_v4, batch_size=CONFIG.BATCH_SIZE, shuffle=True, num_workers=CONFIG.NUM_WORKERS)
triple_val_loader_v4 = DataLoader(triple_val_dataset_v4, batch_size=CONFIG.BATCH_SIZE, shuffle=False, num_workers=CONFIG.NUM_WORKERS)

print(f"Created QA DataLoaders (Version 4): QA Train={len(qa_train_dataset_v4)}, QA Val={len(qa_val_dataset_v4)}")
print(f"Created Triple DataLoaders (Version 4): Triple Train={len(triple_train_dataset_v4)}, Triple Val={len(triple_val_dataset_v4)}")

# Load BART models and tokenizer
bart_qa_model = BartForConditionalGeneration.from_pretrained(CONFIG.BART_MODEL_NAME).to(CONFIG.DEVICE)
bart_qa_model.load_state_dict(torch.load(os.path.join(save_path, 'bart_qa_v4.pt')))
bart_tokenizer = BartTokenizer.from_pretrained(CONFIG.BART_MODEL_NAME)
bart_qa_model.eval()

# Load DPR models and tokenizers (post-fine-tuning from Step 2)
ctx_encoder_qa = DPRContextEncoder.from_pretrained(os.path.join(CONFIG.BASE_PATH, "dpr_ctx_encoder_qa_v4")).to(CONFIG.DEVICE)
question_encoder_qa = DPRQuestionEncoder.from_pretrained(os.path.join(CONFIG.BASE_PATH, "dpr_question_encoder_qa_v4")).to(CONFIG.DEVICE)
ctx_encoder_triple = DPRContextEncoder.from_pretrained(os.path.join(CONFIG.BASE_PATH, "dpr_ctx_encoder_triple_v4")).to(CONFIG.DEVICE)
question_encoder_triple = DPRQuestionEncoder.from_pretrained(os.path.join(CONFIG.BASE_PATH, "dpr_question_encoder_triple_v4")).to(CONFIG.DEVICE)
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(CONFIG.DPR_CTX_MODEL_NAME)
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(CONFIG.DPR_QUESTION_MODEL_NAME)
candidate_embeddings_qa = torch.load(os.path.join(save_path, 'dpr_candidate_embeddings_v4.pt')).to(CONFIG.DEVICE)
candidate_embeddings_triple = torch.load(os.path.join(save_path, 'dpr_candidate_embeddings_triple_v4.pt')).to(CONFIG.DEVICE)

# Load sentence transformer
sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')

print("Artifacts loaded from Step 2 and DataLoaders rebuilt for Version 4.")

Triple Train Size (Version 4): 20908, Triple Val Size (Version 4): 5227
Created QA DataLoaders (Version 4): QA Train=294, QA Val=27
Created Triple DataLoaders (Version 4): Triple Train=20908, Triple Val=5227


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizer'.


Artifacts loaded from Step 2 and DataLoaders rebuilt for Version 4.


In [None]:
# Define Helper Functions

# Normalize text for evaluation
def normalize_text(text: str) -> str:
    text = str(text).lower().strip()
    text = text.translate(str.maketrans("", "", string.punctuation))
    articles = {'a', 'an', 'the'}
    words = text.split()
    words = [word for word in words if word not in articles]
    return ' '.join(words)

# Compute BLEU score
def compute_bleu(generated: str, reference: str) -> float:
    return sentence_bleu([reference.split()], generated.split())

# Compute ROUGE-L score
def compute_rouge_l(generated: str, reference: str) -> float:
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    return scorer.score(reference, generated)['rougeL'].fmeasure

# Compute BERTScore
def compute_bertscore(generated: str, reference: str) -> float:
    return bert_score([generated], [reference], lang="en", verbose=False)[2].mean().item()

# Compute MRR for RL reward
def compute_mrr(generated_ranking: torch.Tensor, ref_idx: int) -> float:
    rank = (generated_ranking == ref_idx).nonzero(as_tuple=True)[0].item() + 1 if ref_idx in generated_ranking else len(generated_ranking)
    return 1.0 / rank

print("Helper functions defined.")

Helper functions defined.


In [None]:
# Define RL Helper Functions

# Normalize text for evaluation
def normalize_text(text: str) -> str:
    text = str(text).lower().strip()
    text = text.translate(str.maketrans("", "", string.punctuation))
    articles = {'a', 'an', 'the'}
    words = text.split()
    words = [word for word in words if word not in articles]
    return ' '.join(words)

# Compute BLEU score
def compute_bleu(generated: str, reference: str) -> float:
    return sentence_bleu([reference.split()], generated.split())

# Compute ROUGE-L score
def compute_rouge_l(generated: str, reference: str) -> float:
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    return scorer.score(reference, generated)['rougeL'].fmeasure

# Compute BERTScore
def compute_bertscore(generated: str, reference: str) -> float:
    return bert_score.score([generated], [reference], lang="en", verbose=False)[2].mean().item()

# Compute MRR for reward
def compute_mrr(generated_ranking: torch.Tensor, ref_idx: int) -> float:
    rank = (generated_ranking == ref_idx).nonzero(as_tuple=True)[0].item() + 1 if ref_idx in generated_ranking else len(generated_ranking)
    return 1.0 / rank

# Compute reward for RL
def compute_reward(generated_ranking: torch.Tensor, reference: str) -> float:
    candidates = all_candidates[:100]  # Use small candidate pool for efficiency
    ref_idx = candidates.index(reference) if reference in candidates else -1
    if ref_idx == -1:
        return 0.0
    mrr = compute_mrr(generated_ranking, ref_idx)
    exact_match_bonus = 0.5 if generated_ranking[0] == ref_idx else 0.0
    return mrr + exact_match_bonus

# Policy Network for DPR (simple MLP to adjust embeddings)
class PolicyNetwork(torch.nn.Module):
    def __init__(self, input_dim=768, hidden_dim=256):
        super(PolicyNetwork, self).__init__()
        self.network = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, input_dim),
            torch.nn.Tanh()  # Output adjustment to [-1, 1]
        )

    def forward(self, x):
        return self.network(x)

# RL Environment for DPR ranking
class RankingEnvironment:
    def __init__(self, ctx_encoder, question_encoder, candidates, candidate_embeddings, val_loader, task: str = "qa"):
        self.ctx_encoder = ctx_encoder
        self.question_encoder = question_encoder
        self.candidates = candidates
        self.candidate_embeddings = candidate_embeddings
        self.val_loader = val_loader
        self.task = task
        self.current_batch = None
        self.current_idx = 0
        self.batch_iterator = iter(val_loader)

    def reset(self):
        try:
            self.current_batch = next(self.batch_iterator)
            self.current_idx = 0
        except StopIteration:
            self.batch_iterator = iter(self.val_loader)
            self.current_batch = next(self.batch_iterator)
            self.current_idx = 0
        question_inputs = {
            "input_ids": self.current_batch["dpr_input_ids"][self.current_idx].to(CONFIG.DEVICE).unsqueeze(0),
            "attention_mask": self.current_batch["dpr_attention_mask"][self.current_idx].to(CONFIG.DEVICE).unsqueeze(0)
        }
        with torch.no_grad():
            state = self.question_encoder(**question_inputs).pooler_output  # Shape: (1, 768)
        return state

    def step(self, action):
        question_inputs = {
            "input_ids": self.current_batch["dpr_input_ids"][self.current_idx].to(CONFIG.DEVICE).unsqueeze(0),
            "attention_mask": self.current_batch["dpr_attention_mask"][self.current_idx].to(CONFIG.DEVICE).unsqueeze(0)
        }
        with torch.no_grad():
            question_embedding = self.question_encoder(**question_inputs).pooler_output  # Shape: (1, 768)
        adjusted_embedding = question_embedding + action  # Adjust embedding
        similarities = torch.matmul(adjusted_embedding, self.candidate_embeddings.T)  # Shape: (1, num_candidates)
        rankings = torch.argsort(similarities, dim=1, descending=True)
        ref = self.current_batch["answer"][self.current_idx]
        done = False
        if ref not in self.candidates:
            reward = 0.0
            done = True
        else:
            reward = compute_reward(rankings[0].cpu(), ref)
            print(f"Task: {self.task}, Reward: {reward:.4f}")

        # Move to next sample
        self.current_idx += 1
        if self.current_idx >= len(self.current_batch["answer"]):
            try:
                self.current_batch = next(self.batch_iterator)
                self.current_idx = 0
            except StopIteration:
                done = True

        # Get next state
        if not done:
            next_question_inputs = {
                "input_ids": self.current_batch["dpr_input_ids"][self.current_idx].to(CONFIG.DEVICE).unsqueeze(0),
                "attention_mask": self.current_batch["dpr_attention_mask"][self.current_idx].to(CONFIG.DEVICE).unsqueeze(0)
            }
            with torch.no_grad():
                next_state = self.question_encoder(**next_question_inputs).pooler_output
        else:
            next_state = None

        return next_state, reward, done

# PPO Implementation
def compute_gae(rewards, values, next_value, gamma=0.99, lam=0.95):
    advantages = []
    gae = 0
    rewards = rewards + [next_value]
    values = values + [next_value]
    for t in reversed(range(len(rewards) - 1)):
        delta = rewards[t] + gamma * values[t + 1] - values[t]
        gae = delta + gamma * lam * gae
        advantages.insert(0, gae)
    return advantages

def ppo_update(policy, optimizer, states, actions, old_log_probs, advantages, returns, clip_epsilon=0.1, epochs=5):
    for _ in range(epochs):
        policy.train()
        new_log_probs = []
        for state, action in zip(states, actions):
            state = state.to(CONFIG.DEVICE)
            action = action.to(CONFIG.DEVICE)
            adjustment = policy(state)
            dist = torch.distributions.Normal(adjustment, 0.1)
            log_prob = dist.log_prob(action).sum(dim=-1)
            new_log_probs.append(log_prob)
        new_log_probs = torch.stack(new_log_probs)

        # Compute ratio
        ratios = torch.exp(new_log_probs - torch.stack(old_log_probs))

        # Compute surrogate loss
        surr1 = ratios * torch.tensor(advantages, device=CONFIG.DEVICE)
        surr2 = torch.clamp(ratios, 1 - clip_epsilon, 1 + clip_epsilon) * torch.tensor(advantages, device=CONFIG.DEVICE)
        policy_loss = -torch.min(surr1, surr2).mean()

        # Update policy
        optimizer.zero_grad()
        policy_loss.backward()
        optimizer.step()
        print(f"Policy Loss: {policy_loss.item():.4f}")

    return policy_loss.item()

print("RL helper functions defined.")

RL helper functions defined.


In [None]:
# RL Fine-Tuning for DPR on QA Task

# Fine-tune DPR with PPO on QA data
policy_qa = PolicyNetwork(input_dim=768, hidden_dim=256).to(CONFIG.DEVICE)
value_network_qa = PolicyNetwork(input_dim=768, hidden_dim=256).to(CONFIG.DEVICE)
policy_optimizer_qa = torch.optim.Adam(policy_qa.parameters(), lr=5e-5)
value_optimizer_qa = torch.optim.Adam(value_network_qa.parameters(), lr=5e-5)

# Create environment for QA task
qa_env = RankingEnvironment(ctx_encoder_qa, question_encoder_qa, all_candidates, candidate_embeddings_qa, qa_val_loader_v4, task="qa")

num_episodes = 500  # Reduced number of episodes
max_steps_per_episode = 50  # Maximum steps per episode

print("Fine-tuning DPR with PPO on QA data...")
for episode in range(num_episodes):
    print(f"QA Episode {episode + 1}/{num_episodes}")
    state = qa_env.reset()
    if state.dim() == 1:
        state = state.unsqueeze(0)
    states, actions, rewards, old_log_probs, values = [], [], [], [], []
    done = False
    episode_steps = 0

    while not done and episode_steps < max_steps_per_episode:
        state = state.to(CONFIG.DEVICE)
        with torch.no_grad():
            adjustment = policy_qa(state)
            dist = torch.distributions.Normal(adjustment, 0.1)
            action = dist.sample()
            log_prob = dist.log_prob(action).sum(dim=-1)
            value = value_network_qa(state).mean()

        next_state, reward, done = qa_env.step(action)

        if next_state is not None and next_state.dim() == 1:
            next_state = next_state.unsqueeze(0)

        states.append(state)
        actions.append(action)
        rewards.append(reward)
        old_log_probs.append(log_prob)
        values.append(value.item())

        state = next_state
        episode_steps += 1

    # Compute returns and advantages
    with torch.no_grad():
        next_value = value_network_qa(state).mean().item() if state is not None else 0.0
    advantages = compute_gae(rewards, values, next_value)
    returns = [r + next_value for r in rewards]

    # Update policy and value network
    policy_loss = ppo_update(policy_qa, policy_optimizer_qa, states, actions, old_log_probs, advantages, returns)

    # Update value network
    value_loss = 0
    for state, ret in zip(states, returns):
        state = state.to(CONFIG.DEVICE)
        value = value_network_qa(state).mean()
        value_loss += (value - ret) ** 2
    value_loss = value_loss / len(states)
    value_optimizer_qa.zero_grad()
    value_loss.backward()
    value_optimizer_qa.step()
    print(f"QA Value Loss: {value_loss.item():.4f}")

# Save fine-tuned DPR models for QA task (Version 4)
ctx_encoder_qa.save_pretrained(os.path.join(CONFIG.BASE_PATH, "dpr_ctx_encoder_rl_qa_v4"))
question_encoder_qa.save_pretrained(os.path.join(CONFIG.BASE_PATH, "dpr_question_encoder_rl_qa_v4"))
print("Saved RL fine-tuned DPR models for QA task.")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
QA Value Loss: 0.0275
QA Episode 354/500
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 1.5000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 1.5000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Policy Loss: -1.1257
Policy Loss: -1.2374
Policy Loss: -1.2356
Policy Loss: -1.2326
Policy Loss: -1.2277
QA Value Loss: 0.0275
QA Episode 355/500
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.0000
Task: qa, Reward: 0.000

In [None]:
# RL Fine-Tuning for DPR on Triple Task

# Fine-tune DPR with PPO on triple data
policy_triple = PolicyNetwork(input_dim=768, hidden_dim=256).to(CONFIG.DEVICE)
value_network_triple = PolicyNetwork(input_dim=768, hidden_dim=256).to(CONFIG.DEVICE)
policy_optimizer_triple = torch.optim.Adam(policy_triple.parameters(), lr=5e-5)
value_optimizer_triple = torch.optim.Adam(value_network_triple.parameters(), lr=5e-5)

# Create environment for triple task
triple_env = RankingEnvironment(ctx_encoder_triple, question_encoder_triple, all_candidates, candidate_embeddings_triple, triple_val_loader_v4, task="triple")

num_episodes = 500  # Reduced number of episodes
max_steps_per_episode = 50  # Maximum steps per episode

print("Fine-tuning DPR with PPO on triple data...")
for episode in range(num_episodes):
    print(f"Triple Episode {episode + 1}/{num_episodes}")
    state = triple_env.reset()
    if state.dim() == 1:
        state = state.unsqueeze(0)
    states, actions, rewards, old_log_probs, values = [], [], [], [], []
    done = False
    episode_steps = 0

    while not done and episode_steps < max_steps_per_episode:
        state = state.to(CONFIG.DEVICE)
        with torch.no_grad():
            adjustment = policy_triple(state)
            dist = torch.distributions.Normal(adjustment, 0.1)
            action = dist.sample()
            log_prob = dist.log_prob(action).sum(dim=-1)
            value = value_network_triple(state).mean()

        next_state, reward, done = triple_env.step(action)

        if next_state is not None and next_state.dim() == 1:
            next_state = next_state.unsqueeze(0)

        states.append(state)
        actions.append(action)
        rewards.append(reward)
        old_log_probs.append(log_prob)
        values.append(value.item())

        state = next_state
        episode_steps += 1

    # Compute returns and advantages
    with torch.no_grad():
        next_value = value_network_triple(state).mean().item() if state is not None else 0.0
    advantages = compute_gae(rewards, values, next_value)
    returns = [r + next_value for r in rewards]

    # Update policy and value network
    policy_loss = ppo_update(policy_triple, policy_optimizer_triple, states, actions, old_log_probs, advantages, returns)

    # Update value network
    value_loss = 0
    for state, ret in zip(states, returns):
        state = state.to(CONFIG.DEVICE)
        value = value_network_triple(state).mean()
        value_loss += (value - ret) ** 2
    value_loss = value_loss / len(states)
    value_optimizer_triple.zero_grad()
    value_loss.backward()
    value_optimizer_triple.step()
    print(f"Triple Value Loss: {value_loss.item():.4f}")

# Save fine-tuned DPR models for triple task
ctx_encoder_triple.save_pretrained(os.path.join(CONFIG.BASE_PATH, "dpr_ctx_encoder_rl_triple_v4"))
question_encoder_triple.save_pretrained(os.path.join(CONFIG.BASE_PATH, "dpr_question_encoder_rl_triple_v4"))
print("Saved RL fine-tuned DPR models for triple task.")

Fine-tuning DPR with PPO on triple data...
Triple Episode 1/500
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 1.5000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.5000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.5000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.2000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0417
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: triple, Reward: 0.0000
Task: tr

In [None]:
# Save RL Metrics and Artifacts

rl_metrics = {
    "qa_policy_losses": [],  # Placeholder, as losses are not computed in the current setup
    "qa_value_losses": [],   # Placeholder
    "triple_policy_losses": [],  # Placeholder
    "triple_value_losses": []    # Placeholder
}

rl_metrics_path = os.path.join(CONFIG.BASE_PATH, "rl_metrics_v4.json")
with open(rl_metrics_path, "w") as f:
    json.dump(rl_metrics, f)
print(f"Saved RL metrics at {rl_metrics_path}")

# Save updated candidate embeddings
with torch.no_grad():
    # Limit candidate pool size for faster encoding
    eval_candidates = all_candidates[:1000]  # Limit to 1000 candidates
    candidate_inputs = ctx_tokenizer(eval_candidates, return_tensors="pt", padding=True, truncation=True, max_length=CONFIG.MAX_LENGTH)
    candidate_inputs = {k: v.to(CONFIG.DEVICE) for k, v in candidate_inputs.items()}
    candidate_embeddings_qa = ctx_encoder_qa(**candidate_inputs).pooler_output
    candidate_embeddings_triple = ctx_encoder_triple(**candidate_inputs).pooler_output
torch.save(candidate_embeddings_qa, os.path.join(save_path, 'dpr_candidate_embeddings_rl_qa_v4.pt'))
torch.save(candidate_embeddings_triple, os.path.join(save_path, 'dpr_candidate_embeddings_rl_triple_v4.pt'))

print("Updated candidate embeddings saved for Version 4.")

Saved RL metrics at /content/drive/MyDrive/LJMU-Datasets/rl_metrics_v4.json
Updated candidate embeddings saved for Version 4.


Improve RL reward

In [None]:
# Define RL Helper Functions

# Normalize text for evaluation
def normalize_text(text: str) -> str:
    text = str(text).lower().strip()
    text = text.translate(str.maketrans("", "", string.punctuation))
    articles = {'a', 'an', 'the'}
    words = text.split()
    words = [word for word in words if word not in articles]
    return ' '.join(words)

# Compute BLEU score
def compute_bleu(generated: str, reference: str) -> float:
    return sentence_bleu([reference.split()], generated.split())

# Compute ROUGE-L score
def compute_rouge_l(generated: str, reference: str) -> float:
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    return scorer.score(reference, generated)['rougeL'].fmeasure

# Compute BERTScore
def compute_bertscore(generated: str, reference: str) -> float:
    return bert_score.score([generated], [reference], lang="en", verbose=False)[2].mean().item()

# Compute MRR for reward
def compute_mrr(generated_ranking: torch.Tensor, ref_idx: int) -> float:
    rank = (generated_ranking == ref_idx).nonzero(as_tuple=True)[0].item() + 1 if ref_idx in generated_ranking else len(generated_ranking)
    return 1.0 / rank

# Compute reward for RL (Updated with semantic similarity for missing references)
def compute_reward(generated_ranking: torch.Tensor, reference: str) -> float:
    candidates = all_candidates[:100]  # Use small candidate pool for efficiency
    ref_idx = candidates.index(reference) if reference in candidates else -1
    if ref_idx == -1:
        # Compute semantic similarity to the reference using SentenceTransformer
        sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
        ref_embedding = sentence_transformer.encode(reference, convert_to_tensor=True).to(CONFIG.DEVICE)
        candidate_embeddings = sentence_transformer.encode(candidates, convert_to_tensor=True).to(CONFIG.DEVICE)
        similarities = torch.cosine_similarity(ref_embedding.unsqueeze(0), candidate_embeddings, dim=1)
        max_similarity = similarities.max().item()
        # Scale the penalty based on the maximum similarity (closer to 0 if similar, closer to -0.5 if dissimilar)
        penalty = -0.5 * (1.0 - max_similarity)
        return penalty
    mrr = compute_mrr(generated_ranking, ref_idx)
    mrr_scaled = mrr * 2.0  # Scale MRR to provide stronger signal
    exact_match_bonus = 0.5 if generated_ranking[0] == ref_idx else 0.0
    exploration_bonus = 0.01  # Small bonus to encourage exploration
    return mrr_scaled + exact_match_bonus + exploration_bonus

# Policy Network for DPR (simple MLP to adjust embeddings)
class PolicyNetwork(torch.nn.Module):
    def __init__(self, input_dim=768, hidden_dim=256):
        super(PolicyNetwork, self).__init__()
        self.network = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, input_dim),
            torch.nn.Tanh()  # Output adjustment to [-1, 1]
        )

    def forward(self, x):
        return self.network(x)

# RL Environment for DPR ranking
class RankingEnvironment:
    def __init__(self, ctx_encoder, question_encoder, candidates, candidate_embeddings, val_loader, task: str = "qa"):
        self.ctx_encoder = ctx_encoder
        self.question_encoder = question_encoder
        self.candidates = candidates
        self.candidate_embeddings = candidate_embeddings
        self.val_loader = val_loader
        self.task = task
        self.current_batch = None
        self.current_idx = 0
        self.batch_iterator = iter(val_loader)

    def reset(self):
        try:
            self.current_batch = next(self.batch_iterator)
            self.current_idx = 0
        except StopIteration:
            self.batch_iterator = iter(self.val_loader)
            self.current_batch = next(self.batch_iterator)
            self.current_idx = 0
        question_inputs = {
            "input_ids": self.current_batch["dpr_input_ids"][self.current_idx].to(CONFIG.DEVICE).unsqueeze(0),
            "attention_mask": self.current_batch["dpr_attention_mask"][self.current_idx].to(CONFIG.DEVICE).unsqueeze(0)
        }
        with torch.no_grad():
            state = self.question_encoder(**question_inputs).pooler_output  # Shape: (1, 768)
        return state

    def step(self, action):
        question_inputs = {
            "input_ids": self.current_batch["dpr_input_ids"][self.current_idx].to(CONFIG.DEVICE).unsqueeze(0),
            "attention_mask": self.current_batch["dpr_attention_mask"][self.current_idx].to(CONFIG.DEVICE).unsqueeze(0)
        }
        with torch.no_grad():
            question_embedding = self.question_encoder(**question_inputs).pooler_output  # Shape: (1, 768)
        adjusted_embedding = question_embedding + action  # Adjust embedding
        similarities = torch.matmul(adjusted_embedding, self.candidate_embeddings.T)  # Shape: (1, num_candidates)
        rankings = torch.argsort(similarities, dim=1, descending=True)
        ref = self.current_batch["answer"][self.current_idx]
        done = False
        reward = compute_reward(rankings[0].cpu(), ref)
        print(f"Task: {self.task}, Reward: {reward:.4f}")

        # Move to next sample
        self.current_idx += 1
        if self.current_idx >= len(self.current_batch["answer"]):
            try:
                self.current_batch = next(self.batch_iterator)
                self.current_idx = 0
            except StopIteration:
                done = True

        # Get next state
        if not done:
            next_question_inputs = {
                "input_ids": self.current_batch["dpr_input_ids"][self.current_idx].to(CONFIG.DEVICE).unsqueeze(0),
                "attention_mask": self.current_batch["dpr_attention_mask"][self.current_idx].to(CONFIG.DEVICE).unsqueeze(0)
            }
            with torch.no_grad():
                next_state = self.question_encoder(**next_question_inputs).pooler_output
        else:
            next_state = None

        return next_state, reward, done

# PPO Implementation
def compute_gae(rewards, values, next_value, gamma=0.99, lam=0.95):
    advantages = []
    gae = 0
    rewards = rewards + [next_value]
    values = values + [next_value]
    for t in reversed(range(len(rewards) - 1)):
        delta = rewards[t] + gamma * values[t + 1] - values[t]
        gae = delta + gamma * lam * gae
        advantages.insert(0, gae)
    return advantages

def ppo_update(policy, optimizer, states, actions, old_log_probs, advantages, returns, clip_epsilon=0.1, epochs=5):
    for _ in range(epochs):
        policy.train()
        new_log_probs = []
        for state, action in zip(states, actions):
            state = state.to(CONFIG.DEVICE)
            action = action.to(CONFIG.DEVICE)
            adjustment = policy(state)
            dist = torch.distributions.Normal(adjustment, 0.1)
            log_prob = dist.log_prob(action).sum(dim=-1)
            new_log_probs.append(log_prob)
        new_log_probs = torch.stack(new_log_probs)

        # Compute ratio
        ratios = torch.exp(new_log_probs - torch.stack(old_log_probs))

        # Compute surrogate loss
        surr1 = ratios * torch.tensor(advantages, device=CONFIG.DEVICE)
        surr2 = torch.clamp(ratios, 1 - clip_epsilon, 1 + clip_epsilon) * torch.tensor(advantages, device=CONFIG.DEVICE)
        policy_loss = -torch.min(surr1, surr2).mean()

        # Update policy
        optimizer.zero_grad()
        policy_loss.backward()
        optimizer.step()
        print(f"Policy Loss: {policy_loss.item():.4f}")

    return policy_loss.item()

print("RL helper functions defined.")

RL helper functions defined.


In [None]:
# Test Updated Reward Function and Re-Run RL Fine-Tuning for Triple Task

# Test the updated reward function on a few examples from triple data with a larger candidate pool
print("Testing updated reward function on triple data...")
val_batch = next(iter(triple_val_loader_v4))
test_references = val_batch["answer"][:5]
question_inputs = {
    "input_ids": val_batch["dpr_input_ids"][:5].to(CONFIG.DEVICE),
    "attention_mask": val_batch["dpr_attention_mask"][:5].to(CONFIG.DEVICE)
}
with torch.no_grad():
    question_embeddings = question_encoder_triple(**question_inputs).pooler_output
    similarities = torch.matmul(question_embeddings, candidate_embeddings_triple.T)
    rankings = torch.argsort(similarities, dim=1, descending=True).cpu()

# Temporarily increase candidate pool for testing
candidates = all_candidates[:500]  # Increase to 500 candidates for testing
for rank, ref in zip(rankings, test_references):
    reward = compute_reward(rank, ref)
    print(f"Reference: {ref}, Reward: {reward:.4f}")

# Re-run RL fine-tuning for the triple task with the updated reward function
policy_triple = PolicyNetwork(input_dim=768, hidden_dim=256).to(CONFIG.DEVICE)
value_network_triple = PolicyNetwork(input_dim=768, hidden_dim=256).to(CONFIG.DEVICE)
policy_optimizer_triple = torch.optim.Adam(policy_triple.parameters(), lr=5e-5)
value_optimizer_triple = torch.optim.Adam(value_network_triple.parameters(), lr=5e-5)

# Create environment for triple task
triple_env = RankingEnvironment(ctx_encoder_triple, question_encoder_triple, all_candidates, candidate_embeddings_triple, triple_val_loader_v4, task="triple")

num_episodes = 500
max_steps_per_episode = 50

print("Re-running fine-tuning DPR with PPO on triple data...")
triple_policy_losses = []
triple_value_losses = []
for episode in range(num_episodes):
    print(f"Triple Episode {episode + 1}/{num_episodes}")
    state = triple_env.reset()
    if state.dim() == 1:
        state = state.unsqueeze(0)
    states, actions, rewards, old_log_probs, values = [], [], [], [], []
    done = False
    episode_steps = 0

    while not done and episode_steps < max_steps_per_episode:
        state = state.to(CONFIG.DEVICE)
        with torch.no_grad():
            adjustment = policy_triple(state)
            dist = torch.distributions.Normal(adjustment, 0.1)
            action = dist.sample()
            log_prob = dist.log_prob(action).sum(dim=-1)
            value = value_network_triple(state).mean()

        next_state, reward, done = triple_env.step(action)

        if next_state is not None and next_state.dim() == 1:
            next_state = next_state.unsqueeze(0)

        states.append(state)
        actions.append(action)
        rewards.append(reward)
        old_log_probs.append(log_prob)
        values.append(value.item())

        state = next_state
        episode_steps += 1

    # Compute returns and advantages
    with torch.no_grad():
        next_value = value_network_triple(state).mean().item() if state is not None else 0.0
    advantages = compute_gae(rewards, values, next_value)
    returns = [r + next_value for r in rewards]

    # Update policy and value network
    policy_loss = ppo_update(policy_triple, policy_optimizer_triple, states, actions, old_log_probs, advantages, returns)
    triple_policy_losses.append(policy_loss)

    # Update value network
    value_loss = 0
    for state, ret in zip(states, returns):
        state = state.to(CONFIG.DEVICE)
        value = value_network_triple(state).mean()
        value_loss += (value - ret) ** 2
    value_loss = value_loss / len(states)
    value_optimizer_triple.zero_grad()
    value_loss.backward()
    value_optimizer_triple.step()
    triple_value_losses.append(value_loss.item())
    print(f"Triple Value Loss: {value_loss.item():.4f}")

# Save fine-tuned DPR models for triple task
ctx_encoder_triple.save_pretrained(os.path.join(CONFIG.BASE_PATH, "dpr_ctx_encoder_rl_triple_v4"))
question_encoder_triple.save_pretrained(os.path.join(CONFIG.BASE_PATH, "dpr_question_encoder_rl_triple_v4"))
print("Saved RL fine-tuned DPR models for triple task.")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Task: triple, Reward: -0.2180
Task: triple, Reward: -0.3029
Task: triple, Reward: -0.2903
Task: triple, Reward: -0.2666
Task: triple, Reward: -0.2686
Task: triple, Reward: -0.2615
Task: triple, Reward: -0.3085
Task: triple, Reward: -0.2847
Policy Loss: 2.8493
Policy Loss: 2.5643
Policy Loss: 2.5643
Policy Loss: 2.5643
Policy Loss: 2.5643
Triple Value Loss: 0.0624
Triple Episode 413/500
Task: triple, Reward: -0.3164
Task: triple, Reward: -0.2176
Task: triple, Reward: -0.2122
Task: triple, Reward: -0.4137
Task: triple, Reward: -0.2231
Task: triple, Reward: -0.2047
Task: triple, Reward: -0.1573
Task: triple, Reward: -0.2571
Task: triple, Reward: -0.1791
Task: triple, Reward: -0.3133
Task: triple, Reward: -0.2240
Task: triple, Reward: -0.3104
Task: triple, Reward: -0.2180
Task: triple, Reward: -0.2472
Task: triple, Reward: -0.2775
Task: triple, Reward: -0.2403
Task: triple, Reward: -0.2100
Task: triple, Reward: -0.1730
Task: 

In [None]:
# Save RL Metrics and Artifacts

# Save RL metrics
rl_metrics = {
    "qa_policy_losses": [],  # Placeholder for QA, as only triple task was re-run
    "qa_value_losses": [],   # Placeholder
    "triple_policy_losses": triple_policy_losses,
    "triple_value_losses": triple_value_losses
}

rl_metrics_path = os.path.join(CONFIG.BASE_PATH, "rl_metrics_v4.json")
with open(rl_metrics_path, "w") as f:
    json.dump(rl_metrics, f)
print(f"Saved RL metrics at {rl_metrics_path}")

# Save updated candidate embeddings
with torch.no_grad():
    # Limit candidate pool size for faster encoding
    eval_candidates = all_candidates[:1000]  # Limit to 1000 candidates
    candidate_inputs = ctx_tokenizer(eval_candidates, return_tensors="pt", padding=True, truncation=True, max_length=CONFIG.MAX_LENGTH)
    candidate_inputs = {k: v.to(CONFIG.DEVICE) for k, v in candidate_inputs.items()}
    candidate_embeddings_qa = ctx_encoder_qa(**candidate_inputs).pooler_output
    candidate_embeddings_triple = ctx_encoder_triple(**candidate_inputs).pooler_output
torch.save(candidate_embeddings_qa, os.path.join(save_path, 'dpr_candidate_embeddings_rl_qa_v4.pt'))
torch.save(candidate_embeddings_triple, os.path.join(save_path, 'dpr_candidate_embeddings_rl_triple_v4.pt'))

print("Updated candidate embeddings saved for Version 4.")

Saved RL metrics at /content/drive/MyDrive/LJMU-Datasets/rl_metrics_v4.json
Updated candidate embeddings saved for Version 4.
