In [2]:
import os
import json
from stelaro.data import format, ncbi
import numpy as np
import random

DATA_DIRECTORY = "../data/"
SUMMARY_DIRECTORY = DATA_DIRECTORY + "ncbi_genome_summaries/"
NCBI_TAXONOMY_DIRECTORY = DATA_DIRECTORY + "ncbi_taxonomy/"
BERTAX_DIRECTORY = DATA_DIRECTORY + "bertax/final/"
BERTAX_DATASET_DIRECTORY = BERTAX_DIRECTORY + "final_model_data_seperate_fasta_per_superkingdom/data/fass2/projects/fk_read_classification/dna_sequences/fragments/genomic_fragments_80_big/"
BERTAX_DOMAINS = (
    "Archaea_db.fa",
    "Bacteria_db.fa",
    "Eukaryota_db.fa",
    "Viruses_db.fa",
)
BERTAX_STATISTIC_DIRECTORY = BERTAX_DIRECTORY + "statistics/"
PROCESSED_PRETRAINING_DATA = DATA_DIRECTORY + "bertax/pretraining/processed/"
SEQUENCE_LENGTH = 1500
N_MINIMUM_READS_PER_TAXON = 10_000


def mkdir(path: str) -> None:
    """Create a directory if it does not exist."""
    if not os.path.exists(path):
        os.makedirs(path)


mkdir(BERTAX_STATISTIC_DIRECTORY)

# One-Shot

In [3]:
BERTAX_TRAIN = BERTAX_DIRECTORY + "train/"
BERTAX_VALIDATION = BERTAX_DIRECTORY + "validation/"
BERTAX_TEST = BERTAX_DIRECTORY + "test/"

import numpy as np
from torch.utils.data import DataLoader
import json
from torch.optim import Adam
import matplotlib.pyplot as plt
from time import time
from torch import nn, exp

from stelaro.data import format
from stelaro import models

LENGTH = 1500
BATCH_SIZE = 128


train_data = DataLoader(
    models.SyntheticTetramerDataset(BERTAX_TRAIN),
    batch_size=BATCH_SIZE,
    shuffle=True
)
validation_data = DataLoader(
    models.SyntheticTetramerDataset(BERTAX_VALIDATION),
    batch_size=BATCH_SIZE,
    shuffle=True
)
test_data = DataLoader(
    models.SyntheticTetramerDataset(BERTAX_TEST),
    batch_size=BATCH_SIZE,
    shuffle=True
)

with open(BERTAX_DIRECTORY + "statistics/map.json", "r") as f:
    mapping = json.load(f)


def benchmark(
        classifier: models.BaseClassifier,
        name: str,
        max_epochs: int = 20,
        loss_fn = nn.CrossEntropyLoss(),
        evaluation_interval=5000,
        learning_rate: float = 0.001
    ):
    parameters = classifier.get_parameters()
    if parameters:
        optimizer = Adam(classifier.get_parameters(), lr=0.001)
        total_params = sum(param.numel() for param in parameters)
        print(f"Number of parameters: {total_params:_}")
    else:
        optimizer = None
    a = time()
    losses, f1, validation_losses = classifier.train(
        train_data,
        validation_data,
        optimizer,
        max_n_epochs=max_epochs,
        patience=3,
        loss_function=loss_fn,
        loss_function_type="supervised",
        evaluation_interval=evaluation_interval
    )
    b = time()
    print(f"Training took {(b - a):.3f} s.")
    if losses:
        fig, ax = plt.subplots(1, 2, figsize=(12, 4))
        x = list(range(len(losses)))
        ax[0].plot(x, losses, label="Training")
        ax[0].plot(x, validation_losses, label="Validation")
        ax[0].set(xlabel='Epochs', ylabel='Loss')
        ax[0].set_title("Normalized Loss Against Epochs")
        ax[0].legend()
        ax[1].set(xlabel='Epochs', ylabel="f1")
        ax[1].set_title("F1 Score")
        r = 0
        for f in f1:
            ax[1].plot(x, f, label=f'Rank {r}')
            r += 1
        ax[1].legend()
        fig.suptitle(f"Classification Training for {name}")
        plt.show()
    result = models.evaluate(classifier, test_data, "cuda", mapping)
    rounded_result = [float(f"{r:.5}") for r in result]
    print(f"F1: {rounded_result}")
    result = models.evaluate_precision(classifier, test_data, "cuda", mapping)
    rounded_result = [float(f"{r:.5}") for r in result]
    print(f"Precision: {rounded_result}")
    return classifier

In [None]:
from torch.optim import Adam
import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig
from mamba_ssm import Mamba
from torch import nn

# 256, 128, 2, 16, 8, 2, 0.1, batch 128, noClip, CEL
# F1: [0.94682, 0.88288]
# Precision: [0.94012, 0.88521]

# 256, 64, 2, 16, 8, 2, 0.1, batch 128, noClip, CEL
# F1: [0.92456, 0.8333]
# Precision: [0.93198, 0.84133]

# 256, 128, 2, 16, 8, 2, 0.1, batch 64, noClip, CEL
# F1: [0.94531, 0.88431]
# Precision: [0.94236, 0.88785]

# 256, 128, 2, 16, 8, 3, 0.1, batch 128, noClip, CEL*
# F1: [0.95005, 0.88783]
# Precision: [0.94598, 0.89011]

# 256, 256, 2, 16, 8, 2, 0.1, batch 64, 5 epochs, noClip, CEL
# F1: [0.95149, 0.88721]
# Precision: [0.95105, 0.88929]

# 256, 128, 2, 16, 16, 2, 0.1, batch 128, noClip, CEL
# F1: [0.95237, 0.88463]
# Precision: [0.95195, 0.88635]

# 256, 128, 2, 16, 16, 2, 0.1, batch 128, clip, focal mean
# F1: [0.94285, 0.87642]
# Precision: [0.93603, 0.87868]

# 256, 128, 2, 16, 16, 2, 0.1, batch 128, clip, focal sum
# F1: [0.94163, 0.87185]
# Precision: [0.93829, 0.87553]

# 256, 128, 2, 16, 16, 2, 0.1, batch 128, clip, penalty sum
# F1: [0.91501, 0.68972]
# Precision: [0.90453, 0.7091]

# 256, 128, 2, 16, 8, 2, 0.1, batch 128, noClip, focal sum
# F1: [0.93204, 0.85329]
# Precision: [0.93095, 0.85767]

# 256, 128, 3, 16, 8, 2, 0.1, batch 128, noClip, CEL
# F1: [0.90227, 0.80133]
# Precision: [0.89234, 0.80588]


class MambaSequenceClassifier(nn.Module):
    def __init__(
        self,
        N: int,
        num_classes: int,
        vocab_size: int = 256,
        d_model: int = 128,
        n_layers: int = 3,
        d_state: int = 16,
        d_conv: int = 4,
        expand: int = 2,
        pooling = "mean",
        dropout: float = 0.1,
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        self.classifier = nn.Linear(d_model, num_classes)
        self.pooling = pooling

    def forward(self, x: torch.LongTensor) -> torch.Tensor:
        h = self.embedding(x).to(dtype=torch.get_default_dtype())
        # h = h.to(torch.float16)
        for block in self.layers:
            h = block(h)   # each Mamba block returns [B, L, d_model]
        h = self.norm(h)
        pooled = h.mean(dim=1)  # [B, d_model]
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)  # [B, num_classes]
        return logits


