# PubMed 200k RCT Classifier

This notebook demonstrates the usage of the `classifier_core` package to train and evaluate a Transformer-based model for classifying sentences in medical abstracts.

## Project Structure
The codebase is modularized into `classifier_core` containing:
- **Data Loading**: Downloading and parsing PubMed 200k RCT dataset.
- **Preprocessing**: Text vectorization and dataset creation.
- **Modeling**: Custom Transformer architecture.
- **Evaluation**: Metrics and reporting.

## 1. Setup
Imports and configuration.

In [None]:
import sys
import os

# Add project root to path
sys.path.append(os.path.abspath('..'))

from classifier_core import downloader, data_loading, preprocessing, modeling, evaluation, utils
from classifier_core.config import DatasetConfig, ModelConfig, TrainConfig
import tensorflow as tf

# --- Path Declarations ---
DATA_DIR = "../pubmed_rct"
MODEL_DIR = "../classifier_core"
MODEL_NAME = "transformer_model.keras"
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME)

# Check for GPU
utils.check_gpu()

## 2. Configuration
Define hyperparameters using `dataclasses`.

In [None]:
data_cfg = DatasetConfig(data_dir=DATA_DIR, batch_size=32)
model_cfg = ModelConfig(embed_dim=128, num_heads=4, ff_dim=128)
train_cfg = TrainConfig(epochs=3, model_save_dir=MODEL_DIR, model_name=MODEL_NAME)


## 3. Data Preparation
Download and process the data.

In [None]:
downloader.download_pubmed_data(data_dir=data_cfg.data_dir)

train_samples, val_samples, test_samples = data_loading.load_data(data_dir=data_cfg.data_dir)

train_ds, val_ds, test_ds, text_vectorizer, class_names, output_seq_len = preprocessing.create_datasets(
    train_samples, val_samples, test_samples,
    batch_size=data_cfg.batch_size,
    max_tokens=data_cfg.max_tokens,
    output_seq_len=data_cfg.output_seq_len
)

print(f"Classes: {class_names}")

## 4. Model Definition
Build the Transformer model.

In [None]:
model = modeling.build_model(
    text_vectorizer=text_vectorizer,
    vocab_size=data_cfg.max_tokens,
    output_seq_len=output_seq_len,
    num_classes=len(class_names),
    embed_dim=model_cfg.embed_dim,
    num_heads=model_cfg.num_heads,
    ff_dim=model_cfg.ff_dim
)

model.summary()

## 5. Training
Train model with early stopping.

In [None]:
history = model.fit(train_ds,
                    epochs=train_cfg.epochs,
                    validation_data=val_ds,
                    callbacks=[tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True)])

## 6. Evaluation
Evaluate on test set and plot results.

In [None]:
# Plot training history
evaluation.plot_loss_curves(history)

# Evaluate model and show confusion matrix
evaluation.evaluate_model(model, test_ds, class_names, plot_conf_mat=True)