# Continual Task Finetuning Demo

This notebook demonstrates how to pre-train the flexible multi-task foundation model on non-PI polymers, then load the best checkpoint, freeze the shared encoder, and fine-tune newly added tasks for PI polymers using the dynamic task management utilities (`add_task` / `remove_tasks`).

## Data Overview

- **Descriptors**: `data/amorphous_polymer_FFDescriptor_20250730.parquet`
- **Non-PI properties**: `data/amorphous_polymer_non_PI_properties_20250730.parquet`
- **PI properties**: `data/amorphous_polymer_PI_properties_20250730.parquet`
- Target regression labels (normalized): density, Cp, Rg, linear_expansion

In [None]:
import os
from pathlib import Path

import pandas as pd
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger

from foundation_model.data.datamodule import CompoundDataModule
from foundation_model.models.flexible_multi_task_model import FlexibleMultiTaskModel
from foundation_model.models.model_config import RegressionTaskConfig

In [None]:
DATA_DIR = Path("data")
DESCRIPTOR_PATH = DATA_DIR / "amorphous_polymer_FFDescriptor_20250730.parquet"
NON_PI_PATH = DATA_DIR / "amorphous_polymer_non_PI_properties_20250730.parquet"
PI_PATH = DATA_DIR / "amorphous_polymer_PI_properties_20250730.parquet"

TARGET_COLUMNS = {
    "density": "density (normalized)",
    "Cp": "Cp (normalized)",
    "Rg": "Rg (normalized)",
    "linear_expansion": "linear_expansion (normalized)",
}

SHARED_BLOCK_DIMS = [190, 256, 128]
HEAD_HIDDEN = 64  # hidden width for regression heads
ARTIFACT_ROOT = Path("notebooks/artifacts/polymers_dynamic_tasks")
ARTIFACT_ROOT.mkdir(parents=True, exist_ok=True)

PRETRAIN_SAMPLE = 6000  # subset for quick demonstration
PI_SAMPLE = None  # use full PI set by default

In [None]:
descriptor_df = pd.read_parquet(DESCRIPTOR_PATH)
non_pi_df = pd.read_parquet(NON_PI_PATH)
pi_df = pd.read_parquet(PI_PATH)

common_non_pi = descriptor_df.index.intersection(non_pi_df.index)
pretrain_features = descriptor_df.loc[common_non_pi]
pretrain_targets = non_pi_df.loc[common_non_pi, list(TARGET_COLUMNS.values())]

if PRETRAIN_SAMPLE is not None and PRETRAIN_SAMPLE < len(pretrain_features):
    pretrain_features = pretrain_features.sample(n=PRETRAIN_SAMPLE, random_state=42)
    pretrain_targets = pretrain_targets.loc[pretrain_features.index]

common_pi = descriptor_df.index.intersection(pi_df.index)
pi_features = descriptor_df.loc[common_pi]
pi_targets = pi_df.loc[common_pi, list(TARGET_COLUMNS.values())]

if PI_SAMPLE is not None and PI_SAMPLE < len(pi_features):
    pi_features = pi_features.sample(n=PI_SAMPLE, random_state=13)
    pi_targets = pi_targets.loc[pi_features.index]

print(f"Pre-train feature tensor: {pretrain_features.shape}")
print(f"Pre-train targets: {pretrain_targets.shape}")
print(f"Fine-tune feature tensor: {pi_features.shape}")
print(f"Fine-tune targets: {pi_targets.shape}")

## Build Task Configurations

In [None]:
def build_regression_task(name: str, column: str) -> RegressionTaskConfig:
    return RegressionTaskConfig(
        name=name,
        data_column=column,
        dims=[SHARED_BLOCK_DIMS[-1], HEAD_HIDDEN, 1],
        norm=True,
        residual=False,
    )

pretrain_task_configs = [build_regression_task(name, col) for name, col in TARGET_COLUMNS.items()]
pi_task_configs = [build_regression_task(f"{name}_pi", col) for name, col in TARGET_COLUMNS.items()]

print("Pretrain tasks:", [cfg.name for cfg in pretrain_task_configs])
print("PI tasks:", [cfg.name for cfg in pi_task_configs])

