In [1]:
import os
os.chdir("/Users/naveenkumar/Desktop/formula-1-bot")
%pwd

'/Users/naveenkumar/Desktop/formula-1-bot'

# Intent Bot

In [2]:
import torch
from transformers import BartTokenizer, BartForSequenceClassification
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
import logging
from typing import List, Tuple, Dict, Any
import re
import spacy

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class F1IntentDataset(Dataset):
    """Custom dataset for F1 intent classification"""
    
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(label, dtype=torch.float)
        }

class F1NER:
    """Custom NER for F1 entities (drivers, teams, sessions, etc.)"""
    
    def __init__(self):
        # Load spaCy model
        try:
            self.nlp = spacy.load("en_core_web_sm")
        except OSError:
            print("Installing spaCy model...")
            import subprocess
            subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
            self.nlp = spacy.load("en_core_web_sm")
        
        # F1-specific entity patterns
        self.f1_entities = {
            "DRIVER": {
                "verstappen": "Max VERSTAPPEN",
                "hamilton": "Lewis HAMILTON",
                "leclerc": "Charles LECLERC",
                "sainz": "Carlos SAINZ",
                "norris": "Lando NORRIS",
                "russell": "George RUSSELL",
                "alonso": "Fernando ALONSO",
                "albon": "Alexander ALBON",
                "tsunoda": "Yuki TSUNODA",
                "hulkenberg": "Nico HULKENBERG",
                "lawson": "Liam LAWSON",
                "antonelli": "Andrea Kimi ANTONELLI",
                "bortoleto": "Gabriel BORTOLETO",
                "hadjar": "Isack HADJAR",
                "max": "Max VERSTAPPEN",
                "lewis": "Lewis HAMILTON",
                "charles": "Charles LECLERC",
                "carlos": "Carlos SAINZ",
                "lando": "Lando NORRIS",
                "george": "George RUSSELL",
                "fernando": "Fernando ALONSO",
                "alexander": "Alexander ALBON",
                "yuki": "Yuki TSUNODA",
                "nico": "Nico HULKENBERG",
                "liam": "Liam LAWSON"
            },
            "TEAM": {
                "red bull": "Red Bull Racing",
                "ferrari": "Ferrari",
                "mercedes": "Mercedes",
                "mclaren": "McLaren",
                "aston martin": "Aston Martin",
                "williams": "Williams",
                "racing bulls": "Racing Bulls",
                "kick sauber": "Kick Sauber",
                "sauber": "Kick Sauber",
                "bull": "Red Bull Racing"
            },
            "SESSION": {
                "race": "Race",
                "qualifying": "Qualifying",
                "qualifying results": "Qualifying",
                "practice": "Practice",
                "practice 1": "Practice 1",
                "practice 2": "Practice 2",
                "practice 3": "Practice 3",
                "sprint": "Sprint",
                "sprint race": "Sprint"
            },
            "METRIC": {
                "race pace": "lap_times",
                "lap times": "lap_times",
                "pit stops": "pit_stops",
                "pit stop": "pit_stops",
                "tire strategy": "tire_strategy",
                "tire compound": "tire_strategy",
                "qualifying results": "qualifying_results",
                "race results": "race_results",
                "fastest lap": "fastest_laps",
                "position": "position_changes",
                "positions": "position_changes",
                "weather": "weather_conditions",
                "temperature": "weather_conditions",
                "safety car": "race_control",
                "incident": "race_control"
            },
            "TIME_CONTEXT": {
                "last race": "recent",
                "recent": "recent",
                "latest": "recent",
                "this season": "season",
                "season": "season",
                "overall": "season",
                "total": "season",
                "today": "recent",
                "yesterday": "recent"
            }
        }
    
    def extract_entities(self, text: str) -> Dict[str, Any]:
        """Extract F1 entities from text"""
        text_lower = text.lower()
        entities = {
            "drivers": [],
            "teams": [],
            "sessions": [],
            "metrics": [],
            "time_context": "recent",  # default
            "raw_text": text
        }
        
        # Extract drivers
        for driver_key, driver_full in self.f1_entities["DRIVER"].items():
            if driver_key in text_lower:
                entities["drivers"].append(driver_full)
        
        # Extract teams
        for team_key, team_full in self.f1_entities["TEAM"].items():
            if team_key in text_lower:
                entities["teams"].append(team_full)
        
        # Extract sessions
        for session_key, session_full in self.f1_entities["SESSION"].items():
            if session_key in text_lower:
                entities["sessions"].append(session_full)
        
        # Extract metrics
        for metric_key, metric_full in self.f1_entities["METRIC"].items():
            if metric_key in text_lower:
                entities["metrics"].append(metric_full)
        
        # Extract time context
        for time_key, time_value in self.f1_entities["TIME_CONTEXT"].items():
            if time_key in text_lower:
                entities["time_context"] = time_value
                break
        
        return entities
    
    def extract_implicit_intents(self, entities: Dict[str, Any]) -> List[str]:
        """Extract implicit intents from entities"""
        intents = []
        
        # Map metrics to intents
        metric_to_intent = {
            "lap_times": "driver_performance",
            "pit_stops": "pit_stops",
            "tire_strategy": "tire_strategy",
            "qualifying_results": "qualifying_results",
            "race_results": "race_results",
            "fastest_laps": "fastest_laps",
            "position_changes": "position_changes",
            "weather_conditions": "weather_conditions",
            "race_control": "race_control"
        }
        
        for metric in entities["metrics"]:
            if metric in metric_to_intent:
                intents.append(metric_to_intent[metric])
        
        # If no metrics found but drivers mentioned, assume driver performance
        if not intents and entities["drivers"]:
            intents.append("driver_performance")
        
        # If no metrics found but teams mentioned, assume team performance
        if not intents and entities["teams"]:
            intents.append("team_performance")
        
        return list(set(intents))  # Remove duplicates

