In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
import models
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal, stats
import torch

from datasets.masked_tf_dataset import *
from omegaconf import OmegaConf
import umap
import plotly.express as px
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def build_model(cfg, *args, **kwargs):
    ckpt_path = cfg.upstream_ckpt
    init_state = torch.load(ckpt_path)
    upstream_cfg = init_state["model_cfg"]
    # upstream_cfg['upstream_ckpt'] = "/mnt/AI_Magic/projects/BrainBERT/pretrained_weights/stft_large_pretrained.pth" 
    upstream = models.build_model(upstream_cfg, *args, **kwargs)
    return upstream, upstream_cfg

def load_model_weights(model, states, multi_gpu):
    if multi_gpu:
        model.module.load_weights(states)
    else:
        model.load_weights(states)

In [3]:
from omegaconf import OmegaConf

device = "cuda:1"
ckpt_path = "/mnt/AI_Magic/projects/BrainBERT/pretrained_weights/stft_large_pretrained.pth"
cfg = OmegaConf.create({"upstream_ckpt": ckpt_path})
model, upstream_cfg = build_model(cfg)
model.to(device)
init_state = torch.load(ckpt_path)
load_model_weights(model, init_state['model'], False)

upstream_model = model



In [4]:
dataset_cfg = OmegaConf.load("../conf/data/masked_tf_dataset.yaml")
task_cfg = OmegaConf.load("../conf/task/fixed_mask_pretrain.yaml")
preprocessor_cfg = OmegaConf.load("../conf/preprocessor/stft.yaml")
dataset_cfg["data"] = "/mnt/Data_Storage/dataset/brain/spindle_ieeg/brainbert_format"
dataset_cfg["use_mask"] = False

dataset = MaskedTFDataset(dataset_cfg, task_cfg=task_cfg, preprocessor_cfg=preprocessor_cfg)

In [5]:
preprocessor_cfg

{'name': 'stft', 'fs': 2000, 'freq_channel_cutoff': 40, 'nperseg': 400, 'noverlap': 350, 'normalizing': 'zscore'}

In [21]:
from tqdm import tqdm
stats = []
samples = []

for sample in tqdm(dataset):
    inputs = torch.FloatTensor(sample["masked_input"]).unsqueeze(0).to(device)
    mask = torch.FloatTensor(sample["mask_label"]).unsqueeze(0).to(device)
    
    with torch.no_grad():
        pred_out = model.forward(inputs, None, intermediate_rep=False)
        embed_out = model.forward(inputs, None, intermediate_rep=True)
        
    original = sample["target"].cpu().numpy()
    recovery = pred_out[0][0].cpu().numpy()
    embedding = embed_out[0][0].cpu().numpy()
    
    mean = sample["mean"]
    std = sample["std"]
    un_normalized_data = sample["un_normalized_target"]
    wav = sample["wav"]
    erased_Zxx = sample["erased_Zxx"]
    fn = sample["fn"]
    
    stats += [np.mean(np.abs(recovery - original)) / np.mean(np.abs(original))]
    samples += [[sample["masked_input"].cpu().numpy(), 
                 sample["target"].cpu().numpy(), 
                 recovery,
                 embedding,
                 sample["freq"], 
                 sample["time"], 
                 np.mean(np.abs(recovery - original)) / np.mean(np.abs(original)), 
                 mean, 
                 std, 
                 un_normalized_data, 
                 wav, 
                 erased_Zxx,
                 fn]]


100%|██████████| 87544/87544 [17:59<00:00, 81.13it/s]


In [23]:
fns = [s[-1] for s in samples]
pids = np.array([fn.split("/")[-2] for fn in fns])
channels = np.array([fn.split("/")[-1].split("_")[0] for fn in fns])
labels = np.array(["non" in fn for fn in fns])
embeds = np.array([s[3] for s in samples])

In [24]:
embeds.shape

(87544, 768)

In [26]:
# use umap to visualize the embeddings, subsample labels first

# subsample each label 100 from Pt1 and each channel

patient_index = np.where(pids == "Pt1")[0]

for channel in np.unique(channels[patient_index]):
    patient_channel_index = np.where((pids == "Pt1") & (channels == channel))[0]
    
    patinet_channel_labels = labels[patient_channel_index]
    patinet_channel_embeds = embeds[patient_channel_index]
    
    subsampled_embeds = []
    subsampled_labels = []

    sample_size = np.min([np.sum(patinet_channel_labels == 0), np.sum(patinet_channel_labels == 1), 50])
    
    if sample_size == 0:
        continue

    for i in range(2):
        idx = np.where(patinet_channel_labels == i)[0]
        idx = np.random.choice(idx, sample_size, replace=False)
        subsampled_embeds.append(patinet_channel_embeds[idx])
        subsampled_labels.append(patinet_channel_labels[idx])
        
    subsampled_embeds = np.concatenate(subsampled_embeds, axis=0)
    subsampled_labels = np.concatenate(subsampled_labels, axis=0).astype(str)
                                
    reducer = umap.UMAP()
    embedding = reducer.fit_transform(subsampled_embeds)

    fig = px.scatter(x=embedding[:, 0], y=embedding[:, 1], color=subsampled_labels)
    # set the title
    fig.update_layout(title=f"Pt1 {channel} UMAP")
    
    fig.show()



n_neighbors is larger than the dataset size; truncating to X.shape[0] - 1




n_neighbors is larger than the dataset size; truncating to X.shape[0] - 1




n_neighbors is larger than the dataset size; truncating to X.shape[0] - 1



KeyboardInterrupt: 

In [31]:
# split the data so that 2 patients are in the test set and 8 patients are in the training set

