<a href="https://colab.research.google.com/github/BhuvaneswariGanagala/currency_converter/blob/main/Copy_of_zero_shot_using_CRG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, pipeline, BartForSequenceClassification
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import time
import re
import json
import os
import requests
import warnings
warnings.filterwarnings("ignore")

In [None]:
# Configuration
class Config:
    # Models
    stance_model_name = "facebook/bart-large-mnli"  # For zero-shot stance detection
    embedding_model = "sentence-transformers/all-mpnet-base-v2"  # For semantic similarity in CRG
    ollama_model = "mistral"  # Local model for knowledge retrieval

    # Parameters
    max_length = 512
    batch_size = 16
    crg_threshold = 0.65  # Relevance threshold for CRG

    # Stance labels
    stance_labels = ["favor", "against", "neutral"]

    # Label mappings
    stance_map_to_num = {'FAVOR': 0, 'AGAINST': 1, 'NONE': 2, 'favor': 0, 'against': 1, 'neutral': 2, 'none': 2}
    stance_map_to_text = {0: 'favor', 1: 'against', 2: 'neutral'}

config = Config()

In [None]:
# 1. Data Loading and Preprocessing
def load_data(file_path):
    """Load stance detection dataset"""
    df = pd.read_csv(file_path)
    print(f"Loaded {len(df)} samples")
    return df

def preprocess_text(text):
    """Clean and normalize text for better model performance"""
    if pd.isna(text) or text is None:
        return ""

    text = str(text).lower()
    text = re.sub(r'http\S+', '', text)  # Remove URLs
    text = re.sub(r'@\w+', '', text)  # Remove mentions
    text = re.sub(r'#(\w+)', r'\1', text)  # Keep hashtag content
    text = re.sub(r'\s+', ' ', text).strip()  # Remove extra spaces
    return text


def prepare_dataset(df, text_col='Tweet', target_col='Target', stance_col='Stance'):
    """Preprocess dataset with proper error handling"""
    # Ensure required columns exist
    required_cols = [text_col, target_col]
    if stance_col and stance_col not in df.columns:
        print(f"Warning: Stance column '{stance_col}' not found. Running in inference mode only.")
        stance_col = None

    for col in required_cols:
        if col not in df.columns:
            raise ValueError(f"Required column '{col}' not found in dataset")

    # Copy to avoid modifying original
    processed_df = df.copy()

    # Preprocess text and target
    processed_df['processed_text'] = processed_df[text_col].apply(preprocess_text)
    processed_df['processed_target'] = processed_df[target_col].apply(preprocess_text)

    # Handle stance labels if available
    if stance_col:
        # Apply label mapping with error handling
        processed_df['label'] = processed_df[stance_col].apply(
            lambda x: config.stance_map_to_num.get(x, 2) if not pd.isna(x) else 2
        )

    return processed_df

In [None]:
class KnowledgeRetriever:
    def __init__(self, gemini_api_key, max_retries=3, delay=5):
        self.gemini_api_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
        self.gemini_api_key = gemini_api_key
        self.max_retries = max_retries
        self.delay = delay

    def _format_prompt(self, query):
        return f"""
        You are an AI assistant. Provide a **concise factual summary** (max 100 words) on the topic:

        "{query}"

        Keep it **neutral and informative**. Avoid opinions.
        """

    def _call_gemini_api(self, query):
        headers = {"Content-Type": "application/json"}
        payload = {
            "contents": [{"role": "user", "parts": [{"text": self._format_prompt(query)}]}]
        }
        params = {"key": self.gemini_api_key}

        for attempt in range(self.max_retries):
            try:
                response = requests.post(self.gemini_api_url, headers=headers, params=params, json=payload)
                response.raise_for_status()  # Raise error if request fails

                data = response.json()
                if "candidates" in data and data["candidates"]:
                    return data["candidates"][0]["content"]["parts"][0]["text"]
                else:
                    return "No relevant information found."

            except requests.exceptions.RequestException as e:
                print(f"Error fetching data from Gemini (Attempt {attempt+1}): {e}")
                time.sleep(self.delay)

        return "Failed to retrieve knowledge after multiple attempts."

    def retrieve(self, topic):
        return self._call_gemini_api(topic)