class EnhancedBARTIntentClassifier:
    """Enhanced BART-based intent classifier with NER and multi-intent support"""
    
    def __init__(self, model_name: str = "facebook/bart-base"):
        self.model_name = model_name
        self.tokenizer = None
        self.model = None
        self.mlb = MultiLabelBinarizer()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.ner = F1NER()
        self.intent_categories = {
            "race_results": "Queries about race winners, positions, and final results",
            "qualifying_results": "Queries about qualifying sessions, pole positions, and grid order",
            "fastest_laps": "Queries about fastest lap times and lap records",
            "driver_performance": "Queries about individual driver performance and statistics",
            "team_performance": "Queries about team performance and comparisons",
            "position_changes": "Queries about position gains/losses and overtakes",
            "tire_strategy": "Queries about tire compounds, stints, and tire strategy",
            "pit_stops": "Queries about pit stop timing, duration, and strategy",
            "lap_times": "Queries about lap times, sector times, and consistency",
            "weather_conditions": "Queries about weather, temperature, and track conditions",
            "race_control": "Queries about incidents, flags, and race control decisions",
            "meeting_schedule": "Queries about race dates, schedules, and event information",
            "multi_table_query": "Complex queries requiring data from multiple tables",
            "general_inquiry": "General or ambiguous questions"
        }
    
    def load_model(self):
        """Load BART model and tokenizer"""
        try:
            logger.info(f"Loading BART model: {self.model_name}")
            self.tokenizer = BartTokenizer.from_pretrained(self.model_name)
            self.model = BartForSequenceClassification.from_pretrained(
                self.model_name,
                num_labels=len(self.intent_categories),
                problem_type="multi_label_classification"
            )
            self.model.to(self.device)
            logger.info(f"BART model loaded successfully on {self.device}")
        except Exception as e:
            logger.error(f"Error loading BART model: {e}")
            raise
    
    def parse_intent_string(self, intent_str: str) -> List[str]:
        """Parse intent string, handling both single and multi-label formats"""
        intent_str = str(intent_str).strip()
        
        # Remove quotes if present
        intent_str = intent_str.strip('"\'')
        
        # Split by comma and clean up
        if ',' in intent_str:
            intents = [intent.strip() for intent in intent_str.split(',')]
        else:
            intents = [intent_str]
        
        # Clean up any empty strings
        intents = [intent for intent in intents if intent]
        
        return intents
    
    def load_dataset_from_csv(self, csv_path: str) -> Tuple[List[str], List[List[str]]]:
        """Load and prepare dataset from CSV with proper parsing"""
        try:
            logger.info(f"Loading dataset from {csv_path}")
            df = pd.read_csv(csv_path)
            
            texts = []
            labels = []
            multi_label_count = 0
            
            for _, row in df.iterrows():
                question = row['question']
                intent_str = row['intent']
                
                # Parse intents
                intents = self.parse_intent_string(intent_str)
                
                # Count multi-label examples
                if len(intents) > 1:
                    multi_label_count += 1
                    logger.info(f"Multi-label example: {question} -> {intents}")
                
                texts.append(question)
                labels.append(intents)
            
            logger.info(f"Loaded {len(texts)} examples from CSV")
            logger.info(f"Found {multi_label_count} multi-label examples")
            
            # Show intent distribution
            all_intents = []
            for intent_list in labels:
                all_intents.extend(intent_list)
            
            intent_counts = pd.Series(all_intents).value_counts()
            logger.info(f"Intent distribution: {intent_counts.to_dict()}")
            
            return texts, labels
            
        except Exception as e:
            logger.error(f"Error loading CSV: {e}")
            raise
    
    def train(self, csv_path: str, epochs: int = 3, batch_size: int = 8, learning_rate: float = 2e-5):
        """Train the BART model"""
        logger.info("Starting Enhanced BART intent classifier training...")
        
        # Load model
        self.load_model()
        
        # Load dataset
        texts, labels = self.load_dataset_from_csv(csv_path)
        
        # Prepare labels
        self.mlb.fit(labels)
        y = self.mlb.transform(labels)
        
        logger.info(f"Number of unique intents: {len(self.mlb.classes_)}")
        logger.info(f"Intent classes: {list(self.mlb.classes_)}")
        
        # Split data
        X_train, X_test, y_train, y_test = train_test_split(
            texts, y, test_size=0.2, random_state=42
        )
        
        # Create datasets
        train_dataset = F1IntentDataset(X_train, y_train, self.tokenizer)
        test_dataset = F1IntentDataset(X_test, y_test, self.tokenizer)
        
        # Create dataloaders
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size)
        
        # Setup training
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
        criterion = torch.nn.BCEWithLogitsLoss()
        
        # Training loop
        self.model.train()
        for epoch in range(epochs):
            total_loss = 0
            for batch in train_loader:
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)
                
                optimizer.zero_grad()
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                loss = criterion(outputs.logits, labels)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            avg_loss = total_loss / len(train_loader)
            logger.info(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")
        
        # Evaluate
        self.model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in test_loader:
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)
                
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                predictions = torch.sigmoid(outputs.logits) > 0.5
                correct += (predictions == labels).all(dim=1).sum().item()
                total += labels.size(0)
        
        accuracy = correct / total
        logger.info(f"Test Accuracy: {accuracy:.3f}")
        
        return accuracy
    
    def classify_intent(self, question: str, threshold: float = 0.3) -> Tuple[List[str], float]:
        """Classify intent using BART model (original method for backward compatibility)"""
        if self.model is None or self.tokenizer is None:
            raise ValueError("Model not trained. Call train() first.")
        
        self.model.eval()
        
        # Tokenize input
        encoding = self.tokenizer(
            question,
            truncation=True,
            padding='max_length',
            max_length=128,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].to(self.device)
        attention_mask = encoding['attention_mask'].to(self.device)
        
        # Get predictions
        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
            probabilities = torch.sigmoid(outputs.logits)
        
        # Get intents above threshold
        predictions = (probabilities > threshold).squeeze()
        predicted_indices = torch.where(predictions)[0].cpu().numpy()
        
        # Convert to intent names
        predicted_intents = [self.mlb.classes_[idx] for idx in predicted_indices]
        
        # Calculate confidence (average of predicted probabilities)
        if len(predicted_indices) > 0:
            confidence = probabilities[0][predicted_indices].mean().item()
        else:
            confidence = 0.0
        
        return predicted_intents, confidence
    
    def classify_intent_with_ner(self, question: str, threshold: float = 0.2) -> Dict[str, Any]:
        """Enhanced classification with NER extraction"""
        
        # Extract entities using NER
        entities = self.ner.extract_entities(question)
        
        # Get explicit intents from BART model
        explicit_intents, confidence = self.classify_intent(question, threshold)
        
        # Get implicit intents from NER
        implicit_intents = self.ner.extract_implicit_intents(entities)
        
        # Combine intents (remove duplicates)
        all_intents = list(set(explicit_intents + implicit_intents))
        
        # If no intents found, try semantic search
        if not all_intents:
            semantic_intents = self._semantic_search(question)
            all_intents = semantic_intents
            confidence = 0.3
        
        return {
            "intents": all_intents,
            "confidence": confidence,
            "entities": entities,
            "explicit_intents": explicit_intents,
            "implicit_intents": implicit_intents
        }
    
    def _semantic_search(self, question: str) -> List[str]:
        """Enhanced semantic search with NER"""
        # Use NER to extract potential intents
        entities = self.ner.extract_entities(question)
        implicit_intents = self.ner.extract_implicit_intents(entities)
        
        if implicit_intents:
            return implicit_intents
        
        # Fallback to keyword matching
        question_lower = question.lower()
        
        intent_keywords = {
            "race_results": ["won", "winner", "race", "result", "podium"],
            "qualifying_results": ["qualify", "qualifying", "grid", "pole"],
            "driver_performance": ["perform", "pace", "driver", "lap"],
            "team_performance": ["team", "constructors"],
            "pit_stops": ["pit", "stop", "pitstop"],
            "tire_strategy": ["tire", "tyre", "compound", "strategy"],
            "weather_conditions": ["weather", "rain", "temperature", "wet"],
            "fastest_laps": ["fastest", "lap", "record"],
            "meeting_schedule": ["when", "next", "schedule", "date"],
            "race_control": ["safety", "car", "incident", "flag"]
        }
        
        detected_intents = []
        for intent, keywords in intent_keywords.items():
            if any(keyword in question_lower for keyword in keywords):
                detected_intents.append(intent)
        
        return detected_intents
    
    def save_model(self, filepath: str):
        """Save the trained model with proper serialization"""
        try:
            model_data = {
                'model_state_dict': self.model.state_dict(),
                'tokenizer': self.tokenizer,
                'mlb': self.mlb,
                'intent_categories': self.intent_categories,
                'model_name': self.model_name
            }
            
            # Save with proper serialization
            torch.save(model_data, filepath, _use_new_zipfile_serialization=False)
            logger.info(f"Model saved to {filepath}")
            
        except Exception as e:
            logger.error(f"Error saving model: {e}")
            raise
    
    def load_model_from_file(self, filepath: str):
        """Load a trained model with proper error handling"""
        try:
            # Try loading with weights_only=False for backward compatibility
            model_data = torch.load(filepath, map_location=self.device, weights_only=False)
            
            self.model_name = model_data['model_name']
            self.tokenizer = model_data['tokenizer']
            self.mlb = model_data['mlb']
            self.intent_categories = model_data['intent_categories']
            
            # Load model
            self.load_model()
            self.model.load_state_dict(model_data['model_state_dict'])
            
            logger.info(f"Model loaded from {filepath}")
            
        except Exception as e:
            logger.error(f"Error loading model: {e}")
            raise

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Create and train BART classifier
enhanced_classifier = EnhancedBARTIntentClassifier()
accuracy = enhanced_classifier.train("research/f1_intent_dataset.csv", epochs=3)
print(f"BART Training completed with accuracy: {accuracy:.3f}")

# Save the model
save_path = "artifacts/models/bart_intent_classifier.pth"
os.makedirs(os.path.dirname(save_path), exist_ok=True)
enhanced_classifier.save_model(save_path)

In [None]:
# Test questions
test_questions = [
    "What was the qualifying result and who won the race?",
    "Tell me about Hamilton's race pace and how many pit stops he made",
    "How did the weather affect Ferrari's tire strategy?",
    "What were the lap times and pit stop durations for Red Bull?",
    "What happened in the race and were there any safety cars?",
    "Tell me about qualifying results, fastest laps",
    "Tell me about driver performance, weather conditions, qualifying results",
]

print("\n=== Testing Basic BART Classification ===")
for question in test_questions:
    intents, confidence = enhanced_classifier.classify_intent(question, threshold=0.2)
    print(f"Question: {question}")
    print(f"Intents: {intents}")
    print(f"Number of intents: {len(intents)}")
    print(f"Confidence: {confidence:.3f}")
    print("-" * 50)

print("\n=== Testing Enhanced NER Classification ===")
for question in test_questions:
    result = enhanced_classifier.classify_intent_with_ner(question, threshold=0.2)
    print(f"Question: {question}")
    print(f"Intents: {result['intents']}")
    print(f"Entities: {result['entities']}")
    print(f"Confidence: {result['confidence']:.3f}")
    print("-" * 50)

# ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

# Generate SQL Queries

In [25]:
import torch
from typing import Optional

