In [1]:
import os
import json
from stelaro.data import format, ncbi
import numpy as np
import random

DATA_DIRECTORY = "../data/"
SUMMARY_DIRECTORY = DATA_DIRECTORY + "ncbi_genome_summaries/"
NCBI_TAXONOMY_DIRECTORY = DATA_DIRECTORY + "ncbi_taxonomy/"
BERTAX_DIRECTORY = DATA_DIRECTORY + "bertax/final/"
BERTAX_DATASET_DIRECTORY = BERTAX_DIRECTORY + "final_model_data_seperate_fasta_per_superkingdom/data/fass2/projects/fk_read_classification/dna_sequences/fragments/genomic_fragments_80_big/"
BERTAX_DOMAINS = (
    "Archaea_db.fa",
    "Bacteria_db.fa",
    "Eukaryota_db.fa",
    "Viruses_db.fa",
)
BERTAX_STATISTIC_DIRECTORY = BERTAX_DIRECTORY + "statistics/"
SEQUENCE_LENGTH = 1500
N_MINIMUM_READS_PER_TAXON = 10_000


def mkdir(path: str) -> None:
    """Create a directory if it does not exist."""
    if not os.path.exists(path):
        os.makedirs(path)


mkdir(BERTAX_STATISTIC_DIRECTORY)

# One-Shot

In [7]:
BERTAX_TRAIN = BERTAX_DIRECTORY + "train/"
BERTAX_VALIDATION = BERTAX_DIRECTORY + "validation/"
BERTAX_TEST = BERTAX_DIRECTORY + "test/"

import numpy as np
from torch.utils.data import DataLoader
import json
from torch.optim import Adam
import matplotlib.pyplot as plt
from time import time
from torch import nn

from stelaro.data import format
from stelaro import models

LENGTH = 1500
BATCH_SIZE = 128


train_data = DataLoader(
    models.SyntheticTetramerDataset(BERTAX_TRAIN),
    batch_size=BATCH_SIZE,
    shuffle=True
)
validation_data = DataLoader(
    models.SyntheticTetramerDataset(BERTAX_VALIDATION),
    batch_size=BATCH_SIZE,
    shuffle=True
)
test_data = DataLoader(
    models.SyntheticTetramerDataset(BERTAX_TEST),
    batch_size=BATCH_SIZE,
    shuffle=True
)

with open(BERTAX_DIRECTORY + "statistics/map.json", "r") as f:
    mapping = json.load(f)


def benchmark(
        classifier: models.BaseClassifier,
        name: str,
        max_epochs: int = 20
    ):
    parameters = classifier.get_parameters()
    if parameters:
        optimizer = Adam(classifier.get_parameters(), lr=0.001)
        total_params = sum(param.numel() for param in parameters)
        print(f"Number of parameters: {total_params:_}")
    else:
        optimizer = None
    a = time()
    losses, f1, validation_losses = classifier.train(
        train_data,
        validation_data,
        optimizer,
        max_n_epochs=max_epochs,
        patience=3,
        loss_function=nn.CrossEntropyLoss(),
        loss_function_type="supervised",
        evaluation_interval=5000,
    )
    b = time()
    print(f"Training took {(b - a):.3f} s.")
    if losses:
        fig, ax = plt.subplots(1, 2, figsize=(12, 4))
        x = list(range(len(losses)))
        ax[0].plot(x, losses, label="Training")
        ax[0].plot(x, validation_losses, label="Validation")
        ax[0].set(xlabel='Epochs', ylabel='Loss')
        ax[0].set_title("Normalized Loss Against Epochs")
        ax[0].legend()
        ax[1].set(xlabel='Epochs', ylabel="f1")
        ax[1].set_title("F1 Score")
        r = 0
        for f in f1:
            ax[1].plot(x, f, label=f'Rank {r}')
            r += 1
        ax[1].legend()
        fig.suptitle(f"Classification Training for {name}")
        plt.show()
    result = models.evaluate(classifier, test_data, "cuda", mapping)
    rounded_result = [float(f"{r:.5}") for r in result]
    print(f"F1: {rounded_result}")
    result = models.evaluate_precision(classifier, test_data, "cuda", mapping)
    rounded_result = [float(f"{r:.5}") for r in result]
    print(f"Precision: {rounded_result}")
    return classifier

