In [None]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report
import json
from datetime import datetime
import torch
import gc
import subprocess
import sys
from collections import Counter, defaultdict
import glob

def install_qwen_packages():
    """Install required packages for Qwen"""
    packages = [
        "transformers>=4.45.0",
        "accelerate",
        "tiktoken",
        "qwen-vl-utils"
    ]
    
    for package in packages:
        print(f"📦 Installing {package}...")
        try:
            subprocess.run([sys.executable, "-m", "pip", "install", package], 
                          check=True, capture_output=True, text=True)
            print(f"✅ {package} installed successfully")
        except subprocess.CalledProcessError as e:
            print(f"⚠️ Warning installing {package}: {e}")

def setup_qwen_model():
    """Initialize Qwen 2.5 Vision model"""
    print("🧠 Setting up Qwen 2.5 Vision model...")
    
    # Install packages
    install_qwen_packages()
    
    # Clear GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()
    
    try:
        from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
        from qwen_vl_utils import process_vision_info
        
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"🖥️ Using device: {device}")
        
        model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
        
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_id, 
            torch_dtype="auto", 
            device_map="auto"
        )
        
        processor = AutoProcessor.from_pretrained(model_id)
        
        print(f"✅ Qwen 2.5 Vision model loaded successfully!")
        return model, processor, device, process_vision_info
        
    except Exception as e:
        print(f"❌ Error loading Qwen model: {e}")
        return None, None, None, None

def load_dataset_with_annotations(dataset_path):
    """Load images and corresponding annotations"""
    print(f"📁 Loading dataset from: {dataset_path}")
    
    # Find all images and annotations
    raw_images = glob.glob(os.path.join(dataset_path, "*_*_r.*"))
    extended_images = glob.glob(os.path.join(dataset_path, "*_*_x.*"))
    annotation_files = glob.glob(os.path.join(dataset_path, "*_annotation.csv"))
    
    print(f"🖼️ Found {len(raw_images)} raw images")
    print(f"🖼️ Found {len(extended_images)} extended images") 
    print(f"📄 Found {len(annotation_files)} annotation files")
    
    # Process annotations
    annotations_dict = {}
    for ann_file in annotation_files:
        frame_id = os.path.basename(ann_file).replace('_annotation.csv', '')
        try:
            df = pd.read_csv(ann_file, header=None)
            behaviors = df.iloc[0].tolist()  # First row contains behaviors
            annotations_dict[frame_id] = behaviors
        except Exception as e:
            print(f"⚠️ Error reading {ann_file}: {e}")
    
    # Create dataset entries
    dataset_entries = []
    
    # Process raw images
    for img_path in raw_images:
        filename = os.path.basename(img_path)
        # Parse: FRAMEID_PERSONID_r.ext
        parts = filename.split('_')
        if len(parts) >= 3:
            frame_id = parts[0]
            person_id = int(parts[1])
            
            if frame_id in annotations_dict:
                behaviors = annotations_dict[frame_id]
                if person_id <= len(behaviors):
                    true_behavior = behaviors[person_id - 1].strip()  # person_id is 1-indexed
                    
                    dataset_entries.append({
                        'image_path': img_path,
                        'frame_id': frame_id,
                        'person_id': person_id,
                        'true_behavior': true_behavior,
                        'crop_type': 'raw',
                        'filename': filename
                    })
    
    # Process extended images
    for img_path in extended_images:
        filename = os.path.basename(img_path)
        # Parse: FRAMEID_PERSONID_x.ext
        parts = filename.split('_')
        if len(parts) >= 3:
            frame_id = parts[0]
            person_id = int(parts[1])
            
            if frame_id in annotations_dict:
                behaviors = annotations_dict[frame_id]
                if person_id <= len(behaviors):
                    true_behavior = behaviors[person_id - 1].strip()  # person_id is 1-indexed
                    
                    dataset_entries.append({
                        'image_path': img_path,
                        'frame_id': frame_id,
                        'person_id': person_id,
                        'true_behavior': true_behavior,
                        'crop_type': 'extended',
                        'filename': filename
                    })
    
    df_dataset = pd.DataFrame(dataset_entries)
    print(f"📊 Dataset loaded: {len(df_dataset)} total entries")
    
    return df_dataset, annotations_dict

