In [5]:
pip install pandas matplotlib seaborn tqdm

Note: you may need to restart the kernel to use updated packages.


In [6]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from tqdm import tqdm

# PubMedQA Labeled Dataset
def load_pubmedqa_labeled(path="/Users/casey/Documents/GitHub/LLM_Healthcare/ori_pqal.json"):
    with open(path, "r") as f:
        data = json.load(f)
    return pd.DataFrame(data)

In [7]:
# Overview of the dataset
def basic_summary(df):
    print("✅ Dataset Dimensions:", df.shape)
    print("✅ Columns:", list(df.columns))
    print("\n✅ Sample Row:")
    print(df.iloc[0])
    print("\n✅ Null Values:")
    print(df.isnull().sum())


In [8]:
# Class Distribution Analysis
def plot_label_distribution(df):
    label_counts = df['final_answer'].value_counts()
    sns.barplot(x=label_counts.index, y=label_counts.values)
    plt.title("Answer Label Distribution")
    plt.xlabel("Final Answer")
    plt.ylabel("Count")
    plt.show()

In [9]:
# Text Length Analysis
def plot_text_lengths(df):
    df['question_len'] = df['question'].apply(lambda x: len(x.split()))
    df['context_len'] = df['context'].apply(lambda x: len(x.split()))
    df['answer_len'] = df['long_answer'].apply(lambda x: len(x.split()))

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    sns.histplot(df['question_len'], ax=axes[0], kde=True)
    axes[0].set_title("Question Length")

    sns.histplot(df['context_len'], ax=axes[1], kde=True)
    axes[1].set_title("Context Length")

    sns.histplot(df['answer_len'], ax=axes[2], kde=True)
    axes[2].set_title("Long Answer Length")
    
    plt.tight_layout()
    plt.show()

In [10]:
# Checking for Duplicates
def check_duplicates(df):
    total = len(df)
    unique_questions = df['question'].nunique()
    print(f"🔍 Unique questions: {unique_questions} / {total}")
    if total != unique_questions:
        duplicates = df[df.duplicated(['question'], keep=False)]
        print(f"⚠️ {len(duplicates)} duplicate entries found.")
        return duplicates
    else:
        print("✅ No duplicate questions found.")
        return None

In [11]:
# Sample Display
def print_random_sample(df, n=3):
    samples = df.sample(n)
    for _, row in samples.iterrows():
        print("\n📌 --------")
        print(f"Question: {row['question']}")
        print(f"Context: {row['context'][:300]}...")  # Displaying first 300 chars
        print(f"Long Answer: {row['long_answer']}")
        print(f"Final Answer: {row['final_answer']}")

In [13]:
# Full EDA Pipeline
def run_eda(path="/Users/casey/Documents/GitHub/LLM_Healthcare/ori_pqal.json"):
    print("📥 Loading dataset...")
    df = load_pubmedqa_labeled(path)
    
    print("\n📊 Basic Summary:")
    basic_summary(df)

    print("\n📊 Label Distribution:")
    plot_label_distribution(df)

    print("\n📊 Text Length Distributions:")
    plot_text_lengths(df)

    print("\n📊 Duplicate Check:")
    duplicates = check_duplicates(df)

    print("\n🔎 Sample Entries:")
    print_random_sample(df)

    return df

if __name__ == "__main__":
    # Run the EDA pipeline
    run_eda()

📥 Loading dataset...

📊 Basic Summary:
✅ Dataset Dimensions: (9, 1000)
✅ Columns: ['21645374', '16418930', '9488747', '17208539', '10808977', '23831910', '26037986', '26852225', '17113061', '10966337', '25432938', '18847643', '18239988', '25957366', '24866606', '26578404', '11729377', '17096624', '22694248', '22990761', '19394934', '11481599', '21669959', '23806388', '17919952', '10966943', '23690198', '17940352', '20537205', '28707539', '7482275', '24183388', '9645785', '26298839', '24153338', '18534072', '8847047', '10575390', '20084845', '15703931', '18269157', '25489696', '14599616', '22537902', '19054501', '16432652', '19504993', '20571467', '24237112', '21402341', '20082356', '26606599', '11340218', '25481573', '25277731', '25475395', '23177368', '17179167', '14612308', '27491658', '19822586', '27643961', '23539689', '22453060', '22227642', '12380309', '22186742', '22188074', '27989969', '18607272', '18235194', '18049437', '15597845', '24996865', '23361217', '17489316', '14518645

KeyError: 'final_answer'