[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/crunchdao/quickstarters/blob/master/competitions/structural-break/quickstarters/baseline/baseline.ipynb)

![Banner](https://raw.githubusercontent.com/crunchdao/quickstarters/refs/heads/master/competitions/structural-break/assets/banner.webp)

# ADIA Lab Structural Break Challenge

## Challenge Overview

Welcome to the ADIA Lab Structural Break Challenge! In this challenge, you will analyze univariate time series data to determine whether a structural break has occurred at a specified boundary point.

### What is a Structural Break?

A structural break occurs when the process governing the data generation changes at a certain point in time. These changes can be subtle or dramatic, and detecting them accurately is crucial across various domains such as climatology, industrial monitoring, finance, and healthcare.

![Structural Break Example](https://raw.githubusercontent.com/crunchdao/competitions/refs/heads/master/competitions/structural-break/quickstarters/baseline/images/example.png)

### Your Task

For each time series in the test set, you need to predict a score between `0` and `1`:
- Values closer to `0` indicate no structural break at the specified boundary point;
- Values closer to `1` indicate a structural break did occur.

### Evaluation Metric

The evaluation metric is [ROC AUC (Area Under the Receiver Operating Characteristic Curve)](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html), which measures the performance of detection algorithms regardless of their specific calibration.

- ROC AUC around `0.5`: No better than random chance;
- ROC AUC approaching `1.0`: Perfect detection.

# Setup

The first steps to get started are:
1. Get the setup command
2. Execute it in the cell below

### >> https://hub.crunchdao.com/competitions/structural-break/submit/notebook

![Reveal token](https://raw.githubusercontent.com/crunchdao/competitions/refs/heads/master/documentation/animations/reveal-token.gif)

In [1]:
%pip install crunch-cli --upgrade --quiet --progress-bar off
!crunch setup-notebook structural-break rALxEZ3q09xPxuC4pc57uwTE

crunch-cli, version 7.2.1
main.py: download from https:crunchdao--competition--production.s3-accelerate.amazonaws.com/submissions/21656/main.py (7619 bytes)
notebook.ipynb: download from https:crunchdao--competition--production.s3-accelerate.amazonaws.com/submissions/21656/notebook.ipynb (98608 bytes)
requirements.txt: download from https:crunchdao--competition--production.s3-accelerate.amazonaws.com/submissions/21656/requirements.original.txt (188 bytes)
data/X_train.parquet: download from https:crunchdao--competition--production.s3-accelerate.amazonaws.com/data-releases/146/X_train.parquet (204327238 bytes)
data/X_test.reduced.parquet: download from https:crunchdao--competition--production.s3-accelerate.amazonaws.com/data-releases/146/X_test.reduced.parquet (2380918 bytes)
data/y_train.parquet: download from https:crunchdao--competition--production.s3-accelerate.amazonaws.com/data-releases/146/y_train.parquet (61003 bytes)
data/y_test.reduced.parquet: download from https:crunchdao--c

# Your model

## Setup

In [2]:
pip install cellpylib

Collecting cellpylib
  Downloading cellpylib-2.4.0.tar.gz (38 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: cellpylib
  Building wheel for cellpylib (setup.py) ... [?25l[?25hdone
  Created wheel for cellpylib: filename=cellpylib-2.4.0-py3-none-any.whl size=37922 sha256=fabe0b515f4807469ba59795652518df5609a101d0955ead936ffd249d1e3b54
  Stored in directory: /root/.cache/pip/wheels/90/db/81/70a63e7c4de08d29f2b1d988ca055bff567bf53439f00e0f3c
Successfully built cellpylib
Installing collected packages: cellpylib
Successfully installed cellpylib-2.4.0


In [3]:
# ==============================================================================
#                      CELL 1: SETUP AND IMPORTS
# ==============================================================================
import os
import typing
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import joblib
import math
import hashlib
from dataclasses import dataclass, field
from typing import List, Tuple, Optional
import tempfile
import shutil
import cellpylib as cpl
from sklearn.metrics import roc_auc_score

import crunch

# Load the Crunch Toolings
crunch = crunch.load_notebook()

loaded inline runner with module: <module '__main__'>

cli version: 7.2.1
available ram: 12.67 gb
available cpu: 2 core
----


## Core Logic

In [4]:
# ==============================================================================
#            CELL 2: CORE LOGIC AND DEPENDENCY CLASSES
# This cell contains all of our custom classes for data processing,
# model architecture, and training orchestration.
# ==============================================================================

# ------------------------------------------------------------------------------
# 1.1: Data Processing Components
# ------------------------------------------------------------------------------

class PermutationSymbolizer:
    def __init__(self, embedding_dim: int):
        if not isinstance(embedding_dim, int) or embedding_dim <= 1:
            raise ValueError("embedding_dim must be an integer greater than 1.")
        self.embedding_dim = embedding_dim
        self.vocab_size = math.factorial(embedding_dim)

    def symbolize_vector(self, vector: np.ndarray) -> int:
        h = int(hashlib.sha256(vector.tobytes()).hexdigest(), 16)
        seed = h % (2**32)
        rng = np.random.default_rng(seed)
        noise = rng.uniform(low=-1e-12, high=1e-12, size=self.embedding_dim)
        perturbed_vector = vector + noise

        p = np.argsort(perturbed_vector)
        n = len(p)
        res = 0
        for i in range(n):
            res += p[i] * math.factorial(n - 1 - i)
            for j in range(i + 1, n):
                if p[j] > p[i]:
                    p[j] -= 1
        return res

class SeriesProcessor:
    def __init__(self, symbolizer: PermutationSymbolizer, sequence_length: int):
        self.symbolizer = symbolizer
        self.sequence_length = sequence_length
        self.embedding_dim = symbolizer.embedding_dim

    def process(self, series: pd.Series) -> Optional[torch.Tensor]:
        if len(series) < self.embedding_dim:
            return None

        series_values = series.values.copy()
        embedded_vectors = np.lib.stride_tricks.as_strided(
            series_values,
            shape=(len(series_values) - self.embedding_dim + 1, self.embedding_dim),
            strides=(series_values.strides[0], series_values.strides[0])
        )

        symbols = [self.symbolizer.symbolize_vector(v) for v in embedded_vectors]

        if len(symbols) < self.sequence_length:
            return None

        symbols_arr = np.array(symbols)
        sequences = np.lib.stride_tricks.as_strided(
            symbols_arr,
            shape=(len(symbols_arr) - self.sequence_length + 1, self.sequence_length),
            strides=(symbols_arr.strides[0], symbols_arr.strides[0])
        )

        return torch.from_numpy(sequences.copy()).long()

class ECADataGenerator:
    def __init__(self, config: dict):
        self.config = config
        self.rng = np.random.default_rng(config['seed'])

    def generate_training_data(self) -> Tuple[np.ndarray, np.ndarray]:
        all_sequences, all_labels = [], []
        rules_map = self.config['rules_to_use']['base'] + list(self.config['rules_to_use']['composite'].keys())
        for _ in range(self.config['n_reps']):
            for rule_key in rules_map:
                initial_state = self.rng.choice([0, 1], size=(1, self.config['width']))
                rule_list = [rule_key] if isinstance(rule_key, int) else self.config['rules_to_use']['composite'][rule_key]
                ca = cpl.evolve(initial_state, timesteps=self.config['steps'] + self.config['warmup'], apply_rule=lambda n, c, t: cpl.nks_rule(n, rule_list[t % len(rule_list)]))
                orbit = ca[self.config['warmup']:]
                num_windows = (orbit.shape[0] - self.config['sequence_length']) // self.config['stride'] + 1
                for i in range(num_windows):
                    start = i * self.config['stride']
                    window = orbit[start : start + self.config['sequence_length']]
                    all_sequences.append(window)
                    all_labels.append(rules_map.index(rule_key))
        return np.array(all_sequences, dtype=np.float32), np.array(all_labels, dtype=np.int64)

# ------------------------------------------------------------------------------
# 1.2: Model Architecture Components
# ------------------------------------------------------------------------------
@dataclass
class HierarchicalArgs:
    eca_input_dim: int; symbol_vocab_size: int; num_classes: int; permutation_embedding_dim: int
    dimensions: List[int] = field(default_factory=lambda: [64, 128, 256]); layers: List[int] = field(default_factory=lambda: [2, 2, 2])
    max_seqlens: List[int] = field(default_factory=lambda: [128, 64, 32]); n_heads: int = 4; d_ff_multiplier: int = 4; dropout: float = 0.1
    def __post_init__(self): self.n_stages = len(self.dimensions); self.latent_dim = self.dimensions[-1]

class CausalTransformer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout):
        super().__init__(); self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout, batch_first=True)
        self.ffn = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model)); self.norm1 = nn.LayerNorm(d_model); self.norm2 = nn.LayerNorm(d_model)
    def forward(self, x):
        causal_mask = torch.triu(torch.ones(x.shape[1], x.shape[1], device=x.device), 1).bool()
        attn_out, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x), attn_mask=causal_mask, is_causal=False)
        x = x + attn_out; x = x + self.ffn(self.norm2(x)); return x