def analyze_dataset_statistics(df_dataset):
    """Generate comprehensive dataset statistics"""
    print("\n📊 DATASET STATISTICS")
    print("=" * 50)
    
    # Basic statistics
    total_entries = len(df_dataset)
    unique_frames = df_dataset['frame_id'].nunique()
    raw_count = len(df_dataset[df_dataset['crop_type'] == 'raw'])
    extended_count = len(df_dataset[df_dataset['crop_type'] == 'extended'])
    
    print(f"Total entries: {total_entries}")
    print(f"Unique frames: {unique_frames}")
    print(f"Raw crops: {raw_count}")
    print(f"Extended crops: {extended_count}")
    
    # Behavior distribution
    print(f"\n🎯 BEHAVIOR DISTRIBUTION:")
    behavior_counts = df_dataset['true_behavior'].value_counts()
    for behavior, count in behavior_counts.items():
        percentage = (count / total_entries) * 100
        print(f"  {behavior}: {count} ({percentage:.1f}%)")
    
    # Per-frame statistics
    frames_stats = df_dataset.groupby('frame_id').agg({
        'person_id': 'count',
        'true_behavior': lambda x: list(x.unique())
    }).rename(columns={'person_id': 'person_count'})
    
    print(f"\n👥 PERSONS PER FRAME:")
    person_count_dist = frames_stats['person_count'].value_counts().sort_index()
    for persons, frames in person_count_dist.items():
        print(f"  {persons} persons: {frames} frames")
    
    # Visualizations
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Behavior distribution
    behavior_counts.plot(kind='bar', ax=axes[0,0], color='skyblue')
    axes[0,0].set_title('Behavior Distribution')
    axes[0,0].set_xlabel('Behavior')
    axes[0,0].set_ylabel('Count')
    axes[0,0].tick_params(axis='x', rotation=45)
    
    # Crop type distribution
    crop_type_counts = df_dataset['crop_type'].value_counts()
    crop_type_counts.plot(kind='pie', ax=axes[0,1], autopct='%1.1f%%')
    axes[0,1].set_title('Crop Type Distribution')
    
    # Persons per frame
    person_count_dist.plot(kind='bar', ax=axes[1,0], color='lightgreen')
    axes[1,0].set_title('Persons per Frame Distribution')
    axes[1,0].set_xlabel('Number of Persons')
    axes[1,0].set_ylabel('Number of Frames')
    
    # Behavior by crop type
    behavior_crop = pd.crosstab(df_dataset['true_behavior'], df_dataset['crop_type'])
    behavior_crop.plot(kind='bar', ax=axes[1,1], width=0.8)
    axes[1,1].set_title('Behaviors by Crop Type')
    axes[1,1].set_xlabel('Behavior')
    axes[1,1].set_ylabel('Count')
    axes[1,1].tick_params(axis='x', rotation=45)
    axes[1,1].legend(title='Crop Type')
    
    plt.tight_layout()
    plt.show()
    
    return behavior_counts

def classify_with_qwen(image_path, model, processor, device, process_vision_info):
    """Classify pedestrian behavior using Qwen"""
    
    behavior_classes = [
        "walking",
        "running",
        "pushing a stroller",
        "biking",
        "standing",
        "skateboarding"
    ]
    
    prompt = f"""Look at this image of a person and classify their behavior/activity. Choose from these options:
{', '.join(behavior_classes)}

Important: Respond with ONLY ONE WORD from the list above. Do not add any other text.

Classification:"""
    
    try:
        # Load and process the image
        image = Image.open(image_path).convert('RGB')
        
        # Prepare messages for Qwen
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": image,
                    },
                    {"type": "text", "text": prompt},
                ],
            }
        ]
        
        # Apply chat template
        text = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        
        # Process vision info
        image_inputs, video_inputs = process_vision_info(messages)
        
        # Process inputs
        inputs = processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(device)
        
        # Generate response
        with torch.no_grad():
            generated_ids = model.generate(
                **inputs,
                max_new_tokens=10,
                temperature=0.1,
                do_sample=False
            )
        
        # Decode response
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        response = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
        
        # Clean up response
        classification = response.strip().lower()
        
        # Map to predefined classes
        for behavior in behavior_classes:
            if behavior.lower() in classification:
                return behavior
        
        return classification
        
    except Exception as e:
        print(f"❌ Error classifying {image_path}: {e}")
        return "error"

def evaluate_qwen_model(df_dataset, model, processor, device, process_vision_info, crop_type='raw'):
    """Evaluate Qwen model on dataset"""
    print(f"\n🧠 EVALUATING QWEN MODEL ON {crop_type.upper()} CROPS")
    print("=" * 60)
    
    # Filter dataset by crop type
    df_eval = df_dataset[df_dataset['crop_type'] == crop_type].copy()
    print(f"📊 Evaluating on {len(df_eval)} {crop_type} images")
    
    # Run predictions
    predictions = []
    true_labels = []
    detailed_results = []
    
    for idx, row in df_eval.iterrows():
        image_path = row['image_path']
        true_behavior = row['true_behavior']
        
        print(f"🔄 Processing {idx+1}/{len(df_eval)}: {row['filename']}")
        
        # Get prediction
        predicted_behavior = classify_with_qwen(
            image_path, model, processor, device, process_vision_info
        )
        
        predictions.append(predicted_behavior)
        true_labels.append(true_behavior)
        
        detailed_results.append({
            'filename': row['filename'],
            'frame_id': row['frame_id'],
            'person_id': row['person_id'],
            'true_behavior': true_behavior,
            'predicted_behavior': predicted_behavior,
            'correct': predicted_behavior == true_behavior,
            'crop_type': crop_type
        })
        
        print(f"   True: {true_behavior} | Predicted: {predicted_behavior} | ✅" if predicted_behavior == true_behavior else f"   True: {true_behavior} | Predicted: {predicted_behavior} | ❌")
        
        # Clean GPU memory periodically
        if torch.cuda.is_available() and idx % 5 == 0:
            torch.cuda.empty_cache()
    
    return predictions, true_labels, detailed_results

