# 03 - EDA and Temporal Splits

## Goal

EDA for trust: class balance, year drift, label co-occurrence. Then temporal splits (train ‚â§2021, val 2022‚Äì2023, test ‚â•2024) to prevent time leakage.


## Why This Step Matters

**Trust in data** comes from understanding it:

- **Class balance:** Are some labels extremely rare?
- **Year trends:** Are study designs changing over time?
- **Co-occurrence:** Do certain labels always appear together?
- **Temporal splits:** Prevent leakage (future knowledge influencing past predictions)

Without EDA, you're training blind.


In [None]:
# TODO: Import libraries
# Hint: import pandas as pd, numpy as np
# import matplotlib.pyplot as plt, seaborn as sns
# from pathlib import Path
# import pandera as pa
# sns.set_style('whitegrid')


In [None]:
# TODO: Load processed parquet
# Hint: df = pd.read_parquet('../data/processed/dental_abstracts.parquet')
#       print(f"Loaded {len(df)} papers")


## Basic Counts

Let's understand the dataset size and temporal distribution.


In [None]:
# TODO: Basic counts
# Hint: print(f"Total papers: {len(df)}")
#       print(f"\nYear distribution:")
#       print(df['year'].value_counts().sort_index())
#       print(f"\nTop 10 journals:")
#       print(df['journal'].value_counts().head(10))


## Class Balance

Which labels are common? Which are rare?


In [None]:
# TODO: Class balance barplot
# Hint: from collections import Counter
# all_labels = [label for labels in df['labels'] for label in labels]
# label_counts = Counter(all_labels)
# 
# plt.figure(figsize=(12, 6))
# labels_sorted = sorted(label_counts.items(), key=lambda x: x[1], reverse=True)
# plt.barh([l[0] for l in labels_sorted], [l[1] for l in labels_sorted])
# plt.xlabel('Count')
# plt.title('Label Distribution (Multi-label)')
# plt.tight_layout()
# plt.show()


## Label Co-occurrence

Do certain labels always appear together?


In [None]:
# TODO: Label co-occurrence heatmap
# Hint: # Build binary matrix
# unique_labels = sorted(label_counts.keys())
# label_matrix = np.zeros((len(df), len(unique_labels)))
# for i, labels in enumerate(df['labels']):
#     for label in labels:
#         j = unique_labels.index(label)
#         label_matrix[i, j] = 1
# 
# # Co-occurrence matrix
# co_occur = label_matrix.T @ label_matrix
# np.fill_diagonal(co_occur, 0)  # Zero out diagonal
# 
# plt.figure(figsize=(10, 8))
# sns.heatmap(co_occur, xticklabels=unique_labels, yticklabels=unique_labels, annot=True, fmt='.0f', cmap='Blues')
# plt.title('Label Co-occurrence Matrix')
# plt.tight_layout()
# plt.show()


## Temporal Splits

**Critical:** Split by year to prevent temporal leakage.

- **Train:** ‚â§ 2021 (~60-70% of data)
- **Val:** 2022-2023 (~15-20%)
- **Test:** ‚â• 2024 (~15-20%)

This mimics real-world deployment: predicting future papers based on past patterns.


In [None]:
# TODO: Decide splits
# Hint: def assign_split(year):
#     if year <= 2021:
#         return 'train'
#     elif year <= 2023:
#         return 'val'
#     else:
#         return 'test'
# 
# df['split'] = df['year'].apply(assign_split)
# print(df['split'].value_counts())


In [None]:
# TODO: Save split parquets
# Hint: for split_name in ['train', 'val', 'test']:
#     split_df = df[df['split'] == split_name]
#     split_df.to_parquet(f'../data/processed/{split_name}.parquet', index=False)
#     print(f"Saved {split_name}.parquet ({len(split_df)} papers)")


## Schema Validation

Use Pandera to validate data quality before training.


In [None]:
# TODO: Schema validation (pandera)
# Hint: schema = pa.DataFrameSchema({
#     'pmid': pa.Column(str, nullable=False),
#     'title': pa.Column(str, nullable=False, checks=pa.Check.str_length(min_value=1)),
#     'abstract': pa.Column(str, nullable=False, checks=pa.Check.str_length(min_value=10)),
#     'year': pa.Column(int, checks=pa.Check.in_range(2018, 2025)),
#     'labels': pa.Column(object, checks=pa.Check(lambda x: len(x) > 0)),
#     'split': pa.Column(str, checks=pa.Check.isin(['train', 'val', 'test']))
# })
# 
# # Validate
# try:
#     schema.validate(df)
#     print("‚úÖ Schema validation passed!")
# except pa.errors.SchemaError as e:
#     print(f"‚ùå Schema validation failed: {e}")


## Recommendations

- **Revisit year cutoffs** if val/test too small (aim for at least 1000 papers each)
- **For severe imbalance:** Consider stratified sampling within year windows (stretch goal)
- **Document decisions:** Why these splits? What assumptions are we making?

## üßò Reflection Log

**What did you learn in this session?**
- 

**What challenges did you encounter?**
- 

**How will this improve Periospot AI?**
- 