class SimpleTransition(nn.Module):
    def __init__(self, d_shallow: int, d_deep: int, factor: int):
        super().__init__(); self.down_proj = nn.Linear(d_shallow, d_deep); self.up_proj = nn.Linear(d_deep, d_shallow * factor); self.factor = factor
    def down(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(x[:, :: self.factor, :])
    def up(self, x: torch.Tensor) -> torch.Tensor:
        x = self.up_proj(x); B, S, D = x.shape
        return x.view(B, S, self.factor, D // self.factor).permute(0, 2, 1, 3).contiguous().view(B, S * self.factor, -1)

class HierarchicalDynamicalEncoder(nn.Module):
    def __init__(self, args: HierarchicalArgs):
        super().__init__(); self.args = args; self.levels = nn.ModuleList(); self.transitions = nn.ModuleList()
        for i in range(args.n_stages):
            d_model = args.dimensions[i]; n_layers = args.layers[i]
            self.levels.append(nn.Sequential(*[CausalTransformer(d_model, args.n_heads, d_model * args.d_ff_multiplier, args.dropout) for _ in range(n_layers)]))
            if i < args.n_stages - 1:
                factor = args.max_seqlens[i] // args.max_seqlens[i+1]
                self.transitions.append(SimpleTransition(d_shallow=d_model, d_deep=args.dimensions[i+1], factor=factor))
    def forward(self, x):
        residuals = [];
        for i in range(self.args.n_stages - 1):
            x = self.levels[i](x); residuals.append(x); x = self.transitions[i].down(x)
        return self.levels[-1](x), residuals

class HierarchicalDynamicalDecoder(nn.Module):
    def __init__(self, args, transitions):
        super().__init__(); self.args = args; self.transitions = transitions; self.levels = nn.ModuleList()
        for i in range(args.n_stages - 1):
            d_model = args.dimensions[i]; n_layers = args.layers[i]
            self.levels.append(nn.Sequential(*[CausalTransformer(d_model, args.n_heads, d_model * args.d_ff_multiplier, args.dropout) for _ in range(n_layers)]))
    def forward(self, fingerprint_seq, residuals):
        x = fingerprint_seq
        for i in range(self.args.n_stages - 2, -1, -1):
            x = self.transitions[i].up(x); x = x + residuals.pop(); x = self.levels[i](x)
        return x

class MDL_AU_Net_Autoencoder(nn.Module):
    def __init__(self, args: HierarchicalArgs):
        super().__init__(); self.args = args; self.eca_input_proj = nn.Linear(args.eca_input_dim, args.dimensions[0])
        self.embedding_head = nn.Embedding(args.symbol_vocab_size, args.dimensions[0]); self.encoder = HierarchicalDynamicalEncoder(args)
        self.decoder = HierarchicalDynamicalDecoder(args, self.encoder.transitions); self.recon_head = nn.Linear(args.dimensions[0], args.eca_input_dim)
        self.classification_head = nn.Linear(args.latent_dim, args.num_classes)
    def forward_pretrain(self, x_float):
        x = self.eca_input_proj(x_float); fingerprint_seq, residuals = self.encoder(x)
        reconstruction_vecs = self.decoder(fingerprint_seq, residuals); recon_logits = self.recon_head(reconstruction_vecs)
        pooled_fingerprint = fingerprint_seq.mean(dim=1); rule_logits = self.classification_head(pooled_fingerprint)
        return recon_logits, rule_logits
    def encode_finetune(self, x_int):
        x = self.embedding_head(x_int); fingerprint_seq, _ = self.encoder(x); return fingerprint_seq

class StructuralBreakClassifier(nn.Module):
    def __init__(self, autoencoder: MDL_AU_Net_Autoencoder, latent_dim: int):
        super().__init__(); self.autoencoder = autoencoder
        self.classifier_head = nn.Sequential(nn.Linear(latent_dim * 2, latent_dim), nn.ReLU(), nn.Linear(latent_dim, 1))
    def forward(self, before_seqs: torch.Tensor, after_seqs: torch.Tensor) -> torch.Tensor:
        fp_before_per_seq = self.autoencoder.encode_finetune(before_seqs).mean(dim=1)
        fp_after_per_seq = self.autoencoder.encode_finetune(after_seqs).mean(dim=1)
        fp_before_stable = fp_before_per_seq.mean(dim=0); fp_after_stable = fp_after_per_seq.mean(dim=0)
        combined_fp = torch.cat([fp_before_stable, fp_after_stable], dim=-1).unsqueeze(0)
        return self.classifier_head(combined_fp).squeeze(-1)

# ------------------------------------------------------------------------------
# 1.3: Training Pipeline Orchestrators
# ------------------------------------------------------------------------------
class MDLPreTrainer:
    def __init__(self, model, config):
        self.model, self.config = model, config; self.optimizer = torch.optim.Adam(self.model.parameters(), lr=config['pretrain_lr'])
        self.recon_criterion, self.class_criterion = nn.BCEWithLogitsLoss(), nn.CrossEntropyLoss(); self.device = config['device']; self.model.to(self.device)
    def run(self):
        print("--- Starting Stage 1: MDL Pre-training ---"); self.model.train()
        data_generator = ECADataGenerator(self.config['eca_config']); X, y = data_generator.generate_training_data()
        loader = DataLoader(TensorDataset(torch.from_numpy(X).float(), torch.from_numpy(y).long()), batch_size=self.config['pretrain_batch_size'], shuffle=True)
        for epoch in range(self.config['pretrain_epochs']):
            progress_bar = tqdm(loader, desc=f"Pre-train Epoch {epoch+1}")
            for sequences, labels in progress_bar:
                sequences, labels = sequences.to(self.device), labels.to(self.device); self.optimizer.zero_grad()
                recon, logits = self.model.forward_pretrain(sequences)
                total_loss = self.config['alpha_loss'] * self.recon_criterion(recon, sequences) + self.config['beta_loss'] * self.class_criterion(logits, labels)
                total_loss.backward(); self.optimizer.step(); progress_bar.set_postfix({'loss': total_loss.item()})
        print("--- MDL Pre-training Complete ---")

class EmbeddingFinetuner:
    def __init__(self, model, config):
        self.model, self.config = model, config; self.device = config['device']; self.model.to(self.device); self.criterion = nn.BCEWithLogitsLoss()
        print("\nConfiguring layers for Stage 2 (Embedding Tuning)...");
        for name, param in self.model.autoencoder.named_parameters():
            if 'encoder' in name: param.requires_grad = False
        self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=config['embedding_tune_lr'])
        print("...Froze core U-Net encoder body."); print("...Ensured embedding_head and classifier_head are tunable.")
    def run(self, num_epochs):
        print("--- Starting Stage 2: Embedding Head Fine-tuning (on mock data) ---"); self.model.train()
        X_mock, y_mock = self._create_mock_data()
        symbolizer = PermutationSymbolizer(self.config['series_proc_config']['embedding_dim'])
        processor = SeriesProcessor(symbolizer, self.config['series_proc_config']['sequence_length'])
        for epoch in range(num_epochs):
            progress_bar = tqdm(y_mock.items(), desc=f"Embedding Tune Epoch {epoch+1}")
            for series_id, label in progress_bar:
                self.optimizer.zero_grad(); series_df = X_mock.loc[series_id]
                before, after = processor.process(series_df[series_df.period==0].value), processor.process(series_df[series_df.period==1].value)
                if before is None or after is None: continue
                logit = self.model(before.to(self.device), after.to(self.device))
                loss = self.criterion(logit, torch.tensor([label], dtype=torch.float32).to(self.device))
                loss.backward(); self.optimizer.step(); progress_bar.set_postfix({'loss': loss.item()})
        print("--- Embedding Head Tuning Complete ---")
    def _create_mock_data(self):
        ids = [f"sin_cos_{i}" for i in range(40)]; y = pd.Series([i % 2 == 0 for i in range(40)], index=ids)
        dfs = []
        for i, idx in enumerate(ids):
            t = np.linspace(0, 20 * np.pi, 800); before = pd.DataFrame({'value': np.sin(t), 'period': 0, 'id': idx})
            after = pd.DataFrame({'value': np.cos(t) if y[idx] else np.sin(t + 0.1), 'period': 1, 'id': idx}); dfs.extend([before, after])
        return pd.concat(dfs).set_index('id'), y

class FinalClassifierFinetuner:
    def __init__(self, model, config):
        self.model, self.config = model, config; self.device = config['device']; self.model.to(self.device); self.criterion = nn.BCEWithLogitsLoss()
        print("\nConfiguring layers for Stage 3 (Final Tuning)...")
        for param in self.model.autoencoder.parameters(): param.requires_grad = False
        for param in self.model.classifier_head.parameters(): param.requires_grad = True
        self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=config['final_tune_lr'])
        print("...Froze entire MDL_AU_Net_Autoencoder (core + all heads)."); print("...Ensured ONLY the final classifier_head is tunable.")
    def run(self, X_train, y_train, num_epochs):
        print("--- Starting Stage 3: Final Classifier Fine-tuning (on real data) ---"); self.model.train()
        symbolizer = PermutationSymbolizer(self.config['series_proc_config']['embedding_dim'])
        processor = SeriesProcessor(symbolizer, self.config['series_proc_config']['sequence_length'])
        ids = y_train.index
        for epoch in range(num_epochs):
            progress_bar = tqdm(ids, desc=f"Final Tune Epoch {epoch+1}")
            for series_id in progress_bar:
                self.optimizer.zero_grad(); series_df = X_train.loc[series_id]
                before, after = processor.process(series_df[series_df.period==0].value), processor.process(series_df[series_df.period==1].value)
                if before is None or after is None: continue
                logit = self.model(before.to(self.device), after.to(self.device))
                loss = self.criterion(logit, torch.tensor([y_train.loc[series_id]], dtype=torch.float32).to(self.device))
                loss.backward(); self.optimizer.step(); progress_bar.set_postfix({'loss': loss.item()})
        print("--- Final Classifier Tuning Complete ---")