def compute_metrics(true_labels, predictions, detailed_results, crop_type, model_name="Qwen"):
    """Compute comprehensive evaluation metrics"""
    print(f"\n📈 COMPUTING METRICS FOR {model_name} ({crop_type.upper()} CROPS)")
    print("=" * 60)
    
    # Get unique labels
    all_labels = sorted(list(set(true_labels + predictions)))
    
    # Basic metrics
    accuracy = accuracy_score(true_labels, predictions)
    precision, recall, f1, support = precision_recall_fscore_support(
        true_labels, predictions, labels=all_labels, average=None, zero_division=0
    )
    
    # Macro averages
    macro_precision = np.mean(precision)
    macro_recall = np.mean(recall)
    macro_f1 = np.mean(f1)
    
    # Weighted averages
    weighted_precision, weighted_recall, weighted_f1, _ = precision_recall_fscore_support(
        true_labels, predictions, average='weighted', zero_division=0
    )
    
    print(f"🎯 OVERALL METRICS:")
    print(f"   Accuracy: {accuracy:.3f}")
    print(f"   Macro Precision: {macro_precision:.3f}")
    print(f"   Macro Recall: {macro_recall:.3f}")
    print(f"   Macro F1: {macro_f1:.3f}")
    print(f"   Weighted Precision: {weighted_precision:.3f}")
    print(f"   Weighted Recall: {weighted_recall:.3f}")
    print(f"   Weighted F1: {weighted_f1:.3f}")
    
    # Per-class metrics
    print(f"\n📊 PER-CLASS METRICS:")
    for i, label in enumerate(all_labels):
        print(f"   {label}:")
        print(f"     Precision: {precision[i]:.3f}")
        print(f"     Recall: {recall[i]:.3f}")
        print(f"     F1-Score: {f1[i]:.3f}")
        print(f"     Support: {support[i]}")
    
    # Confusion Matrix
    cm = confusion_matrix(true_labels, predictions, labels=all_labels)
    
    # Visualizations
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    
    # Confusion Matrix
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=all_labels, yticklabels=all_labels, ax=axes[0])
    axes[0].set_title(f'{model_name} Confusion Matrix ({crop_type} crops)')
    axes[0].set_xlabel('Predicted')
    axes[0].set_ylabel('True')
    
    # Metrics comparison
    metrics_df = pd.DataFrame({
        'Precision': precision,
        'Recall': recall,
        'F1-Score': f1
    }, index=all_labels)
    
    metrics_df.plot(kind='bar', ax=axes[1])
    axes[1].set_title(f'{model_name} Per-Class Metrics ({crop_type} crops)')
    axes[1].set_xlabel('Behavior')
    axes[1].set_ylabel('Score')
    axes[1].tick_params(axis='x', rotation=45)
    axes[1].legend()
    axes[1].set_ylim(0, 1)
    
    plt.tight_layout()
    plt.show()
    
    # Classification Report
    print(f"\n📋 DETAILED CLASSIFICATION REPORT:")
    print(classification_report(true_labels, predictions, labels=all_labels, zero_division=0))
    
    # Error Analysis
    df_results = pd.DataFrame(detailed_results)
    errors = df_results[~df_results['correct']]
    
    print(f"\n❌ ERROR ANALYSIS:")
    print(f"   Total errors: {len(errors)}/{len(df_results)} ({len(errors)/len(df_results)*100:.1f}%)")
    
    if len(errors) > 0:
        print(f"\n   Most common error patterns:")
        error_patterns = errors.groupby(['true_behavior', 'predicted_behavior']).size().sort_values(ascending=False)
        for (true_b, pred_b), count in error_patterns.head(5).items():
            print(f"     {true_b} → {pred_b}: {count} times")
    
    return {
        'accuracy': accuracy,
        'macro_precision': macro_precision,
        'macro_recall': macro_recall,
        'macro_f1': macro_f1,
        'weighted_precision': weighted_precision,
        'weighted_recall': weighted_recall,
        'weighted_f1': weighted_f1,
        'per_class_precision': precision,
        'per_class_recall': recall,
        'per_class_f1': f1,
        'support': support,
        'confusion_matrix': cm,
        'labels': all_labels,
        'detailed_results': detailed_results
    }

