In [27]:
from lz78 import Sequence, LZ78SPA, CharacterMap
import pandas as pd
import numpy as np
from multiprocessing import Pool, Value, Lock
from sys import stdout
from tqdm import tqdm

In [108]:
DATASET = "GUE/virus/covid"
EPOCHS = 1

In [109]:
train_path = f"{DATASET}/train.csv"
val_path = f"{DATASET}/dev.csv"
test_path = f"{DATASET}/test.csv"

In [110]:
train_data = pd.read_csv(train_path)
validation_data = pd.read_csv(val_path)
test_data = pd.read_csv(test_path)

ALPHABET_SIZE = 4
unique_labels = sorted(train_data['label'].unique())
n = len(unique_labels)

In [116]:
for i in train_data.index:
    train_data.loc[i, "sequence"] = "".join([x for x in train_data.loc[i, "sequence"] if x in "ACGT"])
for i in validation_data.index:
    validation_data.loc[i, "sequence"] =  "".join([x for x in validation_data.loc[i, "sequence"] if x in "ACGT"])
for i in test_data.index:
    test_data.loc[i, "sequence"] =  "".join([x for x in test_data.loc[i, "sequence"] if x in "ACGT"])

In [117]:
def train_spa_oneIter(data: pd.DataFrame, spas: list[LZ78SPA]):
    grouped_data = data.groupby("label")["sequence"]
    for (label, data) in grouped_data:
        for seq in data:
            spas[label].reset_state()
            spas[label].train_on_block(Sequence(seq, charmap=CharacterMap("ACGT")))

In [118]:
spas = [LZ78SPA(alphabet_size=4, gamma=1, compute_training_loss=False) for _ in range(n)]

In [135]:
stdout.flush()
for _ in tqdm(range(EPOCHS)):
    train_spa_oneIter(train_data, spas)

100%|██████████| 1/1 [00:12<00:00, 12.69s/it]


In [136]:
for i in range(n):
    spas[i].set_inference_config(
        gamma=1,
        lb=1e-5,
        ensemble_type="entropy",
        ensemble_n=10,
        backshift_parsing=True,
        backshift_ctx_len=20,
        backshift_break_at_phrase=True
    )

In [137]:
def classify(data: pd.DataFrame, spas: list[LZ78SPA], n_threads=64):
    labels = data["label"]
    data = [Sequence(seq, charmap=CharacterMap("ACGT")) for seq in data["sequence"]]
    log_losses = np.zeros((len(spas), len(data)))
    for i in range(len(spas)):
        log_losses[i, :] = [res["avg_log_loss"] for res in spas[i].compute_test_loss_parallel(data, num_threads=n_threads)]
    classes = np.argmin(log_losses, axis=0)
    return (classes == labels).sum() / len(labels)

In [138]:
classify(test_data, spas)

np.float64(0.5093804537521816)