# PubMed RCT - Data Exploration

This notebook explores the PubMed 20k RCT dataset for sentence classification in medical abstracts.

**Goal:** Understand data structure, class distribution, and sentence length statistics.

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

## Data Loading

The PubMed RCT dataset contains medical abstracts where each sentence is labeled with one of 5 categories:
- BACKGROUND, OBJECTIVE, METHODS, RESULTS, CONCLUSIONS

File format:
```
###ABSTRACT_ID
LABEL\tSENTENCE_TEXT
...
(blank line between abstracts)
```

In [None]:
def load_pubmed_data(filepath):
    """Load and preprocess PubMed RCT data from a text file.
    Returns a list of dicts with keys: target, text, line_number, total_lines.
    """
    with open(filepath, "r", encoding="utf-8") as f:
        lines = f.readlines()

    samples = []
    abstract_lines = ""

    for line in lines:
        if line.startswith("###"):
            abstract_lines = ""
        elif line.isspace():
            split = abstract_lines.splitlines()
            for i, al in enumerate(split):
                parts = al.split("\t")
                if len(parts) == 2:
                    samples.append({
                        "target": parts[0],
                        "text": parts[1].lower(),
                        "line_number": i,
                        "total_lines": len(split) - 1
                    })
        else:
            abstract_lines += line

    return samples

In [None]:
# Path to the dataset
DATA_DIR = "../data/pubmed-rct/PubMed_20k_RCT_numbers_replaced_with_at_sign/"

# Load all splits
train_samples = load_pubmed_data(os.path.join(DATA_DIR, "train.txt"))
val_samples = load_pubmed_data(os.path.join(DATA_DIR, "dev.txt"))
test_samples = load_pubmed_data(os.path.join(DATA_DIR, "test.txt"))

print(f"Train: {len(train_samples):,} samples")
print(f"Val:   {len(val_samples):,} samples")
print(f"Test:  {len(test_samples):,} samples")

In [None]:
# Convert to DataFrames
train_df = pd.DataFrame(train_samples)
val_df = pd.DataFrame(val_samples)
test_df = pd.DataFrame(test_samples)

train_df.head(10)

## Class Distribution

Let us check how the classes are distributed in the training set.

In [None]:
class_counts = train_df["target"].value_counts()

print("Class distribution (training set):")
print(class_counts)
print()
print("Percentages:")
print((class_counts / len(train_df) * 100).round(2))

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
bars = ax.bar(class_counts.index, class_counts.values, color="steelblue", edgecolor="black")

for i, (label, val) in enumerate(class_counts.items()):
    pct = val / len(train_df) * 100
    ax.text(i, val + 50, f"{val:,}\n({pct:.1f}%)", ha="center", fontweight="bold")

ax.set_xlabel("Label")
ax.set_ylabel("Count")
ax.set_title("Class Distribution - Training Set")
plt.tight_layout()
plt.show()

ratio = class_counts.max() / class_counts.min()
print(f"Imbalance ratio (max/min): {ratio:.2f}")

## Sentence Length Analysis

We analyze the number of words per sentence to decide on `max_length` for padding in deep learning models.

In [None]:
sent_lengths = [len(s.split()) for s in train_df["text"]]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.hist(sent_lengths, bins=30, color="skyblue", edgecolor="black")
ax1.axvline(np.mean(sent_lengths), color="red", linestyle="--", label=f"Mean: {np.mean(sent_lengths):.1f}")
ax1.axvline(np.median(sent_lengths), color="green", linestyle="--", label=f"Median: {np.median(sent_lengths):.0f}")
ax1.set_xlabel("Number of words")
ax1.set_ylabel("Frequency")
ax1.set_title("Sentence Length Distribution")
ax1.legend()

bp = ax2.boxplot(sent_lengths, vert=True, patch_artist=True)
bp["boxes"][0].set_facecolor("lightblue")
ax2.set_ylabel("Number of words")
ax2.set_title("Box Plot")

plt.tight_layout()
plt.show()

In [None]:
print("Sentence length statistics:")
print(f"  Mean:            {np.mean(sent_lengths):.1f}")
print(f"  Median:          {np.median(sent_lengths):.0f}")
print(f"  Std:             {np.std(sent_lengths):.1f}")
print(f"  Min:             {np.min(sent_lengths)}")
print(f"  Max:             {np.max(sent_lengths)}")
print(f"  95th percentile: {np.percentile(sent_lengths, 95):.0f}")
print(f"  99th percentile: {np.percentile(sent_lengths, 99):.0f}")
print()
p95 = int(np.percentile(sent_lengths, 95))
print(f"Recommended max_length = {p95} (covers 95% of sentences)")