# Retrain CFN Model (Colab)

This notebook retrains the Causal Fusion Network (CFN) with efficient techniques and saves the new weights to `models/cfn.pth`.


## Efficient techniques used
- Feature standardization (train-set mean/std)
- Class-imbalance weighting (positive class upweighting)
- Mixed precision on GPU (AMP)
- Cosine learning-rate schedule
- Gradient clipping
- Early stopping with validation loss


## 1. Setup
Mount Drive (optional) and install dependencies.


In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

# !pip install -r /content/drive/MyDrive/CausalX/requirements.txt


## 2. Set project paths
Update these paths to match your Drive or Colab workspace.


In [None]:
from pathlib import Path

PROJECT_ROOT = Path('/content/drive/MyDrive/CausalX')
DATA_ROOT = PROJECT_ROOT / 'backend' / 'data'
PROCESSED_CSV = DATA_ROOT / 'processed' / 'causal_multimodal_dataset.csv'
MODEL_DIR = PROJECT_ROOT / 'backend' / 'models'


## 3. (Optional) Build the feature dataset
Run the batch feature extractor if the processed CSV does not exist or you updated feature definitions.


In [None]:
# %cd /content/drive/MyDrive/CausalX/backend
# !python -m src.preprocessing.batch_feature_extractor


## 3a. (Optional) Balance the dataset (Option A)
If your dataset is highly imbalanced, downsample to the smallest class before training.


In [None]:
# %cd /content/drive/MyDrive/CausalX/backend
# !python -m src.preprocessing.balance_dataset \
#   --input-csv data/processed/causal_multimodal_dataset.csv \
#   --output-csv data/processed/causal_multimodal_dataset_balanced.csv


## 4. Train the model
The script will save the best checkpoint to `models/cfn.pth`.


In [None]:
# %cd /content/drive/MyDrive/CausalX/backend
# !python -m src.training.train_cfn \
#   --dataset-csv data/processed/causal_multimodal_dataset.csv \
#   --epochs 30 \
#   --batch-size 128 \
#   --lr 1e-3 \
#   --weight-decay 1e-4
