In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directoryo
# For example, running this (by clicking run or pressing Shift+Enter) will list all files undeor the input directory

import os

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install git+https://github.com/sderooij/seizure_data_processing.git@v0.0.1 -q

In [None]:
!pip install git+https://github.com/Roodster/ai2p-asd.git

In [None]:
!pip install nptyping

In [None]:
from torch.utils.data import DataLoader
import torch.nn.functional as F

from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score, accuracy_score, roc_auc_score

from asd.dataset import get_dataloaders, OnlineSegmentsDataset, OfflineSegmentsDataset, OfflineSegmentsDataset4, DummyDataset
from asd.args import Args
from asd.writer import Writer
from asd.plots import EventPlots
from asd.results import Results, EventResults
from asd.models import VisionTransformer, SSLTransformer, ShallowAE, BetaVAE, SoftMaxClassifier, ShallowEncoder, BetaEncoder, CNNBiLSTM, DARLNet
from asd.learner import Learner, AELearner, SSLLearner, DARLNetLearner
from asd.models.losses import BetaVAELoss, ContrastiveLoss
from asd.labels import OHELabelTransformer
from asd.experiment import Experiment
from asd.event_scoring.annotation import Annotation
from asd.event_scoring.scoring import EventScoring

import torch as th
import torch.nn as nn
import warnings
import torch.optim as optim
warnings.filterwarnings('ignore')

In [None]:
if __name__ == "__main__":
    args = Args(file="/kaggle/input/config/default.yaml")
    args.patient_id = '24'
    args.val_patient_id = '01'
    # Load dataset

    # Replace with Kaggle training set link 
    train_dataset = OfflineSegmentsDataset("/kaggle/input/chb-mit-train-4s-ratio151/train_set_15_1/full_train", mode='train', patient_id=args.patient_id, val_patient_id=args.val_patient_id)
    # Replace with Kaggle test set link 
    test_dataset = OfflineSegmentsDataset("/kaggle/input/chb24-50-overlap/chb24_test_overlap", mode='test', patient_id=args.patient_id)
    # Replace with Kaggle validation set link 
    val_dataset = OfflineSegmentsDataset("/kaggle/input/chb01-50-overlap/chb01_test_overlap", mode='validation', patient_id=args.val_patient_id)

In [None]:
args = Args(file="/kaggle/input/config/default.yaml")

# Load dataset
args.model_name = "darlnet"
args.root_dir = "/kaggle/working/"
args.n_epochs = 15
args.eval_interval = 1
args.batch_size = 128
args.learning_rate = 1e-4
args.eval_sample_rate = 0.5
args.device = th.device("cuda" if th.cuda.is_available() else "cpu")


# Instantiate dataloaders 
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)

In [None]:
model = DARLNet(args=args)

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=0.0005)

results = EventResults()
learner = DARLNetLearner(args=args, 
                    model=model,
                    optimizer=optimizer, 
                    criterion=criterion,
                    event_scoring=True
                    # label_transformer=OHELabelTransformer()
                   )

experiment = Experiment(args=args,
                       learner=learner,
                       results=results,
                       event_scoring=True
                       )


In [None]:
threshold, weights = experiment.run(train_loader, val_loader)

In [None]:
model_save_path = '/kaggle/working/chb24_15epochs.pth'
th.save(weights, model_save_path)

In [None]:
print(threshold)

In [None]:
#Test the model
experiment = Experiment(args=args, event_scoring=True)
experiment.evaluate_predictions(model, test_loader, threshold=0.95)