In [1]:
import pandas as pd
import numpy as np
from sktime.utils import mlflow_sktime
from sktime.classification.kernel_based import RocketClassifier
from sklearn.metrics import accuracy_score
import pickle
from torch.utils.data import Dataset, DataLoader, random_split
import torch

In [2]:
# Custom Dataset

train_file = "data/hms-harmful-brain-activity-classification/cleaned_train.csv"
train_dir = "data/hms-harmful-brain-activity-classification/train_eegs"

In [10]:
class RocketDataset(Dataset):
    def __init__(self, train_file, train_path):
        self.df = pd.read_csv(train_file)
        
        self.df = self.df.loc[self.df['is_center'] == True]
        self.df = self.df.drop_duplicates(subset='eeg_id', keep="first")

        self.dir = train_path
        self.len = len(self.df)
    def __len__(self):
        return self.len
    def __getitem__(self, ind):
        eeg_id = self.df.iloc[ind]['eeg_id']
        pq = pd.read_parquet(f"{self.dir}/{eeg_id}.parquet")
        middle = (len(pq)-2_000)//2
        pq = pq.iloc[middle:middle+2_000:2]
        pq = pq.reset_index()
        
        lbl = [self.df.iloc[ind]['seizure_vote'], self.df.iloc[ind]['lpd_vote'], self.df.iloc[ind]['gpd_vote'], self.df.iloc[ind]['lrda_vote'], self.df.iloc[ind]['grda_vote'], self.df.iloc[ind]['other_vote']]
        
        return pq, lbl
    
def collate_fn(batch):
    # convert list of eegs to multi-index df
    # one-hot encode labels
   
    
    x_batch = [item[0] for item in batch]
    y_batch = [item[1] for item in batch]
    batch_size = len(x_batch)
    
    x_batch = pd.concat(x_batch,keys=list(range(batch_size)),axis=0).reset_index(level=1)
    x_batch['instances'] = x_batch.index
    x_batch = x_batch.rename(columns={"level_1": "timepoints"})
    x_batch = x_batch.set_index(['instances', 'timepoints'])
    x_batch = x_batch.fillna(0)
    
    y_batch = pd.DataFrame(y_batch, columns = ['Seizure', 'LPD', 'GPD', 'LRDA', 'GRDA', 'Other']) 
    
    # y_batch = pd.get_dummies(y_batch, columns = ['GPD','GRDA','LPD','LRDA','Other','Seizure']).astype("int32")
    y_batch = y_batch.reset_index(drop=True)
    
    return x_batch, y_batch

In [11]:
rd = RocketDataset(train_file, train_dir)

In [12]:
BATCH_SIZE = 16
train_set, val_set = random_split(rd, [.7, .3], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

In [13]:
# load model
model_dir = "rocket-mini/"

rocket_model = mlflow_sktime.load_model(model_uri=model_dir)

In [14]:
# run prediction on batch (for testing)
pred = []
y_val = []
for x, y in val_loader: 
    predictions = rocket_model.predict(x)
    # calc accuracy
    # acc = accuracy_score(y.idxmax(axis=1),predictions.idxmax(axis=1) )
    pred.extend(predictions.idxmax(axis=1))
    y_val.extend(y.idxmax(axis=1))
    break