In [None]:
from glob import glob
from sklearn import svm
import pandas as pd
import numpy as np

from tqdm import tqdm

import gc

import torch

from initialiser import build_dataset, build_model
from config import config
import polars as pl

from tqdm import tqdm

from util.persist import load_model

from torch.utils.data import DataLoader

from collections import defaultdict

import torch

from glob import glob

import torch.nn.functional as F

models_path = "/data/pfn/models/seeds_paper_1/*/*.ckpt"


def embed(model, tensor):
    
    embedding_context, _ = model.TS_encoder(
        (
            torch.linspace(-3, 1, tensor.shape[-2])
            .unsqueeze(0)
            .unsqueeze(0)
            .unsqueeze(-1)
            .repeat(tensor.shape[0], 1, 1, 1)
            .cuda()
        ),
        tensor.unsqueeze(1).repeat(1, 1, 1, 1).float().cuda(),
    )
    return model.avg_pool(model.proj(embedding_context))#[:, 0, :]

for model_path in glob(models_path):

    seed = model_path.split("/")[-2]

    device = "cuda"

    model_arch = build_model(
        width=8,
        config=config,
        n_outputs=100,
        use_mup_parametrization=True,
        load_base_shapes=True,
        build_base_shapes=True,
    )

    model = load_model(model_path, model_arch, device)

    model = model.to("cuda")

    model.eval()

    base_path = "/data/pfn/test_datasets/ucr/UCRArchive_2018/"

    current_set = []

    all_preds = None
    all_targets = None
    

    for train_df_name, test_df_name in tqdm(zip(
        glob(base_path + "/*/*_TRAIN.tsv"), glob(base_path + "/*/*_TEST.tsv")
    )):
        
        assert train_df_name.replace("_TRAIN.tsv", "") == test_df_name.replace("_TEST.tsv", ""), f"{train_df_name} != {test_df_name}"

        print(
            train_df_name.replace("_TRAIN.tsv", "")
        )
        
        gc.collect()
        torch.cuda.empty_cache()

        
        train_df = pd.read_csv(train_df_name, sep="\t", header=None)
        test_df = pd.read_csv(test_df_name, sep="\t", header=None)

        label_to_index = {label: i for i, label in enumerate(sorted(train_df[0].unique()))}

        train_df[0] = train_df[0].map(label_to_index)
        test_df[0] = test_df[0].map(label_to_index)

        with torch.no_grad():

            # for chunk

            train_tensor = (
                torch.from_numpy(train_df.to_numpy()[:, 1:])
                .cuda()
                .float()
                .unsqueeze(-1)
            ).nan_to_num(0)

            batch_size, sequence_length, feature_dim = train_tensor.shape

            chunk_size = min(128, int(128 * (786 / sequence_length)))

            print(train_tensor.shape, chunk_size)
    
            train_X = torch.cat(
                [
                    embed(
                        model,
                        xx
                    )
                    for xx in torch.split(train_tensor, chunk_size, dim=0) 
                ], 
                0 
            ).cpu() 
            
            train_y = train_df.to_numpy()[:, 0]

            test_tensor = (
                torch.from_numpy(test_df.to_numpy()[:, 1:])
                .cuda()
                .float()
                .unsqueeze(-1)
            ).nan_to_num(0)

            test_X = torch.cat(
                [
                    embed(
                        model,
                        xx
                    ) 
                    for xx in torch.split(test_tensor, chunk_size, dim=0) 
                ], 
                0 
            ).cpu() 

            test_y = test_df.to_numpy()[:, 0]
        
        clf = svm.SVC()
        clf.fit(train_X, train_y)

        pred = clf.predict(test_X)

        acc = (pred == test_y).mean()

        print(
            f"{test_df_name.split('/')[-1].split('.')[0]},{acc},{seed}\n",
            file=open(f"classification_results.csv", "a")
        )

        if all_preds is None:
            all_preds = pred
            all_targets = test_y
        else:
            all_preds = np.concatenate([all_preds, pred])
            all_targets = np.concatenate([all_targets, test_y])


    all_acc = (all_preds == all_targets).mean()

    print(
        f"Overall,{all_acc},{seed}\n",
        file=open(f"classification_results.csv", "a")
    )