In [None]:
# 3. Context Relevance Gate (CRG)
class ContextRelevanceGate:
    """
    The Context Relevance Gate filters external knowledge to ensure it's relevant
    to both the text and the target before using it for stance detection.
    """
    def __init__(self, embedding_model=config.embedding_model, threshold=config.crg_threshold):
        # Load the model for computing semantic similarity
        self.embedding_model = SentenceTransformer(embedding_model)
        self.threshold = threshold

    def compute_similarity(self, text1, text2):
        """Compute semantic similarity between two texts using cosine similarity"""
        if not text1 or not text2:
            return 0.0

        try:
            # Encode text to embedding vectors
            embeddings1 = self.embedding_model.encode(text1, convert_to_tensor=True)
            embeddings2 = self.embedding_model.encode(text2, convert_to_tensor=True)

            # Calculate cosine similarity
            similarity = util.pytorch_cos_sim(embeddings1, embeddings2).item()
            return float(similarity)
        except Exception as e:
            print(f"Error computing similarity: {str(e)}")
            return 0.0

    def filter_knowledge(self, text, target, knowledge):
        """
        Filter knowledge based on relevance to text and target
        Returns filtered knowledge and relevance score
        """
        # Check if knowledge is relevant to both text and target
        text_similarity = self.compute_similarity(text, knowledge)
        target_similarity = self.compute_similarity(target, knowledge)

        # Calculate weighted relevance score (target relevance is more important)
        relevance_score = (text_similarity * 0.4) + (target_similarity * 0.6)

        # Apply filtering based on threshold
        if relevance_score >= self.threshold:
            return knowledge, relevance_score
        else:
            # Extract key sentences if possible, otherwise return empty
            try:
                sentences = re.split(r'[.!?]+', knowledge)
                best_sentence = max(sentences, key=lambda s: self.compute_similarity(s, target))
                if self.compute_similarity(best_sentence, target) > self.threshold:
                    return best_sentence.strip(), relevance_score
            except:
                pass

            # If extraction fails or not relevant enough, return empty
            return "", 0.0


In [None]:
# 4. Stance Detection using Zero-Shot NLI
class StanceDetector:
    """Performs stance detection using zero-shot natural language inference"""
    def __init__(self, model_name=config.stance_model_name):
        # Load models and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = BartForSequenceClassification.from_pretrained(model_name)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        # Set up NLI pipeline
        self.nli_pipeline = pipeline(
            "zero-shot-classification",
            model=model_name,
            tokenizer=self.tokenizer,
            device=0 if torch.cuda.is_available() else -1
        )

        # Define stance labels for NLI
        self.stance_hypotheses = {
            "favor": "The text is in favor of the target.",
            "against": "The text is against the target.",
            "neutral": "The text is neutral towards the target."
        }

    def predict_stance(self, text, target, knowledge=""):
        """
        Predict stance using zero-shot NLI approach
        Returns the predicted stance and confidence score
        """
        # Prepare input with text, target, and filtered knowledge
        if knowledge and knowledge.strip():
            premise = f"Text: {text} Target: {target} Knowledge: {knowledge}"
        else:
            premise = f"Text: {text} Target: {target}"

        # Get candidate labels for NLI
        candidate_labels = list(self.stance_hypotheses.values())

        # Get zero-shot predictions
        try:
            result = self.nli_pipeline(premise, candidate_labels, multi_label=False)

            # Map NLI results to stance labels
            hypothesis_to_stance = {v: k for k, v in self.stance_hypotheses.items()}
            top_prediction = result['labels'][0]
            stance = hypothesis_to_stance.get(top_prediction, "neutral")
            confidence = result['scores'][0]

            return stance, confidence
        except Exception as e:
            print(f"Error in stance prediction: {str(e)}")
            return "neutral", 0.0

