# constants

> Constants.

In [None]:
#| default_exp constants

In [None]:
#| export

from dataclasses import dataclass
from typing import Optional
from io import StringIO

import pandas as pd

import torch


@dataclass
class SubstitutionMatrix:
    name: str
    mat: pd.DataFrame
    expected_value: float


@dataclass
class TokenizedSubstitutionMatrix:
    name: str
    mat: torch.Tensor
    expected_value: float


BLOSUM62 = SubstitutionMatrix(
    name="BLOSUM62",
    mat=pd.read_csv(
        StringIO(
            """   A  R  N  D  C  Q  E  G  H  I  L  K  M  F  P  S  T  W  Y  V  B  Z  X  *
A  4 -1 -2 -2  0 -1 -1  0 -2 -1 -1 -1 -1 -2 -1  1  0 -3 -2  0 -2 -1  0 -4
R -1  5  0 -2 -3  1  0 -2  0 -3 -2  2 -1 -3 -2 -1 -1 -3 -2 -3 -1  0 -1 -4
N -2  0  6  1 -3  0  0  0  1 -3 -3  0 -2 -3 -2  1  0 -4 -2 -3  3  0 -1 -4
D -2 -2  1  6 -3  0  2 -1 -1 -3 -4 -1 -3 -3 -1  0 -1 -4 -3 -3  4  1 -1 -4
C  0 -3 -3 -3  9 -3 -4 -3 -3 -1 -1 -3 -1 -2 -3 -1 -1 -2 -2 -1 -3 -3 -2 -4
Q -1  1  0  0 -3  5  2 -2  0 -3 -2  1  0 -3 -1  0 -1 -2 -1 -2  0  3 -1 -4
E -1  0  0  2 -4  2  5 -2  0 -3 -3  1 -2 -3 -1  0 -1 -3 -2 -2  1  4 -1 -4
G  0 -2  0 -1 -3 -2 -2  6 -2 -4 -4 -2 -3 -3 -2  0 -2 -2 -3 -3 -1 -2 -1 -4
H -2  0  1 -1 -3  0  0 -2  8 -3 -3 -1 -2 -1 -2 -1 -2 -2  2 -3  0  0 -1 -4
I -1 -3 -3 -3 -1 -3 -3 -4 -3  4  2 -3  1  0 -3 -2 -1 -3 -1  3 -3 -3 -1 -4
L -1 -2 -3 -4 -1 -2 -3 -4 -3  2  4 -2  2  0 -3 -2 -1 -2 -1  1 -4 -3 -1 -4
K -1  2  0 -1 -3  1  1 -2 -1 -3 -2  5 -1 -3 -1  0 -1 -3 -2 -2  0  1 -1 -4
M -1 -1 -2 -3 -1  0 -2 -3 -2  1  2 -1  5  0 -2 -1 -1 -1 -1  1 -3 -1 -1 -4
F -2 -3 -3 -3 -2 -3 -3 -3 -1  0  0 -3  0  6 -4 -2 -2  1  3 -1 -3 -3 -1 -4
P -1 -2 -2 -1 -3 -1 -1 -2 -2 -3 -3 -1 -2 -4  7 -1 -1 -4 -3 -2 -2 -1 -2 -4
S  1 -1  1  0 -1  0  0  0 -1 -2 -2  0 -1 -2 -1  4  1 -3 -2 -2  0  0  0 -4
T  0 -1  0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1  1  5 -2 -2  0 -1 -1  0 -4
W -3 -3 -4 -4 -2 -2 -3 -2 -2 -3 -2 -3 -1  1 -4 -3 -2 11  2 -3 -4 -3 -2 -4
Y -2 -2 -2 -3 -2 -1 -2 -3  2 -1 -1 -2 -1  3 -3 -2 -2  2  7 -1 -3 -2 -1 -4
V  0 -3 -3 -3 -1 -2 -2 -3 -3  3  1 -2  1 -1 -2 -2  0 -3 -1  4 -3 -2 -1 -4
B -2 -1  3  4 -3  0  1 -1  0 -3 -4  0 -3 -3 -2  0 -1 -4 -3 -3  4  1 -1 -4
Z -1  0  0  1 -3  3  4 -2  0 -3 -3  1 -1 -3 -1  0 -1 -3 -2 -2  1  4 -1 -4
X  0 -1 -1 -1 -2 -1 -1 -1 -1 -1 -1 -1 -1 -1 -2  0  0 -2 -1 -1 -1 -1 -1 -4
* -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4  1
"""
        ),
        index_col=0,
        header=0,
        sep="\s+",
    ),
    expected_value=-0.5209,
)


DEFAULT_TOKENS = "-ACDEFGHIKLMNPQRSTVWYX"
DEFAULT_AA_TO_INT = dict(zip(DEFAULT_TOKENS, range(len(DEFAULT_TOKENS))))


def get_blosum62_data(
    aa_to_int: Optional[dict[str, int]] = None,
    gaps_as_stars: bool = False,
) -> TokenizedSubstitutionMatrix:
    aa_to_int = (DEFAULT_AA_TO_INT if aa_to_int is None else aa_to_int).copy()

    mat = BLOSUM62.mat.copy()
    if gaps_as_stars:
        if "-" in aa_to_int and "*" in aa_to_int:
            raise ValueError(
                "Cannot have both gaps and stars in `aa_to_int` if `gaps_as_stars` is True"
            )
        aa_to_int["*"] = aa_to_int.pop("-")
    else:
        mat.loc["-"] = 0
        mat.loc[:, "-"] = 0
    aa_to_int = dict(sorted(aa_to_int.items(), key=lambda x: x[1]))

    mat = mat.loc[list(aa_to_int), list(aa_to_int)]

    return TokenizedSubstitutionMatrix(
        name=BLOSUM62.name,
        mat=torch.tensor(mat.to_numpy(), dtype=torch.get_default_dtype()),
        expected_value=BLOSUM62.expected_value,
    )