In [1]:
import pandas as pd
from tqdm import tqdm
from dataclasses import dataclass
from sys import stdout
from multiprocessing import Pool
import numpy as np
from time import time
from lz78 import spa_from_file, LZ78SPA, Sequence, LZ78Classifier, CharacterMap, classifier_from_files

In [2]:
path = "best_spas/minimal"
prefix = "virus_covid"
test_path = "GUE/virus/covid/test.csv"
n_classes = 9

In [3]:
stdout.flush()
spas = classifier_from_files([f"{path}/{prefix}_{i}.bin" for i in range(n_classes)])

In [4]:
test_data = pd.read_csv(test_path)
for i in range(len(test_data)):
    test_data.loc[i, "sequence"] =  "".join([x for x in test_data.loc[i, "sequence"] if x in "ACGT"])

In [5]:
spas.set_inference_config(
    ensemble_n=10,
    ensemble_type="entropy",
    backshift_ctx_len=20
)

In [6]:
def classify_singlethread(data: pd.DataFrame, spas: LZ78Classifier):
    labels = data["label"]
    classes = np.zeros(len(labels))
    for (i, seq) in enumerate(tqdm(data["sequence"])):
        classes[i] = spas.classify(Sequence(seq, charmap=CharacterMap("ACGT")))

    return (classes == labels).sum() / len(labels)

In [12]:
stdout.flush()
tic = time()
print(classify_singlethread(test_data, spas))
elapsed = time() - tic

100%|██████████| 9168/9168 [03:34<00:00, 42.69it/s]

0.7219677137870855





In [15]:
elapsed / len(test_data["sequence"]) / len(test_data["sequence"][0]) * 1e6

23.45038106284405