### Training a new model involves the following steps:
1. Prepare dataset
2. Define coarse-graining scheme
3. Prepare features based on coarse-graining scheme
4. Train model

#### 1. Prepare dataset

Download datasets from https://doi.org/10.5281/zenodo.18506092 and extract into two folders under `data`

In [None]:
%%bash
curl -L -o data/cifs.tar.gz https://zenodo.org/records/18506092/files/cath-cif.tar.gz?download=1 && tar -xzf data/cifs.tar.gz -C data/cifs --strip-components=1 && rm data/cifs.tar.gz
curl -L -o data/dssp.tar.gz https://zenodo.org/records/18506092/files/cath-dssp.tar.gz?download=1 && tar -xzf data/dssp.tar.gz -C data/dssp --strip-components=1 && rm data/dssp.tar.gz

#### 2. Define coarse graining schema

In [None]:
# For example, to coarse grain to C-alpha representation, we can implement the following class:

from corssa.coarse_graining import CoarseGrainModel
from Bio.PDB.Polypeptide import PPBuilder
import pandas as pd

class CAlpha(CoarseGrainModel):
    def scheme(self, structure, postfix, **kwargs):
        """
        Extract C-alpha coordinates to use as coarse-grain beads.
        """

        ppbuilder = PPBuilder()
        p = ppbuilder.build_peptides(structure)[0]

        coordinates = [
            a.get_coord() for a in p.get_ca_list()
        ]

        df = pd.DataFrame(
            data=coordinates,
            columns=[f"x{postfix}", f"y{postfix}", f"z{postfix}"]
        )

        return df

#### 3. Prepare features based on the coarse-graining schema

In [None]:
from pathlib import Path

cif_files = list(Path("../data/cath-cif").glob("*.cif"))

# Fetch corresponding DSSP files
dssp_files = [Path("../data/cath-dssp") / (x.stem + ".dssp") for x in cif_files]

In [None]:
# Postfix the representation with '_ca' to distinguish it from other coarse-grained representations
calpha_rep = CAlphaRep(postfix='_ca')

ca_df = calpha_rep.process_batch(cif_files, dssp_filepaths=dssp_files)

In [None]:
from corssa.featurizer import Featurizer

featr = Featurizer(ca_df, postfix="_ca")
ca_features = featr.extract()

In [None]:
from corssa.datautils import map_dssp_to_3state, split_data

# Map 9-state DSSP to 3-state
ca_features = map_dssp_to_3state(ca_features)

In [None]:
X_train, X_test, y_train, y_test = split_data(ca_features, feature_col='dssp3')

#### 4. Train model

In [None]:
from corssa.model import CORSSA
model = CORSSA()

In [None]:
model.fit(X_train, y_train)

In [None]:
y_pred = model.predict(X_test)

In [None]:
from sklearn.metrics import classification_report

print(classification_report(y_test, y_pred))

In [None]:
from corssa.evalutils import plot_confusion_matrix

plot_confusion_matrix(model, y_test, model.predict(X_test), model_name='CORSSA - C-alpha')

In [None]:
# Save model for subsequent use
model.save_model("corssa_calpha_model.cbm")

In [None]:
# Load model for inference
loaded_model = CORSSA()
loaded_model.load_model(fname="corssa_calpha_model.cbm")