classifier = models.Classifier(
    LENGTH // 4,
    mapping,
    "cuda",
    MambaSequenceClassifier,
    format.to_tetramers,
    True
)
# classifier.model = classifier.model.to(torch.float16)

import torch.nn.functional as F
def focal_loss(inputs, targets, alpha=1, gamma=2):
    num_classes = inputs.shape[1]
    targets = F.one_hot(targets, num_classes=num_classes).float()
    bce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
    pt = exp(-bce_loss)  # Convert BCE loss to probability
    focal_loss = alpha * (1 - pt) ** gamma * bce_loss  # Apply focal adjustment
    return focal_loss.sum()


def create_penalty_matrix(mapping):
    d = len(mapping)
    t = torch.zeros((d, d))
    n_ranks = len(mapping['0'])
    for i in mapping:
        for j in range(d):
            j = str(j)
            union_length = 0
            for a, b in zip(mapping[i], mapping[j]):
                if a == b:
                    union_length += 1
                else:
                    break
            penalty = ((n_ranks - union_length) / n_ranks) ** 0.5
            t[int(i), int(j)] = penalty
    return t
penalty_matrix = create_penalty_matrix(mapping).to("cuda")
def penalty_cross_entropy(y_pred, y_true):
    probs = F.softmax(y_pred, dim=-1)  # [B, M]
    penalty_rows = penalty_matrix[y_true]  # [B, M]
    loss = -torch.sum((1 - penalty_rows) * torch.log(probs + 1e-12), dim=-1)
    return loss.sum()
def hierarchical(input, targets):
    return models.penalized_cross_entropy(input, targets, penalty_matrix)


model = benchmark(
    classifier,
    "MAMBA",
    max_epochs=10,
    loss_fn=nn.CrossEntropyLoss()
)
matrix = models.confusion_matrix(model, test_data, "cuda", mapping)
plt.matshow(matrix)
plt.show()
np.set_printoptions(precision=3, suppress=True, linewidth=500, threshold=np.inf)
print(matrix)

Number of parameters: 388_785


 22%|██▏       | 2003/9196 [04:03<6:20:14,  3.17s/it] 

P: [0.57526, 0.22857]


 44%|████▎     | 4002/9196 [08:44<5:56:33,  4.12s/it]

P: [0.65106, 0.35171]


 65%|██████▌   | 6001/9196 [13:25<5:32:00,  6.23s/it]

P: [0.69989, 0.45002]


 87%|████████▋ | 8003/9196 [17:58<1:02:40,  3.15s/it]

P: [0.72618, 0.48692]


100%|██████████| 9196/9196 [20:25<00:00,  7.51it/s]  


1/10 T loss: 6.61341. V loss: 4.84294. F1: [0.7473, 0.50389]. P: [0.75105, 0.53675] Patience: 3


 22%|██▏       | 2003/9196 [04:20<6:17:59,  3.15s/it] 

P: [0.78049, 0.55876]


 44%|████▎     | 4003/9196 [08:41<4:33:56,  3.17s/it]

P: [0.79384, 0.58838]


 65%|██████▌   | 6003/9196 [13:06<2:47:26,  3.15s/it]

P: [0.80103, 0.60817]


 87%|████████▋ | 8003/9196 [17:29<1:02:31,  3.14s/it]

P: [0.80639, 0.63857]


100%|██████████| 9196/9196 [19:57<00:00,  7.68it/s]  


2/10 T loss: 4.13323. V loss: 3.68842. F1: [0.80875, 0.63045]. P: [0.79989, 0.64341] Patience: 3


 22%|██▏       | 2003/9196 [04:27<6:15:09,  3.13s/it] 

P: [0.82103, 0.64781]


 44%|████▎     | 4003/9196 [08:48<4:28:45,  3.11s/it]

P: [0.82164, 0.66638]


 65%|██████▌   | 6003/9196 [13:08<2:39:19,  2.99s/it]

P: [0.83035, 0.67798]


 87%|████████▋ | 8003/9196 [17:25<1:00:39,  3.05s/it]

P: [0.8352, 0.68557]


100%|██████████| 9196/9196 [19:49<00:00,  7.73it/s]  


3/10 T loss: 3.35514. V loss: 3.08872. F1: [0.84296, 0.69312]. P: [0.83396, 0.70252] Patience: 3


 22%|██▏       | 2002/9196 [26:50<4:12:30,  2.11s/it]   

P: [0.85194, 0.70102]


 44%|████▎     | 4003/9196 [30:21<2:18:18,  1.60s/it]

P: [0.85008, 0.70382]


 65%|██████▌   | 6003/9196 [34:18<2:43:12,  3.07s/it]

P: [0.84822, 0.71657]


 87%|████████▋ | 8003/9196 [38:46<1:03:00,  3.17s/it]

P: [0.85833, 0.72639]


100%|██████████| 9196/9196 [41:19<00:00,  3.71it/s]  


4/10 T loss: 2.91102. V loss: 2.77387. F1: [0.86056, 0.7215]. P: [0.85924, 0.72444] Patience: 3


 22%|██▏       | 2001/9196 [04:37<11:33:42,  5.78s/it]

P: [0.86929, 0.73768]


 44%|████▎     | 4002/9196 [09:18<5:57:19,  4.13s/it] 

P: [0.86356, 0.74011]


 65%|██████▌   | 6001/9196 [13:56<5:07:52,  5.78s/it]

P: [0.87261, 0.74341]


 87%|████████▋ | 8001/9196 [18:34<1:55:36,  5.80s/it]

P: [0.87052, 0.74563]


100%|██████████| 9196/9196 [21:10<00:00,  7.24it/s]  


5/10 T loss: 2.59942. V loss: 2.52903. F1: [0.87302, 0.7456]. P: [0.86696, 0.75424] Patience: 3


 22%|██▏       | 2002/9196 [04:41<8:11:37,  4.10s/it] 

P: [0.86764, 0.75526]


 44%|████▎     | 4002/9196 [09:22<5:54:18,  4.09s/it]

P: [0.88441, 0.76443]


 65%|██████▌   | 6002/9196 [14:03<3:38:29,  4.10s/it]

P: [0.87482, 0.76742]


 87%|████████▋ | 8002/9196 [18:45<1:21:35,  4.10s/it]

P: [0.88295, 0.77152]


100%|██████████| 9196/9196 [21:21<00:00,  7.18it/s]  


6/10 T loss: 2.36366. V loss: 2.31394. F1: [0.88486, 0.76933]. P: [0.88552, 0.77066] Patience: 3


 22%|██▏       | 2002/9196 [04:41<8:12:56,  4.11s/it] 

