# PubmedQA Artificial Set (PQA-A) Analysis Script

In [1]:
#!pip install spacy

In [2]:
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re
from collections import Counter
from datetime import datetime
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.util import ngrams
import textstat
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import LatentDirichletAllocation
from gensim import corpora
from gensim.models import LdaModel
from gensim.utils import simple_preprocess
from wordcloud import WordCloud
import spacy
from transformers import AutoTokenizer
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import warnings
warnings.filterwarnings('ignore')

In [3]:
nltk.download('stopwords')
nltk.download('punkt')
nltk.download('wordnet')

[nltk_data] Downloading package stopwords to /Users/casey/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /Users/casey/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /Users/casey/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [4]:
plt.style.use('seaborn-v0_8-whitegrid')
sns.set(font_scale=1.2)

np.random.seed(42)

In [5]:
class PubMedQAAnalyzer:
    def __init__(self, data_path):
        """
        Initialize the PubMedQA analyzer with dataset path
        
        Args:
            data_path (str): Path to the PubMedQA artificial dataset (PQA-A) JSON file
        """
        self.data_path = data_path
        self.df = None
        self.load_data()
        
    def load_data(self):
        """Load the PubMedQA artificial dataset (PQA-A) and convert to DataFrame"""
        print(f"Loading data from {self.data_path}...")
        
        # Load JSON data
        with open(self.data_path, 'r') as f:
            data = json.load(f)
        
        # Convert to DataFrame
        rows = []
        for pmid, item in data.items():
            row = {
                'pmid': pmid,
                'question': item.get('question', ''),
                'context': ' '.join(item.get('context', {}).get('contexts', [])),
                'abstract': item.get('abstract', []),
                'year': self._extract_year(item),
                'final_decision': item.get('final_decision', ''),  # PQA-A has final_decision instead of long_answer
                'mesh_terms': item.get('mesh', [])
            }
            rows.append(row)
            
        self.df = pd.DataFrame(rows)
        self.df['context_length'] = self.df['context'].apply(len)
        self.df['question_length'] = self.df['question'].apply(len)
        self.df['abstract_text'] = self.df['abstract'].apply(lambda x: ' '.join(x) if isinstance(x, list) else '')
        self.df['abstract_length'] = self.df['abstract_text'].apply(len)
        
        # Process final_decision field which is specific to PQA-A
        if 'final_decision' in self.df.columns:
            # Ensure final_decision is standardized
            self.df['final_decision'] = self.df['final_decision'].str.lower()
            
            # Create binary and categorical versions
            self.df['is_yes'] = self.df['final_decision'] == 'yes'
            self.df['is_no'] = self.df['final_decision'] == 'no'
            self.df['is_maybe'] = self.df['final_decision'] == 'maybe'
        
        print(f"Loaded {len(self.df)} question-article pairs from PQA-A dataset")
        
    def _extract_year(self, item):
        """Extract publication year from item metadata"""
        try:
            if 'pubmed' in item and 'content' in item['pubmed']:
                if 'PubmedArticle' in item['pubmed']['content']:
                    article = item['pubmed']['content']['PubmedArticle']
                    if 'MedlineCitation' in article:
                        if 'DateCompleted' in article['MedlineCitation']:
                            return int(article['MedlineCitation']['DateCompleted']['Year'])
                        elif 'Article' in article['MedlineCitation'] and 'Journal' in article['MedlineCitation']['Article']:
                            if 'JournalIssue' in article['MedlineCitation']['Article']['Journal']:
                                if 'PubDate' in article['MedlineCitation']['Article']['Journal']['JournalIssue']:
                                    return int(article['MedlineCitation']['Article']['Journal']['JournalIssue']['PubDate']['Year'])
        except (KeyError, TypeError):
            pass
        
        # Try to extract year from PMID (may not be accurate)
        try:
            pmid_int = int(item.get('pmid', '0'))
            if 1 <= pmid_int <= 100:  # Very early PMIDs
                return 1975  # Approximate
            elif 100 <= pmid_int <= 1000000:  # Old PMIDs
                return 1985  # Approximate
            elif 1000000 <= pmid_int <= 10000000:  # Mid PMIDs
                return 1995  # Approximate
            elif 10000000 <= pmid_int <= 20000000:
                return 2005  # Approximate
            elif 20000000 <= pmid_int <= 30000000:
                return 2015  # Approximate
            else:
                return 2020  # Recent
        except (ValueError, TypeError):
            return None
    
    def dataset_overview(self, output_dir='results'):
        """Generate dataset overview and context analysis"""
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        print("Generating dataset overview...")
        
        # Basic statistics
        stats = {
            'Total QA pairs': len(self.df),
            'Unique PMIDs': self.df['pmid'].nunique(),
            'Avg question length (chars)': self.df['question_length'].mean(),
            'Avg context length (chars)': self.df['context_length'].mean(),
            'Avg abstract length (chars)': self.df['abstract_length'].mean(),
            'Min question length': self.df['question_length'].min(),
            'Max question length': self.df['question_length'].max(),
            'Min context length': self.df['context_length'].min(),
            'Max context length': self.df['context_length'].max()
        }
        
        # Years coverage
        year_range = (self.df['year'].min(), self.df['year'].max())
        stats['Publication years range'] = f"{year_range[0]} - {year_range[1]}"
        
        # Create summary table
        stats_df = pd.DataFrame(list(stats.items()), columns=['Metric', 'Value'])
        stats_df.to_csv(f"{output_dir}/dataset_overview.csv", index=False)
        
        # Distribution plots
        fig, axs = plt.subplots(2, 2, figsize=(16, 14))
        
        # Question length distribution
        sns.histplot(self.df['question_length'], kde=True, ax=axs[0, 0])
        axs[0, 0].set_title('Question Length Distribution (characters)')
        axs[0, 0].set_xlabel('Length (characters)')
        
        # Context length distribution
        sns.histplot(self.df['context_length'], kde=True, ax=axs[0, 1])
        axs[0, 1].set_title('Context Length Distribution (characters)')
        axs[0, 1].set_xlabel('Length (characters)')
        
        # Abstract length distribution
        sns.histplot(self.df['abstract_length'], kde=True, ax=axs[1, 0])
        axs[1, 0].set_title('Abstract Length Distribution (characters)')
        axs[1, 0].set_xlabel('Length (characters)')
        
        # Year distribution
        year_counts = self.df['year'].value_counts().sort_index()
        year_counts.plot(kind='bar', ax=axs[1, 1])
        axs[1, 1].set_title('Publication Year Distribution')
        axs[1, 1].set_xlabel('Year')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/dataset_distributions.png", dpi=300)
        
        # Generate detailed descriptive statistics
        desc_stats = self.df[['question_length', 'context_length', 'abstract_length']].describe()
        desc_stats.to_csv(f"{output_dir}/descriptive_statistics.csv")
        
        # Question word count statistics
        self.df['question_word_count'] = self.df['question'].apply(lambda x: len(word_tokenize(x)))
        q_word_stats = self.df['question_word_count'].describe()
        
        # Context word count statistics
        sample_size = min(1000, len(self.df))  # Sample for efficiency
        sample_indices = np.random.choice(len(self.df), sample_size, replace=False)
        context_word_counts = [len(word_tokenize(text)) for text in self.df.iloc[sample_indices]['context']]
        c_word_stats = pd.Series(context_word_counts).describe()
        
        # Combine word count statistics
        word_stats = pd.DataFrame({
            'Question Word Count': q_word_stats,
            'Context Word Count (Sample)': c_word_stats
        })
        word_stats.to_csv(f"{output_dir}/word_count_statistics.csv")
        
        print("Dataset overview analysis completed")
        return stats
    
    def distribution_analysis(self, output_dir='results'):
        """Generate data distribution analysis"""
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        print("Generating data distribution analysis...")
        
        # Create a figure with multiple subplots
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=('Question Length Distribution (Words)', 
                           'Context Length Distribution (Words)',
                           'Publication Year Distribution',
                           'MeSH Terms Distribution (Top 20)')
        )
        
        # Question word count distribution
        fig.add_trace(
            go.Histogram(x=self.df['question_word_count'], nbinsx=30, opacity=0.7,
                         marker=dict(color='royalblue')),
            row=1, col=1
        )
        
        # Context length distribution (sampled)
        sample_size = min(1000, len(self.df))
        sample_indices = np.random.choice(len(self.df), sample_size, replace=False)
        context_word_counts = [len(word_tokenize(text)) for text in self.df.iloc[sample_indices]['context']]
        
        fig.add_trace(
            go.Histogram(x=context_word_counts, nbinsx=30, opacity=0.7,
                        marker=dict(color='green')),
            row=1, col=2
        )
        
        # Publication year distribution
        year_counts = self.df['year'].value_counts().sort_index()
        years = list(year_counts.index.astype(str))
        counts = list(year_counts.values)
        
        fig.add_trace(
            go.Bar(x=years, y=counts, marker=dict(color='purple')),
            row=2, col=1
        )
        
        # MeSH terms distribution
        all_mesh_terms = []
        for terms in self.df['mesh_terms']:
            if isinstance(terms, list):
                all_mesh_terms.extend(terms)
                
        top_terms = Counter(all_mesh_terms).most_common(20)
        term_labels = [term[0] for term in top_terms]
        term_counts = [term[1] for term in top_terms]
        
        fig.add_trace(
            go.Bar(y=term_labels, x=term_counts, marker=dict(color='orangered'), orientation='h'),
            row=2, col=2
        )
        
        # Update layout
        fig.update_layout(
            height=800,
            width=1200,
            showlegend=False,
        )
        
        # Save the figure
        fig.write_html(f"{output_dir}/distribution_analysis.html")
        fig.write_image(f"{output_dir}/distribution_analysis.png")
        
        # Calculate correlations between features
        correlations = self.df[['question_length', 'context_length', 'abstract_length', 'question_word_count', 'year']].corr()
        plt.figure(figsize=(10, 8))
        sns.heatmap(correlations, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
        plt.title('Feature Correlations')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/feature_correlations.png", dpi=300)
        
        # Save correlation matrix
        correlations.to_csv(f"{output_dir}/feature_correlations.csv")
        
        print("Data distribution analysis completed")
    
    def temporal_analysis(self, output_dir='results'):
        """Analyze how dataset characteristics change over time"""
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        print("Generating temporal analysis...")
        
        # Ensure we have year data
        self.df = self.df.dropna(subset=['year'])
        
        # Group by year and calculate statistics
        yearly_stats = self.df.groupby('year').agg({
            'pmid': 'count',
            'question_length': 'mean',
            'context_length': 'mean',
            'question_word_count': 'mean'
        }).reset_index()
        
        yearly_stats.columns = ['Year', 'Count', 'Avg Question Length', 
                                'Avg Context Length', 'Avg Question Word Count']
        
        # Save statistics to CSV
        yearly_stats.to_csv(f"{output_dir}/yearly_statistics.csv", index=False)
        
        # Create temporal plots
        fig, axs = plt.subplots(2, 2, figsize=(16, 12))
        
        # Number of articles per year
        axs[0, 0].bar(yearly_stats['Year'], yearly_stats['Count'], color='skyblue')
        axs[0, 0].set_title('Number of Articles by Year')
        axs[0, 0].set_xlabel('Year')
        axs[0, 0].set_ylabel('Count')
        
        # Average question length by year
        axs[0, 1].plot(yearly_stats['Year'], yearly_stats['Avg Question Length'], 
                      marker='o', color='green', linestyle='-')
        axs[0, 1].set_title('Average Question Length by Year')
        axs[0, 1].set_xlabel('Year')
        axs[0, 1].set_ylabel('Average Question Length (chars)')
        
        # Average context length by year
        axs[1, 0].plot(yearly_stats['Year'], yearly_stats['Avg Context Length'], 
                      marker='o', color='purple', linestyle='-')
        axs[1, 0].set_title('Average Context Length by Year')
        axs[1, 0].set_xlabel('Year')
        axs[1, 0].set_ylabel('Average Context Length (chars)')
        
        # Average question word count by year
        axs[1, 1].plot(yearly_stats['Year'], yearly_stats['Avg Question Word Count'], 
                      marker='o', color='red', linestyle='-')
        axs[1, 1].set_title('Average Question Word Count by Year')
        axs[1, 1].set_xlabel('Year')
        axs[1, 1].set_ylabel('Average Word Count')
        
        plt.tight_layout()
        plt.savefig(f"{output_dir}/temporal_analysis.png", dpi=300)
        
        # Advanced temporal analysis with trend detection
        # Filter to years with sufficient data points
        min_year_count = 50  # Minimum number of articles per year to include
        filtered_years = yearly_stats[yearly_stats['Count'] >= min_year_count]
        
        if len(filtered_years) >= 5:  # At least 5 years with sufficient data
            # Perform trend analysis using polynomial regression
            years = filtered_years['Year'].values
            q_lengths = filtered_years['Avg Question Length'].values
            
            # Normalize years for better numerical stability
            years_norm = (years - years.min()) / (years.max() - years.min())
            
            # Fit polynomial regression
            degree = min(3, len(years) - 1)  # Up to cubic regression depending on data points
            coeffs = np.polyfit(years_norm, q_lengths, degree)
            poly = np.poly1d(coeffs)
            
            # Generate trend line
            years_seq = np.linspace(years_norm.min(), years_norm.max(), 100)
            trend = poly(years_seq)
            
            # Convert back to original year scale
            years_seq = years_seq * (years.max() - years.min()) + years.min()
            
            # Plot trend
            plt.figure(figsize=(10, 6))
            plt.scatter(years, q_lengths, color='blue', label='Actual data')
            plt.plot(years_seq, trend, color='red', label=f'Trend (degree {degree})')
            plt.title('Question Length Trend Analysis')
            plt.xlabel('Year')
            plt.ylabel('Average Question Length')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.savefig(f"{output_dir}/question_length_trend.png", dpi=300)
            
            # Calculate trend direction and rate of change
            trend_direction = "increasing" if coeffs[0] > 0 else "decreasing"
            avg_annual_change = (q_lengths[-1] - q_lengths[0]) / (len(years) - 1)
            
            # Save trend analysis results
            trend_results = pd.DataFrame({
                'Metric': ['Trend direction', 'Average annual change', 'Polynomial degree'],
                'Value': [trend_direction, avg_annual_change, degree]
            })
            trend_results.to_csv(f"{output_dir}/question_length_trend_analysis.csv", index=False)
        
        print("Temporal analysis completed")
    
    def topic_domain_analysis(self, output_dir='results', n_topics=10):
        """Analyze topics and domains in the dataset"""
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        print("Generating topic/domain analysis...")
        
        # MeSH terms analysis
        all_mesh = []
        for terms in self.df['mesh_terms']:
            if isinstance(terms, list):
                all_mesh.extend(terms)
        
        mesh_counts = Counter(all_mesh)
        top_mesh = mesh_counts.most_common(30)
        
        # Save MeSH terms data
        mesh_df = pd.DataFrame(top_mesh, columns=['MeSH Term', 'Count'])
        mesh_df.to_csv(f"{output_dir}/top_mesh_terms.csv", index=False)
        
        # Plot top MeSH terms
        plt.figure(figsize=(12, 8))
        sns.barplot(x='Count', y='MeSH Term', data=mesh_df.head(20))
        plt.title('Top 20 MeSH Terms')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/top_mesh_terms.png", dpi=300)
        
        # Create word cloud of MeSH terms
        wordcloud = WordCloud(width=800, height=400, background_color='white', 
                              max_words=100, contour_width=3, contour_color='steelblue')
        wordcloud.generate_from_frequencies(mesh_counts)
        
        plt.figure(figsize=(12, 8))
        plt.imshow(wordcloud, interpolation='bilinear')
        plt.axis('off')
        plt.title('MeSH Terms Word Cloud')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/mesh_wordcloud.png", dpi=300)
        
        # Topic modeling using abstracts
        print("Performing topic modeling...")
        
        # Filter out empty abstracts
        abstracts = self.df['abstract_text'].dropna().tolist()
        
        if len(abstracts) > 100:  # Only perform if we have sufficient data
            # Sample abstracts for efficiency
            sample_size = min(2000, len(abstracts))
            abstract_sample = np.random.choice(abstracts, sample_size, replace=False)
            
            # Preprocess text
            stop_words = set(stopwords.words('english'))
            processed_abstracts = []
            
            for abstract in abstract_sample:
                # Tokenize, lowercase, remove stopwords and short words
                tokens = word_tokenize(abstract.lower())
                tokens = [token for token in tokens if token.isalpha() and token not in stop_words and len(token) > 3]
                processed_abstracts.append(' '.join(tokens))
            
            # Create TF-IDF matrix
            vectorizer = TfidfVectorizer(max_features=5000, min_df=5, max_df=0.7)
            tfidf_matrix = vectorizer.fit_transform(processed_abstracts)
            
            # LDA Topic Modeling
            lda = LatentDirichletAllocation(n_components=n_topics, random_state=42, max_iter=20)
            lda.fit(tfidf_matrix)
            
            # Get feature names
            feature_names = vectorizer.get_feature_names_out()
            
            # Extract top words for each topic
            n_top_words = 15
            topic_words = []
            
            for topic_idx, topic in enumerate(lda.components_):
                top_words_idx = topic.argsort()[:-n_top_words - 1:-1]
                top_words = [feature_names[i] for i in top_words_idx]
                topic_words.append((topic_idx, top_words))
            
            # Save topic words
            topic_df = pd.DataFrame([(idx, ', '.join(words)) for idx, words in topic_words], 
                                   columns=['Topic ID', 'Top Words'])
            topic_df.to_csv(f"{output_dir}/lda_topics.csv", index=False)
            
            # Visualize topics
            fig, axes = plt.subplots(5, 2, figsize=(16, 20))
            axes = axes.flatten()
            
            for i, (topic_idx, top_words) in enumerate(topic_words):
                if i < 10:  # Plot only first 10 topics
                    ax = axes[i]
                    ax.barh(range(len(top_words)), lda.components_[topic_idx][topic.argsort()[:-n_top_words - 1:-1]], 
                           align='center')
                    ax.set_yticks(range(len(top_words)))
                    ax.set_yticklabels(top_words)
                    ax.invert_yaxis()
                    ax.set_title(f'Topic {topic_idx+1}')
            
            plt.tight_layout()
            plt.savefig(f"{output_dir}/lda_topics_visualized.png", dpi=300)
            
            # Document-Topic Distribution
            doc_topic_distr = lda.transform(tfidf_matrix)
            
            # Create a heatmap of document-topic distribution for a sample
            sample_docs = min(50, doc_topic_distr.shape[0])
            sample_indices = np.random.choice(doc_topic_distr.shape[0], sample_docs, replace=False)
            
            plt.figure(figsize=(12, 10))
            sns.heatmap(doc_topic_distr[sample_indices], cmap='YlGnBu', 
                       xticklabels=[f'Topic {i+1}' for i in range(n_topics)])
            plt.title('Document-Topic Distribution (Sample)')
            plt.ylabel('Document')
            plt.tight_layout()
            plt.savefig(f"{output_dir}/doc_topic_distribution.png", dpi=300)
        
        print("Topic/domain analysis completed")
    
    def answer_distribution_analysis(self, output_dir='results'):
        """Analyze the distribution of answers in the PQA-A dataset"""
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        print("Generating answer distribution analysis...")
        
        # Check if we have final_decision column (specific to PQA-A)
        if 'final_decision' not in self.df.columns:
            print("No final_decision column found. This analysis is specific to PQA-A dataset.")
            return
        
        # Get answer distribution
        answer_counts = self.df['final_decision'].value_counts()
        answer_percentages = 100 * answer_counts / len(self.df)
        
        # Save to CSV
        answer_df = pd.DataFrame({
            'Answer': answer_counts.index,
            'Count': answer_counts.values,
            'Percentage': answer_percentages.values
        })
        answer_df.to_csv(f"{output_dir}/answer_distribution.csv", index=False)
        
        # Plot answer distribution
        plt.figure(figsize=(10, 6))
        sns.barplot(x='Answer', y='Count', data=answer_df, palette='viridis')
        plt.title('Answer Distribution in PQA-A Dataset')
        plt.ylabel('Count')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/answer_distribution.png", dpi=300)
        
        # Plot as pie chart
        plt.figure(figsize=(10, 8))
        plt.pie(answer_counts, labels=answer_counts.index, autopct='%1.1f%%', 
                startangle=90, colors=sns.color_palette('viridis', len(answer_counts)))
        plt.axis('equal')
        plt.title('Answer Distribution in PQA-A Dataset')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/answer_distribution_pie.png", dpi=300)
        
        # Analyze temporal trends in answers if we have year data
        if 'year' in self.df.columns:
            # Group by year and calculate percentage of each answer type
            yearly_answers = pd.crosstab(self.df['year'], self.df['final_decision'], normalize='index')
            yearly_answers.to_csv(f"{output_dir}/yearly_answer_distribution.csv")
            
            # Plot trends
            plt.figure(figsize=(14, 8))
            yearly_answers.plot(kind='line', marker='o')
            plt.title('Answer Distribution Trends Over Time')
            plt.xlabel('Year')
            plt.ylabel('Proportion of Answers')
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.savefig(f"{output_dir}/answer_trends.png", dpi=300)
        
        # Analyze relationship between context length and answer
        bins = [0, 500, 1000, 2000, 3000, 5000, 10000, float('inf')]
        labels = ['0-500', '501-1000', '1001-2000', '2001-3000', '3001-5000', '5001-10000', '10000+']
        
        self.df['context_length_bin'] = pd.cut(self.df['context_length'], bins=bins, labels=labels)
        
        # Create cross-tabulation
        context_vs_answer = pd.crosstab(self.df['context_length_bin'], self.df['final_decision'], normalize='index')
        context_vs_answer.to_csv(f"{output_dir}/context_length_vs_answer.csv")
        
        # Plot relationship
        plt.figure(figsize=(14, 8))
        context_vs_answer.plot(kind='bar', stacked=True)
        plt.title('Answer Distribution by Context Length')
        plt.xlabel('Context Length (characters)')
        plt.ylabel('Proportion')
        plt.legend(title='Answer')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/context_length_vs_answer.png", dpi=300)
        
        print("Answer distribution analysis completed")
    
    def text_complexity_analysis(self, output_dir='results', sample_size=1000):
        """Analyze text complexity of questions and contexts"""
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        print("Generating text complexity analysis...")
        
        # Sample data for efficiency
        df_sample = self.df.sample(min(sample_size, len(self.df)), random_state=42)
        
        # Calculate readability metrics for questions
        print("Calculating readability metrics for questions...")
        readability_metrics = {
            'flesch_reading_ease': [],
            'flesch_kincaid_grade': [],
            'smog_index': [],
            'automated_readability_index': [],
            'coleman_liau_index': []
        }
        
        for q in df_sample['question']:
            try:
                readability_metrics['flesch_reading_ease'].append(textstat.flesch_reading_ease(q))
                readability_metrics['flesch_kincaid_grade'].append(textstat.flesch_kincaid_grade(q))
                readability_metrics['smog_index'].append(textstat.smog_index(q))
                readability_metrics['automated_readability_index'].append(textstat.automated_readability_index(q))
                readability_metrics['coleman_liau_index'].append(textstat.coleman_liau_index(q))
            except:
                # Skip if text is too short
                continue
        
        # Create DataFrame for readability metrics
        metrics_df = pd.DataFrame(readability_metrics)
        metrics_df.dropna(inplace=True)  # Remove any NaN values
        
        # Calculate descriptive statistics for readability metrics
        read_stats = metrics_df.describe()
        read_stats.to_csv(f"{output_dir}/question_readability_metrics.csv")
        
        # Plot readability distributions
        plt.figure(figsize=(12, 8))
        metrics_df.boxplot()
        plt.title('Question Readability Metrics Distribution')
        plt.grid(False)
        plt.tight_layout()
        plt.savefig(f"{output_dir}/question_readability_boxplot.png", dpi=300)
        
        # Sample contexts for analysis
        context_sample = df_sample['context'].sample(min(200, len(df_sample)), random_state=42)
        
        print("Analyzing sentence complexity...")
        # Sentence complexity analysis
        sentence_lengths = []
        word_lengths = []
        
        for text in context_sample:
            sentences = sent_tokenize(text)
            for sentence in sentences:
                words = word_tokenize(sentence)
                if words:  # Only if there are words
                    sentence_lengths.append(len(words))
                    word_lengths.extend([len(word) for word in words if word.isalpha()])
        
        # Save sentence complexity data
        sentence_data = pd.DataFrame({
            'Sentence Lengths': sentence_lengths,
            'Word Lengths': word_lengths[:len(sentence_lengths)]  # Match lengths
        })
        sentence_data.describe().to_csv(f"{output_dir}/sentence_complexity.csv")
        
        # Plot sentence and word length distributions
        fig, axs = plt.subplots(1, 2, figsize=(16, 6))
        
        sns.histplot(sentence_lengths, kde=True, ax=axs[0])
        axs[0].set_title('Sentence Length Distribution (words)')
        axs[0].set_xlabel('Words per Sentence')
        
        sns.histplot(word_lengths, kde=True, ax=axs[1])
        axs[1].set_title('Word Length Distribution (characters)')
        axs[1].set_xlabel('Characters per Word')
        
        plt.tight_layout()
        plt.savefig(f"{output_dir}/sentence_word_lengths.png", dpi=300)
        
        # Technical terminology analysis
        print("Analyzing medical terminology density...")
        
        # Use spaCy to identify medical entities
        medical_entity_counts = []
        entity_types = Counter()
        
        for text in context_sample.iloc[:50]:  # Limit for processing time
            doc = nlp(text)
            entities = [ent.text for ent in doc.ents]
            medical_entity_counts.append(len(entities))
            
            # Count entity types
            for ent in doc.ents:
                entity_types[ent.label_] += 1
        
        # Entity type distribution
        entity_type_df = pd.DataFrame(list(entity_types.items()), columns=['Entity Type', 'Count'])
        entity_type_df = entity_type_df.sort_values('Count', ascending=False)
        entity_type_df.to_csv(f"{output_dir}/entity_type_distribution.csv", index=False)
        
        # Plot entity type distribution
        plt.figure(figsize=(12, 8))
        sns.barplot(x='Count', y='Entity Type', data=entity_type_df.head(15))
        plt.title('Top 15 Entity Types')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/entity_types.png", dpi=300)
        
        # Medical terminology density
        plt.figure(figsize=(10, 6))
        sns.histplot(medical_entity_counts, kde=True)
        plt.title('Medical Entity Count Distribution')
        plt.xlabel('Entities per Context')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/medical_entity_counts.png", dpi=300)
        
        print("Text complexity analysis completed")
        
    def classwise_feature_analysis(self, output_dir='results'):
        """Analyze features across different answer classes in the PQA-A dataset"""
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        print("Generating class-wise feature analysis...")
        
        # Check if we have final_decision column (specific to PQA-A)
        if 'final_decision' not in self.df.columns:
            print("No final_decision column found. This analysis is specific to PQA-A dataset.")
            return
        
        # Create a feature matrix for analysis
        feature_df = self.df[['question_length', 'context_length', 'question_word_count', 
                              'final_decision']].copy()
        
        # Add question complexity metrics
        sample_size = min(1000, len(self.df))
        sampled_indices = np.random.choice(len(self.df), sample_size, replace=False)
        
        # Average word length in questions
        feature_df.loc[sampled_indices, 'avg_word_length'] = self.df.loc[sampled_indices, 'question'].apply(
            lambda q: np.mean([len(word) for word in word_tokenize(q) if word.isalpha()]) if word_tokenize(q) else 0
        )
        
        # Calculate readability of questions for the sample
        feature_df.loc[sampled_indices, 'flesch_kincaid_grade'] = self.df.loc[sampled_indices, 'question'].apply(
            lambda q: textstat.flesch_kincaid_grade(q) if len(q) > 50 else np.nan
        )
        
        # Group statistics by answer class
        class_stats = feature_df.groupby('final_decision').agg({
            'question_length': ['mean', 'median', 'std'],
            'context_length': ['mean', 'median', 'std'],
            'question_word_count': ['mean', 'median', 'std'],
            'avg_word_length': ['mean', 'median', 'std'],
            'flesch_kincaid_grade': ['mean', 'median', 'std']
        })
        
        # Save class statistics
        class_stats.to_csv(f"{output_dir}/class_feature_statistics.csv")
        
        # Plot class-wise feature comparisons
        features_to_plot = ['question_length', 'context_length', 'question_word_count']
        
        # Create boxplots for each feature by class
        fig, axes = plt.subplots(len(features_to_plot), 1, figsize=(12, 4*len(features_to_plot)))
        
        for i, feature in enumerate(features_to_plot):
            sns.boxplot(x='final_decision', y=feature, data=feature_df, ax=axes[i])
            axes[i].set_title(f'{feature} by Answer Class')
            axes[i].set_xlabel('Answer')
        
        plt.tight_layout()
        plt.savefig(f"{output_dir}/class_feature_boxplots.png", dpi=300)
        
        # Statistical significance tests
        print("Performing statistical significance tests...")
        
        from scipy import stats
        
        # Perform ANOVA for each feature
        anova_results = {}
        for feature in features_to_plot:
            # Get data for each class
            groups = [feature_df[feature_df['final_decision'] == cls][feature].dropna() 
                     for cls in feature_df['final_decision'].unique()]
            
            # Perform ANOVA
            try:
                f_stat, p_value = stats.f_oneway(*groups)
                anova_results[feature] = {
                    'f_statistic': f_stat,
                    'p_value': p_value,
                    'significant': p_value < 0.05
                }
            except:
                anova_results[feature] = {
                    'f_statistic': np.nan,
                    'p_value': np.nan,
                    'significant': False
                }
        
        # Save ANOVA results
        anova_df = pd.DataFrame(anova_results).T
        anova_df.to_csv(f"{output_dir}/feature_anova_results.csv")
        
        # Model predictiveness analysis
        if len(feature_df) >= 1000:  # Only if we have enough data
            print("Analyzing feature predictiveness for answer classification...")
            
            from sklearn.model_selection import train_test_split
            from sklearn.ensemble import RandomForestClassifier
            from sklearn.metrics import classification_report, confusion_matrix
            from sklearn.preprocessing import StandardScaler
            
            # Prepare features and target
            X = feature_df[features_to_plot].dropna()
            y = feature_df.loc[X.index, 'final_decision']
            
            # Train-test split
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
            
            # Scale features
            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_test_scaled = scaler.transform(X_test)
            
            # Train classifier
            clf = RandomForestClassifier(n_estimators=100, random_state=42)
            clf.fit(X_train_scaled, y_train)
            
            # Evaluate
            y_pred = clf.predict(X_test_scaled)
            
            # Save classification report
            report = classification_report(y_test, y_pred, output_dict=True)
            report_df = pd.DataFrame(report).T
            report_df.to_csv(f"{output_dir}/classification_report.csv")
            
            # Plot confusion matrix
            cm = confusion_matrix(y_test, y_pred)
            plt.figure(figsize=(10, 8))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                       xticklabels=clf.classes_, yticklabels=clf.classes_)
            plt.title('Confusion Matrix')
            plt.xlabel('Predicted')
            plt.ylabel('Actual')
            plt.tight_layout()
            plt.savefig(f"{output_dir}/confusion_matrix.png", dpi=300)
            
            # Feature importance
            feature_imp = pd.DataFrame({
                'Feature': features_to_plot,
                'Importance': clf.feature_importances_
            }).sort_values('Importance', ascending=False)
            
            feature_imp.to_csv(f"{output_dir}/feature_importance.csv", index=False)
            
            # Plot feature importance
            plt.figure(figsize=(10, 6))
            sns.barplot(x='Importance', y='Feature', data=feature_imp)
            plt.title('Feature Importance for Answer Classification')
            plt.tight_layout()
            plt.savefig(f"{output_dir}/feature_importance.png", dpi=300)
        
        print("Class-wise feature analysis completed")