class HybridF1QueryGenerator:
    """Hybrid query generator that combines predefined queries with Ollama generation"""
    
    def __init__(self):
        self.ner = F1NER()
        self.predefined_generator = EnhancedF1QueryGenerator()  # Keep existing logic
        self.db_schema = self._load_db_schema()
        
        self.nl2sql_generator = FineTunedNL2SQLGenerator()
        self.nl2sql_generator.load_model()

    def _try_nl2sql_model(self, question: str, entities: Dict, intents: List[str]) -> Optional[Dict[str, Any]]:
        """Try to generate SQL using the pre-trained NL2SQL model"""
        try:
            generated_sql = self.nl2sql_generator.generate_sql(question)
            
            if self.nl2sql_generator.validate_sql(generated_sql):
                print("🤖 Using Pre-trained NL2SQL model")
                return {
                    "type": "pretrained_nl2sql",
                    "queries": [{
                        "intent": intents[0] if intents else "pretrained_nl2sql",
                        "query": generated_sql,
                        "filters": {"source": "pretrained_model"}
                    }],
                    "intents": intents,
                    "entities": entities,
                    "filters": {"source": "pretrained_model"}
                }
            else:
                print("⚠️ Pre-trained model generated invalid SQL")
                return None
                
        except Exception as e:
            print(f"Pre-trained model failed: {e}")
            return None
    
    
    # MODIFY THIS EXISTING METHOD (add NL2SQL as first option)
    def generate_dynamic_query(self, question: str, classification_result: Dict[str, Any]) -> Dict[str, Any]:
        """Generate query using hybrid approach"""
        
        intents = classification_result["intents"]
        entities = classification_result["entities"]
        
        # ADD THIS: Try NL2SQL model first
        if self.nl2sql_model:
            nl2sql_result = self._try_nl2sql_model(question, entities, intents)
            if nl2sql_result and self._validate_nl2sql_result(nl2sql_result):
                print(" Using NL2SQL model")
                return nl2sql_result
        
        # Keep your existing logic
        if self._can_use_predefined_query(question, entities, intents):
            print("🔧 Using predefined query")
            return self.predefined_generator.generate_dynamic_query(question, classification_result)
        
        print("🤖 Using Ollama-generated query")
        return self._generate_ollama_query(question, entities, intents, classification_result)
    
    # def _prepare_model_input(self, question: str, entities: Dict, intents: List[str]) -> str:
    #     """Prepare input for the NL2SQL model"""
    #     schema_text = self._format_schema_for_model()
    #     return f"Question: {question}\nSchema: {schema_text}\nGenerate SQL:"
    
    # def _format_schema_for_model(self) -> str:
    #     """Format schema for model input"""
    #     schema_lines = []
    #     for table_name, columns in self.schema["tables"].items():
    #         schema_lines.append(f"- Table: {table_name} ({', '.join(columns)})")
    #     return "\n".join(schema_lines)
        
    def _load_db_schema(self):
        """Load database schema for Ollama - ENHANCED VERSION"""
        return """
        -- Formula 1 Database Schema (Detailed)
        
        -- Main Tables with Column Details:
        
        -- drivers_transformed
        -- Columns: id, session_key, meeting_key, driver_number, full_name, team_name, created_at, team_name_encoded
        -- Primary key: (driver_number, session_key)
        
        -- positions_transformed  
        -- Columns: id, session_key, meeting_key, driver_number, position, date, created_at, position_change, position_std, is_leader, position_improved, position_declined, is_outlier
        -- Primary key: (driver_number, session_key)
        
        -- laps_transformed
        -- Columns: id, session_key, meeting_key, driver_number, lap_number, lap_duration, duration_sector_1, duration_sector_2, duration_sector_3, is_pit_out_lap, created_at, lap_time_std, lap_time_mean, lap_time_deviation, total_sector_time, sector_consistency, had_incident, safety_car_lap, is_outlier
        -- Primary key: (driver_number, session_key, lap_number)
        -- Note: is_outlier is BOOLEAN (true/false), not integer
        
        -- sessions_transformed
        -- Columns: session_key, meeting_key, session_name, session_type, date_start, date_end, created_at, session_type_encoded
        -- Primary key: session_key
        
        -- meetings
        -- Columns: meeting_key, meeting_name, country_name, circuit_short_name, date_start, year, created_at
        -- Primary key: meeting_key
        
        -- Additional Tables:
        
        -- pit_stops_transformed
        -- Columns: id, session_key, meeting_key, driver_number, lap_number, pit_duration, created_at, pit_stop_count, pit_stop_timing, normal_pit_stop, long_pit_stop, penalty_pit_stop, is_outlier
        -- Primary key: (driver_number, session_key, lap_number)
        
        -- stints_transformed
        -- Columns: id, session_key, meeting_key, driver_number, compound, lap_start, lap_end, tyre_age_at_start, created_at, stint_duration, tire_age_progression, is_outlier
        -- Primary key: (driver_number, session_key, lap_start)
        
        -- weather_transformed
        -- Columns: id, session_key, meeting_key, air_temperature, track_temperature, humidity, rainfall, date, created_at, temperature_delta, weather_severity, extreme_weather
        -- Primary key: (session_key, meeting_key)
        -- Note: rainfall, extreme_weather are BOOLEAN (true/false)
        
        -- intervals_transformed
        -- Columns: id, session_key, meeting_key, driver_number, gap_to_leader, interval, date, created_at, is_leader, is_lapped, is_outlier
        -- Primary key: (driver_number, session_key)
        -- Note: is_leader, is_lapped, is_outlier are BOOLEAN (true/false)
        
        -- race_control
        -- Columns: id, session_key, meeting_key, driver_number, category, flag, lap_number, message, scope, sector, date, created_at
        -- Primary key: (driver_number, session_key, lap_number)
        
        -- Key Relationships:
        -- All tables join on (driver_number, session_key) or (session_key, meeting_key)
        -- session_type can be 'Race', 'Qualifying', 'Practice 1', 'Practice 2', 'Practice 3', 'Sprint'
        -- Boolean fields use true/false, not 1/0
        -- Use proper table aliases: d for drivers_transformed, p for positions_transformed, l for laps_transformed, s for sessions_transformed, m for meetings
        """
    
    def generate_dynamic_query(self, question: str, classification_result: Dict[str, Any]) -> Dict[str, Any]:
        """Generate query using hybrid approach"""
        
        intents = classification_result["intents"]
        entities = classification_result["entities"]
        
        # Check if we can use predefined query
        if self._can_use_predefined_query(question, entities, intents):
            print("🔧 Using predefined query")
            return self.predefined_generator.generate_dynamic_query(question, classification_result)
        
        # Use Ollama for complex queries
        print("🤖 Using Ollama-generated query")
        return self._generate_ollama_query(question, entities, intents, classification_result)
    
    def _can_use_predefined_query(self, question: str, entities: Dict, intents: List[str]) -> bool:
        """Determine if we can use a predefined query"""
        
        # Simple heuristics for when to use predefined queries
        simple_patterns = [
            "who won", "who got", "what position", "qualifying results", 
            "race results", "fastest lap", "pit stops", "weather",
            "how did", "team performance", "driver performance"
        ]
        
        question_lower = question.lower()
        
        # If it's a simple pattern and we have clear entities/intents
        if any(pattern in question_lower for pattern in simple_patterns):
            if entities.get('drivers') or entities.get('teams') or entities.get('sessions'):
                return True
        
        # If we have multiple intents, it might be complex
        if len(intents) > 2:
            return False
        
        # If question is very long or complex
        if len(question.split()) > 15:
            return False
        
        return True
    
    def _generate_ollama_query(self, question: str, entities: Dict, intents: List[str], classification_result: Dict[str, Any]) -> Dict[str, Any]:
        """Generate query using Ollama - ENHANCED VERSION"""
        
        ollama_prompt = f"""
        You are an expert SQL query generator for Formula 1 data. 
        
        Database Schema:
        {self.db_schema}
        
        Question: {question}
        Extracted Entities: {entities}
        Detected Intents: {intents}
        
        Generate a safe, efficient SQL query that answers this question.
        
        Rules:
        1. Only use SELECT statements (no INSERT, UPDATE, DELETE)
        2. Always include proper JOINs between tables
        3. Use appropriate WHERE clauses for filtering
        4. Limit results to reasonable amounts (max 50 rows)
        5. Return ONLY the SQL query, no explanations, no markdown, no notes
        6. Use these table aliases: d for drivers_transformed, p for positions_transformed, l for laps_transformed, s for sessions_transformed, m for meetings, ps for pit_stops_transformed, st for stints_transformed, w for weather_transformed, i for intervals_transformed, rc for race_control
        7. Include relevant columns like driver_name, team_name, position, lap_times, etc.
        8. Boolean fields use true/false, not 1/0 (e.g., l.is_outlier = false, not l.is_outlier = 0)
        9. Do not add any explanatory text or notes after the query
        10. Always use proper JOIN syntax with ON clauses
        
        SQL Query:
        """
        
        try:
            response = requests.post(
                "http://localhost:11434/api/generate",
                json={
                    "model": "llama3",
                    "prompt": ollama_prompt,
                    "stream": False
                }
            )
            query = response.json().get("response", "").strip()
            
            # Clean up the query (remove markdown, extra text)
            query = self._clean_generated_query(query)
            
            # Validate safety
            if self._is_safe_query(query):
                return {
                    "type": "ollama_generated",
                    "queries": [{
                        "intent": "ollama_generated",
                        "query": query,
                        "filters": {"source": "ollama"}
                    }],
                    "intents": intents,
                    "entities": entities,
                    "filters": {"source": "ollama"}
                }
            else:
                # Fallback to predefined query
                print("⚠️ Ollama query failed safety check, using predefined fallback")
                return self.predefined_generator.generate_dynamic_query(question, classification_result)
                
        except Exception as e:
            print(f"❌ Ollama query generation failed: {e}")
            # Fallback to predefined query
            return self.predefined_generator.generate_dynamic_query(question, classification_result)
    
    def _clean_generated_query(self, query: str) -> str:
        """Clean up generated query - ENHANCED VERSION"""
        # Remove markdown code blocks
        if query.startswith("```sql"):
            query = query[6:]
        if query.endswith("```"):
            query = query[:-3]
        
        # Remove common Ollama response patterns
        lines = query.split('\n')
        cleaned_lines = []
        
        for line in lines:
            line = line.strip()
            # Skip lines that are explanations, notes, or comments
            if (line.startswith('Note:') or 
                line.startswith('--') or 
                line.startswith('#') or
                line.startswith('/*') or
                line.endswith('*/') or
                line.lower().startswith('the above query') or
                line.lower().startswith('this query') or
                line.lower().startswith('assumes') or
                line.lower().startswith('please note') or
                line.lower().startswith('important:') or
                line.lower().startswith('note that') or
                line == ''):
                continue
            cleaned_lines.append(line)
        
        # Join lines back together
        query = '\n'.join(cleaned_lines)
        
        # Remove extra whitespace and newlines
        query = query.strip()
        
        # Ensure it ends with semicolon
        if not query.endswith(';'):
            query += ';'
        
        return query
    
    def _is_safe_query(self, query: str) -> bool:
        """Basic safety check for generated queries"""
        query_upper = query.upper()
        
        # Check for dangerous keywords
        dangerous_keywords = ['INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER', 'EXEC', 'EXECUTE', 'TRUNCATE']
        for keyword in dangerous_keywords:
            if keyword in query_upper:
                return False
        
        # Check for basic SELECT structure
        if not query_upper.startswith('SELECT'):
            return False
        
        # Check for reasonable length
        if len(query) > 2000:
            return False
        
        # Check for required FROM clause
        if 'FROM' not in query_upper:
            return False
        
        return True

