In [None]:
# EDA for MTS-Dialog Dataset

import os
from pathlib import Path
import pandas as pd

# Default paths (adjust if needed)
BASE_DIR = Path('data/primary/mts-dialog/clinical_visit_note_summarization_corpus-main/data')
MTS_DIR = BASE_DIR / 'mts-dialog'

# Prefer MTS-Dialog CSVs; fallback to aci-bench if needed
CANDIDATE_FILES = [
    MTS_DIR / 'MTS_Dataset_TrainingSet.csv',
    MTS_DIR / 'MTS_Dataset_ValidationSet.csv',
    MTS_DIR / 'MTS_Dataset_Final_200_TestSet_1.csv',
    MTS_DIR / 'MTS_Dataset_Final_200_TestSet_2.csv',
    BASE_DIR / 'aci-bench' / 'challenge_data' / 'train.csv',
]

def first_existing(paths):
    for p in paths:
        if p.exists():
            return p
    return None

train_path = first_existing(CANDIDATE_FILES)
assert train_path is not None, 'No dataset CSV found in expected locations.'

print(f'Using dataset: {train_path}')
df = pd.read_csv(train_path)
print(df.head(3))
print('\nColumns:', list(df.columns))


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Identify input/output text columns heuristically
# Common schemas:
# - MTS-Dialog: columns include ['dataset','encounter_id','dialogue','note']
# - ACI-bench: columns include ['ID','section_header','section_text','dialogue']

cols = {c.lower(): c for c in df.columns}

dialogue_col = cols.get('dialogue')
note_col = cols.get('note') or cols.get('section_text')

assert dialogue_col is not None, 'Could not find dialogue column.'
assert note_col is not None, 'Could not find note/section_text column.'

print(f'Using dialogue column: {dialogue_col}')
print(f'Using note/summary column: {note_col}')

# Compute basic stats
for label, col in [('dialogue', dialogue_col), ('note', note_col)]:
    lengths = df[col].fillna('').astype(str).str.len()
    print(f"\n{label} length (chars): count={lengths.size}, mean={lengths.mean():.1f}, median={lengths.median():.1f}, p90={lengths.quantile(0.9):.1f}, max={lengths.max()}")



In [None]:
# Plot distributions of character lengths
plt.figure(figsize=(10,5))
sns.kdeplot(df[dialogue_col].fillna('').astype(str).str.len(), label='dialogue (input)', linewidth=2)
sns.kdeplot(df[note_col].fillna('').astype(str).str.len(), label='note/summary (target)', linewidth=2)
plt.title('Length Distribution (characters)')
plt.xlabel('length (chars)')
plt.ylabel('density')
plt.legend()
plt.tight_layout()
plt.show()

# Also show histograms for a different perspective
fig, axes = plt.subplots(1, 2, figsize=(12,4))
axes[0].hist(df[dialogue_col].fillna('').astype(str).str.len(), bins=50, color='#4e79a7')
axes[0].set_title('Dialogue length (chars)')
axes[1].hist(df[note_col].fillna('').astype(str).str.len(), bins=50, color='#f28e2b')
axes[1].set_title('Note/Summary length (chars)')
for ax in axes:
    ax.set_xlabel('length (chars)')
    ax.set_ylabel('count')
plt.tight_layout()
plt.show()


In [None]:
# Show 10 random pairs to manually review note/summary style
SAMPLE_N = 10
sample_df = df.sample(min(SAMPLE_N, len(df)), random_state=42)[[dialogue_col, note_col]]
for i, row in sample_df.iterrows():
    print('='*80)
    print(f'Row index: {i}')
    print('- Dialogue -')
    print(str(row[dialogue_col])[:2000])
    print('\n- Note/Summary -')
    print(str(row[note_col])[:2000])