P: [0.88251, 0.77883]


 44%|████▎     | 4002/9196 [09:27<5:54:43,  4.10s/it]

P: [0.89717, 0.77904]


 65%|██████▌   | 6001/9196 [14:08<5:11:34,  5.85s/it]

P: [0.89149, 0.78166]


 87%|████████▋ | 8002/9196 [18:49<1:21:39,  4.10s/it]

P: [0.89935, 0.78397]


100%|██████████| 9196/9196 [21:25<00:00,  7.15it/s]  


7/10 T loss: 2.16952. V loss: 2.14979. F1: [0.89476, 0.78544]. P: [0.90153, 0.78891] Patience: 3


 22%|██▏       | 2002/9196 [04:41<8:12:25,  4.11s/it] 

P: [0.89517, 0.79322]


 44%|████▎     | 4002/9196 [09:22<5:58:38,  4.14s/it]

P: [0.89351, 0.79669]


 65%|██████▌   | 6002/9196 [14:03<3:39:00,  4.11s/it]

P: [0.90727, 0.79766]


 87%|████████▋ | 8002/9196 [18:49<1:25:25,  4.29s/it]

P: [0.90441, 0.80098]


100%|██████████| 9196/9196 [21:27<00:00,  7.14it/s]  


8/10 T loss: 2.00815. V loss: 2.05332. F1: [0.90019, 0.79307]. P: [0.89525, 0.79748] Patience: 3


 16%|█▌        | 1462/9196 [03:12<16:56,  7.61it/s]


KeyboardInterrupt: 

In [31]:
result = models.evaluate(classifier, test_data, "cuda", mapping)
rounded_result = [float(f"{r:.5}") for r in result]
print(f"F1: {rounded_result}")
result = models.evaluate_precision(classifier, test_data, "cuda", mapping)
rounded_result = [float(f"{r:.5}") for r in result]
print(f"Precision: {rounded_result}")

F1: [0.90227, 0.80133]
Precision: [0.89234, 0.80588]


In [12]:
print(mapping)
m = models.create_penalty_matrix(mapping)
print(m)