In [None]:
from torch.optim import Adam
import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig
from mamba_ssm import Mamba
from torch import nn


# 256, 128, 2, 16, 8, 2, 0.1, batch 128
# F1: [0.94682, 0.88288]
# Precision: [0.94012, 0.88521]

# 256, 64, 2, 16, 8, 2, 0.1, batch 128
# F1: [0.92456, 0.8333]
# Precision: [0.93198, 0.84133]

# 256, 128, 2, 16, 8, 2, 0.1, batch 64
# F1: [0.94531, 0.88431]
# Precision: [0.94236, 0.88785]

# 256, 128, 2, 16, 8, 3, 0.1, batch 128
# F1: [0.95005, 0.88783]
# Precision: [0.94598, 0.89011]

# 256, 256, 2, 16, 8, 2, 0.1, batch 128

class MambaSequenceClassifier(nn.Module):
    def __init__(
        self,
        N: int,
        num_classes: int,
        vocab_size: int = 256,
        d_model: int = 128,
        n_layers: int = 2,
        d_state: int = 16,
        d_conv: int = 8,
        expand: int = 3,
        pooling = "mean",
        dropout: float = 0.1,
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        self.classifier = nn.Linear(d_model, num_classes)
        self.pooling = pooling

    def forward(self, x: torch.LongTensor) -> torch.Tensor:
        h = self.embedding(x).to(dtype=torch.get_default_dtype())
        for block in self.layers:
            h = block(h)   # each Mamba block returns [B, L, d_model]
        h = self.norm(h)
        pooled = h.mean(dim=1)  # [B, d_model]
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)  # [B, num_classes]
        return logits


