# Pretrain & PI Finetuning Suite

This notebook orchestrates 20 randomized incremental pretrain runs on non-PI polymers followed by PI-property finetuning with a frozen shared encoder. It combines the continual-task recipes from `dynamic_task_finetuning_demo.ipynb` and `dynamic_task_incremental_finetuning.ipynb`.


## Data Overview

- **Descriptors**: `data/amorphous_polymer_FFDescriptor_20250730.parquet`
- **Non-PI properties**: `data/amorphous_polymer_non_PI_properties_20250730.parquet`
- **PI properties**: `data/amorphous_polymer_PI_properties_20250730.parquet`
- Pretrain tasks: 15 properties (density through thermal_diffusivity) sampled in random order per run
- PI finetune tasks: density, Rg, r2, self-diffusion, Cp, Cv, linear_expansion, refractive_index, tg


In [1]:
import json
import math
import random
import re
from pathlib import Path
from typing import Any

import joblib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from loguru import logger as fm_logger

from foundation_model.data.datamodule import CompoundDataModule
from foundation_model.models.flexible_multi_task_model import FlexibleMultiTaskModel
from foundation_model.models.model_config import OptimizerConfig, RegressionTaskConfig, TaskType


[32m2025-10-31 11:53:56.442[0m | [1mINFO    [0m | [36m__init__[0m:[36m<module>[0m:[36m34[0m - [1mLoguru logger initialized for foundation_model package.[0m


In [2]:
DATA_DIR = Path("../data")
DESCRIPTOR_PATH = DATA_DIR / "amorphous_polymer_FFDescriptor_20250730.parquet"
NON_PI_PATH = DATA_DIR / "amorphous_polymer_non_PI_properties_20250730.parquet"
PI_PATH = DATA_DIR / "amorphous_polymer_PI_properties_20250730.parquet"
SCALER_PATH = DATA_DIR / "amorphous_polymer_properties_scaler_20250730.pkl.z"

USE_NORMALIZED_TARGETS = True  # If True, use normalized targets
FINETUNE_FREEZE_SHARED = True  # If True, freeze shared encoder during finetuning
QUIET_MODEL_LOGGING = True  # If True, reduce model logging output
ENABLE_LEARNABLE_LOSS = False  # If True, enable learnable loss balancing
KEEP_NORMALIZED_TARGETS = False  # If True, keep using normalized targets

PRETRAIN_TASK_NAMES = [
    "density",
    "Rg",
    "r2",
    "self-diffusion",
    "Cp",
    "Cv",
    "bulk_modulus",
    "volume_expansion",
    "linear_expansion",
    "static_dielectric_const",
    "dielectric_const_dc",
    "refractive_index",
    "tg",
    "thermal_conductivity",
    "thermal_diffusivity",
]
FINETUNE_TASK_NAMES = [
    "density",
    "Rg",
    "r2",
    "self-diffusion",
    "Cp",
    "Cv",
    "linear_expansion",
    "refractive_index",
    "tg",
]

LOWER_CASE_PROPERTIES = sorted(set(PRETRAIN_TASK_NAMES) | set(FINETUNE_TASK_NAMES))

def target_column(property_name: str) -> str:
    return f"{property_name}{' (normalized)' if USE_NORMALIZED_TARGETS else ''}"

TARGET_COLUMNS = {name: target_column(name) for name in LOWER_CASE_PROPERTIES}
PRETRAIN_TARGET_COLUMNS = {name: TARGET_COLUMNS[name] for name in PRETRAIN_TASK_NAMES}
FINETUNE_TARGET_COLUMNS = {name: TARGET_COLUMNS[name] for name in FINETUNE_TASK_NAMES}

SHARED_BLOCK_DIMS = [190, 256, 128]
HEAD_HIDDEN = 64
ARTIFACT_ROOT = Path("../artifacts/polymers_pretrain_finetune_runs")
ARTIFACT_ROOT.mkdir(parents=True, exist_ok=True)

NUM_PRETRAIN_RUNS = 10  # Number of pretraining runs with different random seeds
PRETRAIN_MAX_EPOCHS = 200  # Max epochs for pretraining
FINETUNE_MAX_EPOCHS = 200  # Max epochs for finetuning
BATCH_SIZE = 256
NUM_WORKERS = 0
LOG_EVERY_N_STEPS = 20
DATAMODULE_RANDOM_SEED = 42
TASK_MASKING_RATIOS: float | dict[str, float] | None = None
SWAP_TRAIN_VAL_SPLIT = 0.0
VAL_SPLIT = 0.1
TEST_SPLIT = 0.1
TEST_ALL = False
RANDOM_SEED_BASE = 1729

PRETRAIN_SAMPLE = None  # Set to an int for smoke tests
PI_SAMPLE = None  # Set to an int for smoke tests
# PRETRAIN_SAMPLE = 2000  # Set to an int for smoke tests
# PI_SAMPLE = 1000  # Set to an int for smoke tests

PROPERTY_SCALERS: dict[str, Any] = {}

if QUIET_MODEL_LOGGING:
    fm_logger.disable("foundation_model")
else:
    fm_logger.enable("foundation_model")


In [3]:
descriptor_df = pd.read_parquet(DESCRIPTOR_PATH)
non_pi_df = pd.read_parquet(NON_PI_PATH)
pi_df = pd.read_parquet(PI_PATH)

if USE_NORMALIZED_TARGETS:
    if not SCALER_PATH.exists():
        raise FileNotFoundError(f"Missing scaler file: {SCALER_PATH}")
    PROPERTY_SCALERS = joblib.load(SCALER_PATH)
    missing_scalers = [name for name in LOWER_CASE_PROPERTIES if name not in PROPERTY_SCALERS]
    if missing_scalers:
        raise KeyError(f"Scaler missing entries for: {missing_scalers}")
else:
    PROPERTY_SCALERS = {}

missing_pretrain = [PRETRAIN_TARGET_COLUMNS[name] for name in PRETRAIN_TASK_NAMES if PRETRAIN_TARGET_COLUMNS[name] not in non_pi_df.columns]
if missing_pretrain:
    raise KeyError(f"Non-PI table missing columns: {missing_pretrain}")

missing_finetune = [name for name in FINETUNE_TASK_NAMES if FINETUNE_TARGET_COLUMNS[name] not in pi_df.columns]
if missing_finetune:
    print(f"Warning: PI table missing columns for tasks: {missing_finetune}. They will be skipped.")
available_finetune_tasks = [name for name in FINETUNE_TASK_NAMES if name not in missing_finetune]
if not available_finetune_tasks:
    raise ValueError("No PI finetune tasks remain after filtering missing columns.")
original_finetune_columns = FINETUNE_TARGET_COLUMNS
FINETUNE_TASK_NAMES = available_finetune_tasks
FINETUNE_TARGET_COLUMNS = {name: original_finetune_columns[name] for name in FINETUNE_TASK_NAMES}

common_non_pi_index = descriptor_df.index.intersection(non_pi_df.index)
pretrain_features = descriptor_df.loc[common_non_pi_index]
pretrain_targets = non_pi_df.loc[common_non_pi_index, [PRETRAIN_TARGET_COLUMNS[name] for name in PRETRAIN_TASK_NAMES]]

if PRETRAIN_SAMPLE is not None and PRETRAIN_SAMPLE < len(pretrain_features):
    pretrain_features = pretrain_features.sample(n=PRETRAIN_SAMPLE, random_state=42)
    pretrain_targets = pretrain_targets.loc[pretrain_features.index]

common_pi_index = descriptor_df.index.intersection(pi_df.index)
pi_features = descriptor_df.loc[common_pi_index]
pi_targets = pi_df.loc[common_pi_index, [FINETUNE_TARGET_COLUMNS[name] for name in FINETUNE_TASK_NAMES]]

if PI_SAMPLE is not None and PI_SAMPLE < len(pi_features):
    pi_features = pi_features.sample(n=PI_SAMPLE, random_state=13)
    pi_targets = pi_targets.loc[pi_features.index]

print(f"Pretrain feature matrix: {pretrain_features.shape}")
print(f"Pretrain target matrix: {pretrain_targets.shape}")
print(f"PI feature matrix: {pi_features.shape}")
print(f"PI target matrix: {pi_targets.shape}")


Pretrain feature matrix: (71725, 190)
Pretrain target matrix: (71725, 15)
PI feature matrix: (1083, 190)
PI target matrix: (1083, 9)


## Helper Utilities


In [4]:
def build_pretrain_datamodule(
    task_names: list[str],
    *,
    batch_size: int = BATCH_SIZE,
    random_seed: int | None = DATAMODULE_RANDOM_SEED,
) -> CompoundDataModule:
    stage_targets = pretrain_targets.loc[:, [PRETRAIN_TARGET_COLUMNS[name] for name in task_names]]
    return CompoundDataModule(
        formula_desc_source=pretrain_features,
        attributes_source=stage_targets,
        task_configs=make_pretrain_task_configs(task_names),
        task_masking_ratios=TASK_MASKING_RATIOS,
        random_seed=random_seed,
        val_split=VAL_SPLIT,
        test_split=TEST_SPLIT,
        test_all=TEST_ALL,
        swap_train_val_split=SWAP_TRAIN_VAL_SPLIT,
        batch_size=batch_size,
        num_workers=NUM_WORKERS,
    )

def build_finetune_datamodule(
    task_name: str,
    *,
    batch_size: int = BATCH_SIZE,
    random_seed: int | None = DATAMODULE_RANDOM_SEED,
) -> CompoundDataModule:
    target_frame = pi_targets.loc[:, [FINETUNE_TARGET_COLUMNS[task_name]]]
    task_config = make_finetune_task_config(task_name)
    return CompoundDataModule(
        formula_desc_source=pi_features,
        attributes_source=target_frame,
        task_configs=[task_config],
        task_masking_ratios=TASK_MASKING_RATIOS,
        random_seed=random_seed,
        val_split=VAL_SPLIT,
        test_split=TEST_SPLIT,
        test_all=TEST_ALL,
        swap_train_val_split=SWAP_TRAIN_VAL_SPLIT,
        batch_size=batch_size,
        num_workers=NUM_WORKERS,
    )


In [5]:
torch.serialization.add_safe_globals([RegressionTaskConfig, TaskType, OptimizerConfig])


## Pretrain & Finetune Workflow


In [None]:
experiment_records: list[dict] = []

for run_idx in range(1, NUM_PRETRAIN_RUNS + 1):
    rng = random.Random(RANDOM_SEED_BASE + run_idx)
    task_sequence = rng.sample(PRETRAIN_TASK_NAMES, k=len(PRETRAIN_TASK_NAMES))
    run_label = f"run{run_idx:02d}"
    print(f"""
====================
Starting {run_label}
Task order: {task_sequence}
===================="""
)

    run_root = ARTIFACT_ROOT / run_label
    run_root.mkdir(parents=True, exist_ok=True)

    previous_checkpoint: str | None = None
    pretrain_stage_records: list[dict] = []
    finetune_records: list[dict] = []

    for stage_idx, task_name in enumerate(task_sequence, start=1):
        stage_tasks = task_sequence[:stage_idx]
        datamodule = build_pretrain_datamodule(stage_tasks, random_seed=RANDOM_SEED_BASE + run_idx)
        task_configs = make_pretrain_task_configs(stage_tasks)

        if previous_checkpoint is None:
            model = FlexibleMultiTaskModel(
                shared_block_dims=SHARED_BLOCK_DIMS,
                task_configs=task_configs,
                enable_learnable_loss_balancer=ENABLE_LEARNABLE_LOSS,
                shared_block_optimizer=OptimizerConfig(lr=5e-2),
            )
        else:
            model = FlexibleMultiTaskModel.load_from_checkpoint(
                checkpoint_path=previous_checkpoint,
                strict=False,
                enable_learnable_loss_balancer=ENABLE_LEARNABLE_LOSS,
            )
            existing = set(model.task_heads.keys())
            new_configs = [cfg for cfg in task_configs if cfg.name not in existing]
            if new_configs:
                model.add_task(*new_configs)

        stage_dir = run_root / f"pretrain_stage{stage_idx:02d}_{safe_slug(task_name)}"
        stage_dir.mkdir(parents=True, exist_ok=True)

        checkpoint_cb = ModelCheckpoint(
            dirpath=stage_dir / "checkpoints",
            filename=f"{safe_slug(task_name)}-{{epoch:02d}}-{{val_final_loss:.4f}}",
            monitor="val_final_loss",
            mode="min",
            save_top_k=1,
        )
        early_stopping = EarlyStopping(monitor="val_final_loss", mode="min", patience=10)
        csv_logger = CSVLogger(save_dir=stage_dir / "logs", name="csv")
        tensorboard_logger = TensorBoardLogger(save_dir=stage_dir / "logs", name="tensorboard")

        trainer = Trainer(
            max_epochs=PRETRAIN_MAX_EPOCHS,
            accelerator="auto",
            devices="auto",
            callbacks=[checkpoint_cb, early_stopping],
            logger=[csv_logger, tensorboard_logger],
            log_every_n_steps=LOG_EVERY_N_STEPS,
        )

        trainer.fit(model, datamodule=datamodule)
        best_model_path = checkpoint_cb.best_model_path
        print(f"Run {run_label} stage {stage_idx}: best checkpoint -> {best_model_path}")

        if best_model_path:
            state = torch.load(best_model_path, map_location="cpu", weights_only=True)
            state_dict = state.get("state_dict", state)
            model.load_state_dict(state_dict)
            previous_checkpoint = best_model_path
        else:
            print("Warning: no best checkpoint captured; using current weights.")

        prediction_dir = stage_dir / "prediction"
        plot_test_predictions(
            model=model,
            datamodule=datamodule,
            phase="pretrain",
            run_id=run_idx,
            stage_num=stage_idx,
            stage_tasks=stage_tasks,
            new_task_name=task_name,
            output_dir=prediction_dir,
        )

        stage_record = {
            "stage": stage_idx,
            "task_name": task_name,
            "task_sequence": list(stage_tasks),
            "checkpoint": previous_checkpoint,
            "stage_dir": stage_dir,
        }

        stage_finetune_records: list[dict] = []
        if previous_checkpoint is None:
            print(
                "Warning: skipping finetune because no checkpoint is available for stage",
                stage_idx,
            )
        else:
            for finetune_name in FINETUNE_TASK_NAMES:
                finetune_model = FlexibleMultiTaskModel.load_from_checkpoint(
                    checkpoint_path=previous_checkpoint,
                    strict=False,
                    enable_learnable_loss_balancer=ENABLE_LEARNABLE_LOSS,
                    freeze_shared_encoder=FINETUNE_FREEZE_SHARED,
                    shared_block_optimizer=OptimizerConfig(lr=5e-2),
                )
                active_tasks = list(finetune_model.task_heads.keys())
                if active_tasks:
                    finetune_model.remove_tasks(*active_tasks)

                task_config = make_finetune_task_config(finetune_name)
                finetune_model.add_task(task_config)

                datamodule = build_finetune_datamodule(finetune_name, random_seed=RANDOM_SEED_BASE + run_idx)

                finetune_root = stage_dir / "finetune"
                finetune_stage_dir = finetune_root / safe_slug(finetune_name)
                finetune_stage_dir.mkdir(parents=True, exist_ok=True)

                checkpoint_cb = ModelCheckpoint(
                    dirpath=finetune_stage_dir / "checkpoints",
                    filename=f"{safe_slug(finetune_name)}-{{epoch:02d}}-{{val_final_loss:.4f}}",
                    monitor="val_final_loss",
                    mode="min",
                    save_top_k=1,
                )
                early_stopping = EarlyStopping(
                    monitor="val_final_loss", mode="min", patience=10
                )
                csv_logger = CSVLogger(save_dir=finetune_stage_dir / "logs", name="csv")
                tensorboard_logger = TensorBoardLogger(
                    save_dir=finetune_stage_dir / "logs", name="tensorboard"
                )

                trainer = Trainer(
                    max_epochs=FINETUNE_MAX_EPOCHS,
                    accelerator="auto",
                    devices="auto",
                    callbacks=[checkpoint_cb, early_stopping],
                    logger=[csv_logger, tensorboard_logger],
                    log_every_n_steps=LOG_EVERY_N_STEPS,
                )

                trainer.fit(finetune_model, datamodule=datamodule)
                finetune_best_path = checkpoint_cb.best_model_path
                print(
                    f"Run {run_label} stage {stage_idx} finetune {finetune_name}: "
                    f"best checkpoint -> {finetune_best_path}"
                )

                if finetune_best_path:
                    state = torch.load(finetune_best_path, map_location="cpu", weights_only=True)
                    state_dict = state.get("state_dict", state)
                    finetune_model.load_state_dict(state_dict)
                else:
                    print("Warning: finetune stage missing checkpoint; using current weights.")

                prediction_dir = finetune_stage_dir / "prediction"
                plot_test_predictions(
                    model=finetune_model,
                    datamodule=datamodule,
                    phase="finetune",
                    run_id=run_idx,
                    stage_num=stage_idx,
                    stage_tasks=stage_tasks,
                    new_task_name=finetune_name,
                    output_dir=prediction_dir,
                )

                finetune_record = {
                    "stage": stage_idx,
                    "task_name": finetune_name,
                    "pretrain_task_sequence": list(stage_tasks),
                    "checkpoint": finetune_best_path,
                    "stage_dir": finetune_stage_dir,
                }
                stage_finetune_records.append(finetune_record)
                finetune_records.append(dict(finetune_record))

        stage_record["finetune"] = stage_finetune_records
        pretrain_stage_records.append(stage_record)

    if previous_checkpoint is None:
        raise RuntimeError(f"Run {run_label} produced no pretrain checkpoint; cannot finetune.")

    experiment_records.append(
        {
            "run": run_label,
            "task_sequence": task_sequence,
            "pretrain": pretrain_stage_records,
            "pretrain_checkpoint": previous_checkpoint,
            "finetune": finetune_records,
        }
    )

print("Completed all pretrain + finetune runs.")


## Run Summary


In [None]:
print(f"Recorded {len(experiment_records)} runs.")
for record in experiment_records:
    stage_finetune_counts = [len(stage.get("finetune", [])) for stage in record["pretrain"]]
    total_finetunes = sum(stage_finetune_counts)
    print(
        record["run"],
        "pretrain stages:",
        len(record["pretrain"]),
        "finetune stages:",
        total_finetunes,
    )
    for stage_record, count in zip(record["pretrain"], stage_finetune_counts):
        if count == 0:
            continue
        print(
            "  ",
            f"stage {stage_record['stage']:02d} ({stage_record['task_name']}):",
            f"{count} finetune runs",
        )


Recorded 2 runs.
run01 pretrain stages: 3 finetune stages: 3
run02 pretrain stages: 3 finetune stages: 3