In [None]:
# 5. Main Pipeline
class StanceDetectionPipeline:
    """End-to-end pipeline for stance detection with knowledge filtering"""
    def __init__(self):
        # Initialize components
        self.knowledge_retriever = KnowledgeRetriever("AIzaSyAosP5Y4_jIbRljHwk6CiurZCHlMxkc7v0")
        self.crg = ContextRelevanceGate()
        self.stance_detector = StanceDetector()

    def process_single(self, text, target):
        """
        Process a single text-target pair for stance detection
        Returns stance, confidence, knowledge used, and relevance score
        """
        # Step 1: Preprocess input
        clean_text = preprocess_text(text)
        clean_target = preprocess_text(target)

        query="Search for tweets that fall under the category of, or explicitly mention "+ clean_target + "within the context of" +clean_text
        # Step 2: Retrieve knowledge
        knowledge = self.knowledge_retriever.retrieve(clean_target)

        # Step 3: Filter knowledge using CRG
        filtered_knowledge, relevance_score = self.crg.filter_knowledge(clean_text, clean_target, knowledge)

        # Step 4: Predict stance
        stance, confidence = self.stance_detector.predict_stance(clean_text, clean_target, filtered_knowledge)

        return {
            "text": text,
            "target": target,
            "stance": stance,
            "confidence": confidence,
            "knowledge": knowledge,
            "filtered_knowledge": filtered_knowledge,
            "relevance_score": relevance_score
        }

    def process_dataset(self, df, text_col='Tweet', target_col='Target', stance_col='Stance'):
        """
        Process an entire dataset for stance detection
        Returns dataframe with predictions and metrics if true labels are available
        """
        # Prepare dataset
        processed_df = prepare_dataset(df, text_col, target_col, stance_col)

        # Track results
        results = []

        # Process each sample
        for idx, row in tqdm(processed_df.iterrows(), total=len(processed_df), desc="Processing samples"):
            text = row['processed_text']
            target = row['processed_target']

            # Skip empty entries
            if not text or not target:
                continue

            # Process sample
            result = self.process_single(text, target)

            # Add true stance if available
            if stance_col and stance_col in df.columns:
                result["true_stance"] = row[stance_col]

            results.append(result)

        # Convert to DataFrame
        results_df = pd.DataFrame(results)

        # Calculate metrics if true labels available
        if "true_stance" in results_df.columns:
            # Convert stance labels to numerical format for metrics
            results_df['predicted_label'] = results_df['stance'].map(config.stance_map_to_num)
            results_df['true_label'] = results_df['true_stance'].map(config.stance_map_to_num)

            # Calculate and print metrics
            self._calculate_metrics(results_df)

        return results_df

    def _calculate_metrics(self, df):
        """Calculate and display performance metrics"""
        # Ensure we have the required columns
        if 'predicted_label' not in df.columns or 'true_label' not in df.columns:
            print("Cannot calculate metrics: missing predicted or true labels")
            return

        # Calculate metrics
        accuracy = accuracy_score(df['true_label'], df['predicted_label'])
        f1 = f1_score(df['true_label'], df['predicted_label'], average='weighted')
        report = classification_report(
            df['true_label'],
            df['predicted_label'],
            target_names=config.stance_labels,
            zero_division=0
        )
        conf_matrix = confusion_matrix(df['true_label'], df['predicted_label'])

        print("\nStance Detection Results:")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"F1 Score: {f1:.4f}")
        print("Classification Report:")
        print(report)

        # Plot confusion matrix
        plt.figure(figsize=(10, 8))
        sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                   xticklabels=config.stance_labels,
                   yticklabels=config.stance_labels)
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.title('Confusion Matrix - Stance Detection with CRG')
        plt.show()

        # Analyze CRG impact
        plt.figure(figsize=(10, 6))
        sns.boxplot(x='stance', y='relevance_score', data=df)
        plt.title('Knowledge Relevance Score Distribution by Predicted Stance')
        plt.xlabel('Predicted Stance')
        plt.ylabel('CRG Relevance Score')
        plt.show()

        # Analyze confidence distribution
        plt.figure(figsize=(10, 6))
        sns.histplot(data=df, x='confidence', hue='stance', bins=20, kde=True)
        plt.title('Confidence Score Distribution by Predicted Stance')
        plt.xlabel('Confidence Score')
        plt.ylabel('Frequency')
        plt.show()

        return {
            "accuracy": accuracy,
            "f1": f1,
            "report": report,
            "confusion_matrix": conf_matrix
        }