classifier = models.Classifier(LENGTH // 4, mapping, "cuda", MambaSequenceClassifier, format.to_tetramers)
model = benchmark(
    classifier,
    "MAMBA",
    max_epochs=10,
)
matrix = models.confusion_matrix(model, test_data, "cuda", mapping)
plt.matshow(matrix)
plt.show()
np.set_printoptions(precision=3, suppress=True, linewidth=500, threshold=np.inf)
print(matrix)

Number of parameters: 391_857


  0%|          | 0/9196 [00:00<?, ?it/s]

 54%|█████▍    | 5001/9196 [10:22<6:48:27,  5.84s/it]

P: [0.85659, 0.71649]


100%|██████████| 9196/9196 [19:26<00:00,  7.89it/s]  


1/10 T loss: 3.36632. V loss: 2.25728. F1: [0.88917, 0.77206]. P: [0.88976, 0.78047] Patience: 3


 54%|█████▍    | 5002/9196 [11:19<4:42:11,  4.04s/it]

P: [0.91593, 0.81432]


100%|██████████| 9196/9196 [20:34<00:00,  7.45it/s]  


2/10 T loss: 1.89184. V loss: 1.68167. F1: [0.92152, 0.82997]. P: [0.91607, 0.83817] Patience: 3


 54%|█████▍    | 5002/9196 [11:33<4:48:19,  4.12s/it]

P: [0.93227, 0.84978]


100%|██████████| 9196/9196 [20:44<00:00,  7.39it/s]  


3/10 T loss: 1.48514. V loss: 1.51054. F1: [0.93244, 0.846]. P: [0.92662, 0.85752] Patience: 3


 54%|█████▍    | 5003/9196 [11:18<3:42:51,  3.19s/it]

P: [0.93364, 0.8667]


100%|██████████| 9196/9196 [20:23<00:00,  7.52it/s]  


4/10 T loss: 1.29058. V loss: 1.32135. F1: [0.93738, 0.86708]. P: [0.93545, 0.87031] Patience: 3


 54%|█████▍    | 5001/9196 [11:00<6:45:32,  5.80s/it]

P: [0.94699, 0.8788]


100%|██████████| 9196/9196 [20:09<00:00,  7.60it/s]  


5/10 T loss: 1.17214. V loss: 1.18121. F1: [0.94918, 0.88245]. P: [0.95223, 0.88366] Patience: 3


 54%|█████▍    | 5002/9196 [11:21<4:47:18,  4.11s/it]

P: [0.94922, 0.88268]


100%|██████████| 9196/9196 [20:09<00:00,  7.60it/s]  


6/10 T loss: 1.08505. V loss: 1.23265. F1: [0.94069, 0.87787]. P: [0.93593, 0.88513] Patience: 2


 49%|████▉     | 4548/9196 [09:18<09:31,  8.14it/s]


KeyboardInterrupt: 

In [9]:
result = models.evaluate(classifier, test_data, "cuda", mapping)
rounded_result = [float(f"{r:.5}") for r in result]
print(f"F1: {rounded_result}")
result = models.evaluate_precision(classifier, test_data, "cuda", mapping)
rounded_result = [float(f"{r:.5}") for r in result]
print(f"Precision: {rounded_result}")

F1: [0.95005, 0.88783]
Precision: [0.94598, 0.89011]


# Hierarchical

In [None]:
BERTAX_TRAIN = BERTAX_DIRECTORY + "train/"
BERTAX_VALIDATION = BERTAX_DIRECTORY + "validation/"
BERTAX_TEST = BERTAX_DIRECTORY + "test/"

from time import time
import json
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torch import nn

from stelaro import models
from stelaro.models import autoencoder, feedforward, transformer

LENGTH = 1500
BATCH_SIZE = 64

with open(BERTAX_DIRECTORY + "statistics/map.json", "r") as f:
    mapping = json.load(f)


train_data = DataLoader(
    models.SyntheticMultiLevelTetramerDataset(
        BERTAX_TRAIN,
        mapping,
        (),
        1,
        balance = True,
        other_factor = 0
    ),
    batch_size=BATCH_SIZE,
    shuffle=True
)
validation_data = DataLoader(
    models.SyntheticMultiLevelTetramerDataset(
        BERTAX_VALIDATION,
        mapping,
        (),
        1,
        balance = False,
        other_factor = 0
    ),
    batch_size=BATCH_SIZE,
    shuffle=True
)
test_data = DataLoader(
    models.SyntheticMultiLevelTetramerDataset(
        BERTAX_TEST,
        mapping,
        (),
        1,
        balance = False,
        other_factor = 0
    ),
    batch_size=BATCH_SIZE,
    shuffle=True
)
mapping = train_data.dataset.mapping


def benchmark(
        classifier: models.Classifier,
        name: str,
        max_epochs: int = 20,
        learning_rate: float = 0.001
    ):
    parameters = classifier.get_parameters()
    if parameters:
        optimizer = Adam(classifier.get_parameters(), lr=learning_rate)
        total_params = sum(param.numel() for param in parameters)
        print(f"Number of parameters: {total_params:_}")
    else:
        optimizer = None
    a = time()
    losses, f1, validation_losses = classifier.train(
        train_data,
        validation_data,
        optimizer,
        max_n_epochs=max_epochs,
        patience=3,
        loss_function=nn.CrossEntropyLoss(),
        loss_function_type="supervised",
        evaluation_interval=5000,
    )
    b = time()
    print(f"Training took {(b - a):.3f} s.")
    if losses:
        fig, ax = plt.subplots(1, 2, figsize=(12, 4))
        x = list(range(len(losses)))
        ax[0].plot(x, losses, label="Training")
        ax[0].plot(x, validation_losses, label="Validation")
        ax[0].set(xlabel='Epochs', ylabel='Loss')
        ax[0].set_title("Normalized Loss Against Epochs")
        ax[0].legend()
        ax[1].set(xlabel='Epochs', ylabel="f1")
        ax[1].set_title("F1 Score")
        r = 0
        for f in f1:
            ax[1].plot(x, f, label=f'Rank {r}')
            r += 1
        ax[1].legend()
        fig.suptitle(f"Classification Training for {name}")
        plt.show()
    result = models.evaluate(classifier, test_data, "cuda", mapping)
    rounded_result = [float(f"{r:.5}") for r in result]
    print(f"Test F1 score: {rounded_result}")

    result = models.evaluate_precision(classifier, test_data, "cuda", mapping)
    rounded_result = [float(f"{r:.5}") for r in result]
    print(f"Test precision score: {rounded_result}")

    return classifier