In [None]:
import torch
from transformers import T5Tokenizer, T5ForSequenceClassification
from tqdm import tqdm


def load_model(model_path, device="cpu"):

    tokenizer = T5Tokenizer.from_pretrained(model_path)
    model = T5ForSequenceClassification.from_pretrained(model_path, num_labels=1)
    model.to(device)
    model.eval()
    return tokenizer, model


def evaluate_relevance(tokenizer, model, questions, contexts, device, n_docs=10):

    scores = []
    for q, context in tqdm(zip(questions, contexts), total=len(questions)):
        input_text = f"{q} [SEP] {context}"

        inputs = tokenizer(
            input_text, return_tensors="pt", padding="max_length", max_length=512
        )

        with torch.no_grad():
            output = model(
                inputs["input_ids"].to(device),
                attention_mask=inputs["attention_mask"].to(device),
            )

        # Relevance score for the context
        scores.append(float(output.logits.cpu()))
        print(scores)
    return scores


def flag_relevance(scores, threshold1, threshold2):

    flags = []
    for score in scores:
        if score >= threshold1:
            flags.append(2)  # High relevance
        elif score >= threshold2:
            flags.append(1)  # Moderate relevance
        else:
            flags.append(0)  # Low retlevance
    return flags


def check_context_relevance(
    model_path, questions, contexts, threshold1, threshold2, device="cpu"
):

    tokenizer, model = load_model(model_path, device)
    scores = evaluate_relevance(tokenizer, model, questions, contexts, device)
    flags = flag_relevance(scores, threshold1, threshold2)
    return flags


# Example usage
# Define your input questions, contexts, and thresholds
questions = ["What is the capital of France?"]
contexts = [
    "Paris is the capital and largest city of France."
]
model_path = "google/flan-t5-base"
threshold1 = 0.5  # Example threshold for high relevance
threshold2 = 0.1  # Example threshold for moderate relevance

# Run the relevance checking pipeline
relevance_flags = check_context_relevance(
    model_path, questions, contexts, threshold1, threshold2
)
print(
    relevance_flags
)  # Output: [2, 2] if both are highly relevant, or other values based on thresholds

Some weights of T5ForSequenceClassification were not initialized from the model checkpoint at google/flan-t5-base and are newly initialized: ['classification_head.dense.bias', 'classification_head.dense.weight', 'classification_head.out_proj.bias', 'classification_head.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 1/1 [00:07<00:00,  7.97s/it]

[-0.10978049039840698]
[0]





In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch

# Initialize the T5 model and tokenizer
model_name = "google/flan-t5-base"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define the prompt template
def generate_relevance_prompt(question, document):
    return (
        f"You are a grader assessing the relevance of a retrieved document to a user question. "
        f"User question: {question} "
        f"Retrieved document: {document} "
        f"Rate the relevance on a scale from 0 to 10 where 0 means not relevant at all and 10 means highly relevant."
    )

# Function to check relevance and apply a threshold
def check_relevance_with_threshold(question, context, threshold=50):
    # Generate the prompt for relevance scoring
    prompt = generate_relevance_prompt(question, context)
    inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=512).to(device)

    with torch.no_grad():
        outputs = model.generate(inputs["input_ids"], max_length=10)  # Increase max_length if the response is longer

    # Decode the output and try to convert it to a score
    relevance_score_text = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

    try:
        # Convert the text response to a numerical score
        relevance_score = float(relevance_score_text)
    except ValueError:
        # If conversion fails, assume a default score (optional)
        relevance_score = 0.0  # or some neutral score

    # Check if the score meets the threshold for relevance
    is_relevant = relevance_score >= threshold
    return relevance_score, is_relevant

# Example question and context
questions = ["What is the capital of France?"]
contexts = ["Paris is the Capital city of France."]

# Define the threshold
threshold = 5

# Check relevance for each question-context pair
for question, context in zip(questions, contexts):
    relevance_score, is_relevant = check_relevance_with_threshold(question, context, threshold)
    print(f"Question: '{question}'\nDocument: '{context}'\nRelevance Score: {relevance_score}\nIs Relevant: {is_relevant}\n")


tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Question: 'What is the capital of France?'
Document: 'Paris is the Capital city of France.'
Relevance Score: 10.0
Is Relevant: True

