# Interactive Model Test and Debugging

This notebook is designed to help test and debug the `FlexibleMultiTaskModel` and `CompoundDataModule` step by step. It loads configurations, initializes the data module and model, fetches a batch of data, and manually walks through the core logic of the `training_step`.

## 1. Setup and Imports

In [3]:
import os
import sys
import torch
import pandas as pd
from omegaconf import OmegaConf
import yaml  # For loading the raw config if needed
import pprint
import logging

# Configure basic logging for the notebook
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# Add project root to sys.path to allow imports from src
# Assumes the notebook is in a subdirectory of the project root (e.g., 'notebooks/')
project_root = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
# if project_root not in sys.path:
#     sys.path.insert(0, project_root)
#     logger.info(f"Added {project_root} to sys.path")

try:
    from foundation_model.models.flexible_multi_task_model import FlexibleMultiTaskModel
    from foundation_model.data.datamodule import CompoundDataModule
    from foundation_model.configs.model_config import TaskType  # For potential inspection

    logger.info("Successfully imported project modules.")
except ImportError as e:
    logger.error(
        f"Error importing project modules: {e}. Ensure PYTHONPATH is set correctly or notebook is in the correct location."
    )
    raise

2025-05-14 04:43:49,056 - INFO - Successfully imported project modules.


## 2. Load Configuration

Load the model and data configurations from `samples/generated_configs/generated_model_config.yaml`.

In [9]:
config_path = os.path.join(project_root, "samples/generated_configs/generated_model_config.yaml")
logger.info(f"Loading configuration from: {config_path}")

try:
    cfg = OmegaConf.load(config_path)
    logger.info("Configuration loaded successfully.")
    # Pretty print the loaded configuration (optional)
    # logger.info(OmegaConf.to_yaml(cfg))
except FileNotFoundError:
    logger.error(f"Configuration file not found at {config_path}")
    raise
except Exception as e:
    logger.error(f"Error loading OmegaConf configuration: {e}")
    raise

# Extract model and data specific configurations

model_cfg = cfg.model
data_cfg = cfg.data

pp = pprint.PrettyPrinter(indent=2)
logger.info("Model Configuration:")
pp.pprint(OmegaConf.to_container(model_cfg, resolve=True))  # resolve=True to see interpolated values
logger.info("Data Configuration:")
pp.pprint(OmegaConf.to_container(data_cfg, resolve=True))

2025-05-14 04:51:03,812 - INFO - Loading configuration from: /data/foundation_model/samples/generated_configs/generated_model_config.yaml
2025-05-14 04:51:03,849 - INFO - Configuration loaded successfully.
2025-05-14 04:51:03,850 - INFO - Model Configuration:
2025-05-14 04:51:03,857 - INFO - Data Configuration:


