In [1]:
import os
os.chdir("/app/")

In [2]:
import pytorch_lightning
from pathlib import Path
from yaml import safe_load as load_yaml
from omegaconf import DictConfig
import torch
from hydra.utils import instantiate
import pandas as pd
from ptls.frames import PtlsDataModule
from torch.utils.data import DataLoader
from torcheval.metrics.functional import multiclass_f1_score
from scipy.optimize import minimize_scalar
from ptls.data_load.datasets.memory_dataset import MemoryMapDataset
from ptls.data_load.iterable_processing import SeqLenFilter
from ptls.data_load import PaddedBatch
from ptls.data_load.utils import collate_feature_dict
from src.preprocessing.churn_preproc import preprocessing
from src.datamodules.autoencoder import MyColesDataset
from src.networks.decoders import LSTMCellDecoder
from src.networks.modules import VanillaAE

In [3]:
ckpt_path = next(
    Path("transaction_data_generation/3iqys3po/checkpoints/").glob("*.ckpt")
)

with open("wandb/run-20230925_093441-3iqys3po/files/config.yaml", "r") as f:
    cfg = DictConfig(load_yaml(f.read()))

encoder = instantiate(cfg["encoder"]["value"])
decoder = instantiate(cfg["decoder"]["value"])

mcc_column: str = cfg["dataset"]["value"]["mcc_column"]
amt_column: str = cfg["dataset"]["value"]["amt_column"]

dataset = MemoryMapDataset(
    preprocessing(cfg["dataset"]["value"]),
    [SeqLenFilter(20, 40)]
)

module: VanillaAE = instantiate(cfg["module"]["value"], _recursive_=False)(
    encoder=encoder,
    decoder=decoder,
    amnt_col=amt_column,
    mcc_col=mcc_column,
)

module.load_state_dict(torch.load(ckpt_path)["state_dict"])

<All keys matched successfully>

In [4]:
trainer = pytorch_lightning.Trainer(accelerator="gpu", devices=1)
preds = trainer.predict(module, dataloaders=DataLoader(
    dataset,
    collate_fn=collate_feature_dict,
    batch_size=1
))

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Predicting: 0it [00:00, ?it/s]

In [5]:
pred_mccs = torch.concat([row[0][0] for row in preds])
orig_mccs = torch.concat([row["mcc_code"] for row in dataset])

In [9]:

multiclass_f1_score(
    pred_mccs[:, 1:].argmax(1),
    orig_mccs - 1,
    num_classes=344,
    average="micro"
)

tensor(0.3594)