{'0': ['Archaea', 'Methanobacteriota'], '1': ['Archaea', 'Nitrososphaerota'], '2': ['Archaea', 'Thermoplasmatota'], '3': ['Archaea', 'Thermoproteota'], '4': ['Bacteria', 'Actinomycetota'], '5': ['Bacteria', 'Aquificota'], '6': ['Bacteria', 'Bacillota'], '7': ['Bacteria', 'Bacteroidota'], '8': ['Bacteria', 'Bdellovibrionota'], '9': ['Bacteria', 'Campylobacterota'], '10': ['Bacteria', 'Chlamydiota'], '11': ['Bacteria', 'Chlorobiota'], '12': ['Bacteria', 'Chloroflexota'], '13': ['Bacteria', 'Cyanobacteriota'], '14': ['Bacteria', 'Deinococcota'], '15': ['Bacteria', 'Fusobacteriota'], '16': ['Bacteria', 'Gemmatimonadota'], '17': ['Bacteria', 'Lentisphaerota'], '18': ['Bacteria', 'Mycoplasmatota'], '19': ['Bacteria', 'Myxococcota'], '20': ['Bacteria', 'Nitrospirota'], '21': ['Bacteria', 'Planctomycetota'], '22': ['Bacteria', 'Pseudomonadota'], '23': ['Bacteria', 'Rhodothermota'], '24': ['Bacteria', 'Spirochaetota'], '25': ['Bacteria', 'Thermodesulfobacteriota'], '26': ['Bacteria', 'Thermotog

# 1-NT Tokens

In [None]:
from torch.optim import Adam
import torch
from mamba_ssm import Mamba
from torch import nn

# 256, 64, 2, 16, 8, 2, 0.1, batch 128
# F1: [0.88082, 0.78494]
# P: [0.89169, 0.80668]

class MambaSequenceClassifier(nn.Module):
    def __init__(
        self,
        N: int,
        num_classes: int,
        vocab_size: int = 4,
        d_model: int = 64,
        n_layers: int = 2,
        d_state: int = 16,
        d_conv: int = 8,
        expand: int = 2,
        pooling = "mean",
        dropout: float = 0.1,
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        self.classifier = nn.Linear(d_model, num_classes)
        self.pooling = pooling

    def forward(self, x: torch.LongTensor) -> torch.Tensor:
        h = self.embedding(x).to(dtype=torch.get_default_dtype())
        for block in self.layers:
            h = block(h)   # each Mamba block returns [B, L, d_model]
        h = self.norm(h)
        pooled = h.mean(dim=1)  # [B, d_model]
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)  # [B, num_classes]
        return logits


classifier = models.Classifier(LENGTH, mapping, "cuda", MambaSequenceClassifier, format.to_digits)
model = benchmark(
    classifier,
    "MAMBA",
    max_epochs=10,
)
matrix = models.confusion_matrix(model, test_data, "cuda", mapping)
plt.matshow(matrix)
plt.show()
np.set_printoptions(precision=3, suppress=True, linewidth=500, threshold=np.inf)
print(matrix)

Number of parameters: 69_873


 54%|█████▍    | 5001/9196 [28:50<8:21:51,  7.18s/it]

P: [0.87169, 0.7341]


100%|██████████| 9196/9196 [1:36:26<00:00,  1.59it/s]     


Halting evaluation after 28032 data points.
1/10 T loss: 13.99071. V loss: 10.32095. F1: [0.86647, 0.74595]. P: [0.88074, 0.77692] Patience: 3


 54%|█████▍    | 5000/9196 [1:05:05<51:28,  1.36it/s]  

Halting evaluation after 38144 data points.


 54%|█████▍    | 5001/9196 [1:05:36<11:29:05,  9.86s/it]

P: [0.90584, 0.8145]


100%|██████████| 9196/9196 [1:57:10<00:00,  1.31it/s]   


Halting evaluation after 38144 data points.
2/10 T loss: 8.43410. V loss: 8.98794. F1: [0.88082, 0.78494]. P: [0.89169, 0.80668] Patience: 3


  1%|          | 94/9196 [01:11<1:54:40,  1.32it/s]


KeyboardInterrupt: 

# Pretraining

In [None]:
# Obtain pretraining data
DATA_DIRECTORY = "../data/"
BERTAX_DIRECTORY = DATA_DIRECTORY + "bertax/pretraining/"
BERTAX_DATASET_DIRECTORY = BERTAX_DIRECTORY + "pretraining_dataset/data/fass2/projects/fk_read_classification/dna_sequences/fragments/genomic_fragments_80/"
BERTAX_DOMAINS = (
    "Archaea_db.fa",
    "Bacteria_db.fa",
    "Eukaryota_db.fa",
    "Viruses_db.fa",
)

total = 0
for domain in BERTAX_DOMAINS:
    with open(BERTAX_DATASET_DIRECTORY + domain, "r") as f:
        for line in f:
            if line.startswith(">"):
                total += 1
    print(f"Domain: {domain}. Total: {total}")

x = np.zeros((total, 1500 // 4), dtype=np.uint8)
mkdir(PROCESSED_PRETRAINING_DATA)
i = 0
for domain in BERTAX_DOMAINS:
    with open(BERTAX_DATASET_DIRECTORY + domain, "r") as f:
        for line in f:
            if not line.startswith(">"):
                sequence = line.strip().upper()
                characters = set(sequence)
                if len(characters) == 4:
                    encoding = format.encode_tetramer(sequence)
                    x[i] = encoding
                    i += 1
np.save(PROCESSED_PRETRAINING_DATA + "/x.npy", x)

In [9]:
from torch.optim import Adam, AdamW
import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig
from mamba_ssm import Mamba
from torch import nn


# 128, 2, 16, 4, 2, mean, 0.1, noClip, pretraining: 10_000 -> 0.15 MLM
# F1: [0.9365, 0.86395]
# Precision: [0.93262, 0.87061]

# 128, 2, 16, 4, 2, mean, 0.1, noClip, pretraining: 20_000 -> 0.15 MLM
# Precision: [0.93845, 0.86496] (incomplete training - no difference with 10k steps)
# Lowest MLM loss: 0.1352
# Fine-tuning:
#   Epoch 1: F1: [0.88529, 0.75644]. P: [0.89593, 0.76523]
#   Epoch 2: F1: [0.91113, 0.81615]. P: [0.91365, 0.82102]

# 128, 2, 16, 4, 2, mean, 0.1, noClip, pretraining: 100_000 -> 0.15 MLM
# 1/10 T loss: 2.97148. V loss: 2.46173. F1: [0.87928, 0.75458]. P: [0.87739, 0.77189] Patience: 3
# 2/10 T loss: 1.97898. V loss: 1.84719. F1: [0.91845, 0.82066]. P: [0.92227, 0.82479] Patience: 3
# 5/10 T loss: 1.37992. V loss: 1.43371. F1: [0.9326, 0.86184]. P: [0.92561, 0.86369]

# 128, 2, 16, 4, 2, mean, 0.1, noClip, pretraining: 50_000 new -> 0.15 MLM
# 1/10 T loss: 3.11947. V loss: 2.29345. F1: [0.89625, 0.77654]. P: [0.89769, 0.78591] Patience: 3
# 2/10 T loss: 2.04814. V loss: 1.93529. F1: [0.91427, 0.81052]. P: [0.913, 0.81998] Patience: 3
# 3/10 T loss: 1.70350. V loss: 1.71731. F1: [0.92385, 0.83688]. P: [0.9214, 0.84447] Patience: 3

# 128, 2, 16, 4, 2, mean, 0.1, noClip, pretraining: 40_000 new -> 0.15 MLM, lr = 0.0001
# 1/10 T loss: 3.12365. V loss: 2.35835. F1: [0.8859, 0.76903]. P: [0.88096, 0.77557] Patience: 3
# 2/10 T loss: 1.99888. V loss: 1.85153. F1: [0.9151, 0.8191]. P: [0.9125, 0.82259] Patience: 3
# 3/10 T loss: 1.68860. V loss: 1.62224. F1: [0.92761, 0.84333]. P: [0.92996, 0.84681] Patience: 3

# 128, 3 (norms), 16, 4, 2, mean, 0.1, clip, pretraining: 10_000 new -> 0.2 MLM, lr = 0.0001


class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state, d_conv, expand):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.mamba = Mamba(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand
        )

    def forward(self, x):
        # Pre-norm + residual connection
        return x + self.mamba(self.norm(x))


class MambaSequenceClassifier(nn.Module):
    def __init__(
        self,
        N: int,
        num_classes: int,
        vocab_size: int = 256,
        d_model: int = 128,
        n_layers: int = 3,
        d_state: int = 16,
        d_conv: int = 4,
        expand: int = 2,
        pooling = "mean",
        dropout: float = 0.1,
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size + 1, d_model)
        self.layers = nn.ModuleList([
            # Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
            MambaBlock(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

        self.classifier = nn.Linear(d_model, num_classes)
        # self.classifier = nn.Sequential(
        #     nn.Conv1d(128, 256, kernel_size=5, padding=2),
        #     nn.ReLU(),
        #     nn.Flatten(),
        #     nn.Linear(N * 256, num_classes),
        # )

        self.pooling = pooling
        self.mlm_head = nn.Linear(d_model, vocab_size + 1)
        self.mlm_head.weight = self.embedding.weight

    def forward(self, x: torch.LongTensor, mlm=None) -> torch.Tensor:
        h = self.embedding(x)#.to(dtype=torch.get_default_dtype())
        for block in self.layers:
            h = block(h)   # each Mamba block returns [B, L, d_model]
        h = self.norm(h)
        if mlm is not None:
            logits = self.mlm_head(h)  # [B, L, vocab_size + 1]
            return logits
        else:
            # pooled = h[:, 0]
            pooled = h.mean(dim=1)  # [B, d_model]
            pooled = self.dropout(pooled)
            # pooled = pooled.permute(0, 2, 1)
            logits = self.classifier(pooled)  # [B, num_classes]
            return logits


classifier = models.Classifier(
    LENGTH // 4,
    mapping,
    "cuda",
    MambaSequenceClassifier,
    format.to_tetramers,
    True
)

print("Pretraining:")
optimizer = AdamW(classifier.get_parameters(), lr=0.001)
pretraining_data = DataLoader(
    models.SyntheticTetramerDataset(PROCESSED_PRETRAINING_DATA, labels=False),
    batch_size=BATCH_SIZE,
    shuffle=True
)
classifier.pretrain(
    pretraining_data,
    optimizer,
    10_000,
    256,
    patience=3,
    mlm_probability=0.25
)

print("Fine-tuning:")
model = benchmark(
    classifier,
    "MAMBA",
    max_epochs=10,
    loss_fn=nn.CrossEntropyLoss(),
    evaluation_interval=2000,
    learning_rate=0.0001
)
matrix = models.confusion_matrix(model, test_data, "cuda", mapping)
plt.matshow(matrix)
plt.show()
np.set_printoptions(precision=3, suppress=True, linewidth=500, threshold=np.inf)
print(matrix)

Fine-tuning:
Number of parameters: 389_938


 22%|██▏       | 2001/9196 [03:41<9:12:55,  4.61s/it]

P: [0.80326, 0.64982]


 44%|████▎     | 4001/9196 [07:23<6:36:08,  4.58s/it]

P: [0.85748, 0.70958]


 65%|██████▌   | 6001/9196 [11:08<4:09:57,  4.69s/it]

P: [0.88819, 0.7525]


 87%|████████▋ | 8001/9196 [15:02<1:41:36,  5.10s/it]

P: [0.89519, 0.77108]


100%|██████████| 9196/9196 [17:18<00:00,  8.86it/s]  


1/10 T loss: 3.19294. V loss: 2.32218. F1: [0.89688, 0.77193]. P: [0.90517, 0.78354] Patience: 3


 22%|██▏       | 2002/9196 [04:06<7:58:51,  3.99s/it] 

P: [0.89802, 0.80257]


 44%|████▎     | 4002/9196 [08:26<6:06:06,  4.23s/it]

P: [0.90943, 0.80979]


 65%|██████▌   | 6002/9196 [12:33<3:16:53,  3.70s/it]

P: [0.92629, 0.81882]


 87%|████████▋ | 8002/9196 [16:39<1:16:15,  3.83s/it]

P: [0.90996, 0.82479]


100%|██████████| 9196/9196 [19:04<00:00,  8.04it/s]  


2/10 T loss: 1.95762. V loss: 1.75689. F1: [0.92233, 0.82708]. P: [0.92497, 0.83141] Patience: 3


 22%|██▏       | 2002/9196 [03:59<7:22:55,  3.69s/it] 

P: [0.91814, 0.83478]


 44%|████▎     | 4002/9196 [08:01<5:21:18,  3.71s/it]

P: [0.93004, 0.84943]


 65%|██████▌   | 6001/9196 [12:09<4:51:13,  5.47s/it]

P: [0.93896, 0.84498]


 87%|████████▋ | 8002/9196 [16:12<1:16:59,  3.87s/it]

P: [0.93467, 0.85354]


100%|██████████| 9196/9196 [18:24<00:00,  8.32it/s]  


3/10 T loss: 1.57684. V loss: 1.56207. F1: [0.93461, 0.84664]. P: [0.93917, 0.85025] Patience: 3


 22%|██▏       | 2002/9196 [04:02<8:10:50,  4.09s/it] 

P: [0.93231, 0.85553]


 44%|████▎     | 4002/9196 [08:11<5:34:05,  3.86s/it]

P: [0.93944, 0.85901]


 50%|█████     | 4622/9196 [09:30<09:24,  8.10it/s]  


KeyboardInterrupt: 

# Frozen Initial Layers

In [None]:
from torch.optim import Adam, AdamW
import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig
from mamba_ssm import Mamba
from torch import nn


class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state, d_conv, expand):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.mamba = Mamba(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand
        )

    def forward(self, x):
        # Pre-norm + residual connection
        return x + self.mamba(self.norm(x))


class MambaSequenceClassifier(nn.Module):
    def __init__(
        self,
        N: int,
        num_classes: int,
        vocab_size: int = 256,
        d_model: int = 256,
        n_layers: int = 3,
        d_state: int = 16,
        d_conv: int = 4,
        expand: int = 2,
        pooling = "mean",
        dropout: float = 0.1,
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size + 1, d_model)
        self.first_layers = nn.ModuleList([
            MambaBlock(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
            for _ in range(n_layers)
        ])
        self.last_layers = nn.ModuleList([
            MambaBlock(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

        self.classifier = nn.Linear(d_model, num_classes)
        # self.classifier = nn.Sequential(
        #     nn.Conv1d(128, 256, kernel_size=5, padding=2),
        #     nn.ReLU(),
        #     nn.Flatten(),
        #     nn.Linear(N * 256, num_classes),
        # )

        self.pooling = pooling
        self.mlm_head = nn.Linear(d_model, vocab_size + 1)
        # self.mlm_head.weight = self.embedding.weight

    def forward(self, x: torch.LongTensor, mlm=None) -> torch.Tensor:
        h = self.embedding(x)#.to(dtype=torch.get_default_dtype())
        if mlm is not None:
            for block in self.first_layers:
                h = block(h)   # each Mamba block returns [B, L, d_model]
            for block in self.last_layers:
                h = block(h)   # each Mamba block returns [B, L, d_model]
            h = self.norm(h)
            logits = self.mlm_head(h)  # [B, L, vocab_size + 1]
            return logits
        else:
            for block in self.first_layers:
                h = block(h)   # each Mamba block returns [B, L, d_model]
            for block in self.last_layers:
                h = block(h)   # each Mamba block returns [B, L, d_model]
            h = self.norm(h)
            pooled = h.mean(dim=1)  # [B, d_model]
            pooled = self.dropout(pooled)
            # pooled = pooled.permute(0, 2, 1)
            logits = self.classifier(pooled)  # [B, num_classes]
            return logits

    def freeze_base(self):
        for block in self.first_layers:
            for param in block.parameters():
                param.requires_grad = False

    def unfreeze_base(self):
        for block in self.first_layers:
            for param in block.parameters():
                param.requires_grad = True


classifier = models.Classifier(
    LENGTH // 4,
    mapping,
    "cuda",
    MambaSequenceClassifier,
    format.to_tetramers,
    True
)

print("Pretraining:")
optimizer = AdamW(classifier.get_parameters(), lr=0.001)
pretraining_data = DataLoader(
    models.SyntheticTetramerDataset(PROCESSED_PRETRAINING_DATA, labels=False),
    batch_size=BATCH_SIZE,
    shuffle=True
)
classifier.pretrain(
    pretraining_data,
    optimizer,
    20_000,
    256,
    patience=3,
    mlm_probability=0.1
)
# classifier.model.freeze_base()

print("Fine-tuning:")
model = benchmark(
    classifier,
    "MAMBA",
    max_epochs=10,
    loss_fn=nn.CrossEntropyLoss(),
    evaluation_interval=5000,
    learning_rate=0.0001
)
matrix = models.confusion_matrix(model, test_data, "cuda", mapping)
plt.matshow(matrix)
plt.show()
np.set_printoptions(precision=3, suppress=True, linewidth=500, threshold=np.inf)
print(matrix)


# 128, 3 / 3, 16, 4, 2, mean, 0.1, clip, pretraining: None
# 1/10 T loss: 3.08970. V loss: 1.81286. F1: [0.91764, 0.82083]. P: [0.9162, 0.82467] Patience: 3
# 2/10 T loss: 1.49454. V loss: 1.34284. F1: [0.94197, 0.87123]. P: [0.94196, 0.87661] Patience: 3

# 128, 3 / 3, 16, 4, 2, mean, 0.1, clip, pretraining: 5k, freeze
# 1/10 T loss: 9.88273. V loss: 9.09895. F1: [0.35103, 0.10588]. P: [0.38429, 0.12499]

# 128, 2 / 2, 16, 4, 2, mean, 0.1, clip, pretraining: 10k, freeze
# 1/10 T loss: 3.26994. V loss: 2.42430. F1: [0.89472, 0.75593]. P: [0.89744, 0.76282] Patience: 3

# 128, 2 / 2, 16, 4, 2, mean, 0.1, clip, pretraining: 20k, freeze
# 1/10 T loss: 4.59330. V loss: 3.32864. F1: [0.84704, 0.67053]. P: [0.84721, 0.67978] Patience: 3
# 2/10 T loss: 2.78251. V loss: 2.51102. F1: [0.88506, 0.74861]. P: [0.88034, 0.75483] Patience: 3

# 128, 2 / 2, 16, 4, 2, mean, 0.1, clip, pretraining: 20k, no freeze, span
# 1/10 T loss: 3.30888. V loss: 2.14662. F1: [0.90234, 0.79144]. P: [0.90968, 0.79664] Patience: 3
# 2/10 T loss: 1.76234. V loss: 1.54726. F1: [0.93, 0.84993]. P: [0.92694, 0.85452] Patience: 3
# 3/10 T loss: 1.34409. V loss: 1.35847. F1: [0.93591, 0.86929]. P: [0.93794, 0.87497] Patience: 3

# 128, 2 / 2, 16, 4, 2, mean, 0.1, clip, pretraining: 10k, no freeze, span
# 1/10 T loss: 3.32805. V loss: 1.91426. F1: [0.91257, 0.81601]. P: [0.91242, 0.81589] Patience: 3
# 2/10 T loss: 1.67703. V loss: 1.49529. F1: [0.93111, 0.85618]. P: [0.9346, 0.86045] Patience: 3
# 3/10 T loss: 1.33147. V loss: 1.36263. F1: [0.93998, 0.86843]. P: [0.93879, 0.87145] Patience: 3

# 128, 2 / 2, 16, 4, 2, mean, 0.1, clip, pretraining: 40k, no freeze, span
# 1/10 T loss: 2.66859. V loss: 1.88050. F1: [0.91414, 0.81202]. P: [0.90473, 0.82252] Patience: 3
# 2/10 T loss: 1.59724. V loss: 1.41483. F1: [0.94162, 0.86189]. P: [0.94196, 0.86865] Patience: 3
# 3/10 T loss: 1.28868. V loss: 1.25275. F1: [0.94469, 0.8754]. P: [0.94138, 0.87864] Patience: 3

# 128, 2 / 2, 16, 4, 2, mean, 0.1, clip, pretraining: 20k, no freeze, all masked 15 %
# 1/10 T loss: 2.53448. V loss: 1.84904. F1: [0.91666, 0.8147]. P: [0.90713, 0.82688] Patience: 3
# 2/10 T loss: 1.52909. V loss: 1.40461. F1: [0.94339, 0.86156]. P: [0.9458, 0.86432] Patience: 3
# 3/10 T loss: 1.23919. V loss: 1.23687. F1: [0.95087, 0.8776]. P: [0.95401, 0.87988] Patience: 3

# 128, 3 / 3, 16, 4, 2, mean, 0.1, clip, pretraining: 20k, no freeze, all masked 10 %
# 1/10 T loss: 2.03367. V loss: 1.45920. F1: [0.93885, 0.85342]. P: [0.93636, 0.86178] Patience: 3
# 2/10 T loss: 1.20462. V loss: 1.13109. F1: [0.95273, 0.88658]. P: [0.95174, 0.88691] Patience: 3
# 3/10 T loss: 0.97474. V loss: 1.03495. F1: [0.95888, 0.89874]. P: [0.95798, 0.90307] Patience: 3
# 4/10 T loss: 0.84054. V loss: 0.89632. F1: [0.96534, 0.91211]. P: [0.96508, 0.91451] Patience: 3

Pretraining:


  0%|          | 0/19473 [00:00<?, ?it/s]

Step: 1. Epoch: 1. MLM loss: 5.7056. Patience: 3
Mean entropy: 5.3827


  0%|          | 2/19473 [00:04<10:46:28,  1.99s/it]

Correctly predicted masked tokens: 20 / 4757 (0.42043 %).


  1%|▏         | 250/19473 [00:53<1:04:14,  4.99it/s]

Step: 250. Epoch: 1. MLM loss: 5.3718. Patience: 3
Mean entropy: 5.3696
Correctly predicted masked tokens: 76 / 4741 (1.603 %).


  3%|▎         | 500/19473 [01:43<1:04:09,  4.93it/s]

Step: 500. Epoch: 1. MLM loss: 2.6515. Patience: 3
Mean entropy: 5.27
Correctly predicted masked tokens: 124 / 4776 (2.5963 %).


  5%|▌         | 1000/19473 [03:25<1:03:04,  4.88it/s]

Step: 1000. Epoch: 1. MLM loss: 2.6291. Patience: 3
Mean entropy: 5.024
Correctly predicted masked tokens: 171 / 4780 (3.5774 %).


  8%|▊         | 1500/19473 [05:07<1:01:34,  4.87it/s]

Step: 1500. Epoch: 1. MLM loss: 1.7436. Patience: 3
Mean entropy: 4.7029
Correctly predicted masked tokens: 93 / 4859 (1.914 %).


 10%|█         | 2000/19473 [06:50<59:42,  4.88it/s]  

Step: 2000. Epoch: 1. MLM loss: 1.3040. Patience: 3
Mean entropy: 4.4719
Correctly predicted masked tokens: 129 / 4840 (2.6653 %).


 13%|█▎        | 2500/19473 [08:32<57:57,  4.88it/s]

Step: 2500. Epoch: 1. MLM loss: 1.0409. Patience: 3
Mean entropy: 4.2244
Correctly predicted masked tokens: 121 / 4733 (2.5565 %).


 15%|█▌        | 3000/19473 [10:14<56:29,  4.86it/s]

Step: 3000. Epoch: 1. MLM loss: 0.8664. Patience: 3
Mean entropy: 3.8972
Correctly predicted masked tokens: 161 / 4830 (3.3333 %).


 18%|█▊        | 3500/19473 [11:57<54:48,  4.86it/s]

Step: 3500. Epoch: 1. MLM loss: 0.7419. Patience: 3
Mean entropy: 3.9608
Correctly predicted masked tokens: 130 / 4820 (2.6971 %).


 21%|██        | 4000/19473 [13:39<53:06,  4.86it/s]

Step: 4000. Epoch: 1. MLM loss: 0.6484. Patience: 3
Mean entropy: 3.7569
Correctly predicted masked tokens: 156 / 4883 (3.1948 %).


 23%|██▎       | 4500/19473 [15:22<51:43,  4.82it/s]

Step: 4500. Epoch: 1. MLM loss: 0.5759. Patience: 3
Mean entropy: 3.605
Correctly predicted masked tokens: 175 / 4820 (3.6307 %).


 26%|██▌       | 5000/19473 [17:06<50:14,  4.80it/s]

Step: 5000. Epoch: 1. MLM loss: 0.5182. Patience: 3
Mean entropy: 3.726
Correctly predicted masked tokens: 153 / 4887 (3.1308 %).


 28%|██▊       | 5500/19473 [18:54<48:09,  4.84it/s]  

Step: 5500. Epoch: 1. MLM loss: 0.4708. Patience: 3
Mean entropy: 3.7266
Correctly predicted masked tokens: 100 / 4751 (2.1048 %).


 31%|███       | 6000/19473 [20:37<46:25,  4.84it/s]

Step: 6000. Epoch: 1. MLM loss: 0.4314. Patience: 3
Mean entropy: 3.6123
Correctly predicted masked tokens: 122 / 4807 (2.538 %).


 33%|███▎      | 6500/19473 [22:20<44:39,  4.84it/s]

Step: 6500. Epoch: 1. MLM loss: 0.3980. Patience: 3
Mean entropy: 3.285
Correctly predicted masked tokens: 124 / 4875 (2.5436 %).


 36%|███▌      | 7000/19473 [24:03<42:54,  4.84it/s]

Step: 7000. Epoch: 1. MLM loss: 0.3693. Patience: 3
Mean entropy: 3.6564
Correctly predicted masked tokens: 140 / 4874 (2.8724 %).


 39%|███▊      | 7500/19473 [25:45<41:05,  4.86it/s]

Step: 7500. Epoch: 1. MLM loss: 0.3446. Patience: 3
Mean entropy: 3.7492
Correctly predicted masked tokens: 114 / 4879 (2.3365 %).


 41%|████      | 8000/19473 [27:28<39:32,  4.84it/s]

Step: 8000. Epoch: 1. MLM loss: 0.3231. Patience: 3
Mean entropy: 3.6094
Correctly predicted masked tokens: 111 / 4844 (2.2915 %).


 44%|████▎     | 8500/19473 [29:11<37:39,  4.86it/s]

Step: 8500. Epoch: 1. MLM loss: 0.3037. Patience: 3
Mean entropy: 3.7117
Correctly predicted masked tokens: 117 / 4814 (2.4304 %).


 46%|████▌     | 9000/19473 [30:54<36:00,  4.85it/s]

Step: 9000. Epoch: 1. MLM loss: 0.2869. Patience: 3
Mean entropy: 3.7194
Correctly predicted masked tokens: 121 / 4755 (2.5447 %).


 49%|████▉     | 9500/19473 [32:36<34:16,  4.85it/s]

Step: 9500. Epoch: 1. MLM loss: 0.2717. Patience: 3
Mean entropy: 3.745
Correctly predicted masked tokens: 175 / 4795 (3.6496 %).


 51%|█████▏    | 10000/19473 [34:19<32:29,  4.86it/s]

Step: 10000. Epoch: 1. MLM loss: 0.2581. Patience: 3
Mean entropy: 3.8241
Correctly predicted masked tokens: 115 / 4830 (2.381 %).


 54%|█████▍    | 10500/19473 [36:02<30:46,  4.86it/s]

Step: 10500. Epoch: 1. MLM loss: 0.2458. Patience: 3
Mean entropy: 3.9131
Correctly predicted masked tokens: 142 / 4738 (2.997 %).


 56%|█████▋    | 11000/19473 [37:44<29:05,  4.85it/s]

Step: 11000. Epoch: 1. MLM loss: 0.2345. Patience: 3
Mean entropy: 4.029
Correctly predicted masked tokens: 172 / 4826 (3.564 %).


 59%|█████▉    | 11500/19473 [39:27<27:24,  4.85it/s]

Step: 11500. Epoch: 1. MLM loss: 0.2243. Patience: 3
Mean entropy: 4.0439
Correctly predicted masked tokens: 143 / 4823 (2.965 %).


 62%|██████▏   | 12000/19473 [41:09<25:35,  4.87it/s]

Step: 12000. Epoch: 1. MLM loss: 0.2148. Patience: 3
Mean entropy: 4.1731
Correctly predicted masked tokens: 117 / 4691 (2.4941 %).


 64%|██████▍   | 12500/19473 [42:52<23:52,  4.87it/s]

Step: 12500. Epoch: 1. MLM loss: 0.2062. Patience: 3
Mean entropy: 4.2026
Correctly predicted masked tokens: 157 / 4774 (3.2886 %).


 67%|██████▋   | 13000/19473 [44:34<22:09,  4.87it/s]

Step: 13000. Epoch: 1. MLM loss: 0.1983. Patience: 3
Mean entropy: 4.3419
Correctly predicted masked tokens: 155 / 4744 (3.2673 %).


 69%|██████▉   | 13500/19473 [46:17<20:28,  4.86it/s]

Step: 13500. Epoch: 1. MLM loss: 0.1908. Patience: 3
Mean entropy: 4.4118
Correctly predicted masked tokens: 141 / 4811 (2.9308 %).


 72%|███████▏  | 14000/19473 [47:59<18:46,  4.86it/s]

Step: 14000. Epoch: 1. MLM loss: 0.1841. Patience: 3
Mean entropy: 4.5218
Correctly predicted masked tokens: 111 / 4871 (2.2788 %).


 74%|███████▍  | 14500/19473 [49:42<17:03,  4.86it/s]

Step: 14500. Epoch: 1. MLM loss: 0.1776. Patience: 3
Mean entropy: 4.6857
Correctly predicted masked tokens: 135 / 4721 (2.8596 %).


 77%|███████▋  | 15000/19473 [51:24<15:21,  4.85it/s]

Step: 15000. Epoch: 1. MLM loss: 0.1717. Patience: 3
Mean entropy: 4.7635
Correctly predicted masked tokens: 114 / 4780 (2.3849 %).


 80%|███████▉  | 15500/19473 [53:08<13:43,  4.82it/s]

Step: 15500. Epoch: 1. MLM loss: 0.1661. Patience: 3
Mean entropy: 4.8268
Correctly predicted masked tokens: 138 / 4964 (2.78 %).


 82%|████████▏ | 16000/19473 [54:51<12:00,  4.82it/s]

Step: 16000. Epoch: 1. MLM loss: 0.1610. Patience: 3
Mean entropy: 4.8371
Correctly predicted masked tokens: 164 / 4789 (3.4245 %).


 85%|████████▍ | 16500/19473 [56:34<10:17,  4.82it/s]

Step: 16500. Epoch: 1. MLM loss: 0.1560. Patience: 3
Mean entropy: 4.9073
Correctly predicted masked tokens: 122 / 4822 (2.5301 %).


 87%|████████▋ | 17000/19473 [58:17<08:32,  4.83it/s]

Step: 17000. Epoch: 1. MLM loss: 0.1515. Patience: 3
Mean entropy: 4.9364
Correctly predicted masked tokens: 116 / 4926 (2.3549 %).


 90%|████████▉ | 17500/19473 [1:00:00<06:47,  4.84it/s]

Step: 17500. Epoch: 1. MLM loss: 0.1471. Patience: 3
Mean entropy: 4.9816
Correctly predicted masked tokens: 142 / 4848 (2.929 %).


 92%|█████████▏| 18000/19473 [1:01:43<05:03,  4.85it/s]

Step: 18000. Epoch: 1. MLM loss: 0.1430. Patience: 3
Mean entropy: 5.0167
Correctly predicted masked tokens: 134 / 4762 (2.8139 %).


 95%|█████████▌| 18500/19473 [1:03:26<03:20,  4.85it/s]

Step: 18500. Epoch: 1. MLM loss: 0.1391. Patience: 3
Mean entropy: 4.9511
Correctly predicted masked tokens: 205 / 4736 (4.3285 %).


 98%|█████████▊| 19000/19473 [1:05:09<01:37,  4.84it/s]

Step: 19000. Epoch: 1. MLM loss: 0.1354. Patience: 3
Mean entropy: 5.014
Correctly predicted masked tokens: 145 / 4817 (3.0102 %).


100%|██████████| 19473/19473 [1:06:47<00:00,  4.86it/s]
  0%|          | 27/19473 [00:05<1:06:52,  4.85it/s]

Step: 19500. Epoch: 2. MLM loss: 0.1320. Patience: 3
Mean entropy: 5.0309
Correctly predicted masked tokens: 131 / 4651 (2.8166 %).


  3%|▎         | 527/19473 [01:48<1:05:09,  4.85it/s]

Step: 20000. Epoch: 2. MLM loss: 0.1286. Patience: 3
Mean entropy: 5.0233
Correctly predicted masked tokens: 125 / 4808 (2.5998 %).


  3%|▎         | 527/19473 [01:48<1:05:13,  4.84it/s]


Performed enough steps.
Fine-tuning:
Number of parameters: 773_042


  0%|          | 2/9196 [00:31<46:19:39, 18.14s/it]

P: [0.348, 0.02801]


 27%|██▋       | 2502/9196 [09:30<11:48:10,  6.35s/it]

P: [0.91342, 0.7974]


 54%|█████▍    | 5001/9196 [18:28<10:28:26,  8.99s/it]

P: [0.93815, 0.82926]


100%|██████████| 9196/9196 [32:42<00:00,  4.69it/s]   


1/10 T loss: 2.03367. V loss: 1.45920. F1: [0.93885, 0.85342]. P: [0.93636, 0.86178] Patience: 3


  0%|          | 2/9196 [00:29<44:39:10, 17.48s/it]

P: [0.93849, 0.86398]


 27%|██▋       | 2502/9196 [09:26<11:46:11,  6.33s/it]

P: [0.94704, 0.87765]


 54%|█████▍    | 5001/9196 [18:22<10:25:50,  8.95s/it]

P: [0.94981, 0.88286]


100%|██████████| 9196/9196 [32:33<00:00,  4.71it/s]   


2/10 T loss: 1.20462. V loss: 1.13109. F1: [0.95273, 0.88658]. P: [0.95174, 0.88691] Patience: 3


  0%|          | 2/9196 [00:29<44:35:34, 17.46s/it]

P: [0.94957, 0.88635]


 27%|██▋       | 2502/9196 [09:30<11:51:41,  6.38s/it]

P: [0.95439, 0.89453]


 54%|█████▍    | 5001/9196 [18:25<10:25:18,  8.94s/it]

P: [0.95805, 0.89627]


100%|██████████| 9196/9196 [33:05<00:00,  4.63it/s]   


3/10 T loss: 0.97474. V loss: 1.03495. F1: [0.95888, 0.89874]. P: [0.95798, 0.90307] Patience: 3


  0%|          | 2/9196 [00:30<45:41:51, 17.89s/it]

P: [0.9548, 0.9002]


 27%|██▋       | 2500/9196 [09:31<23:06,  4.83it/s]

Halting evaluation after 61056 data points.


 27%|██▋       | 2502/9196 [10:02<12:18:32,  6.62s/it]

P: [0.96177, 0.90755]


 54%|█████▍    | 5000/9196 [19:11<16:14,  4.31it/s]   

Halting evaluation after 53504 data points.


 54%|█████▍    | 5001/9196 [19:42<10:54:32,  9.36s/it]

P: [0.9697, 0.90802]


100%|██████████| 9196/9196 [35:01<00:00,  4.38it/s]   


4/10 T loss: 0.84054. V loss: 0.89632. F1: [0.96534, 0.91211]. P: [0.96508, 0.91451] Patience: 3


  0%|          | 3/9196 [00:30<24:32:20,  9.61s/it]

P: [0.96495, 0.91195]


 27%|██▋       | 2502/9196 [09:44<11:59:09,  6.45s/it]

P: [0.96592, 0.90785]


 29%|██▉       | 2688/9196 [10:22<25:07,  4.32it/s]   


KeyboardInterrupt: 

In [5]:
torch.save(classifier.model.state_dict(), "trained_models/mamba_v1/model.pt2")

# Hierarchical

In [None]:
BERTAX_TRAIN = BERTAX_DIRECTORY + "train/"
BERTAX_VALIDATION = BERTAX_DIRECTORY + "validation/"
BERTAX_TEST = BERTAX_DIRECTORY + "test/"

from time import time
import json
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torch import nn

from stelaro import models
from stelaro.models import autoencoder, feedforward, transformer

LENGTH = 1500
BATCH_SIZE = 64

with open(BERTAX_DIRECTORY + "statistics/map.json", "r") as f:
    mapping = json.load(f)


train_data = DataLoader(
    models.SyntheticMultiLevelTetramerDataset(
        BERTAX_TRAIN,
        mapping,
        (),
        1,
        balance = True,
        other_factor = 0
    ),
    batch_size=BATCH_SIZE,
    shuffle=True
)
validation_data = DataLoader(
    models.SyntheticMultiLevelTetramerDataset(
        BERTAX_VALIDATION,
        mapping,
        (),
        1,
        balance = False,
        other_factor = 0
    ),
    batch_size=BATCH_SIZE,
    shuffle=True
)
test_data = DataLoader(
    models.SyntheticMultiLevelTetramerDataset(
        BERTAX_TEST,
        mapping,
        (),
        1,
        balance = False,
        other_factor = 0
    ),
    batch_size=BATCH_SIZE,
    shuffle=True
)
mapping = train_data.dataset.mapping


def benchmark(
        classifier: models.Classifier,
        name: str,
        max_epochs: int = 20,
        learning_rate: float = 0.001
    ):
    parameters = classifier.get_parameters()
    if parameters:
        optimizer = Adam(classifier.get_parameters(), lr=learning_rate)
        total_params = sum(param.numel() for param in parameters)
        print(f"Number of parameters: {total_params:_}")
    else:
        optimizer = None
    a = time()
    losses, f1, validation_losses = classifier.train(
        train_data,
        validation_data,
        optimizer,
        max_n_epochs=max_epochs,
        patience=3,
        loss_function=nn.CrossEntropyLoss(),
        loss_function_type="supervised",
        evaluation_interval=5000,
    )
    b = time()
    print(f"Training took {(b - a):.3f} s.")
    if losses:
        fig, ax = plt.subplots(1, 2, figsize=(12, 4))
        x = list(range(len(losses)))
        ax[0].plot(x, losses, label="Training")
        ax[0].plot(x, validation_losses, label="Validation")
        ax[0].set(xlabel='Epochs', ylabel='Loss')
        ax[0].set_title("Normalized Loss Against Epochs")
        ax[0].legend()
        ax[1].set(xlabel='Epochs', ylabel="f1")
        ax[1].set_title("F1 Score")
        r = 0
        for f in f1:
            ax[1].plot(x, f, label=f'Rank {r}')
            r += 1
        ax[1].legend()
        fig.suptitle(f"Classification Training for {name}")
        plt.show()
    result = models.evaluate(classifier, test_data, "cuda", mapping)
    rounded_result = [float(f"{r:.5}") for r in result]
    print(f"Test F1 score: {rounded_result}")

    result = models.evaluate_precision(classifier, test_data, "cuda", mapping)
    rounded_result = [float(f"{r:.5}") for r in result]
    print(f"Test precision score: {rounded_result}")

    return classifier