In [6]:
def question_type_analysis(self, output_dir='results'):
        """Analyze question types and patterns"""
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        print("Generating question type analysis...")
        
        # Function to classify question type
        def classify_question(q):
            q = q.lower().strip()
            
            # Basic question classification
            if q.startswith('is ') or q.startswith('are ') or q.startswith('does ') or q.startswith('do ') or q.startswith('can ') or q.startswith('could '):
                return 'Yes/No'
            elif q.startswith('what '):
                return 'What'
            elif q.startswith('how '):
                return 'How'
            elif q.startswith('why '):
                return 'Why'
            elif q.startswith('when '):
                return 'When'
            elif q.startswith('where '):
                return 'Where'
            elif q.startswith('which '):
                return 'Which'
            elif q.startswith('who '):
                return 'Who'
            else:
                return 'Other'
        
        # Classify each question
        self.df['question_type'] = self.df['question'].apply(classify_question)
        
        # Get question type distribution
        q_type_counts = self.df['question_type'].value_counts()
        
        # Save results
        q_type_df = pd.DataFrame(q_type_counts).reset_index()
        q_type_df.columns = ['Question Type', 'Count']
        q_type_df.to_csv(f"{output_dir}/question_type_distribution.csv", index=False)
        
        # Plot question type distribution
        plt.figure(figsize=(12, 8))
        sns.barplot(x='Count', y='Question Type', data=q_type_df)
        plt.title('Question Type Distribution')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/question_type_distribution.png", dpi=300)
        
        # Analyze question complexity
        self.df['question_word_count'] = self.df['question'].apply(lambda x: len(word_tokenize(x)))
        
        # Plot question complexity by type
        plt.figure(figsize=(14, 8))
        sns.boxplot(x='question_type', y='question_word_count', data=self.df)
        plt.title('Question Word Count by Question Type')
        plt.xlabel('Question Type')
        plt.ylabel('Word Count')
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(f"{output_dir}/question_complexity_by_type.png", dpi=300)
        
        # If we have final_decision data (specific to PQA-A dataset)
        if 'final_decision' in self.df.columns:
            # Analyze relationship between question type and answer
            q_type_answer = pd.crosstab(self.df['question_type'], self.df['final_decision'])
            q_type_answer.to_csv(f"{output_dir}/question_type_vs_answer.csv")
            
            # Normalize for percentage
            q_type_answer_pct = pd.crosstab(self.df['question_type'], self.df['final_decision'], normalize='index')
            q_type_answer_pct.to_csv(f"{output_dir}/question_type_vs_answer_pct.csv")
            
            # Plot relationship
            plt.figure(figsize=(14, 10))
            q_type_answer_pct.plot(kind='bar', stacked=True, colormap='viridis')
            plt.title('Answer Distribution by Question Type')
            plt.xlabel('Question Type')
            plt.ylabel('Percentage')
            plt.legend(title='Answer')
            plt.tight_layout()
            plt.savefig(f"{output_dir}/answer_by_question_type.png", dpi=300)
            
            # Analyze question complexity impact on answer
            answer_by_complexity = {}
            word_count_bins = [0, 5, 10, 15, 20, 25, 30, 999]
            bin_labels = ['1-5', '6-10', '11-15', '16-20', '21-25', '26-30', '30+']
            
            self.df['word_count_bin'] = pd.cut(self.df['question_word_count'], 
                                              bins=word_count_bins, 
                                              labels=bin_labels)
            
            complexity_vs_answer = pd.crosstab(self.df['word_count_bin'], self.df['final_decision'], normalize='index')
            complexity_vs_answer.to_csv(f"{output_dir}/complexity_vs_answer.csv")
            
            # Plot relationship
            plt.figure(figsize=(14, 8))
            complexity_vs_answer.plot(kind='bar', stacked=True, colormap='viridis')
            plt.title('Answer Distribution by Question Complexity')
            plt.xlabel('Question Word Count')
            plt.ylabel('Percentage')
            plt.legend(title='Answer')
            plt.tight_layout()
            plt.savefig(f"{output_dir}/answer_by_complexity.png", dpi=300)
        
        # Extract key medical entities from questions
        print("Extracting medical entities from questions...")
        
        # Sample for efficiency
        sample_size = min(1000, len(self.df))
        sampled_questions = self.df['question'].sample(sample_size, random_state=42)
        
        medical_entities = []
        for question in sampled_questions:
            doc = nlp(question)
            entities = [(ent.text, ent.label_) for ent in doc.ents]
            medical_entities.extend(entities)
        
        # Count entity occurrences
        entity_counter = Counter([entity[0].lower() for entity in medical_entities])
        top_entities = entity_counter.most_common(30)
        
        # Save top entities
        entity_df = pd.DataFrame(top_entities, columns=['Entity', 'Count'])
        entity_df.to_csv(f"{output_dir}/top_question_entities.csv", index=False)
        
        # Plot top entities
        plt.figure(figsize=(14, 10))
        sns.barplot(x='Count', y='Entity', data=entity_df.head(20))
        plt.title('Top 20 Medical Entities in Questions')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/top_question_entities.png", dpi=300)
        
        print("Question type analysis completed")