def main_qwen_evaluation():
    """Main evaluation function for Qwen"""
    
    # Configuration
    DATASET_PATH = '/kaggle/input/pedestrian-cropped-annot'  
    OUTPUT_PATH = '/kaggle/working/qwen_evaluation_results'
    
    print("🧠 QWEN 2.5 VISION PEDESTRIAN BEHAVIOR EVALUATION")
    print("=" * 70)
    print(f"📁 Dataset path: {DATASET_PATH}")
    print(f"📁 Output path: {OUTPUT_PATH}")
    
    # Create output directory
    os.makedirs(OUTPUT_PATH, exist_ok=True)
    
    # Load dataset
    try:
        df_dataset, annotations_dict = load_dataset_with_annotations(DATASET_PATH)
        if len(df_dataset) == 0:
            print("❌ No valid dataset entries found!")
            return
    except Exception as e:
        print(f"❌ Error loading dataset: {e}")
        return
    
    # Analyze dataset statistics
    behavior_counts = analyze_dataset_statistics(df_dataset)
    
    # Setup Qwen model
    model, processor, device, process_vision_info = setup_qwen_model()
    if model is None:
        print("❌ Failed to load Qwen model!")
        return
    
    # Evaluate on raw crops
    print(f"\n🚀 Starting evaluation on RAW crops...")
    predictions_raw, true_labels_raw, detailed_results_raw = evaluate_qwen_model(
        df_dataset, model, processor, device, process_vision_info, crop_type='raw'
    )
    
    metrics_raw = compute_metrics(
        true_labels_raw, predictions_raw, detailed_results_raw, 'raw', 'Qwen'
    )
    
    # Evaluate on extended crops
    print(f"\n🚀 Starting evaluation on EXTENDED crops...")
    predictions_ext, true_labels_ext, detailed_results_ext = evaluate_qwen_model(
        df_dataset, model, processor, device, process_vision_info, crop_type='extended'
    )
    
    metrics_ext = compute_metrics(
        true_labels_ext, predictions_ext, detailed_results_ext, 'extended', 'Qwen'
    )
    
    # Compare raw vs extended
    print(f"\n🔄 RAW vs EXTENDED CROPS COMPARISON:")
    print("=" * 50)
    comparison_metrics = ['accuracy', 'macro_f1', 'weighted_f1']
    for metric in comparison_metrics:
        raw_val = metrics_raw[metric]
        ext_val = metrics_ext[metric]
        improvement = ext_val - raw_val
        print(f"{metric.replace('_', ' ').title()}:")
        print(f"  Raw: {raw_val:.3f}")
        print(f"  Extended: {ext_val:.3f}")
        print(f"  Improvement: {improvement:+.3f} ({'✅' if improvement > 0 else '❌'})")
    
    # Save comprehensive results
    final_results = {
        'evaluation_info': {
            'model': 'Qwen/Qwen2.5-VL-3B-Instruct',
            'dataset_path': DATASET_PATH,
            'total_samples': len(df_dataset),
            'raw_samples': len(df_dataset[df_dataset['crop_type'] == 'raw']),
            'extended_samples': len(df_dataset[df_dataset['crop_type'] == 'extended']),
            'evaluation_date': datetime.now().isoformat()
        },
        'dataset_statistics': {
            'behavior_distribution': behavior_counts.to_dict(),
            'total_frames': df_dataset['frame_id'].nunique(),
            'avg_persons_per_frame': df_dataset.groupby('frame_id')['person_id'].count().mean()
        },
        'raw_crops_results': {
            'metrics': {k: v.tolist() if isinstance(v, np.ndarray) else v 
                       for k, v in metrics_raw.items() if k != 'detailed_results'},
            'detailed_results': detailed_results_raw
        },
        'extended_crops_results': {
            'metrics': {k: v.tolist() if isinstance(v, np.ndarray) else v 
                       for k, v in metrics_ext.items() if k != 'detailed_results'},
            'detailed_results': detailed_results_ext
        }
    }
    
    # Save results
    results_file = os.path.join(OUTPUT_PATH, 'qwen_evaluation_complete_results.json')
    with open(results_file, 'w') as f:
        json.dump(final_results, f, indent=2)
    
    print(f"\n🎉 QWEN EVALUATION COMPLETE!")
    print(f"📁 Results saved to: {results_file}")
    
    # Final cleanup
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("🧹 GPU memory cleaned up")

if __name__ == "__main__":
    main_qwen_evaluation()

In [None]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report
import json
from datetime import datetime
import torch
import gc
import subprocess
import sys
from collections import Counter, defaultdict
import glob