{ 'class_path': 'foundation_model.models.FlexibleMultiTaskModel',
  'init_args': { 'enable_self_supervised_training': False,
                 'loss_weights': { 'contrastive': 1.0,
                                   'cross_recon': 1.0,
                                   'mfm': 1.0},
                 'mask_ratio': 0.15,
                 'modality_dropout_p': 0.3,
                 'norm_shared': True,
                 'residual_shared': False,
                 'shared_block_dims': [256, 128, 64],
                 'shared_block_optimizer': { 'betas': [0.9, 0.999],
                                             'eps': 1e-06,
                                             'factor': 0.1,
                                             'freeze_parameters': False,
                                             'lr': 0.001,
                                             'min_lr': 1e-06,
                                             'mode': 'min',
                                             'monitor': 'val_

## 3. Initialize DataModule

Instantiate `CompoundDataModule`, prepare data, and set up for the 'fit' stage to get training and validation dataloaders.

In [None]:
logger.info("Initializing CompoundDataModule...")
# The data_cfg.init_args.task_configs uses OmegaConf interpolation ${model.init_args.task_configs}
# We need to pass the resolved model task_configs to the datamodule if not already resolved by OmegaConf access.
# However, OmegaConf usually resolves this when accessing data_cfg.init_args

datamodule_args = OmegaConf.to_container(data_cfg.init_args, resolve=True)

# Ensure task_configs are passed correctly (OmegaConf should handle the interpolation)
if "task_configs" not in datamodule_args or datamodule_args["task_configs"] is None:
    logger.info("Manually assigning task_configs to datamodule_args from model_cfg")
    datamodule_args["task_configs"] = OmegaConf.to_container(model_cfg.init_args.task_configs, resolve=True)

datamodule = CompoundDataModule(**datamodule_args)

logger.info("Preparing data...")
datamodule.prepare_data()  # Downloads or verifies data, typically no-op if local

logger.info("Setting up DataModule for 'fit' stage...")
datamodule.setup(stage="fit")

train_dataloader = datamodule.train_dataloader()
val_dataloader = datamodule.val_dataloader()

logger.info(f"Number of training batches: {len(train_dataloader)}")
logger.info(f"Number of validation batches: {len(val_dataloader)}")

## 4. Initialize Model

Instantiate `FlexibleMultiTaskModel` using the loaded model configuration.

In [None]:
logger.info("Initializing FlexibleMultiTaskModel...")
model_init_args = OmegaConf.to_container(model_cfg.init_args, resolve=True)
model = FlexibleMultiTaskModel(**model_init_args)
logger.info("Model initialized successfully.")
logger.info(f"Model structure:\n{model}")

## 5. Fetch a Batch of Data

Get one batch from the training dataloader to simulate what the model receives during training.

In [None]:
logger.info("Fetching one batch from train_dataloader...")
batch = next(iter(train_dataloader))
batch_idx = 0  # For simulation

x, y_dict_batch, task_masks_batch, task_sequence_data_batch = batch

logger.info("Batch details:")
if isinstance(x, tuple):
    logger.info(
        f"  x_formula shape: {x[0].shape if x[0] is not None else 'None'}, dtype: {x[0].dtype if x[0] is not None else 'None'}"
    )
    logger.info(
        f"  x_struct shape: {x[1].shape if x[1] is not None else 'None'}, dtype: {x[1].dtype if x[1] is not None else 'None'}"
    )
else:
    logger.info(f"  x shape: {x.shape}, dtype: {x.dtype}")

logger.info("  y_dict_batch keys: " + str(list(y_dict_batch.keys())))
for task_name, tensor in y_dict_batch.items():
    logger.info(f"    {task_name} target shape: {tensor.shape}, dtype: {tensor.dtype}")

logger.info("  task_masks_batch keys: " + str(list(task_masks_batch.keys())))
for task_name, tensor in task_masks_batch.items():
    logger.info(f"    {task_name} mask shape: {tensor.shape}, dtype: {tensor.dtype}")

logger.info("  task_sequence_data_batch keys: " + str(list(task_sequence_data_batch.keys())))
for task_name, tensor in task_sequence_data_batch.items():
    logger.info(f"    {task_name} sequence data shape: {tensor.shape}, dtype: {tensor.dtype}")

## 6. Manual `training_step` Walkthrough

This section replicates the logic inside the model's `training_step` method cell by cell to inspect intermediate values and the final loss.

### 6.1 Unpack Batch and Determine Input Modalities

In [None]:
logs = {}

# Determine input modalities (simplified from model's training_step)
x_formula = None
original_x_struct = None  # Keep original structure input for potential cross-recon target

if model.with_structure and isinstance(x, (list, tuple)):
    x_formula, original_x_struct = x
    if x_formula is None:
        raise ValueError("Formula input (x_formula) cannot be None in multi-modal mode.")
elif not model.with_structure and isinstance(x, torch.Tensor):
    x_formula = x
elif model.with_structure and isinstance(x, torch.Tensor):
    x_formula = x
    # original_x_struct remains None
else:
    raise TypeError(f"Unexpected input type/combination. with_structure={model.with_structure}, type(x)={type(x)}")

logger.info(f"x_formula device: {x_formula.device if x_formula is not None else 'N/A'}")
total_loss = torch.tensor(0.0, device=x_formula.device if x_formula is not None else "cpu")
logger.info(f"Initial total_loss: {total_loss}, requires_grad: {total_loss.requires_grad}, device: {total_loss.device}")

# Modality Dropout (skipped as SSL is disabled in current config)
x_struct_for_processing = original_x_struct
if model.enable_self_supervised_training and model.with_structure and original_x_struct is not None:
    logger.info("SSL and structure are enabled, modality dropout would be considered here.")
    # Placeholder for modality dropout logic if it were active
    pass

# Self-Supervised Learning (SSL) Calculations (skipped as SSL is disabled)
if model.enable_self_supervised_training:
    logger.info("SSL is enabled, SSL losses would be calculated here.")
    # Placeholder for SSL loss calculations
    pass
else:
    logger.info("SSL is disabled, skipping SSL loss calculations.")

### 6.2 Supervised Task Calculations: Forward Pass

In [None]:
# Prepare input for the standard forward pass
if model.with_structure:
    forward_input = (x_formula, x_struct_for_processing)
else:
    forward_input = x_formula

logger.info("Performing forward pass...")
# Call the model's forward method directly
# Ensure model is in training mode if it has dropout/batchnorm layers that behave differently
model.train()
preds = model(forward_input, task_sequence_data_batch)

logger.info("Predictions (preds) keys: " + str(list(preds.keys())))
for task_name, pred_tensor in preds.items():
    logger.info(
        f"  {task_name} prediction shape: {pred_tensor.shape}, dtype: {pred_tensor.dtype}, requires_grad: {pred_tensor.requires_grad}"
    )

### 6.3 Supervised Task Calculations: Loss Computation

In [None]:
logger.info("Calculating supervised task losses...")
for name, pred_tensor in preds.items():
    if name not in y_dict_batch or not model.task_configs_map[name].enabled:
        logger.info(f"Skipping loss for task {name} (no target or disabled).")
        continue

    head = model.task_heads[name]
    target = y_dict_batch[name]
    sample_mask = task_masks_batch.get(name)

    if sample_mask is None:
        logger.warning(f"Mask not found for task {name}. Assuming all samples are valid.")
        sample_mask = torch.ones_like(target, dtype=torch.bool, device=target.device)

    loss, _ = head.compute_loss(pred_tensor, target, sample_mask)
    task_weight = model.w.get(name, 1.0)
    weighted_loss = task_weight * loss

    logger.info(f"Task: {name}")
    logger.info(f"  Raw loss: {loss.item()}, requires_grad: {loss.requires_grad}")
    logger.info(f"  Task weight: {task_weight}")
    logger.info(f"  Weighted loss: {weighted_loss.item()}, requires_grad: {weighted_loss.requires_grad}")

    total_loss += weighted_loss

    logs[f"train_{name}_loss"] = loss.detach()
    logs[f"train_{name}_loss_weighted"] = weighted_loss.detach()

logger.info(
    f"Final total_loss: {total_loss.item()}, requires_grad: {total_loss.requires_grad}, grad_fn: {total_loss.grad_fn}"
)
logs["train_total_loss"] = total_loss.detach()

# At this point, you can inspect logs or total_loss

### 6.4 Manual Backward Pass and Optimizer Step (Conceptual)

This demonstrates how the backward pass and optimizer steps would be called. Note that `FlexibleMultiTaskModel` uses manual optimization and can have multiple optimizers. For simplicity, we'll just show the call to `manual_backward`. The actual optimizer configuration and stepping would involve iterating through `model.optimizers()`.

In [None]:
if total_loss.requires_grad:
    logger.info("total_loss requires grad. Proceeding with conceptual backward pass.")
    # In a real scenario with a Lightning Trainer, trainer.strategy.backward would be called via model.manual_backward()
    # For this notebook, we can try to call backward directly on the loss if no trainer is involved.
    # However, model.manual_backward(total_loss) is the correct way if simulating the model's own logic.

    # To simulate model's internal call if it were part of a Trainer:
    # model.manual_backward(total_loss) # This would require a trainer instance to be set on the model.

    # Direct backward call for demonstration (if no trainer context):
    # This will populate .grad attributes of tensors that were part of the computation graph and require grad.
    try:
        total_loss.backward()  # Computes gradients
        logger.info("total_loss.backward() called successfully.")

        # Conceptual optimizer step (actual model has multiple optimizers)
        # optimizers = model.configure_optimizers() # This returns a list of optimizers/schedulers
        # For example, taking the first optimizer if it exists and is a plain optimizer:
        # if optimizers:
        #     opt0_config = optimizers[0]
        #     if isinstance(opt0_config, torch.optim.Optimizer):
        #         opt0_config.step()
        #         opt0_config.zero_grad(set_to_none=True)
        #         logger.info("Conceptual step and zero_grad for the first optimizer.")
        #     elif isinstance(opt0_config, dict) and 'optimizer' in opt0_config:
        #         opt0_config['optimizer'].step()
        #         opt0_config['optimizer'].zero_grad(set_to_none=True)
        #         logger.info("Conceptual step and zero_grad for the first optimizer from dict.")
        logger.info("Gradients would now be populated. Optimizer step would follow.")
        # Example: check grad of a parameter from the first layer of the shared encoder
        if hasattr(model, "shared") and hasattr(model.shared, "0") and hasattr(model.shared[0], "weight"):
            logger.info(f"Gradient of model.shared[0].weight: {model.shared[0].weight.grad}")
        else:
            logger.info("Could not access model.shared[0].weight.grad for inspection.")

    except RuntimeError as e:
        logger.error(f"RuntimeError during backward pass: {e}")
        logger.error("This likely means an issue with the computation graph or requires_grad status.")
else:
    logger.warning(
        "total_loss does not require grad and has no grad_fn. Skipping backward pass. "
        "This might indicate all parameters are frozen or loss contributions are zero."
    )

## 7. Prediction Step Walkthrough (Optional)

Demonstrate how to use the model for prediction.

In [None]:
logger.info("Setting up DataModule for 'predict' stage...")
datamodule.setup(stage="predict")  # Re-setup for the predict dataloader
predict_dataloader = datamodule.predict_dataloader()

if len(predict_dataloader) > 0:
    logger.info("Fetching one batch from predict_dataloader...")
    predict_batch = next(iter(predict_dataloader))
    predict_batch_idx = 0

    # The predict_step in the model expects batch[0] to be x_formula
    # and batch[3] to be task_sequence_data_batch (if present)
    # The CompoundDataset for predict_set=True yields: (model_input_x, sample_y_dict, sample_task_masks_dict, sample_task_sequence_data_dict)
    # where model_input_x is x_formula (or (x_formula, None) if with_structure)

    logger.info("Performing prediction_step...")
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():  # Ensure no gradients are computed
        predictions = model.predict_step(batch=predict_batch, batch_idx=predict_batch_idx, additional_output=True)

    logger.info("Predictions output:")
    pp.pprint(predictions)
else:
    logger.info("Predict dataloader is empty, skipping prediction step walkthrough.")

## End of Notebook

This notebook provides a basic framework for interactively testing the model. You can expand on this by:
- Modifying configurations.
- Testing specific parts of the model (e.g., individual task heads, encoder blocks).
- Visualizing weights, activations, or gradients.