In [None]:
# 6. Running the System
def main():
    """Main function to run the stance detection pipeline"""
    print("Initializing Stance Detection Pipeline with Context Relevance Gate...")
    pipeline = StanceDetectionPipeline()

    # Check if dataset file exists
    dataset_path = '/kaggle/input/data12/VAST_test.csv'
    if os.path.exists(dataset_path):
        print(f"Loading dataset from {dataset_path}")
        df = load_data(dataset_path)
        df=df.sample(frac=1).head(100)
        # Process the dataset
        results = pipeline.process_dataset(df)

        # Save results
        results.to_csv('stance_detection_results1.csv', index=False)
        print("Results saved to 'stance_detection_results1.csv'")


    else:
        print("No dataset file found. Running example...")
        # Example usage
        example_text = "Nuclear energy is the only feasible solution to combat climate change at scale."
        example_target = "Nuclear energy"

        result = pipeline.process_single(example_text, example_target)

        print("\nExample Stance Detection:")
        print(f"Text: {result['text']}")
        print(f"Target: {result['target']}")
        print(f"Retrieved knowledge: {result['knowledge']}")
        print(f"Filtered knowledge: {result['filtered_knowledge']}")
        print(f"Relevance score: {result['relevance_score']:.4f}")
        print(f"Predicted stance: {result['stance']} (confidence: {result['confidence']:.4f})")
    data=load_data('/kaggle/working/stance_detection_results1.csv')
    pipeline._calculate_metrics(data)

In [None]:
# 7. Interactive Demo
def interactive_demo():
    """Interactive demo for testing stance detection"""
    print("Initializing Stance Detection Pipeline...")
    pipeline = StanceDetectionPipeline()

    print("\nZero-Shot Stance Detection with Context Relevance Gate")
    print("======================================================")
    print("Enter text and target to detect stance (type 'quit' to exit)")

    while True:
        text = input("\nEnter text: ")
        if text.lower() == 'quit':
            break

        target = input("Enter target: ")
        if target.lower() == 'quit':
            break

        print("\nProcessing...")
        result = pipeline.process_single(text, target)

        print("\nResults:")
        print(f"Text: {result['text']}")
        print(f"Target: {result['target']}")
        print(f"Retrieved knowledge: {result['knowledge']}")
        print(f"Filtered knowledge: {result['filtered_knowledge']}")
        print(f"Relevance score: {result['relevance_score']:.4f}")
        print(f"Predicted stance: {result['stance']} (confidence: {result['confidence']:.4f})")

# Run the main pipeline or interactive demo
if __name__ == "__main__":
    print("Running Stance Detection System...")
    print("1. Process dataset")
    print("2. Run interactive demo")
    choice = input("Enter your choice (1/2): ")

    if choice == "1":
        main()
    elif choice == "2":
        interactive_demo()
    else:
        print("Invalid choice. Exiting.")

Running Stance Detection System...
1. Process dataset
2. Run interactive demo
Enter your choice (1/2): 2
Initializing Stance Detection Pipeline...


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.4k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

Device set to use cpu



Zero-Shot Stance Detection with Context Relevance Gate
Enter text and target to detect stance (type 'quit' to exit)

Enter text: nuzvid is famous for mangoes
Enter target: nuzvid

Processing...

Results:
Text: nuzvid is famous for mangoes
Target: nuzvid
Retrieved knowledge: Nuzvid is a city in the Krishna district of the Indian state of Andhra Pradesh. It serves as the headquarters of the Nuzvid mandal. The city is known for its mango production and educational institutions, including the Rajiv Gandhi University of Knowledge Technologies (RGUKT), also known as IIIT Nuzvid. Agriculture and education are key economic activities in the region. Nuzvid is located near National Highway 65 and is accessible by road and rail.

Filtered knowledge: Nuzvid is a city in the Krishna district of the Indian state of Andhra Pradesh. It serves as the headquarters of the Nuzvid mandal. The city is known for its mango production and educational institutions, including the Rajiv Gandhi University of Know