def setup_blip_bart_models():
    """Initialize BLIP for captioning and BART for classification"""
    print("🤖 Setting up BLIP + BART models...")
    
    # Clear GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()
    
    try:
        from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
        
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"🖥️ Using device: {device}")
        
        # Load BLIP for image captioning
        blip_model_id = "Salesforce/blip-image-captioning-large"
        print(f"📥 Loading BLIP: {blip_model_id}")
        
        blip_processor = BlipProcessor.from_pretrained(blip_model_id)
        blip_model = BlipForConditionalGeneration.from_pretrained(
            blip_model_id,
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
            low_cpu_mem_usage=True
        )
        
        if device == "cuda":
            blip_model = blip_model.to(device)
        
        # Load BART for classification
        print(f"📥 Loading BART classifier...")
        classifier = pipeline(
            "zero-shot-classification",
            model="facebook/bart-large-mnli",
            device=0 if device == "cuda" else -1
        )
        
        print(f"✅ BLIP + BART models loaded successfully!")
        return blip_model, blip_processor, classifier, device
        
    except Exception as e:
        print(f"❌ Error loading models: {e}")
        print("🔄 Installing required packages...")
        subprocess.run([sys.executable, "-m", "pip", "install", "transformers>=4.45.0", "torch"], check=True)
        
        try:
            from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
            
            blip_processor = BlipProcessor.from_pretrained(blip_model_id)
            blip_model = BlipForConditionalGeneration.from_pretrained(
                blip_model_id,
                torch_dtype=torch.float16 if device == "cuda" else torch.float32,
                low_cpu_mem_usage=True
            )
            
            if device == "cuda":
                blip_model = blip_model.to(device)
            
            classifier = pipeline(
                "zero-shot-classification",
                model="facebook/bart-large-mnli",
                device=0 if device == "cuda" else -1
            )
            
            print(f"✅ Models loaded after installing dependencies!")
            return blip_model, blip_processor, classifier, device
            
        except Exception as e2:
            print(f"❌ Still failed: {e2}")
            return None, None, None, None

def load_dataset_with_annotations(dataset_path):
    """Load images and corresponding annotations (same as Qwen cell)"""
    print(f"📁 Loading dataset from: {dataset_path}")
    
    # Find all images and annotations
    raw_images = glob.glob(os.path.join(dataset_path, "*_*_r.*"))
    extended_images = glob.glob(os.path.join(dataset_path, "*_*_x.*"))
    annotation_files = glob.glob(os.path.join(dataset_path, "*_annotation.csv"))
    
    print(f"🖼️ Found {len(raw_images)} raw images")
    print(f"🖼️ Found {len(extended_images)} extended images") 
    print(f"📄 Found {len(annotation_files)} annotation files")
    
    # Process annotations
    annotations_dict = {}
    for ann_file in annotation_files:
        frame_id = os.path.basename(ann_file).replace('_annotation.csv', '')
        try:
            df = pd.read_csv(ann_file, header=None)
            behaviors = df.iloc[0].tolist()  # First row contains behaviors
            annotations_dict[frame_id] = behaviors
        except Exception as e:
            print(f"⚠️ Error reading {ann_file}: {e}")
    
    # Create dataset entries
    dataset_entries = []
    
    # Process raw images
    for img_path in raw_images:
        filename = os.path.basename(img_path)
        # Parse: FRAMEID_PERSONID_r.ext
        parts = filename.split('_')
        if len(parts) >= 3:
            frame_id = parts[0]
            person_id = int(parts[1])
            
            if frame_id in annotations_dict:
                behaviors = annotations_dict[frame_id]
                if person_id <= len(behaviors):
                    true_behavior = behaviors[person_id - 1].strip()  # person_id is 1-indexed
                    
                    dataset_entries.append({
                        'image_path': img_path,
                        'frame_id': frame_id,
                        'person_id': person_id,
                        'true_behavior': true_behavior,
                        'crop_type': 'raw',
                        'filename': filename
                    })
    
    # Process extended images
    for img_path in extended_images:
        filename = os.path.basename(img_path)
        # Parse: FRAMEID_PERSONID_x.ext
        parts = filename.split('_')
        if len(parts) >= 3:
            frame_id = parts[0]
            person_id = int(parts[1])
            
            if frame_id in annotations_dict:
                behaviors = annotations_dict[frame_id]
                if person_id <= len(behaviors):
                    true_behavior = behaviors[person_id - 1].strip()  # person_id is 1-indexed
                    
                    dataset_entries.append({
                        'image_path': img_path,
                        'frame_id': frame_id,
                        'person_id': person_id,
                        'true_behavior': true_behavior,
                        'crop_type': 'extended',
                        'filename': filename
                    })
    
    df_dataset = pd.DataFrame(dataset_entries)
    print(f"📊 Dataset loaded: {len(df_dataset)} total entries")
    
    return df_dataset, annotations_dict