In [27]:
import pandas as pd
from typing import List, Dict, Any, Optional
import logging

logger = logging.getLogger(__name__)

class EnhancedF1QueryGenerator:
    """Enhanced query generator with NER integration"""
    
    def __init__(self):
        self.ner = F1NER()
        self.intent_queries = {
            "race_results": self._get_race_results_query,
            "qualifying_results": self._get_qualifying_results_query,
            "fastest_laps": self._get_fastest_laps_query,
            "driver_performance": self._get_driver_performance_query,
            "team_performance": self._get_team_performance_query,
            "tire_strategy": self._get_tire_strategy_query,
            "pit_stops": self._get_pit_stops_query,
            "weather_conditions": self._get_weather_conditions_query,
            "meeting_schedule": self._get_meeting_schedule_query,
            "race_control": self._get_race_control_query,
            "position_changes": self._get_position_changes_query,
            "lap_times": self._get_lap_times_query
        }
    
    def generate_dynamic_query(self, question: str, classification_result: Dict[str, Any]) -> Dict[str, Any]:
        """Generate dynamic query based on NER and intent classification"""
        
        intents = classification_result["intents"]
        entities = classification_result["entities"]
        
        # Build dynamic filters
        filters = self._build_dynamic_filters(entities)
        
        # Generate queries for each intent
        queries = []
        for intent in intents:
            if intent in self.intent_queries:
                query = self.intent_queries[intent](filters)
                queries.append({
                    "intent": intent,
                    "query": query,
                    "filters": filters
                })
        
        return {
            "type": "multi_intent" if len(queries) > 1 else "single_intent",
            "queries": queries,
            "intents": intents,
            "entities": entities,
            "filters": filters
        }
    
    def _build_dynamic_filters(self, entities: Dict[str, Any]) -> Dict[str, Any]:
        """Build dynamic SQL filters from extracted entities"""
        filters = {
            "driver_filter": "",
            "team_filter": "",
            "session_filter": "",
            "time_filter": "",
            "meeting_filter": "",
            "limit": 10
        }
        
        # Driver filter
        if entities["drivers"]:
            driver_names = "', '".join(entities["drivers"])
            filters["driver_filter"] = f"AND d.full_name IN ('{driver_names}')"
        
        # Team filter
        if entities["teams"]:
            team_names = "', '".join(entities["teams"])
            filters["team_filter"] = f"AND d.team_name IN ('{team_names}')"
        
        # Session filter
        if entities["sessions"]:
            session_names = "', '".join(entities["sessions"])
            filters["session_filter"] = f"AND s.session_name IN ('{session_names}')"
        
        # Meeting filter (for specific races)
        if entities.get("meeting_name"):
            filters["meeting_filter"] = f"AND m.meeting_name = '{entities['meeting_name']}'"
        
        # Time filter - FIXED
        if entities["time_context"] == "recent":
            filters["time_filter"] = "AND s.date_start >= NOW() - INTERVAL '30 days'"
        elif entities["time_context"] == "season":
            filters["time_filter"] = "AND EXTRACT(YEAR FROM s.date_start) = EXTRACT(YEAR FROM NOW())"
        elif entities["time_context"] == "last_race":
            filters["time_filter"] = "AND s.date_start = (SELECT MAX(date_start) FROM sessions_transformed WHERE session_type = 'Race')"
        
        return filters
    
    def _get_race_results_query(self, filters: Dict[str, Any]) -> str:
        """Get race results query with dynamic filters"""
        return f"""
            SELECT 
                d.full_name,
                d.team_name,
                p.position,
                i.gap_to_leader,
                m.meeting_name,
                s.session_name,
                s.date_start
            FROM positions_transformed p
            JOIN drivers_transformed d ON p.driver_number = d.driver_number AND p.session_key = d.session_key
            JOIN sessions_transformed s ON p.session_key = s.session_key
            JOIN meetings m ON p.meeting_key = m.meeting_key
            LEFT JOIN intervals_transformed i ON p.driver_number = i.driver_number AND p.session_key = i.session_key
            WHERE s.session_type = 'Race'
                {filters['driver_filter']} 
                {filters['team_filter']} 
                {filters['time_filter']}
            ORDER BY s.date_start DESC, p.position
            LIMIT {filters['limit']}
        """
    
    def _get_qualifying_results_query(self, filters: Dict[str, Any]) -> str:
        """Get qualifying results query with dynamic filters"""
        return f"""
            SELECT 
                d.full_name,
                d.team_name,
                p.position,
                l.lap_duration as best_lap_time,
                m.meeting_name,
                s.date_start
            FROM positions_transformed p
            JOIN drivers_transformed d ON p.driver_number = d.driver_number AND p.session_key = d.session_key
            JOIN sessions_transformed s ON p.session_key = s.session_key
            JOIN meetings m ON p.meeting_key = m.meeting_key
            LEFT JOIN laps_transformed l ON p.driver_number = l.driver_number AND p.session_key = l.session_key
            WHERE s.session_type = 'Qualifying'
                {filters['driver_filter']} 
                {filters['team_filter']} 
                {filters['time_filter']}
            ORDER BY s.date_start DESC, p.position
            LIMIT {filters['limit']}
        """
    
    def _get_fastest_laps_query(self, filters: Dict[str, Any]) -> str:
        """Get fastest laps query with dynamic filters"""
        return f"""
            SELECT 
                d.full_name,
                d.team_name,
                l.lap_duration,
                l.lap_number,
                m.meeting_name,
                s.session_name,
                s.date_start
            FROM laps_transformed l
            JOIN drivers_transformed d ON l.driver_number = d.driver_number AND l.session_key = d.session_key
            JOIN sessions_transformed s ON l.session_key = s.session_key
            JOIN meetings m ON l.meeting_key = m.meeting_key
            WHERE l.is_outlier = false
                {filters['driver_filter']} 
                {filters['team_filter']} 
                {filters['time_filter']}
            ORDER BY l.lap_duration
            LIMIT {filters['limit']}
        """
    
    def _get_team_performance_query(self, filters: Dict[str, Any]) -> str:
        """Get team performance query with dynamic filters"""
        return f"""
            SELECT 
                d.team_name,
                AVG(l.lap_duration) as avg_lap_time,
                COUNT(l.id) as total_laps,
                AVG(p.position) as avg_position,
                COUNT(CASE WHEN i.is_leader = true THEN 1 END) as leading_laps
            FROM drivers_transformed d
            JOIN laps_transformed l ON d.driver_number = l.driver_number AND d.session_key = l.session_key
            JOIN positions_transformed p ON d.driver_number = p.driver_number AND d.session_key = p.session_key
            LEFT JOIN intervals_transformed i ON d.driver_number = i.driver_number AND d.session_key = i.session_key
            JOIN sessions_transformed s ON d.session_key = s.session_key
            WHERE 1=1
                {filters['team_filter']} 
                {filters['time_filter']}
            GROUP BY d.team_name
            ORDER BY avg_lap_time
            LIMIT {filters['limit']}
        """
    
    def _get_tire_strategy_query(self, filters: Dict[str, Any]) -> str:
        """Get tire strategy query with dynamic filters"""
        return f"""
            SELECT 
                d.full_name,
                d.team_name,
                st.compound,
                COUNT(st.id) as stint_count,
                AVG(st.stint_duration) as avg_stint_duration,
                SUM(st.stint_duration) as total_stint_time
            FROM stints_transformed st
            JOIN drivers_transformed d ON st.driver_number = d.driver_number AND st.session_key = d.session_key
            JOIN sessions_transformed s ON st.session_key = s.session_key
            WHERE s.session_type = 'Race'
                {filters['driver_filter']} 
                {filters['team_filter']} 
                {filters['time_filter']}
            GROUP BY d.full_name, d.team_name, st.compound
            ORDER BY d.full_name, st.compound
            LIMIT {filters['limit']}
        """
    
    def _get_weather_conditions_query(self, filters: Dict[str, Any]) -> str:
        """Get weather conditions query with dynamic filters"""
        return f"""
            SELECT 
                m.meeting_name,
                s.session_name,
                AVG(w.air_temperature) as avg_air_temp,
                AVG(w.track_temperature) as avg_track_temp,
                AVG(w.humidity) as avg_humidity,
                COUNT(CASE WHEN w.rainfall = true THEN 1 END) as rainy_laps,
                COUNT(w.id) as total_laps
            FROM weather_transformed w
            JOIN sessions_transformed s ON w.session_key = s.session_key
            JOIN meetings m ON w.meeting_key = m.meeting_key
            WHERE 1=1
                {filters['time_filter']}
            GROUP BY m.meeting_name, s.session_name, s.date_start
            ORDER BY s.date_start DESC
            LIMIT {filters['limit']}
        """
    
    def _get_race_control_query(self, filters: Dict[str, Any]) -> str:
        """Get race control query with dynamic filters"""
        return f"""
            SELECT 
                d.full_name,
                d.team_name,
                rc.category,
                rc.flag,
                rc.lap_number,
                rc.message,
                rc.scope,
                m.meeting_name,
                s.session_name
            FROM race_control rc
            JOIN drivers_transformed d ON rc.driver_number = d.driver_number AND rc.session_key = d.session_key
            JOIN sessions_transformed s ON rc.session_key = s.session_key
            JOIN meetings m ON rc.meeting_key = m.meeting_key
            WHERE 1=1
                {filters['driver_filter']} 
                {filters['team_filter']} 
                {filters['time_filter']}
            ORDER BY rc.date DESC
            LIMIT {filters['limit']}
        """
    
    def _get_meeting_schedule_query(self, filters: Dict[str, Any]) -> str:
        """Get meeting schedule query with dynamic filters"""
        return f"""
            SELECT 
                meeting_name,
                country_name,
                circuit_short_name,
                date_start,
                year
            FROM meetings
            WHERE 1=1
                {filters['time_filter']}
            ORDER BY date_start DESC
            LIMIT {filters['limit']}
        """
    
    def _get_lap_times_query(self, filters: Dict[str, Any]) -> str:
        """Get lap times query with dynamic filters"""
        return f"""
            SELECT 
                d.full_name,
                d.team_name,
                l.lap_number,
                l.lap_duration,
                l.duration_sector_1,
                l.duration_sector_2,
                l.duration_sector_3,
                m.meeting_name,
                s.session_name
            FROM laps_transformed l
            JOIN drivers_transformed d ON l.driver_number = d.driver_number AND l.session_key = d.session_key
            JOIN sessions_transformed s ON l.session_key = s.session_key
            JOIN meetings m ON l.meeting_key = m.meeting_key
            WHERE l.is_outlier = false
                {filters['driver_filter']} 
                {filters['team_filter']} 
                {filters['time_filter']}
            ORDER BY l.lap_duration
            LIMIT {filters['limit']}
        """
    
    def _get_position_changes_query(self, filters: Dict[str, Any]) -> str:
        """Get position changes query with dynamic filters"""
        return f"""
            SELECT 
                d.full_name,
                d.team_name,
                p.position_change,
                p.position,
                m.meeting_name,
                s.session_name,
                p.date
            FROM positions_transformed p
            JOIN drivers_transformed d ON p.driver_number = d.driver_number AND p.session_key = d.session_key
            JOIN sessions_transformed s ON p.session_key = s.session_key
            JOIN meetings m ON p.meeting_key = m.meeting_key
            WHERE p.position_change IS NOT NULL
                {filters['driver_filter']} 
                {filters['team_filter']}
                {filters['meeting_filter']}
                {filters['time_filter']}
            ORDER BY ABS(p.position_change) DESC
            LIMIT {filters['limit']}
        """
    
    def _get_driver_performance_query(self, filters: Dict[str, Any]) -> str:
        """Get driver performance query with dynamic filters"""
        # Check if we need a simple position query or detailed performance
        if filters.get('meeting_name') and filters.get('session_filter'):
            # Simple position query for specific race
            return f"""
                SELECT 
                    d.full_name,
                    d.team_name,
                    p.position,
                    p.position_change,
                    m.meeting_name,
                    s.session_name,
                    s.date_start
                FROM positions_transformed p
                JOIN drivers_transformed d ON p.driver_number = d.driver_number AND p.session_key = d.session_key
                JOIN sessions_transformed s ON p.session_key = s.session_key
                JOIN meetings m ON p.meeting_key = m.meeting_key
                WHERE 1=1 
                    {filters['driver_filter']} 
                    {filters['team_filter']} 
                    {filters['session_filter']} 
                    {filters['time_filter']}
                    {filters.get('meeting_filter', '')}
                ORDER BY s.date_start DESC, p.position
                LIMIT {filters['limit']}
            """
        else:
            # Detailed performance query
            return f"""
                SELECT 
                    d.full_name,
                    d.team_name,
                    AVG(l.lap_duration) as avg_lap_time,
                    COUNT(l.id) as total_laps,
                    COUNT(DISTINCT s.session_key) as sessions_participated,
                    AVG(p.position) as avg_position,
                    COUNT(CASE WHEN i.is_leader = true THEN 1 END) as leading_laps,
                    m.meeting_name,
                    s.session_name,
                    s.date_start
                FROM drivers_transformed d
                LEFT JOIN laps_transformed l ON d.driver_number = l.driver_number AND d.session_key = l.session_key
                LEFT JOIN positions_transformed p ON d.driver_number = p.driver_number AND d.session_key = p.session_key
                LEFT JOIN intervals_transformed i ON d.driver_number = i.driver_number AND d.session_key = i.session_key
                LEFT JOIN sessions_transformed s ON d.session_key = s.session_key
                LEFT JOIN meetings m ON d.meeting_key = m.meeting_key
                WHERE 1=1 
                    {filters['driver_filter']} 
                    {filters['team_filter']} 
                    {filters['session_filter']} 
                    {filters['time_filter']}
                GROUP BY d.full_name, d.team_name, m.meeting_name, s.session_name, s.date_start
                ORDER BY avg_lap_time ASC
                LIMIT {filters['limit']}
            """
    
    def _get_pit_stops_query(self, filters: Dict[str, Any]) -> str:
        """Get pit stops query with dynamic filters"""
        return f"""
            SELECT 
                d.full_name,
                d.team_name,
                COUNT(ps.id) as pit_stop_count,
                AVG(ps.pit_duration) as avg_pit_duration,
                MIN(ps.pit_duration) as fastest_pit_stop,
                MAX(ps.pit_duration) as slowest_pit_stop,
                m.meeting_name,
                s.session_name,
                s.date_start
            FROM drivers_transformed d
            LEFT JOIN pit_stops_transformed ps ON d.driver_number = ps.driver_number AND d.session_key = ps.session_key
            LEFT JOIN sessions_transformed s ON d.session_key = s.session_key
            LEFT JOIN meetings m ON d.meeting_key = m.meeting_key
            WHERE ps.id IS NOT NULL 
                {filters['driver_filter']} 
                {filters['team_filter']} 
                {filters['session_filter']} 
                {filters['time_filter']}
            GROUP BY d.full_name, d.team_name, m.meeting_name, s.session_name, s.date_start
            ORDER BY pit_stop_count DESC
            LIMIT {filters['limit']}
        """