## Stage 1 — Pre-train on non-PI polymers

In [None]:
pretrain_datamodule = CompoundDataModule(
    formula_desc_source=pretrain_features,
    attributes_source=pretrain_targets,
    task_configs=pretrain_task_configs,
    batch_size=256,
    num_workers=0,
    val_split=0.1,
    test_split=0.1,
)

pretrain_model = FlexibleMultiTaskModel(
    shared_block_dims=SHARED_BLOCK_DIMS,
    task_configs=pretrain_task_configs,
    enable_learnable_loss_balancer=True,
)

pretrain_checkpoint_dir = ARTIFACT_ROOT / "pretrain_checkpoints"
pretrain_checkpoint_dir.mkdir(parents=True, exist_ok=True)
pretrain_ckpt = ModelCheckpoint(
    dirpath=pretrain_checkpoint_dir,
    filename="pretrain-{epoch:02d}-{val_final_loss:.4f}",
    monitor="val_final_loss",
    mode="min",
    save_top_k=1,
)

pretrain_logger = CSVLogger(save_dir=ARTIFACT_ROOT / "logs", name="pretrain")

pretrain_trainer = Trainer(
    max_epochs=3,
    accelerator="cpu",
    devices=1,
    callbacks=[pretrain_ckpt],
    logger=pretrain_logger,
    log_every_n_steps=10,
    limit_train_batches=0.2,
    limit_val_batches=0.5,
)

pretrain_trainer.fit(pretrain_model, datamodule=pretrain_datamodule)
print(f"Best checkpoint: {pretrain_ckpt.best_model_path}")

## Stage 2 — Fine-tune newly added PI tasks

In [None]:
pi_datamodule = CompoundDataModule(
    formula_desc_source=pi_features,
    attributes_source=pi_targets,
    task_configs=pi_task_configs,
    batch_size=64,
    num_workers=0,
    val_split=0.2,
    test_split=0.0,
)

finetune_model = FlexibleMultiTaskModel(
    shared_block_dims=SHARED_BLOCK_DIMS,
    task_configs=pretrain_task_configs,
    enable_learnable_loss_balancer=True,
    strict_loading=False,
)

best_ckpt_path = pretrain_ckpt.best_model_path
if not best_ckpt_path:
    raise RuntimeError("Pre-training did not produce a checkpoint. Check earlier cells for errors.")

state = torch.load(best_ckpt_path, map_location="cpu")
finetune_model.load_state_dict(state["state_dict"], strict=False)

for param in finetune_model.encoder.parameters():
    param.requires_grad_(False)

finetune_model.remove_tasks(*TARGET_COLUMNS.keys())
for cfg in pi_task_configs:
    finetune_model.add_task(cfg)

print("Trainable task heads:", list(finetune_model.task_heads.keys()))

In [None]:
finetune_checkpoint_dir = ARTIFACT_ROOT / "finetune_checkpoints"
finetune_checkpoint_dir.mkdir(parents=True, exist_ok=True)
finetune_ckpt = ModelCheckpoint(
    dirpath=finetune_checkpoint_dir,
    filename="finetune-{epoch:02d}-{val_final_loss:.4f}",
    monitor="val_final_loss",
    mode="min",
    save_top_k=1,
)

finetune_logger = CSVLogger(save_dir=ARTIFACT_ROOT / "logs", name="finetune")

finetune_trainer = Trainer(
    max_epochs=5,
    accelerator="cpu",
    devices=1,
    callbacks=[finetune_ckpt],
    logger=finetune_logger,
    log_every_n_steps=5,
    limit_train_batches=1.0,
    limit_val_batches=1.0,
)

finetune_trainer.fit(finetune_model, datamodule=pi_datamodule)
print(f"Best fine-tuning checkpoint: {finetune_ckpt.best_model_path}")

## Inspect fine-tuned predictions

In [None]:
pi_datamodule.setup(stage="validate")
val_loader = pi_datamodule.val_dataloader()
example_batch = next(iter(val_loader))

with torch.no_grad():
    outputs = finetune_model(example_batch[0], example_batch[3])

for name, tensor in outputs.items():
    print(name, tensor[:5].squeeze().cpu().numpy())