def classify_with_blip_bart(image_path, blip_model, blip_processor, classifier, device):
    """
    1. Generate caption using BLIP
    2. Classify caption using BART
    """
    
    behavior_classes = [
        "walking",
        "running",
        "pushing a stroller",
        "biking",
        "standing",
        "skateboarding"
    ]
    
    try:
        # Load image
        image = Image.open(image_path).convert('RGB')
        
        # Step 1: Generate caption with BLIP
        inputs = blip_processor(image, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            generated_ids = blip_model.generate(**inputs, max_length=50, do_sample=False, num_beams=5)
        
        caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
        
        # Step 2: Classify the caption using BART
        classification_result = classifier(caption, behavior_classes)
        
        # Get the best classification
        best_label = classification_result['labels'][0]
        best_score = classification_result['scores'][0]
        
        # Map back to simple behavior names
        behavior_mapping = {
            "person walking": "walking",
            "person running": "running",
            "person pushing a stroller": "pushing a stroller",
            "person biking or cycling": "biking",
            "person standing still": "standing",
            "person on skateboard": "skateboarding",
        }
        
        final_classification = behavior_mapping.get(best_label, best_label)
        
        return final_classification, caption, best_score
        
    except Exception as e:
        print(f"❌ Error processing {image_path}: {e}")
        return "error", str(e), 0.0

def evaluate_blip_bart_model(df_dataset, blip_model, blip_processor, classifier, device, crop_type='raw'):
    """Evaluate BLIP+BART model on dataset"""
    print(f"\n🤖 EVALUATING BLIP+BART MODEL ON {crop_type.upper()} CROPS")
    print("=" * 60)
    
    # Filter dataset by crop type
    df_eval = df_dataset[df_dataset['crop_type'] == crop_type].copy()
    print(f"📊 Evaluating on {len(df_eval)} {crop_type} images")
    
    # Run predictions
    predictions = []
    true_labels = []
    detailed_results = []
    
    for idx, row in df_eval.iterrows():
        image_path = row['image_path']
        true_behavior = row['true_behavior']
        
        print(f"🔄 Processing {idx+1}/{len(df_eval)}: {row['filename']}")
        
        # Get prediction
        predicted_behavior, caption, confidence = classify_with_blip_bart(
            image_path, blip_model, blip_processor, classifier, device
        )
        
        predictions.append(predicted_behavior)
        true_labels.append(true_behavior)
        
        detailed_results.append({
            'filename': row['filename'],
            'frame_id': row['frame_id'],
            'person_id': row['person_id'],
            'true_behavior': true_behavior,
            'predicted_behavior': predicted_behavior,
            'caption': caption,
            'confidence': float(confidence),
            'correct': predicted_behavior == true_behavior,
            'crop_type': crop_type
        })
        
        print(f"   Caption: '{caption}'")
        print(f"   True: {true_behavior} | Predicted: {predicted_behavior} | Conf: {confidence:.3f} | ✅" if predicted_behavior == true_behavior else f"   True: {true_behavior} | Predicted: {predicted_behavior} | Conf: {confidence:.3f} | ❌")
        
        # Clean GPU memory periodically
        if torch.cuda.is_available() and idx % 5 == 0:
            torch.cuda.empty_cache()
    
    return predictions, true_labels, detailed_results

def compute_metrics_blip(true_labels, predictions, detailed_results, crop_type, model_name="BLIP+BART"):
    """Compute comprehensive evaluation metrics for BLIP+BART"""
    print(f"\n📈 COMPUTING METRICS FOR {model_name} ({crop_type.upper()} CROPS)")
    print("=" * 60)
    
    # Get unique labels
    all_labels = sorted(list(set(true_labels + predictions)))
    
    # Basic metrics
    accuracy = accuracy_score(true_labels, predictions)
    precision, recall, f1, support = precision_recall_fscore_support(
        true_labels, predictions, labels=all_labels, average=None, zero_division=0
    )
    
    # Macro averages
    macro_precision = np.mean(precision)
    macro_recall = np.mean(recall)
    macro_f1 = np.mean(f1)
    
    # Weighted averages
    weighted_precision, weighted_recall, weighted_f1, _ = precision_recall_fscore_support(
        true_labels, predictions, average='weighted', zero_division=0
    )
    
    print(f"🎯 OVERALL METRICS:")
    print(f"   Accuracy: {accuracy:.3f}")
    print(f"   Macro Precision: {macro_precision:.3f}")
    print(f"   Macro Recall: {macro_recall:.3f}")
    print(f"   Macro F1: {macro_f1:.3f}")
    print(f"   Weighted Precision: {weighted_precision:.3f}")
    print(f"   Weighted Recall: {weighted_recall:.3f}")
    print(f"   Weighted F1: {weighted_f1:.3f}")
    
    # Per-class metrics
    print(f"\n📊 PER-CLASS METRICS:")
    for i, label in enumerate(all_labels):
        print(f"   {label}:")
        print(f"     Precision: {precision[i]:.3f}")
        print(f"     Recall: {recall[i]:.3f}")
        print(f"     F1-Score: {f1[i]:.3f}")
        print(f"     Support: {support[i]}")
    
    # Confusion Matrix
    cm = confusion_matrix(true_labels, predictions, labels=all_labels)
    
    # Visualizations
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    
    # Confusion Matrix
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=all_labels, yticklabels=all_labels, ax=axes[0])
    axes[0].set_title(f'{model_name} Confusion Matrix ({crop_type} crops)')
    axes[0].set_xlabel('Predicted')
    axes[0].set_ylabel('True')
    
    # Metrics comparison
    metrics_df = pd.DataFrame({
        'Precision': precision,
        'Recall': recall,
        'F1-Score': f1
    }, index=all_labels)
    
    metrics_df.plot(kind='bar', ax=axes[1])
    axes[1].set_title(f'{model_name} Per-Class Metrics ({crop_type} crops)')
    axes[1].set_xlabel('Behavior')
    axes[1].set_ylabel('Score')
    axes[1].tick_params(axis='x', rotation=45)
    axes[1].legend()
    axes[1].set_ylim(0, 1)
    
    plt.tight_layout()
    plt.show()
    
    # Classification Report
    print(f"\n📋 DETAILED CLASSIFICATION REPORT:")
    print(classification_report(true_labels, predictions, labels=all_labels, zero_division=0))
    
    # Error Analysis
    df_results = pd.DataFrame(detailed_results)
    errors = df_results[~df_results['correct']]
    
    print(f"\n❌ ERROR ANALYSIS:")
    print(f"   Total errors: {len(errors)}/{len(df_results)} ({len(errors)/len(df_results)*100:.1f}%)")
    
    if len(errors) > 0:
        print(f"\n   Most common error patterns:")
        error_patterns = errors.groupby(['true_behavior', 'predicted_behavior']).size().sort_values(ascending=False)
        for (true_b, pred_b), count in error_patterns.head(5).items():
            print(f"     {true_b} → {pred_b}: {count} times")
        
        # Show some example captions for errors
        print(f"\n   Sample error captions:")
        for idx, row in errors.head(3).iterrows():
            print(f"     {row['true_behavior']} → {row['predicted_behavior']}: '{row['caption']}'")
    
    # Confidence analysis for BLIP+BART
    df_results = pd.DataFrame(detailed_results)
    
    print(f"\n🎯 CONFIDENCE ANALYSIS:")
    correct_samples = df_results[df_results['correct']]
    incorrect_samples = df_results[~df_results['correct']]
    
    if len(correct_samples) > 0 and len(incorrect_samples) > 0:
        avg_conf_correct = correct_samples['confidence'].mean()
        avg_conf_incorrect = incorrect_samples['confidence'].mean()
        
        print(f"   Average confidence (correct): {avg_conf_correct:.3f}")
        print(f"   Average confidence (incorrect): {avg_conf_incorrect:.3f}")
        print(f"   Confidence gap: {avg_conf_correct - avg_conf_incorrect:.3f}")
    
    return {
        'accuracy': accuracy,
        'macro_precision': macro_precision,
        'macro_recall': macro_recall,
        'macro_f1': macro_f1,
        'weighted_precision': weighted_precision,
        'weighted_recall': weighted_recall,
        'weighted_f1': weighted_f1,
        'per_class_precision': precision,
        'per_class_recall': recall,
        'per_class_f1': f1,
        'support': support,
        'confusion_matrix': cm,
        'labels': all_labels,
        'detailed_results': detailed_results
    }