In [30]:
# ... existing code ...

# Fine-tune Pre-trained Text-to-SQL Model on F1 Dataset
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments, TrainerCallback
from torch.utils.data import Dataset
import logging
from typing import List, Dict, Any, Optional
import gc
import os

logger = logging.getLogger(__name__)

class F1SQLDataset(Dataset):
    """Custom dataset for F1 SQL generation - IMPROVED VERSION"""
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        input_text = item["input"]
        target_text = item["target"]
        
        # Tokenize input and target
        inputs = self.tokenizer(
            input_text, 
            max_length=self.max_length, 
            padding="max_length", 
            truncation=True, 
            return_tensors="pt"
        )
        
        targets = self.tokenizer(
            target_text, 
            max_length=self.max_length, 
            padding="max_length", 
            truncation=True, 
            return_tensors="pt"
        )
        
        return {
            "input_ids": inputs["input_ids"].squeeze(),
            "attention_mask": inputs["attention_mask"].squeeze(),
            "labels": targets["input_ids"].squeeze()
        }

class ProgressCallback(TrainerCallback):
    """Custom callback to monitor training progress"""
    def __init__(self, log_steps=10):
        self.log_steps = log_steps
        self.step_count = 0
    
    def on_step_end(self, args, state, control, **kwargs):
        self.step_count += 1
        if self.step_count % self.log_steps == 0:
            if state.log_history:
                latest_log = state.log_history[-1]
                loss = latest_log.get('loss', 'N/A')
                print(f"Step {self.step_count}: Loss = {loss:.4f}")
    
    def on_epoch_end(self, args, state, control, **kwargs):
        print(f"Epoch {state.epoch} completed. Total steps: {self.step_count}")