class ArtifactSaver:
    def save(self, model: StructuralBreakClassifier, model_args: HierarchicalArgs, path: str):
        if not os.path.exists(path): os.makedirs(path)
        torch.save(model.state_dict(), os.path.join(path, "final_classifier.pth")); joblib.dump(model_args, os.path.join(path, "model_config.joblib"))
        print(f"\n✅ Final classifier and config saved to '{path}'")

## Understanding the Data

The dataset consists of univariate time series, each containing ~2,000-5,000 values with a designated boundary point. For each time series, you need to determine whether a structural break occurred at this boundary point.

The data was downloaded when you setup your local environment and is now available in the `data/` directory.

In [None]:
# Load the data simply
X_train, y_train, X_test = crunch.load_data()

### Understanding `X_train`

The training data is structured as a pandas DataFrame with a MultiIndex:

**Index Levels:**
- `id`: Identifies the unique time series
- `time`: The timestep within each time series

**Columns:**
- `value`: The actual time series value at each timestep
- `period`: A binary indicator where `0` represents the **period before** the boundary point, and `1` represents the **period after** the boundary point

In [None]:
X_train

Unnamed: 0_level_0,Unnamed: 1_level_0,value,period
id,time,Unnamed: 2_level_1,Unnamed: 3_level_1
0,0,0.001858,0
0,1,-0.001664,0
0,2,-0.004386,0
0,3,0.000699,0
0,4,-0.002433,0
...,...,...,...
10000,1890,-0.005903,1
10000,1891,0.007295,1
10000,1892,0.003527,1
10000,1893,0.007218,1


