# Incremental Task Finetuning Demo

This notebook incrementally adds regression tasks (density -> Cp -> Rg -> linear_expansion) using the flexible multi-task foundation model. Each stage reloads the previous checkpoint, adds a new task head, and continues training on the combined task set.


## Data Overview

- **Descriptors**: `data/amorphous_polymer_FFDescriptor_20250730.parquet`
- **Target properties**: `data/amorphous_polymer_non_PI_properties_20250730.parquet`
- Sequential task order: density -> Cp -> Rg -> linear_expansion
- The descriptor and property tables are aligned on their shared indices prior to splitting.


In [1]:
import os
from pathlib import Path
import math
import json
import re

import pandas as pd
import matplotlib.pyplot as plt
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger

from foundation_model.data.datamodule import CompoundDataModule
from foundation_model.models.flexible_multi_task_model import FlexibleMultiTaskModel
from foundation_model.models.model_config import RegressionTaskConfig, TaskType, OptimizerConfig


[32m2025-10-31 08:38:22.889[0m | [1mINFO    [0m | [36m__init__[0m:[36m<module>[0m:[36m34[0m - [1mLoguru logger initialized for foundation_model package.[0m


In [2]:
DATA_DIR = Path('../data')
DESCRIPTOR_PATH = DATA_DIR / 'amorphous_polymer_FFDescriptor_20250730.parquet'
PROPERTY_PATH = DATA_DIR / 'amorphous_polymer_non_PI_properties_20250730.parquet'

USE_NORMALIZED_TARGETS = False
TASK_SEQUENCE = ['density', 'Cp', 'Rg', 'linear_expansion']
TARGET_COLUMNS = {
    'density': f"density{'(normalized)' if USE_NORMALIZED_TARGETS else ''}",
    'Cp': f"Cp{'(normalized)' if USE_NORMALIZED_TARGETS else ''}",
    'Rg': f"Rg{'(normalized)' if USE_NORMALIZED_TARGETS else ''}",
    'linear_expansion': f"linear_expansion{'(normalized)' if USE_NORMALIZED_TARGETS else ''}",
}

SHARED_BLOCK_DIMS = [190, 256, 128]
HEAD_HIDDEN = 64
ARTIFACT_ROOT = Path('../artifacts/polymers_incremental_tasks')
ARTIFACT_ROOT.mkdir(parents=True, exist_ok=True)

TRAIN_SAMPLE = None  # set to an int for faster smoke runs
BATCH_SIZE = 256
NUM_WORKERS = 0
MAX_EPOCHS = 20
LOG_EVERY_N_STEPS = 5


In [3]:
descriptor_df = pd.read_parquet(DESCRIPTOR_PATH)
property_df = pd.read_parquet(PROPERTY_PATH)

missing = [col for col in TARGET_COLUMNS.values() if col not in property_df.columns]
if missing:
    raise KeyError(f'Missing target columns in property table: {missing}')

common_index = descriptor_df.index.intersection(property_df.index)
feature_frame = descriptor_df.loc[common_index]
target_frame = property_df.loc[common_index, [TARGET_COLUMNS[name] for name in TASK_SEQUENCE]]

if TRAIN_SAMPLE is not None and TRAIN_SAMPLE < len(feature_frame):
    feature_frame = feature_frame.sample(n=TRAIN_SAMPLE, random_state=42)
    target_frame = target_frame.loc[feature_frame.index]

print(f'Feature matrix: {feature_frame.shape}')
print(f'Target matrix: {target_frame.shape}')
print(f'First targets: {list(target_frame.columns)}')


Feature matrix: (71725, 190)
Target matrix: (71725, 4)
First targets: ['density', 'Cp', 'Rg', 'linear_expansion']


## Helper Functions


In [4]:
def build_regression_task(name: str, column: str) -> RegressionTaskConfig:
    return RegressionTaskConfig(
        name=name,
        data_column=column,
        dims=[SHARED_BLOCK_DIMS[-1], HEAD_HIDDEN, 1],
        norm=True,
        residual=False,
    )

def make_task_configs(task_names: list[str]) -> list[RegressionTaskConfig]:
    return [build_regression_task(task_name, TARGET_COLUMNS[task_name]) for task_name in task_names]

def build_datamodule(task_configs: list[RegressionTaskConfig], *, batch_size: int = BATCH_SIZE) -> CompoundDataModule:
    stage_targets = target_frame.loc[:, [cfg.data_column for cfg in task_configs]]
    return CompoundDataModule(
        formula_desc_source=feature_frame,
        attributes_source=stage_targets,
        task_configs=task_configs,
        batch_size=batch_size,
        num_workers=NUM_WORKERS,
    )


In [5]:
def plot_test_predictions(
    model: FlexibleMultiTaskModel,
    datamodule: CompoundDataModule,
    *,
    stage_num: int,
    stage_tasks: list[str],
    new_task_name: str,
    prediction_dir: Path | str | None = None,
) -> None:
    # Render predicted vs. actual scatter plots and persist evaluation artifacts.
    if prediction_dir is None:
        prediction_dir = ARTIFACT_ROOT / f'Stage{stage_num}_{new_task_name}' / 'prediction'
    prediction_dir = Path(prediction_dir)
    prediction_dir.mkdir(parents=True, exist_ok=True)

    metrics_path = prediction_dir / 'metrics.json'
    predictions_path = prediction_dir / 'predictions.parquet'
    task_order_path = prediction_dir / 'tasks.txt'
    task_order_path.write_text(' -> '.join(stage_tasks) + '', encoding='utf-8')

    if torch.cuda.is_available():
        device = 'cuda'
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device = 'mps'
    else:
        device = 'cpu'

    datamodule.setup(stage='test')
    test_loader = datamodule.test_dataloader()
    if test_loader is None:
        raise RuntimeError(f'Stage {stage_num} datamodule does not define a test_dataloader().')

    original_device = next(model.parameters()).device
    was_training = model.training
    model = model.to(device)
    model.eval()

    aggregated: dict[str, dict[str, list[torch.Tensor]]] = {}
    prediction_rows: list[dict[str, float | str | int]] = []
    per_task_counts: dict[str, int] = {}

    with torch.no_grad():
        for batch in test_loader:
            x, y_dict, mask_dict, t_sequences = batch
            x = x.to(device)
            preds = model(x, t_sequences)

            for name, pred_tensor in preds.items():
                if name not in y_dict:
                    continue

                target_tensor = y_dict[name]
                mask_tensor = mask_dict.get(name)

                if isinstance(target_tensor, list):
                    target_flat = torch.cat([t.detach().cpu().reshape(-1) for t in target_tensor])
                else:
                    target_flat = target_tensor.detach().cpu().reshape(-1)

                pred_flat = pred_tensor.detach().cpu().reshape(-1)

                if mask_tensor is not None:
                    if isinstance(mask_tensor, list):
                        mask_flat = torch.cat([m.detach().cpu().reshape(-1) for m in mask_tensor])
                    else:
                        mask_flat = mask_tensor.detach().cpu().reshape(-1)
                    mask_flat = mask_flat.bool()
                    target_flat = target_flat[mask_flat]
                    pred_flat = pred_flat[mask_flat]

                if target_flat.numel() == 0:
                    continue

                entry = aggregated.setdefault(name, {'preds': [], 'targets': []})
                entry['preds'].append(pred_flat)
                entry['targets'].append(target_flat)

                start_idx = per_task_counts.get(name, 0)
                for offset, (actual_val, pred_val) in enumerate(zip(target_flat.tolist(), pred_flat.tolist())):
                    prediction_rows.append(
                        {
                            'stage': stage_num,
                            'task': name,
                            'sample_index': start_idx + offset,
                            'actual': actual_val,
                            'predicted': pred_val,
                        }
                    )
                per_task_counts[name] = start_idx + target_flat.numel()

    if not aggregated:
        print(f'No test predictions available for Stage {stage_num}.')
        model.to(original_device)
        if was_training:
            model.train()
        return

    ordered_items = [(name, aggregated[name]) for name in stage_tasks if name in aggregated]

    metrics: dict[str, dict[str, float | int | None]] = {}
    num_tasks = len(ordered_items)
    cols = 2 if num_tasks > 1 else 1
    rows = math.ceil(num_tasks / cols)
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 4.5, rows * 4.5))
    if hasattr(axes, 'flat'):
        axes_list = list(axes.flat)
    else:
        axes_list = [axes]

    for ax, (name, data) in zip(axes_list, ordered_items):
        preds = torch.cat(data['preds'])
        targets = torch.cat(data['targets'])
        diff = preds - targets
        mae = torch.mean(torch.abs(diff)).item()
        mse = torch.mean(diff ** 2).item()
        rmse = torch.sqrt(torch.mean(diff ** 2)).item()
        ss_tot = torch.sum((targets - targets.mean()) ** 2).item()
        ss_res = torch.sum(diff ** 2).item()
        r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else None

        metrics[name] = {
            'samples': int(targets.numel()),
            'mae': mae,
            'mse': mse,
            'rmse': rmse,
            'r2': r2,
        }

        preds_np = preds.numpy()
        targets_np = targets.numpy()
        lo = float(min(preds_np.min(), targets_np.min()))
        hi = float(max(preds_np.max(), targets_np.max()))
        buffer = 0.05 * (hi - lo) if hi > lo else 0.1
        lo -= buffer
        hi += buffer

        ax.scatter(targets_np, preds_np, s=12, alpha=0.6, edgecolors='none')
        ax.plot([lo, hi], [lo, hi], '--', color='tab:red', linewidth=1)
        if r2 is not None:
            annotation = rf"MAE: {mae:.3f} $R^2$: {r2:.3f}"
        else:
            annotation = f"MAE: {mae:.3f}"
        ax.text(
            0.05,
            0.95,
            annotation,
            transform=ax.transAxes,
            fontsize=10,
            verticalalignment='top',
            bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.6),
        )
        ax.set_xlim(lo, hi)
        ax.set_ylim(lo, hi)
        ax.set_xlabel('Actual')
        ax.set_ylabel('Predicted')
        ax.set_title(f'Stage {stage_num}: {name}')
        ax.grid(alpha=0.2)
        ax.set_aspect('equal', adjustable='box')

    for ax in axes_list[len(ordered_items):]:
        ax.axis('off')

    fig.tight_layout()
    fig.savefig(prediction_dir / f'Stage{stage_num}_overview.png', dpi=180)
    plt.close(fig)

    for name, data in ordered_items:
        preds = torch.cat(data['preds'])
        targets = torch.cat(data['targets'])
        preds_np = preds.numpy()
        targets_np = targets.numpy()
        lo = float(min(preds_np.min(), targets_np.min()))
        hi = float(max(preds_np.max(), targets_np.max()))
        buffer = 0.05 * (hi - lo) if hi > lo else 0.1
        lo -= buffer
        hi += buffer

        fig_single, ax_single = plt.subplots(figsize=(5, 5))
        ax_single.scatter(targets_np, preds_np, s=12, alpha=0.6, edgecolors='none')
        ax_single.plot([lo, hi], [lo, hi], '--', color='tab:red', linewidth=1)
        ax_single.set_xlim(lo, hi)
        ax_single.set_ylim(lo, hi)
        ax_single.set_xlabel('Actual')
        ax_single.set_ylabel('Predicted')
        ax_single.set_title(f'Stage {stage_num}: {name}')
        ax_single.grid(alpha=0.2)
        ax_single.set_aspect('equal', adjustable='box')
        fig_single.tight_layout()
        safe_name = re.sub(r'[^a-z0-9]+', '_', name.lower()).strip('_') or 'task'
        fig_single.savefig(prediction_dir / f'{safe_name}_pred.png', dpi=180)
        plt.close(fig_single)

    metrics_payload = {
        'stage': stage_num,
        'task_sequence': list(stage_tasks),
        'metrics': metrics,
    }

    if prediction_rows:
        pd.DataFrame(prediction_rows).to_parquet(predictions_path, index=False)
        print(f'Saved predictions to {predictions_path}')

    with open(metrics_path, 'w', encoding='utf-8') as f:
        json.dump(metrics_payload, f, indent=2)
    print(f'Saved metrics to {metrics_path}')

    model.to(original_device)
    if was_training:
        model.train()


In [6]:
torch.serialization.add_safe_globals([RegressionTaskConfig, TaskType, OptimizerConfig])

## Incremental Training


In [None]:
stage_records: list[dict] = []
previous_checkpoint: str | None = None

for stage_idx, task_name in enumerate(TASK_SEQUENCE, start=1):
    stage_task_names = TASK_SEQUENCE[:stage_idx]
    stage_label = f'Stage {stage_idx}'
    print(f"=== {stage_label} ({task_name}) ===")

    task_configs = make_task_configs(stage_task_names)
    datamodule = build_datamodule(task_configs)

    if previous_checkpoint is None:
        model = FlexibleMultiTaskModel(
            shared_block_dims=SHARED_BLOCK_DIMS,
            task_configs=task_configs,
            enable_learnable_loss_balancer=True,
            shared_block_optimizer=OptimizerConfig(lr=5e-2),
        )
    else:
        model = FlexibleMultiTaskModel.load_from_checkpoint(
            checkpoint_path=previous_checkpoint,
            strict=False,
            enable_learnable_loss_balancer=True,
        )
        existing_tasks = set(model.task_heads.keys())
        new_task_configs = [cfg for cfg in task_configs if cfg.name not in existing_tasks]
        if new_task_configs:
            model.add_task(*new_task_configs)

    stage_root = ARTIFACT_ROOT / f'Stage{stage_idx}_{task_name}'
    stage_root.mkdir(parents=True, exist_ok=True)

    checkpoint_cb = ModelCheckpoint(
        dirpath=stage_root / 'checkpoints',
        filename=f"{task_name}-{{epoch:02d}}-{{val_final_loss:.4f}}",
        monitor='val_final_loss',
        mode='min',
        save_top_k=1,
    )
    early_stopping = EarlyStopping(monitor='val_final_loss', mode='min', patience=10)
    csv_logger = CSVLogger(save_dir=stage_root / 'logs', name='csv')
    tensorboard_logger = TensorBoardLogger(save_dir=stage_root / 'logs', name='tensorboard')

    trainer = Trainer(
        max_epochs=MAX_EPOCHS,
        accelerator='auto',
        devices='auto',
        callbacks=[checkpoint_cb, early_stopping],
        logger=[csv_logger, tensorboard_logger],
        log_every_n_steps=LOG_EVERY_N_STEPS,
    )

    trainer.fit(model, datamodule=datamodule)
    best_model_path = checkpoint_cb.best_model_path
    print(f'Best checkpoint: {best_model_path}')

    if best_model_path:
        state = torch.load(best_model_path, map_location='cpu', weights_only=True)
        state_dict = state.get('state_dict', state)
        model.load_state_dict(state_dict)

    prediction_dir = stage_root / 'prediction'
    plot_test_predictions(
        model,
        datamodule,
        stage_num=stage_idx,
        stage_tasks=stage_task_names,
        new_task_name=task_name,
        prediction_dir=prediction_dir,
    )

    stage_records.append({
        'stage': stage_idx,
        'label': stage_label,
        'task_names': stage_task_names,
        'new_task_name': task_name,
        'checkpoint': best_model_path,
        'prediction_dir': prediction_dir,
        'datamodule': datamodule,
        'model': model,
    })

    previous_checkpoint = best_model_path


[32m2025-10-31 08:38:30.192[0m | [1mINFO    [0m | [36mdatamodule[0m:[36m__init__[0m:[36m160[0m - [1mInitializing CompoundDataModule...[0m
[32m2025-10-31 08:38:30.192[0m | [1mINFO    [0m | [36mdatamodule[0m:[36m__init__[0m:[36m192[0m - [1m--- Loading Data ---[0m
[32m2025-10-31 08:38:30.193[0m | [1mINFO    [0m | [36mdatamodule[0m:[36m_load_data[0m:[36m432[0m - [1mUsing provided pd.DataFrame for 'formula_desc' data.[0m
[32m2025-10-31 08:38:30.205[0m | [1mINFO    [0m | [36mdatamodule[0m:[36m_load_data[0m:[36m439[0m - [1mSuccessfully loaded 'formula_desc'. Shape: (71725, 190)[0m
[32m2025-10-31 08:38:30.205[0m | [1mINFO    [0m | [36mdatamodule[0m:[36m__init__[0m:[36m197[0m - [1mInitial loaded formula_df length: 71725[0m
[32m2025-10-31 08:38:30.220[0m | [1mINFO    [0m | [36mdatamodule[0m:[36m__init__[0m:[36m204[0m - [1mFormula_df length after initial dropna: 71725. This index is now the master reference.[0m
[32m2025-10-3

=== Stage 1 (density) ===


[32m2025-10-31 08:38:30.433[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m334[0m - [1m[train_dataset] CompoundDataset initialization complete. Processed 1 enabled tasks.[0m
[32m2025-10-31 08:38:30.436[0m | [1mINFO    [0m | [36mdatamodule[0m:[36m_build_fit_datasets[0m:[36m483[0m - [1mCreating val_dataset with 7173 samples.[0m
[32m2025-10-31 08:38:30.441[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m99[0m - [1m[val_dataset] Initializing CompoundDataset...[0m
[32m2025-10-31 08:38:30.442[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m133[0m - [1m[val_dataset] Final x_formula shape: torch.Size([7173, 190])[0m
[32m2025-10-31 08:38:30.442[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m148[0m - [1m[val_dataset] Processing enabled task 'density' (type: REGRESSION)[0m
[32m2025-10-31 08:38:30.454[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m334[0m - [1m[val_dataset] Compoun

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/Users/liuchang/projects/foundation_model/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


                                                                           

/Users/liuchang/projects/foundation_model/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 225/225 [00:01<00:00, 118.15it/s, v_num=0, train_final_loss_step=-2.26, val_final_loss=-2.51, train_final_loss_epoch=-2.47] 

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 225/225 [00:01<00:00, 117.03it/s, v_num=0, train_final_loss_step=-2.26, val_final_loss=-2.51, train_final_loss_epoch=-2.47]
Best checkpoint: /Users/liuchang/projects/foundation_model/artifacts/polymers_incremental_tasks/Stage1_density/checkpoints/density-epoch=19-val_final_loss=-2.5144.ckpt


[32m2025-10-31 08:39:09.907[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m551[0m - [1m--- Setting up DataModule for stage: test ---[0m
[32m2025-10-31 08:39:09.908[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m561[0m - [1mTotal samples available before splitting (from attributes_df index): 71725[0m
[32m2025-10-31 08:39:09.908[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m591[0m - [1mData split strategy: Performing random train/val/test splits based on full_idx (derived from attributes_df).[0m
[32m2025-10-31 08:39:09.908[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m594[0m - [1mTest split ratio: 0.1, Validation split ratio (of non-test): 0.1[0m
[32m2025-10-31 08:39:09.911[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m606[0m - [1mSplit full data (71725) into train_val (64552) and test (7173) using seed 24.[0m
[32m2025-10-31 08:39:09.912[0m | [1mINFO    [0m | [36mdatamodu

Saved predictions to ../artifacts/polymers_incremental_tasks/Stage1_density/prediction/predictions.parquet
Saved metrics to ../artifacts/polymers_incremental_tasks/Stage1_density/prediction/metrics.json
=== Stage 2 (Cp) ===


[32m2025-10-31 08:39:10.430[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m148[0m - [1m[train_dataset] Processing enabled task 'Cp' (type: REGRESSION)[0m
[32m2025-10-31 08:39:10.526[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m334[0m - [1m[train_dataset] CompoundDataset initialization complete. Processed 2 enabled tasks.[0m
[32m2025-10-31 08:39:10.528[0m | [1mINFO    [0m | [36mdatamodule[0m:[36m_build_fit_datasets[0m:[36m483[0m - [1mCreating val_dataset with 7173 samples.[0m
[32m2025-10-31 08:39:10.533[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m99[0m - [1m[val_dataset] Initializing CompoundDataset...[0m
[32m2025-10-31 08:39:10.534[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m133[0m - [1m[val_dataset] Final x_formula shape: torch.Size([7173, 190])[0m
[32m2025-10-31 08:39:10.534[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m148[0m - [1m[val_dataset] Processing

                                                                           

/Users/liuchang/projects/foundation_model/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/Users/liuchang/projects/foundation_model/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 225/225 [00:03<00:00, 74.47it/s, v_num=0, train_final_loss_step=3.62e+5, val_final_loss=3.89e+5, train_final_loss_epoch=3.94e+5]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 225/225 [00:03<00:00, 73.88it/s, v_num=0, train_final_loss_step=3.62e+5, val_final_loss=3.89e+5, train_final_loss_epoch=3.94e+5]
Best checkpoint: /Users/liuchang/projects/foundation_model/artifacts/polymers_incremental_tasks/Stage2_Cp/checkpoints/Cp-epoch=19-val_final_loss=388704.4688.ckpt


[32m2025-10-31 08:40:10.598[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m551[0m - [1m--- Setting up DataModule for stage: test ---[0m
[32m2025-10-31 08:40:10.599[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m561[0m - [1mTotal samples available before splitting (from attributes_df index): 71725[0m
[32m2025-10-31 08:40:10.599[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m591[0m - [1mData split strategy: Performing random train/val/test splits based on full_idx (derived from attributes_df).[0m
[32m2025-10-31 08:40:10.599[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m594[0m - [1mTest split ratio: 0.1, Validation split ratio (of non-test): 0.1[0m
[32m2025-10-31 08:40:10.601[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m606[0m - [1mSplit full data (71725) into train_val (64552) and test (7173) using seed 24.[0m
[32m2025-10-31 08:40:10.602[0m | [1mINFO    [0m | [36mdatamodu

Saved predictions to ../artifacts/polymers_incremental_tasks/Stage2_Cp/prediction/predictions.parquet
Saved metrics to ../artifacts/polymers_incremental_tasks/Stage2_Cp/prediction/metrics.json
=== Stage 3 (Rg) ===


[32m2025-10-31 08:40:11.157[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m148[0m - [1m[train_dataset] Processing enabled task 'Cp' (type: REGRESSION)[0m
[32m2025-10-31 08:40:11.251[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m148[0m - [1m[train_dataset] Processing enabled task 'Rg' (type: REGRESSION)[0m
[32m2025-10-31 08:40:11.344[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m334[0m - [1m[train_dataset] CompoundDataset initialization complete. Processed 3 enabled tasks.[0m
[32m2025-10-31 08:40:11.347[0m | [1mINFO    [0m | [36mdatamodule[0m:[36m_build_fit_datasets[0m:[36m483[0m - [1mCreating val_dataset with 7173 samples.[0m
[32m2025-10-31 08:40:11.352[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m99[0m - [1m[val_dataset] Initializing CompoundDataset...[0m
[32m2025-10-31 08:40:11.353[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m133[0m - [1m[val_dataset] Final x

                                                                           

/Users/liuchang/projects/foundation_model/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/Users/liuchang/projects/foundation_model/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 225/225 [00:03<00:00, 58.47it/s, v_num=0, train_final_loss_step=1.52e+4, val_final_loss=1.62e+4, train_final_loss_epoch=1.65e+4]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 225/225 [00:03<00:00, 58.01it/s, v_num=0, train_final_loss_step=1.52e+4, val_final_loss=1.62e+4, train_final_loss_epoch=1.65e+4]


[32m2025-10-31 08:41:29.055[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m551[0m - [1m--- Setting up DataModule for stage: test ---[0m
[32m2025-10-31 08:41:29.055[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m561[0m - [1mTotal samples available before splitting (from attributes_df index): 71725[0m
[32m2025-10-31 08:41:29.055[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m591[0m - [1mData split strategy: Performing random train/val/test splits based on full_idx (derived from attributes_df).[0m
[32m2025-10-31 08:41:29.055[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m594[0m - [1mTest split ratio: 0.1, Validation split ratio (of non-test): 0.1[0m
[32m2025-10-31 08:41:29.057[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m606[0m - [1mSplit full data (71725) into train_val (64552) and test (7173) using seed 24.[0m


Best checkpoint: /Users/liuchang/projects/foundation_model/artifacts/polymers_incremental_tasks/Stage3_Rg/checkpoints/Rg-epoch=19-val_final_loss=16172.6553.ckpt


[32m2025-10-31 08:41:29.058[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m632[0m - [1mSplit train_val (64552) into train (57379) and val (7173) using seed 42, effective_val_split 0.111.[0m
[32m2025-10-31 08:41:29.058[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m658[0m - [1mFinal dataset sizes after splitting: Train=57379, Validation=7173, Test=7173[0m
[32m2025-10-31 08:41:29.059[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m666[0m - [1m--- Creating 'test' stage dataset ---[0m
[32m2025-10-31 08:41:29.059[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m668[0m - [1mCreating test_dataset with 7173 samples.[0m
[32m2025-10-31 08:41:29.070[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m99[0m - [1m[test_dataset] Initializing CompoundDataset...[0m
[32m2025-10-31 08:41:29.071[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m133[0m - [1m[test_dataset] Final x_formula

Saved predictions to ../artifacts/polymers_incremental_tasks/Stage3_Rg/prediction/predictions.parquet
Saved metrics to ../artifacts/polymers_incremental_tasks/Stage3_Rg/prediction/metrics.json
=== Stage 4 (linear_expansion) ===


[32m2025-10-31 08:41:29.782[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m148[0m - [1m[train_dataset] Processing enabled task 'Cp' (type: REGRESSION)[0m
[32m2025-10-31 08:41:29.876[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m148[0m - [1m[train_dataset] Processing enabled task 'Rg' (type: REGRESSION)[0m
[32m2025-10-31 08:41:29.969[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m148[0m - [1m[train_dataset] Processing enabled task 'linear_expansion' (type: REGRESSION)[0m
[32m2025-10-31 08:41:30.061[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m334[0m - [1m[train_dataset] CompoundDataset initialization complete. Processed 4 enabled tasks.[0m
[32m2025-10-31 08:41:30.064[0m | [1mINFO    [0m | [36mdatamodule[0m:[36m_build_fit_datasets[0m:[36m483[0m - [1mCreating val_dataset with 7173 samples.[0m
[32m2025-10-31 08:41:30.069[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m99

Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00, 32.11it/s]

/Users/liuchang/projects/foundation_model/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


                                                                           

/Users/liuchang/projects/foundation_model/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 225/225 [00:04<00:00, 49.17it/s, v_num=0, train_final_loss_step=1112.0, val_final_loss=995.0, train_final_loss_epoch=1.01e+3]   

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 225/225 [00:04<00:00, 48.79it/s, v_num=0, train_final_loss_step=1112.0, val_final_loss=995.0, train_final_loss_epoch=1.01e+3]


[32m2025-10-31 08:43:03.872[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m551[0m - [1m--- Setting up DataModule for stage: test ---[0m
[32m2025-10-31 08:43:03.872[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m561[0m - [1mTotal samples available before splitting (from attributes_df index): 71725[0m
[32m2025-10-31 08:43:03.873[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m591[0m - [1mData split strategy: Performing random train/val/test splits based on full_idx (derived from attributes_df).[0m
[32m2025-10-31 08:43:03.873[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m594[0m - [1mTest split ratio: 0.1, Validation split ratio (of non-test): 0.1[0m
[32m2025-10-31 08:43:03.874[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m606[0m - [1mSplit full data (71725) into train_val (64552) and test (7173) using seed 24.[0m


Best checkpoint: /Users/liuchang/projects/foundation_model/artifacts/polymers_incremental_tasks/Stage4_linear_expansion/checkpoints/linear_expansion-epoch=19-val_final_loss=994.6105.ckpt


[32m2025-10-31 08:43:03.876[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m632[0m - [1mSplit train_val (64552) into train (57379) and val (7173) using seed 42, effective_val_split 0.111.[0m
[32m2025-10-31 08:43:03.876[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m658[0m - [1mFinal dataset sizes after splitting: Train=57379, Validation=7173, Test=7173[0m
[32m2025-10-31 08:43:03.876[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m666[0m - [1m--- Creating 'test' stage dataset ---[0m
[32m2025-10-31 08:43:03.876[0m | [1mINFO    [0m | [36mdatamodule[0m:[36msetup[0m:[36m668[0m - [1mCreating test_dataset with 7173 samples.[0m
[32m2025-10-31 08:43:03.882[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m99[0m - [1m[test_dataset] Initializing CompoundDataset...[0m
[32m2025-10-31 08:43:03.883[0m | [1mINFO    [0m | [36mdataset[0m:[36m__init__[0m:[36m133[0m - [1m[test_dataset] Final x_formula

Saved predictions to ../artifacts/polymers_incremental_tasks/Stage4_linear_expansion/prediction/predictions.parquet
Saved metrics to ../artifacts/polymers_incremental_tasks/Stage4_linear_expansion/prediction/metrics.json