def analyze_dataset_statistics_simple(df_dataset):
    """Generate basic dataset statistics (lighter version for BLIP cell)"""
    print("\n📊 DATASET OVERVIEW")
    print("=" * 30)
    
    total_entries = len(df_dataset)
    raw_count = len(df_dataset[df_dataset['crop_type'] == 'raw'])
    extended_count = len(df_dataset[df_dataset['crop_type'] == 'extended'])
    
    print(f"Total entries: {total_entries}")
    print(f"Raw crops: {raw_count}")
    print(f"Extended crops: {extended_count}")
    
    # Behavior distribution
    behavior_counts = df_dataset['true_behavior'].value_counts()
    print(f"\n🎯 Behavior Distribution:")
    for behavior, count in behavior_counts.items():
        print(f"  {behavior}: {count}")
    
    return behavior_counts

def main_blip_bart_evaluation():
    """Main evaluation function for BLIP+BART"""
    
    # Configuration
    DATASET_PATH = '/kaggle/input/pedestrian-cropped-annot'  
    OUTPUT_PATH = '/kaggle/working/blip_bart_evaluation_results'
    
    print("🤖 BLIP + BART PEDESTRIAN BEHAVIOR EVALUATION")
    print("=" * 70)
    print(f"📁 Dataset path: {DATASET_PATH}")
    print(f"📁 Output path: {OUTPUT_PATH}")
    print(f"🔧 Method: BLIP captioning → BART zero-shot classification")
    
    # Create output directory
    os.makedirs(OUTPUT_PATH, exist_ok=True)
    
    # Load dataset
    try:
        df_dataset, annotations_dict = load_dataset_with_annotations(DATASET_PATH)
        if len(df_dataset) == 0:
            print("❌ No valid dataset entries found!")
            return
    except Exception as e:
        print(f"❌ Error loading dataset: {e}")
        return
    
    # Analyze dataset statistics (lighter version)
    behavior_counts = analyze_dataset_statistics_simple(df_dataset)
    
    # Setup BLIP+BART models
    blip_model, blip_processor, classifier, device = setup_blip_bart_models()
    if blip_model is None:
        print("❌ Failed to load BLIP+BART models!")
        return
    
    # Evaluate on raw crops
    print(f"\n🚀 Starting evaluation on RAW crops...")
    predictions_raw, true_labels_raw, detailed_results_raw = evaluate_blip_bart_model(
        df_dataset, blip_model, blip_processor, classifier, device, crop_type='raw'
    )
    
    metrics_raw = compute_metrics_blip(
        true_labels_raw, predictions_raw, detailed_results_raw, 'raw', 'BLIP+BART'
    )
    
    # Evaluate on extended crops
    print(f"\n🚀 Starting evaluation on EXTENDED crops...")
    predictions_ext, true_labels_ext, detailed_results_ext = evaluate_blip_bart_model(
        df_dataset, blip_model, blip_processor, classifier, device, crop_type='extended'
    )
    
    metrics_ext = compute_metrics_blip(
        true_labels_ext, predictions_ext, detailed_results_ext, 'extended', 'BLIP+BART'
    )
    
    # Compare raw vs extended
    print(f"\n🔄 RAW vs EXTENDED CROPS COMPARISON:")
    print("=" * 50)
    comparison_metrics = ['accuracy', 'macro_f1', 'weighted_f1']
    for metric in comparison_metrics:
        raw_val = metrics_raw[metric]
        ext_val = metrics_ext[metric]
        improvement = ext_val - raw_val
        print(f"{metric.replace('_', ' ').title()}:")
        print(f"  Raw: {raw_val:.3f}")
        print(f"  Extended: {ext_val:.3f}")
        print(f"  Improvement: {improvement:+.3f} ({'✅' if improvement > 0 else '❌'})")
    
    # Caption analysis
    print(f"\n📝 CAPTION ANALYSIS:")
    print("=" * 30)
    
    # Analyze captions for correct vs incorrect predictions
    df_raw = pd.DataFrame(detailed_results_raw)
    df_ext = pd.DataFrame(detailed_results_ext)
    
    print(f"Sample captions from RAW crops:")
    for i, row in df_raw.head(3).iterrows():
        status = "✅" if row['correct'] else "❌"
        print(f"  {status} {row['true_behavior']}: '{row['caption']}'")
    
    # Save comprehensive results
    final_results = {
        'evaluation_info': {
            'captioning_model': 'Salesforce/blip-image-captioning-large',
            'classification_model': 'facebook/bart-large-mnli',
            'method': 'BLIP captioning + BART zero-shot classification',
            'dataset_path': DATASET_PATH,
            'total_samples': len(df_dataset),
            'raw_samples': len(df_dataset[df_dataset['crop_type'] == 'raw']),
            'extended_samples': len(df_dataset[df_dataset['crop_type'] == 'extended']),
            'evaluation_date': datetime.now().isoformat()
        },
        'dataset_statistics': {
            'behavior_distribution': behavior_counts.to_dict(),
            'total_frames': df_dataset['frame_id'].nunique(),
            'avg_persons_per_frame': df_dataset.groupby('frame_id')['person_id'].count().mean()
        },
        'raw_crops_results': {
            'metrics': {k: v.tolist() if isinstance(v, np.ndarray) else v 
                       for k, v in metrics_raw.items() if k != 'detailed_results'},
            'detailed_results': detailed_results_raw
        },
        'extended_crops_results': {
            'metrics': {k: v.tolist() if isinstance(v, np.ndarray) else v 
                       for k, v in metrics_ext.items() if k != 'detailed_results'},
            'detailed_results': detailed_results_ext
        }
    }
    
    # Save results
    results_file = os.path.join(OUTPUT_PATH, 'blip_bart_evaluation_complete_results.json')
    with open(results_file, 'w') as f:
        json.dump(final_results, f, indent=2)
    
    print(f"\n🎉 BLIP+BART EVALUATION COMPLETE!")
    print(f"📁 Results saved to: {results_file}")
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("🧹 GPU memory cleaned up")

# Run the evaluation
if __name__ == "__main__":
    main_blip_bart_evaluation()