# mute warnings
import warnings
warnings.filterwarnings("ignore")

for i in range(1, 11, 2):
    patients_2_test = [f"Pt{i}", f"Pt{i+1}"]
    test_index = np.where(np.isin(pids, patients_2_test))[0]
    
    X_train, X_test = embeds[np.where(np.isin(pids, patients_2_test) == False)[0]], embeds[test_index]
    y_train, y_test = labels[np.where(np.isin(pids, patients_2_test) == False)[0]], labels[test_index]
    
    clf = LogisticRegression(random_state=0).fit(X_train, y_train)
    preds = clf.predict(X_test)

    # print acc, precision, recall, f1
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

    acc = accuracy_score(y_test, preds)
    precision = precision_score(y_test, preds)
    recall = recall_score(y_test, preds)
    f1 = f1_score(y_test, preds)

    print(f"Accuracy: {acc}, Precision: {precision}, Recall: {recall}, F1: {f1}")

Accuracy: 0.8761795460341749, Precision: 0.8761795460341749, Recall: 1.0, F1: 0.9340039420920274
Accuracy: 0.90579087744677, Precision: 0.90579087744677, Recall: 1.0, F1: 0.9505669149390389
Accuracy: 0.8009540074200577, Precision: 0.8009540074200577, Recall: 1.0, F1: 0.8894774704074292
Accuracy: 0.7163479923518165, Precision: 0.7163479923518165, Recall: 1.0, F1: 0.8347351417590375
Accuracy: 0.8821321321321322, Precision: 0.8821321321321322, Recall: 1.0, F1: 0.9373753490227363


In [5]:
from omegaconf import OmegaConf
device = "cuda:1"
ckpt_path = "/mnt/AI_Magic/projects/BrainBERT/pretrained_weights/stft_large_pretrained.pth"
cfg = OmegaConf.create({"upstream_ckpt": ckpt_path})
model, upstream_cfg = build_model(cfg)
model.to(device)
init_state = torch.load(ckpt_path)
load_model_weights(model, init_state['model'], False)

upstream_model = model



In [6]:
# /mnt/AI_Magic/projects/BrainBERT/outputs/2024-05-02/19-10-03/checkpoint_best_f0.pth
# /mnt/AI_Magic/projects/BrainBERT/outputs/2024-05-02/19-58-49/checkpoint_best_f1.pth
# /mnt/AI_Magic/projects/BrainBERT/outputs/2024-05-02/20-10-35/checkpoint_best_f2.pth

device = "cuda:1"
ckpt_path = "/mnt/AI_Magic/projects/BrainBERT/outputs/2024-05-03/08-33-42/checkpoint_best.pth"
cfg = OmegaConf.create({"upstream_ckpt": ckpt_path})
model, upstream_cfg = build_model(cfg, upstream_model=upstream_model)
model.to(device)
init_state = torch.load(ckpt_path)
load_model_weights(model, init_state['model'], False)

In [7]:
from datasets.finetuning_spindle_dataset import *

dataset_cfg = OmegaConf.load("../conf/data/finetuning_spindle.yaml")
task_cfg = OmegaConf.load("../conf/task/finetune_task.yaml")
preprocessor_cfg = OmegaConf.load("../conf/preprocessor/stft.yaml")
dataset_cfg["manifest_name"] = "test_manifest_0.tsv"

dataset = SpindleFinetuning(dataset_cfg, task_cfg=task_cfg, preprocessor_cfg=preprocessor_cfg)

In [8]:
from tqdm import tqdm

preds = []
labels = []

for sample in tqdm(dataset):
    inputs_batch = torch.FloatTensor(sample["input"]).unsqueeze(0).to(device)
    labels_batch = torch.LongTensor([sample["label"]]).to(device)
    
    with torch.no_grad():
        pred_out = model.forward(inputs_batch, None)
        # sigmoid
        pred_out = torch.softmax(pred_out, dim=1)
        pred_out = torch.argmax(pred_out, dim=1).cpu().numpy()

    preds.append(pred_out)
    labels.append(labels_batch.cpu().numpy())

preds = np.array(preds).flatten()
labels = np.array(labels).flatten()

  0%|          | 0/8196 [00:00<?, ?it/s]

100%|██████████| 8196/8196 [00:56<00:00, 145.67it/s]


In [14]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, balanced_accuracy_score

print(f"Accuracy: {accuracy_score(labels, preds)}")
print(f"Precision: {precision_score(labels, preds)}")
print(f"Recall: {recall_score(labels, preds)}")
print(f"F1: {f1_score(labels, preds)}")
print(f"Balanced Accuracy: {balanced_accuracy_score(labels, preds)}")

Accuracy: 0.5006100536847242
Precision: 0.5009769441187965
Recall: 0.3128355295265983
F1: 0.3851584797957038
Balanced Accuracy: 0.5006100536847242


In [11]:
print(np.unique(labels, return_counts=True))
print(np.unique(preds, return_counts=True))

(array([0, 1]), array([4098, 4098]))
(array([0, 1]), array([5637, 2559]))


In [13]:
labels

array([0, 0, 0, ..., 1, 1, 1])

In [12]:
preds

array([1, 0, 0, ..., 0, 1, 0])

In [45]:
training_data = np.load("/mnt/AI_Magic/projects/BrainBERT/outputs/2024-05-03/05-40-22/training_debug.npz", allow_pickle=True)
inference_data = np.load("/mnt/AI_Magic/projects/BrainBERT/outputs/2024-05-03/05-40-22/valid_debug.npz", allow_pickle=True)

In [46]:
np.sum(np.abs(training_data["output"] - inference_data["output"]))

5.664646