def prepare_f1_training_data(dataset_path: str, max_examples: Optional[int] = None) -> List[Dict[str, Any]]:
    """Prepare F1 training data for fine-tuning - IMPROVED VERSION"""
    print(f"Loading F1 dataset from {dataset_path}")
    
    with open(dataset_path, 'r') as f:
        data = json.load(f)
    
    if max_examples and len(data) > max_examples:
        data = data[:max_examples]
        print(f"Using {len(data)} examples for fine-tuning")
    
    print("Preparing training examples...")
    training_data = []
    
    for i, item in enumerate(data):
        try:
            # Format schema properly for the pre-trained model
            schema_text = "Tables:\n"
            for table_name, columns in item['schema']['tables'].items():
                schema_text += f"{table_name} ({', '.join(columns)})\n"
            
            input_text = f"Question: {item['question']}\n{schema_text}"
            
            # Clean up the SQL - remove leading newlines and indentation
            sql = item['sql'].strip()
            lines = sql.split('\n')
            cleaned_lines = []
            for line in lines:
                line = line.strip()
                if line:
                    cleaned_lines.append(line)
            
            # Join lines back together
            cleaned_sql = ' '.join(cleaned_lines)
            
            # Ensure it starts with SELECT
            if not cleaned_sql.upper().startswith('SELECT'):
                print(f"Warning: SQL doesn't start with SELECT: {cleaned_sql[:100]}...")
                continue
            
            training_data.append({
                'input': input_text,
                'target': cleaned_sql
            })
            
            if i < 2:  # Show first 2 examples
                print(f"\nExample {i+1}:")
                print(f"Input: {input_text[:200]}...")
                print(f"Target: {cleaned_sql[:200]}...")
                
        except Exception as e:
            print(f"Error processing example {i}: {e}")
            continue
    
    print(f"Prepared {len(training_data)} training examples")
    return training_data

def fine_tune_pretrained_model(dataset_path: str, max_examples: int = 500, epochs: int = 3):
    """Fine-tune the pre-trained text-to-SQL model on F1 data"""
    print("=== Fine-tuning Pre-trained Text-to-SQL Model on F1 Dataset ===")
    
    try:
        # Step 1: Prepare F1 training data
        training_data = prepare_f1_training_data(dataset_path, max_examples)
        
        if not training_data:
            raise ValueError("No training data prepared")
        
        # Step 2: Load pre-trained model and tokenizer
        print("\nLoading pre-trained model for fine-tuning...")
        model_name = "juierror/text-to-sql-with-table-schema"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        
        # Set pad token if not set
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Step 3: Create dataset
        print("Creating F1 dataset...")
        dataset = F1SQLDataset(training_data, tokenizer)
        
        # Step 4: Training arguments for fine-tuning
        # Step 4: Training arguments - SIMPLIFIED VERSION
        training_args = TrainingArguments(
            output_dir="./f1_nl2sql_finetuned",
            num_train_epochs=epochs,
            per_device_train_batch_size=2,
            learning_rate=3e-5,
            logging_steps=10,
            save_steps=100,
            save_total_limit=3,
            remove_unused_columns=False,
            gradient_accumulation_steps=2,
            dataloader_pin_memory=False,
            dataloader_num_workers=0,
            report_to=None,
            logging_dir="./logs",
            logging_first_step=True,
            save_strategy="steps",
            fp16=False,
            dataloader_drop_last=False
            # Removed: load_best_model_at_end, evaluation_strategy, eval_steps
        )
        
        # Step 5: Create trainer
        print("Setting up trainer for fine-tuning...")
        progress_callback = ProgressCallback(log_steps=10)
        
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=dataset,
            tokenizer=tokenizer,
            callbacks=[progress_callback]
        )
        
        # Step 6: Fine-tune model
        print(f"Starting fine-tuning...")
        print(f"Dataset size: {len(dataset)}")
        print(f"Batch size: {training_args.per_device_train_batch_size}")
        print(f"Epochs: {training_args.num_train_epochs}")
        print(f"Learning rate: {training_args.learning_rate}")
        
        trainer.train()
        
        # Step 7: Save fine-tuned model
        print("Saving fine-tuned model...")
        os.makedirs("artifacts/models", exist_ok=True)
        model.save_pretrained("artifacts/models/f1_nl2sql_finetuned")
        tokenizer.save_pretrained("artifacts/models/f1_nl2sql_finetuned")
        
        # Clean up
        del trainer, model, tokenizer
        gc.collect()
        
        print("✅ Fine-tuned model saved successfully!")
        return True
        
    except Exception as e:
        print(f"❌ Fine-tuning failed: {e}")
        import traceback
        traceback.print_exc()
        return False