### Understanding `y_train`

This is a simple `pandas.Series` that tells if a dataset id has a structural breakpoint or not.

**Index:**
- `id`: the ID of the dataset

**Value:**
- `structural_breakpoint`: Boolean indicating whether a structural break occurred (`True`) or not (`False`)

In [None]:
y_train

id
0         True
1         True
2        False
3         True
4        False
         ...  
9996     False
9997      True
9998     False
9999     False
10000     True
Name: structural_breakpoint, Length: 10001, dtype: bool

### Understanding `X_test`

The test data is provided as a **`list` of `pandas.DataFrame`s** with the same format as [`X_train`](#understanding-X_test).

It is structured as a list to encourage processing records one by one, which will be mandatory in the `infer()` function.

In [None]:
print("Number of datasets:", len(X_test))

Number of datasets: 101


In [None]:
X_test[0]

Unnamed: 0_level_0,Unnamed: 1_level_0,value,period
id,time,Unnamed: 2_level_1,Unnamed: 3_level_1
10001,0,-0.020657,0
10001,1,-0.005894,0
10001,2,-0.003052,0
10001,3,-0.000590,0
10001,4,0.009887,0
10001,...,...,...
10001,2517,0.005084,1
10001,2518,-0.024414,1
10001,2519,-0.014986,1
10001,2520,0.012999,1


## Strategy Implementation

There are multiple approaches you can take to detect structural breaks:

1. **Statistical Tests**: Compare distributions before and after the boundary point;
2. **Feature Engineering**: Extract features from both segments for comparison;
3. **Time Series Modeling**: Detect deviations from expected patterns;
4. **Machine Learning**: Train models to recognize break patterns from labeled examples.

The baseline implementation below uses a simple statistical approach: a t-test to compare the distributions before and after the boundary point.

### The `train()` Function

In this function, you build and train your model for making inferences on the test data. Your model must be stored in the `model_directory_path`.

The baseline implementation below doesn't require a pre-trained model, as it uses a statistical test that will be computed at inference time.

In [None]:
# ==============================================================================
#                  CELL 3: The `train()` function (UPGRADED WITH CACHING)
# ==============================================================================
def train(
    X_train: pd.DataFrame,
    y_train: pd.Series,
    model_directory_path: str,
):
    """
    Main entrypoint for the training process.
    Orchestrates pre-training, embedding tuning, final fine-tuning,
    and saving the final model artifact.

    *** UPGRADED with caching for the pre-training stage. ***
    """
    # --- SCALED-UP CONFIGURATION ---
    embedding_dim = 4
    vocab_size = math.factorial(embedding_dim)
    seq_len = 128

    config = {
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'pretrain_lr': 1e-4, 'embedding_tune_lr': 1e-4, 'final_tune_lr': 1e-3,
        'pretrain_batch_size': 128,
        'pretrain_epochs': 10,
        'embedding_tune_epochs': 5,
        'final_tune_epochs': 5,
        'alpha_loss': 1.0, 'beta_loss': 0.5,
        'series_proc_config': {'embedding_dim': embedding_dim, 'sequence_length': seq_len},
        'eca_config': {
            "rules_to_use": {
                "base": [110, 54, 30], "composite": {1: [90, 110]}
            },
            "width": 64, "steps": seq_len * 40, "n_reps": 50, "warmup": 100,
            "seed": 42, "sequence_length": seq_len, "stride": seq_len // 4
        }
    }

    model_args = HierarchicalArgs(
        eca_input_dim=config['eca_config']['width'],
        symbol_vocab_size=vocab_size,
        num_classes=len(config['eca_config']['rules_to_use']['base']) + len(config['eca_config']['rules_to_use']['composite']),
        permutation_embedding_dim=embedding_dim,
        dimensions=[64, 128, 256], layers=[2, 3, 3],
        max_seqlens=[seq_len, seq_len // 4, seq_len // 16]
    )

    # --- Instantiate Model ---
    autoencoder = MDL_AU_Net_Autoencoder(model_args)

    # --- Caching Logic for Stage 1 ---
    pre_trained_path = os.path.join(model_directory_path, "pre_trained_autoencoder.pth")

    if os.path.exists(pre_trained_path):
        print("--- Found cached pre-trained model. Loading from disk. ---")
        autoencoder.load_state_dict(torch.load(pre_trained_path, map_location=config['device']))
    else:
        # --- Run Stage 1: Pre-training ---
        pre_trainer = MDLPreTrainer(autoencoder, config)
        pre_trainer.run()
        print(f"--- Caching pre-trained model to {pre_trained_path} ---")
        torch.save(autoencoder.state_dict(), pre_trained_path)

    # --- Run Stage 2: Embedding Tuning ---
    classifier = StructuralBreakClassifier(autoencoder, model_args.latent_dim)
    embedding_tuner = EmbeddingFinetuner(classifier, config)
    embedding_tuner.run(config['embedding_tune_epochs'])

    # --- Run Stage 3: Final Classifier Tuning ---
    final_tuner = FinalClassifierFinetuner(classifier, config)
    final_tuner.run(X_train, y_train, config['final_tune_epochs'])

    # --- Save Final Artifacts ---
    saver = ArtifactSaver()
    saver.save(classifier, model_args, model_directory_path)

### The `infer()` Function

In the inference function, your trained model (if any) is loaded and used to make predictions on test data.

**Important workflow:**
1. Load your model;
2. Use the `yield` statement to signal readiness to the runner;
3. Process each dataset one by one within the for loop;
4. For each dataset, use `yield prediction` to return your prediction.

**Note:** The datasets can only be iterated once!

In [6]:
# ==============================================================================
#                  CELL 4: The `infer()` function
# ==============================================================================
def infer(
    X_test: typing.Iterable[pd.DataFrame],
    model_directory_path: str,
):
    """
    Main entrypoint for the inference process.
    Loads the trained model and generates predictions for the test set.
    """
    print("\n--- Starting Inference ---")
    config_path = os.path.join(model_directory_path, "model_config.joblib")
    model_path = os.path.join(model_directory_path, "final_classifier.pth")
    if not os.path.exists(config_path) or not os.path.exists(model_path):
        raise FileNotFoundError("Model artifacts not found. Please run the train() function first.")

    model_args = joblib.load(config_path)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Recreate the full model structure and load the trained state
    loaded_autoencoder = MDL_AU_Net_Autoencoder(model_args)
    model = StructuralBreakClassifier(loaded_autoencoder, model_args.latent_dim)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    # Instantiate the data processor using the saved configuration
    symbolizer = PermutationSymbolizer(model_args.permutation_embedding_dim)
    processor = SeriesProcessor(symbolizer, model_args.max_seqlens[0])

    # Signal readiness to the runner
    yield

    # Process each dataset one by one
    with torch.no_grad():
        for series_df in X_test:
            before = processor.process(series_df[series_df.period==0].value)
            after = processor.process(series_df[series_df.period==1].value)

            if before is None or after is None:
                # Default prediction for series too short to process
                yield 0.5
                continue

            logit = model(before.to(device), after.to(device))
            prediction = torch.sigmoid(logit).item()

            # Yield the prediction for the current dataset
            yield prediction

    print("--- Inference Complete ---")

## Local testing

To make sure your `train()` and `infer()` function are working properly, you can call the `crunch.test()` function that will reproduce the cloud environment locally. <br />
Even if it is not perfect, it should give you a quick idea if your model is working properly.

In [None]:
crunch.test(
    # Uncomment to disable the train
    # force_first_train=False,

    # Uncomment to disable the determinism check
    # no_determinism_check=True,
)

04:38:13 no forbidden library found
04:38:13 
04:38:13 started
04:38:13 running local test
04:38:13 internet access isn't restricted, no check will be done
04:38:13 
04:38:14 starting unstructured loop...
04:38:14 executing - command=train


data/X_train.parquet: download from https:crunchdao--competition--production.s3-accelerate.amazonaws.com/data-releases/146/X_train.parquet (204327238 bytes)
data/X_train.parquet: already exists, file length match
data/X_test.reduced.parquet: download from https:crunchdao--competition--production.s3-accelerate.amazonaws.com/data-releases/146/X_test.reduced.parquet (2380918 bytes)
data/X_test.reduced.parquet: already exists, file length match
data/y_train.parquet: download from https:crunchdao--competition--production.s3-accelerate.amazonaws.com/data-releases/146/y_train.parquet (61003 bytes)
data/y_train.parquet: already exists, file length match
data/y_test.reduced.parquet: download from https:crunchdao--competition--production.s3-accelerate.amazonaws.com/data-releases/146/y_test.reduced.parquet (2655 bytes)
data/y_test.reduced.parquet: already exists, file length match
--- Starting Stage 1: MDL Pre-training ---


Pre-train Epoch 1:   0%|          | 0/246 [00:00<?, ?it/s]

Pre-train Epoch 2:   0%|          | 0/246 [00:00<?, ?it/s]

Pre-train Epoch 3:   0%|          | 0/246 [00:00<?, ?it/s]

Pre-train Epoch 4:   0%|          | 0/246 [00:00<?, ?it/s]

Pre-train Epoch 5:   0%|          | 0/246 [00:00<?, ?it/s]

Pre-train Epoch 6:   0%|          | 0/246 [00:00<?, ?it/s]

Pre-train Epoch 7:   0%|          | 0/246 [00:00<?, ?it/s]

Pre-train Epoch 8:   0%|          | 0/246 [00:00<?, ?it/s]

Pre-train Epoch 9:   0%|          | 0/246 [00:00<?, ?it/s]

Pre-train Epoch 10:   0%|          | 0/246 [00:00<?, ?it/s]

--- MDL Pre-training Complete ---

Configuring layers for Stage 2 (Embedding Tuning)...
...Froze core U-Net encoder body.
...Ensured embedding_head and classifier_head are tunable.
--- Starting Stage 2: Embedding Head Fine-tuning (on mock data) ---


Embedding Tune Epoch 1: 0it [00:00, ?it/s]

Embedding Tune Epoch 2: 0it [00:00, ?it/s]

Embedding Tune Epoch 3: 0it [00:00, ?it/s]

Embedding Tune Epoch 4: 0it [00:00, ?it/s]

Embedding Tune Epoch 5: 0it [00:00, ?it/s]

--- Embedding Head Tuning Complete ---

Configuring layers for Stage 3 (Final Tuning)...
...Froze entire MDL_AU_Net_Autoencoder (core + all heads).
...Ensured ONLY the final classifier_head is tunable.
--- Starting Stage 3: Final Classifier Fine-tuning (on real data) ---


Final Tune Epoch 1:   0%|          | 0/10001 [00:00<?, ?it/s]

Final Tune Epoch 2:   0%|          | 0/10001 [00:00<?, ?it/s]

Final Tune Epoch 3:   0%|          | 0/10001 [00:00<?, ?it/s]

Final Tune Epoch 4:   0%|          | 0/10001 [00:00<?, ?it/s]

Final Tune Epoch 5:   0%|          | 0/10001 [00:00<?, ?it/s]

## Results

Once the local tester is done, you can preview the result stored in `data/prediction.parquet`.

In [None]:
prediction = pd.read_parquet("data/prediction.parquet")
prediction

### Local scoring

You can call the function that the system uses to estimate your score locally.

In [None]:
# Load the targets
target = pd.read_parquet("data/y_test.reduced.parquet")["structural_breakpoint"]

# Call the scoring function
sklearn.metrics.roc_auc_score(
    target,
    prediction,
)

# Submit your Notebook

To submit your work, you must:
1. Download your Notebook from Colab
2. Upload it to the platform
3. Create a run to validate it

### >> https://hub.crunchdao.com/competitions/structural-break/submit/notebook

![Download and Submit Notebook](https://raw.githubusercontent.com/crunchdao/competitions/refs/heads/master/documentation/animations/download-and-submit-notebook.gif)