class FineTunedNL2SQLGenerator(PreTrainedNL2SQLGenerator):
    """Fine-tuned text-to-SQL generator for F1 data"""
    
    def __init__(self, model_path: str = "artifacts/models/f1_nl2sql_finetuned"):
        super().__init__()
        self.model_path = model_path
        self.fine_tuned_model = None
        self.fine_tuned_tokenizer = None
    
    def load_fine_tuned_model(self):
        """Load the fine-tuned model"""
        try:
            print(f"Loading fine-tuned model from {self.model_path}")
            self.fine_tuned_tokenizer = AutoTokenizer.from_pretrained(self.model_path)
            self.fine_tuned_model = AutoModelForSeq2SeqLM.from_pretrained(self.model_path)
            self.fine_tuned_model.to(self.device)
            print(f"✅ Fine-tuned model loaded successfully on {self.device}")
        except Exception as e:
            print(f"❌ Error loading fine-tuned model: {e}")
            print("Falling back to pre-trained model...")
            self.load_model()  # Load the original pre-trained model
    
    def generate_sql(self, question: str, max_length: int = 512) -> str:
        """Generate SQL using fine-tuned model if available"""
        if self.fine_tuned_model is None:
            self.load_fine_tuned_model()
        
        # Use fine-tuned model if available, otherwise fall back to pre-trained
        model_to_use = self.fine_tuned_model if self.fine_tuned_model else self.model
        tokenizer_to_use = self.fine_tuned_tokenizer if self.fine_tuned_tokenizer else self.tokenizer
        
        if model_to_use is None or tokenizer_to_use is None:
            raise ValueError("No model available")
        
        # Format input with schema
        schema = self.format_schema_for_model()
        input_text = f"Question: {question}\n{schema}"
        
        # Tokenize input
        inputs = tokenizer_to_use(
            input_text,
            return_tensors="pt",
            max_length=512,
            truncation=True,
            padding=True
        ).to(self.device)
        
        # Generate SQL
        with torch.no_grad():
            outputs = model_to_use.generate(
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=max_length,
                num_beams=4,
                early_stopping=True,
                pad_token_id=tokenizer_to_use.pad_token_id,
                eos_token_id=tokenizer_to_use.eos_token_id,
                do_sample=False,
                temperature=1.0,
                no_repeat_ngram_size=3
            )
        
        # Decode output
        generated_sql = tokenizer_to_use.decode(
            outputs[0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
        
        return generated_sql.strip()

# Fine-tuning execution
def run_fine_tuning():
    """Run the fine-tuning process"""
    print("=== F1 Text-to-SQL Model Fine-tuning ===")
    
    # Step 1: Fine-tune with moderate dataset
    print("\n1. Fine-tuning with 500 examples...")
    success = fine_tune_pretrained_model(
        "research/f1_training_balanced_dataset.json", 
        max_examples=500, 
        epochs=3
    )
    
    if success:
        print("\n2. Testing the fine-tuned model...")
        fine_tuned_generator = FineTunedNL2SQLGenerator()
        
        test_questions = [
            "Who won the Miami Grand Prix?",
            "What position did Lewis Hamilton get?",
            "What was the fastest lap?",
            "How did Max Verstappen perform?",
            "Show me the qualifying results",
            "What were the weather conditions?"
        ]
        
        for question in test_questions:
            print(f"\n❓ Question: {question}")
            try:
                sql = fine_tuned_generator.generate_sql(question)
                print(f"🤖 Generated SQL: {sql}")
                
                if fine_tuned_generator.validate_sql(sql):
                    print(f"✅ Valid SQL")
                else:
                    print(f"⚠️  Invalid SQL")
            except Exception as e:
                print(f"❌ Error: {e}")
            print("-" * 50)
        
        print("\n3. If successful, you can fine-tune with more data:")
        print("   - Increase max_examples to 1001 (full dataset)")
        print("   - Increase epochs to 5-10")
        print("   - Try different learning rates")
        
    else:
        print("❌ Fine-tuning failed. Check the error messages above.")

# Run fine-tuning
run_fine_tuning()

=== F1 Text-to-SQL Model Fine-tuning ===

1. Fine-tuning with 500 examples...
=== Fine-tuning Pre-trained Text-to-SQL Model on F1 Dataset ===
Loading F1 dataset from research/f1_training_balanced_dataset.json
Using 500 examples for fine-tuning
Preparing training examples...

Example 1:
Input: Question: What was Alex ALBON's finishing position in the {race_name}?
Tables:
positions_transformed (driver_number, position, session_key, meeting_key, date, position_change, is_leader)
drivers_trans...
Target: SELECT d.full_name, d.team_name, p.position, m.meeting_name, s.session_name, s.date_start FROM positions_transformed p JOIN drivers_transformed d ON p.driver_number = d.driver_number AND p.session_key...

Example 2:
Input: Question: What position did Lando NORRIS end up in the {race_name}?
Tables:
positions_transformed (driver_number, position, session_key, meeting_key, date, position_change, is_leader)
drivers_transfor...
Target: SELECT d.full_name, d.team_name, p.position, m.meeting_name

  trainer = Trainer(


Starting fine-tuning...
Dataset size: 500
Batch size: 2
Epochs: 3
Learning rate: 3e-05


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss
1,17.1971
10,10.6136
20,3.3606
30,1.4791
40,1.0556
50,0.725
60,0.5425
70,0.4147
80,0.3316
90,0.2598


Step 10: Loss = 17.1971
Step 20: Loss = 10.6136
Step 30: Loss = 3.3606
Step 40: Loss = 1.4791
Step 50: Loss = 1.0556
Step 60: Loss = 0.7250
Step 70: Loss = 0.5425
Step 80: Loss = 0.4147
Step 90: Loss = 0.3316
Step 100: Loss = 0.2598
Step 110: Loss = 0.2350
Step 120: Loss = 0.1827
Epoch 1.0 completed. Total steps: 125
Step 130: Loss = 0.1498
Step 140: Loss = 0.1378
Step 150: Loss = 0.1107
Step 160: Loss = 0.0969
Step 170: Loss = 0.0965
Step 180: Loss = 0.1002
Step 190: Loss = 0.0795
Step 200: Loss = 0.0756
Step 210: Loss = 0.0703
Step 220: Loss = 0.0706
Step 230: Loss = 0.0594
Step 240: Loss = 0.0632
Step 250: Loss = 0.0627
Epoch 2.0 completed. Total steps: 250
Step 260: Loss = 0.0530
Step 270: Loss = 0.0467
Step 280: Loss = 0.0543
Step 290: Loss = 0.0526
Step 300: Loss = 0.0520
Step 310: Loss = 0.0455
Step 320: Loss = 0.0447
Step 330: Loss = 0.0493
Step 340: Loss = 0.0427
Step 350: Loss = 0.0434
Step 360: Loss = 0.0406
Step 370: Loss = 0.0456
Epoch 3.0 completed. Total steps: 375
Savin

# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [None]:
import requests
import pandas as pd
from typing import List, Dict, Any, Optional
import logging

logger = logging.getLogger(__name__)

# Initialize database connection
from src.formula_one.entity.config_entity import DatabaseConfig
from src.formula_one.utils.database_utils import DatabaseUtils

db_config = DatabaseConfig()
db_utils = DatabaseUtils(db_config)

# Load the trained Enhanced BART model
enhanced_classifier = EnhancedBARTIntentClassifier()
enhanced_classifier.load_model_from_file("artifacts/models/bart_intent_classifier.pth")

# Initialize the hybrid query generator
query_generator = HybridF1QueryGenerator()

def execute_query_with_data_fixed(db_utils, query):
    """Execute query and return the data directly using the correct method"""
    try:
        # Use execute_query_with_result instead of execute_query
        rows = db_utils.execute_query_with_result(query)
        
        # Get column names by executing a modified query
        conn = db_utils.connect_to_db()
        cursor = conn.cursor()
        cursor.execute(query)
        columns = [desc[0] for desc in cursor.description] if cursor.description else []
        cursor.close()
        conn.close()
        
        return rows, columns
    except Exception as e:
        print(f"Error executing query: {e}")
        return [], []

def enhanced_f1_pipeline(question: str, enhanced_classifier, query_generator, db_utils):
    """Complete enhanced F1 pipeline with NER and multi-intent support"""
    
    # Step 1: Classify intent with NER
    classification_result = enhanced_classifier.classify_intent_with_ner(question)
    
    # Step 2: Generate dynamic queries
    query_result = query_generator.generate_dynamic_query(question, classification_result)
    
    # Step 3: Execute queries
    results = []
    for query_info in query_result["queries"]:
        try:
            rows, columns = execute_query_with_data_fixed(db_utils, query_info["query"])
            results.append({
                "intent": query_info["intent"],
                "query": query_info["query"],
                "data": {"rows": rows, "columns": columns},
                "filters": query_info["filters"],
                "success": True
            })
        except Exception as e:
            print(f"Error executing {query_info['intent']} query: {e}")
            results.append({
                "intent": query_info["intent"],
                "query": query_info["query"],
                "data": f"Error: {str(e)}",
                "filters": query_info["filters"],
                "success": False
            })
    
    return {
        "question": question,
        "classification": classification_result,
        "query_result": query_result,
        "results": results,
        "success": len([r for r in results if r["success"]]) > 0
    }

def display_enhanced_results(result):
    """Display enhanced query results in a readable format"""
    print(f"Question: {result['question']}")
    print(f"Extracted Entities: {result['classification']['entities']}")
    print(f"Detected Intents: {result['classification']['intents']}")
    print(f"Confidence: {result['classification']['confidence']:.3f}")
    print(f"Query Type: {result['query_result']['type']}")
    print(f"Overall Success: {result['success']}")
    
    if result['success']:
        for res in result['results']:
            print(f"\n--- {res['intent']} ---")
            print(f"Filters Applied: {res['filters']}")
            print(f"Success: {res['success']}")
            
            if res['success']:
                data = res['data']
                rows = data['rows']
                columns = data['columns']
                
                if rows:
                    print(f"Columns: {columns}")
                    print(f"Number of rows: {len(rows)}")
                    print("First 3 rows:")
                    for i, row in enumerate(rows[:3]):
                        print(f"  Row {i+1}: {row}")
                else:
                    print("No data returned")
            else:
                print(f"Error: {res['data']}")
    else:
        print("No successful queries executed")
    
    print("-" * 50)

# Test the hybrid pipeline with questions that should trigger Ollama
test_questions = [
    # Simple questions (should use predefined queries)
    "Who won the last race?",
    "How did Verstappen perform?",
    "What were the qualifying results?",
    
    # Complex questions (should use Ollama)
    "Which drivers had the best tire strategy in wet conditions during qualifying?",
    "How did the weather affect lap times and pit stop strategies?",
    "Compare the performance of Red Bull and Ferrari in different weather conditions",
    "What was the correlation between tire compounds and lap times in the last race?",
    "Which teams had the most consistent performance across different sessions?",
    "How did track temperature affect qualifying performance?",
    "What was the impact of weather conditions on pit stop strategies?",
    
    # Edge cases
    "What position did Hamilton get at the Miami race this year?",
    "How did McLaren perform this season?",
    "What was the fastest lap in qualifying?"
]

print("=== Testing Hybrid Query Generator ===")
print("�� = Predefined Query")
print("�� = Ollama Generated Query")
print("=" * 60)

for question in test_questions:
    print(f"\n❓ Question: {question}")
    result = enhanced_f1_pipeline(question, enhanced_classifier, query_generator, db_utils)
    display_enhanced_results(result)

# Interactive testing
print("\n=== Interactive Testing ===")
print("Type 'quit' to exit")
print("Try complex questions to see Ollama in action!")

while True:
    try:
        question = input("\nEnter your question: ")
        if question.lower() == 'quit':
            break
        
        result = enhanced_f1_pipeline(question, enhanced_classifier, query_generator, db_utils)
        display_enhanced_results(result)
        
    except KeyboardInterrupt:
        print("\nExiting...")
        break
    except Exception as e:
        print(f"Error: {e}")

# F1 QA BOT

In [None]:
import requests
import pandas as pd
from datetime import datetime, timedelta

def analyze_context_gaps(question, entities, intents):
    """Analyze what contextual information is missing"""
    gaps = []
    
    # Check for missing time context
    if not entities.get('time_context') or entities['time_context'] == 'recent':
        gaps.append({
            'type': 'time_context',
            'message': 'Which time period are you interested in? (e.g., "last race", "this season", "Austrian GP")',
            'priority': 'high'
        })
    
    # Check for missing session context for certain intents
    session_specific_intents = ['qualifying_results', 'race_results', 'pit_stops', 'tire_strategy']
    if any(intent in intents for intent in session_specific_intents):
        if not entities.get('sessions'):
            gaps.append({
                'type': 'session_context',
                'message': 'Which session are you asking about? (e.g., "qualifying", "race", "practice")',
                'priority': 'medium'
            })
    
    # Check for missing driver/team context for performance queries
    performance_intents = ['driver_performance', 'team_performance', 'fastest_laps']
    if any(intent in intents for intent in performance_intents):
        if not entities.get('drivers') and not entities.get('teams'):
            gaps.append({
                'type': 'entity_context',
                'message': 'Which driver or team are you asking about?',
                'priority': 'medium'
            })
    
    # Check for specific race context
    if not entities.get('meeting_name'):
        gaps.append({
            'type': 'race_context',
            'message': 'Which race or Grand Prix are you referring to?',
            'priority': 'low'
        })
    
    return gaps

def enhance_entities_with_context(entities, question_lower):
    """Enhance entities with additional context from the question"""
    enhanced_entities = entities.copy()
    
    # Extract race names from question
    race_keywords = {
        'australian': 'Australian Grand Prix',
        'chinese': 'Chinese Grand Prix',
        'japanese': 'Japanese Grand Prix',
        'bahrain': 'Bahrain Grand Prix',
        'saudi': 'Saudi Arabian Grand Prix',
        'miami': 'Miami Grand Prix',             # (American GP at Miami)
        'italian emilia-romagna': 'Emilia‑Romagna Grand Prix',
        'monaco': 'Monaco Grand Prix',
        'spanish': 'Spanish Grand Prix',
        'canadian': 'Canadian Grand Prix',
        'austrian': 'Austrian Grand Prix',
        'british': 'British Grand Prix'
    }
    
    for keyword, race_name in race_keywords.items():
        if keyword in question_lower and 'meeting_name' not in enhanced_entities:
            enhanced_entities['meeting_name'] = race_name
    
    # Extract time context
    time_keywords = {
        'last race': 'last_race',
        'last grand prix': 'last_race',
        'this season': 'season',
        'this year': 'season',
        'today': 'today',
        'yesterday': 'yesterday',
        'weekend': 'recent'
    }
    
    for keyword, time_context in time_keywords.items():
        if keyword in question_lower:
            enhanced_entities['time_context'] = time_context
            break
    
    return enhanced_entities

def generate_contextual_prompt(question, entities, intents, gaps):
    """Generate a contextual prompt for Ollama"""
    context_parts = [f"Question: {question}"]
    context_parts.append(f"Detected Intents: {intents}")
    
    # Add entity information
    if entities.get('drivers'):
        context_parts.append(f"Drivers mentioned: {entities['drivers']}")
    if entities.get('teams'):
        context_parts.append(f"Teams mentioned: {entities['teams']}")
    if entities.get('sessions'):
        context_parts.append(f"Sessions mentioned: {entities['sessions']}")
    if entities.get('time_context'):
        context_parts.append(f"Time context: {entities['time_context']}")
    
    # Add context gaps
    if gaps:
        context_parts.append("\nContext Gaps Identified:")
        for gap in gaps:
            context_parts.append(f"- {gap['message']}")
        
        context_parts.append("\nInstructions: If the data doesn't provide enough context to answer the question completely, acknowledge the gaps and provide the best answer possible with the available information. Suggest what additional context would help provide a more complete answer.")
    
    return "\n".join(context_parts)

def run_f1_qa_with_contextual_awareness(question, classifier, query_generator, db_utils):
    print(f"\n❓ Question: {question}")
    print("=" * 60)

    # Step 1: Intent classification
    classification_result = classifier.classify_intent_with_ner(question)
    intents = classification_result['intents']
    entities = classification_result['entities']
    
    print(f"🎯 Detected Intents: {intents}")
    print(f"��️  Confidence Score: {classification_result.get('confidence')}")
    print(f"📦 Entities: {entities}")
    
    # Step 2: Enhance entities with additional context
    enhanced_entities = enhance_entities_with_context(entities, question.lower())
    if enhanced_entities != entities:
        print(f"�� Enhanced Entities: {enhanced_entities}")
    
    # Step 3: Analyze context gaps
    gaps = analyze_context_gaps(question, enhanced_entities, intents)
    if gaps:
        print(f"\n⚠️  Context Gaps Identified:")
        for gap in gaps:
            print(f"   - {gap['message']}")
    
    print()

    # Step 4: SQL generation and execution
    query_result = query_generator.generate_dynamic_query(question, classification_result)

    all_data = []
    print("🛠️  Generated Queries & Results:\n")

    for query_info in query_result["queries"]:
        intent = query_info["intent"]
        sql_query = query_info["query"]

        print(f"🔎 Intent: {intent}")
        print("📄 SQL Query:")
        print(sql_query)
        print()

        try:
            rows, columns = execute_query_with_data_fixed(db_utils, sql_query)
            
            if rows and columns:
                df = pd.DataFrame(rows, columns=columns)
                all_data.append({
                    "intent": intent,
                    "query": sql_query,
                    "data": df,
                    "success": True
                })
                print(f"✅ Successfully executed query for intent '{intent}'")
                print(f"📊 Retrieved {len(rows)} rows with {len(columns)} columns")
            else:
                print(f"⚠️  No data returned for intent '{intent}'")
                all_data.append({
                    "intent": intent,
                    "query": sql_query,
                    "data": pd.DataFrame(),
                    "success": True
                })

        except Exception as e:
            print(f"❌ Error running query for intent '{intent}': {e}\n")
            all_data.append({
                "intent": intent,
                "query": sql_query,
                "data": None,
                "error": str(e),
                "success": False
            })

    # Step 5: Build enhanced context for Ollama
    summary_parts = []
    
    # Add contextual prompt
    contextual_prompt = generate_contextual_prompt(question, enhanced_entities, intents, gaps)
    summary_parts.append(contextual_prompt)
    
    # Add data results
    for result in all_data:
        if result["success"] and result["data"] is not None and not result["data"].empty:
            summary_parts.append(f"\nIntent: {result['intent']}")
            summary_parts.append(result["data"].head(3).to_markdown(index=False))
        elif result["success"] and result["data"] is not None and result["data"].empty:
            summary_parts.append(f"\nIntent: {result['intent']} - No data available")

    full_context = "\n".join(summary_parts)
    print("�� Final Context Sent to Ollama:\n")
    print(full_context)
    print()

    # Step 6: Call Ollama with enhanced prompt
    try:
        enhanced_prompt = f"""You're an expert F1 assistant with contextual awareness. 

Based on the following data and context analysis, provide a comprehensive and accurate answer. 

If there are context gaps identified, acknowledge them and:
1. Provide the best answer possible with available data
2. Mention what additional context would help
3. Make reasonable assumptions when appropriate
4. Ask clarifying questions if the answer would be significantly different

Context and Data:
{full_context}

Please provide a detailed, contextual response:"""

        response = requests.post(
            "http://localhost:11434/api/generate",
            json={
                "model": "llama3",
                "prompt": enhanced_prompt,
                "stream": False
            }
        )
        result = response.json()
        answer = result.get("response", "No response from Ollama.")
    except Exception as e:
        answer = f"Ollama summary failed: {e}"

    print("\n🧠 LLM Summary (Ollama):")
    print(answer)
    print("-" * 80)

# # Test with questions that need context
# contextual_questions = [
#     "What is the fastest pit stop?",
#     "Who had the best qualifying performance?",
#     "What was the weather like?",
#     "How did the team perform?",
#     "Show me the fastest lap times",
#     "What were the qualifying results for the Austrian Grand Prix?",
#     "How did Verstappen perform in the last race?",
#     "What was the tire strategy in qualifying?"
# ]

# print("=== Testing Contextual Awareness ===")
# for q in contextual_questions:
#     run_f1_qa_with_contextual_awareness(q, enhanced_classifier, query_generator, db_utils)
    
while True:
    try:
        question = input("\n❓ Enter your F1 question: ")
        
        if question.lower() in ['quit', 'exit', 'q']:
            print("👋 Thanks for using the F1 QA Bot!")
            break
        
        if not question.strip():
            print("Please enter a question.")
            continue
        
        # Run the contextual QA system
        run_f1_qa_with_contextual_awareness(question, enhanced_classifier, query_generator, db_utils)
        
    except KeyboardInterrupt:
        print("\n👋 Thanks for using the F1 QA Bot!")
        break
    except Exception as e:
        print(f"❌ Error: {e}")
        print("Please try again with a different question.")