# Collect Results and Plots

## Load pcg model

In [9]:
from classifier.datasets import HeartAudioDataset
from classifier.model_factory import ModelFactory
import torch

model_path = "../../models/wav2vec-4s-pcg-training-a-pcg-oct22.log-9"
model = "wav2vec"
aux_type = model


def load_all_pcg_ta_models(model_base_path, model_str, class_names):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    optimizer = 'sgd'
    models_factory = ModelFactory(device, class_names, freeze=False, optimizer_type=optimizer)
    models = []

    for i in range(1,11):
        model_name = f"{model_base_path}-{i}"
        models.append(models_factory.load_model(model_name, model_str == "inception", model_str == "wav2vec"))

    return models

def load_pcg_ta_model(model_path, model_str, class_names):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    optimizer = 'sgd'
    models_factory = ModelFactory(device, class_names, freeze=False, optimizer_type=optimizer)
    model = models_factory.load_model(model_path, model_str == "inception", model_str == "wav2vec")

    return model

def get_pcg_ta_dataset(fs=4125):
    data_dir = "../../data/physionet.org/files/challenge-2016/1.0.0/training-a"
    split_path = "../../data/heart-sounds/actually-is-reference-CTH/REFERENCE.csv"
    audio_dir = "../../data/preprocessed_audio/training-a-4s-for-wav2vec-paper"
    segment_dir = "../../data/segmentation/wav2vec-training-a-4"
    database = "training-a-pcg"
    segmentation = "time"
    four_bands = False
    skip_data_valid = False 
    sig_len = 4

    Dataset = HeartAudioDataset
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    phases = ['test']
    datasets = {p: Dataset(
        data_dir,
        split_path,
        segment_dir,
        p,
        audio_dir,
        ecg=(database == "training-a"),
        segmentation=segmentation,
        augmentation=False,
        four_band=four_bands,
        fs=fs,
        skip_data_valid=skip_data_valid,
        sig_len=sig_len,
    ) for p in phases}

    class_names = next(iter(datasets.values())).classes
    optimizer = 'sgd'

    # Setup dataset
    for key in datasets.keys():
        datasets[key].channel = 0

    return datasets

datasets = get_pcg_ta_dataset()
class_names = next(iter(datasets.values())).classes

model = load_pcg_ta_model(model_path, model, class_names)

100%|██████████| 81/81 [00:00<00:00, 1746.41it/s]
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Load ecg model

In [1]:

from classifier.datasets import HeartAudioDataset
from classifier.model_factory import ModelFactory
import torch

model_path = "../../models/wav2vec-4s-ecg-training-a-ecg-oct23.log-2"
model = "wav2vec"
aux_type = model

def load_all_ecg_ta_models(model_base_path, model_str, class_names):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    optimizer = 'sgd'
    models_factory = ModelFactory(device, class_names, freeze=False, optimizer_type=optimizer)
    models = []

    for i in range(1,11):
        model_name = f"{model_base_path}-{i}"
        models.append(models_factory.load_model(model_name, model_str == "inception", model_str == "wav2vec"))

    return models

def load_ecg_ta_model(model_path, model_str, class_names):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    optimizer = 'sgd'
    models_factory = ModelFactory(device, class_names, freeze=False, optimizer_type=optimizer)
    model = models_factory.load_model(model_path, model_str == "inception", model_str == "wav2vec")

    return model

def get_ecg_ta_dataset(fs=4125):
    data_dir = "../../data/physionet.org/files/challenge-2016/1.0.0/training-a"
    split_path = "../../data/heart-sounds/actually-is-reference-CTH/REFERENCE.csv"
    audio_dir = "../../data/preprocessed_audio/training-a-4s-for-wav2vec-paper"
    segment_dir = "../../data/segmentation/wav2vec-training-a-4"
    database = "training-a-ecg"
    segmentation = "time"
    four_bands = False
    skip_data_valid = False 
    sig_len = 4

    Dataset = HeartAudioDataset
    phases = ['test']
    datasets = {p: Dataset(
        data_dir,
        split_path,
        segment_dir,
        p,
        audio_dir,
        ecg=(database == "training-a"),
        segmentation=segmentation,
        augmentation=False,
        four_band=four_bands,
        fs=fs,
        skip_data_valid=skip_data_valid,
        sig_len=sig_len,
    ) for p in phases}

    class_names = next(iter(datasets.values())).classes
    optimizer = 'sgd'

    # Setup dataset
    for key in datasets.keys():
        datasets[key].channel = 1

    return datasets

datasets = get_ecg_ta_dataset()
class_names = next(iter(datasets.values())).classes
#
model = load_ecg_ta_model(model_path, model, class_names)

100%|██████████| 81/81 [00:00<00:00, 154.09it/s]
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Load PECG Model

In [2]:
from classifier.datasets import HeartAudioDataset
from classifier.model_factory import ModelFactory
import torch

model = "wav2vec"
aux_type = model
model_path = "../../models/wav2vec-4s-training-a-split-1-jan17.log.pt-1"

def load_all_pecg_ta_models(model_base_path, model_str, class_names):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    optimizer = 'sgd'
    models_factory = ModelFactory(device, class_names, freeze=False, optimizer_type=optimizer)
    models = []

    for i in range(1,11):
        model_name = f"{model_base_path}-{i}"
        models.append(models_factory.load_model(model_name, is_large_wav2vec=True))

    return models

def load_pecg_ta_model(model_path, model_str, class_names):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    optimizer = 'sgd'
    models_factory = ModelFactory(device, class_names, freeze=False, optimizer_type=optimizer)
    model = models_factory.load_model(model_path, is_large_wav2vec=True)
    model.model_ft.to(device)
    for submodel in model.model_ft.models:
        submodel.to(device)

    return model

def get_pecg_ta_dataset(fs=16000, augment=False):
    data_dir = "../../data/physionet.org/files/challenge-2016/1.0.0/training-a"
    split_path = "../../data/heart-sounds/actually-is-reference-CTH/REFERENCE.csv"
    audio_dir = "../../data/preprocessed_audio/training-a-4s-for-wav2vec-paper"
    segment_dir = "../../data/segmentation/wav2vec-training-a-4"
    database = "training-a"
    segmentation = "time"
    four_bands = False
    skip_data_valid = False
    sig_len = 4

    Dataset = HeartAudioDataset
    phases = ['test']
    datasets = {p: Dataset(
        data_dir,
        split_path,
        segment_dir,
        p,
        audio_dir,
        ecg=(database == "training-a"),
        segmentation=segmentation,
        augmentation=False,
        four_band=four_bands,
        fs=fs,
        skip_data_valid=skip_data_valid,
        sig_len=sig_len,
    ) for p in phases}

    class_names = next(iter(datasets.values())).classes
    optimizer = 'sgd'

    # Setup dataset
    for key in datasets.keys():
        datasets[key].channel = -1

    return datasets

datasets = get_pecg_ta_dataset()
class_names = next(iter(datasets.values())).classes

model = load_pecg_ta_model(model_path, model, class_names)
model.set_mode('test')

100%|██████████| 81/81 [00:00<00:00, 1233.42it/s]
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:
from classifier.datasets import HeartAudioDataset
from classifier.model_factory import ModelFactory
import torch

model_str = "wav2vec-cnn"
aux_type = model_str
model_path = "../../models/wav2vec-cnn-4s-training-a-wav2vec_cnn-split-1-jan21.log.pt-1"

def load_all_pecg_ta_models(model_base_path, model_str, class_names):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    optimizer = 'sgd'
    models_factory = ModelFactory(device, class_names, freeze=False, optimizer_type=optimizer)
    models = []

    for i in range(1,11):
        model_name = f"{model_base_path}-{i}"
        models.append(models_factory.load_model(model_name, is_large_wav2veccnn=True))

    return models

def load_pecg_ta_model(model_path, model_str, class_names):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    optimizer = 'sgd'
    models_factory = ModelFactory(device, class_names, freeze=False, optimizer_type=optimizer)
    model = models_factory.load_model(model_path, is_large_wav2veccnn=True)
    model.model_ft.to(device)
    for submodel in model.model_ft.models:
        submodel.to(device)

    return model

def get_pecg_ta_dataset(fs=16000, augment=False):
    data_dir = "../../data/physionet.org/files/challenge-2016/1.0.0/training-a"
    split_path = "../../data/heart-sounds/actually-is-reference-CTH/REFERENCE.csv"
    audio_dir = "../../data/preprocessed_audio/training-a-4s-for-wav2vec-paper"
    segment_dir = "../../data/segmentation/wav2vec-training-a-4"
    database = "training-a"
    segmentation = "time"
    four_bands = False
    skip_data_valid = False
    sig_len = 4

    Dataset = HeartAudioDataset
    phases = ['test']
    datasets = {p: Dataset(
        data_dir,
        split_path,
        segment_dir,
        p,
        audio_dir,
        ecg=(database == "training-a"),
        segmentation=segmentation,
        augmentation=False,
        four_band=four_bands,
        fs=fs,
        skip_data_valid=skip_data_valid,
        sig_len=sig_len,
    ) for p in phases}

    class_names = next(iter(datasets.values())).classes
    optimizer = 'sgd'

    # Setup dataset
    for key in datasets.keys():
        datasets[key].channel = -1

    return datasets

datasets = get_pecg_ta_dataset(fs=4125)
class_names = next(iter(datasets.values())).classes

model = load_pecg_ta_model(model_path, model_str, class_names)
model.set_mode('test')

100%|██████████| 81/81 [00:00<00:00, 967.67it/s]
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


16500


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


16500


## Load PCG CinC Model

In [15]:
# PCG CNN model
from classifier.datasets import HeartAudioDatabase
from classifier.model_factory import ModelFactory
import torch

model_path = '../../data/models/cinc-wav2vecwav2vec-4s-cinc-wav2vec-cnn-cinc-wav2vec-cnn-dec19.log-10.pth'
#model_path = '../../data/models/cinc-wav2vecwav2vec-4s-cinc-cinc-jan06.log-10.pth'
model = "wav2vec-cnn"
aux_type = model

def load_all_pcg_cinc_cnn_models(model_base_path, model_str, class_names):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    optimizer = 'sgd'
    models_factory = ModelFactory(device, class_names, freeze=False, optimizer_type=optimizer)
    models = []

    for i in range(1,11):
        model_name = f"{model_base_path}-{i}.pth"
        models.append(models_factory.load_model(model_name, is_wav2veccnn=True))

    return models

def load_pcg_cinc_cnn_model(model_path, model_str, class_names):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    optimizer = 'sgd'
    models_factory = ModelFactory(device, class_names, freeze=False, optimizer_type=optimizer)
    model = models_factory.load_model(model_path, is_wav2veccnn=True)

    return model

def get_pcg_cinc_dataset(fs=16000):
    data_dir = "../../data/physionet.org/files/challenge-2016/1.0.0/"
    split_path = "../../data/splits/rnn"
    audio_dir = "../../data/processed_audio/cinc/entire"
    segment_dir = "../../data/segmentation/rnn"
    database = "cinc"
    segmentation = "time"
    four_bands = False
    skip_data_valid = True 
    sig_len = 4

    Dataset = HeartAudioDatabase
    phases = ['test']
    datasets = {p: Dataset(
        data_dir,
        split_path,
        segment_dir,
        p,
        audio_dir,
        ecg=(database == "training-a"),
        segmentation=segmentation,
        augmentation=False,
        four_band=four_bands,
        fs=fs,
        skip_data_valid=skip_data_valid,
        sig_len=sig_len,
    ) for p in phases}

    return datasets

datasets = get_pcg_cinc_dataset(fs=4125)
class_names = next(iter(datasets.values())).classes

model = load_pcg_cinc_cnn_model(model_path, model, class_names)


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


16500


In [8]:
from classifier.datasets import HeartAudioDatabase
from classifier.model_factory import ModelFactory
import torch

model_path = '../../models/models_for_maybe_paper/wav2vec-4s-cinc-16k-cinc-16k-dec16.log-4.pth'
#model_path = '../../data/models/cinc-wav2vecwav2vec-4s-cinc-cinc-jan06.log-10.pth'
model = "wav2vec"
aux_type = model

def load_all_pcg_cinc_models(model_base_path, model_str, class_names):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    optimizer = 'sgd'
    models_factory = ModelFactory(device, class_names, freeze=False, optimizer_type=optimizer)
    models = []

    for i in range(1,11):
        model_name = f"{model_base_path}-{i}.pth"
        models.append(models_factory.load_model(model_name, is_wav2vec=True))

    return models

def load_pcg_cinc_model(model_path, model_str, class_names):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    optimizer = 'sgd'
    models_factory = ModelFactory(device, class_names, freeze=False, optimizer_type=optimizer)
    model = models_factory.load_model(model_path, is_wav2vec=True)

    return model

def get_pcg_cinc_dataset(fs=16000):
    data_dir = "../../data/physionet.org/files/challenge-2016/1.0.0/"
    split_path = "../../data/splits/rnn"
    audio_dir = "../../data/processed_audio/cinc/entire"
    segment_dir = "../../data/segmentation/rnn"
    database = "cinc"
    segmentation = "time"
    four_bands = False
    skip_data_valid = True 
    sig_len = 4

    Dataset = HeartAudioDatabase
    phases = ['test']
    datasets = {p: Dataset(
        data_dir,
        split_path,
        segment_dir,
        p,
        audio_dir,
        ecg=(database == "training-a"),
        segmentation=segmentation,
        augmentation=False,
        four_band=four_bands,
        fs=fs,
        skip_data_valid=skip_data_valid,
        sig_len=sig_len,
    ) for p in phases}

    return datasets

datasets = get_pcg_cinc_dataset()
class_names = next(iter(datasets.values())).classes

model = load_pcg_cinc_model(model_path, model, class_names)


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Run Test Set

In [None]:
from classifier.dataloaders import create_dataloaders
from classifier.testing import FineTunerFragmentTester, FineTunerPatientTester
import torch

dataloader = create_dataloaders(datasets, aux_type)

fragment_tester = FineTunerFragmentTester(model, dataloader)
patient_tester = FineTunerPatientTester(model, dataloader)

print()
fragment_tester.test()
print()
patient_tester.test()
print()





Fragment Stats:
test
loss: 0.635, tpr: 0.831, tnr: 0.750, fpr: 0.250, ppv: 0.890, npv: 0.646, acc: 0.807, acc_mu: 0.790, qi: 0.789, j: 0.581, f1p: 0.859, f1n: 0.694, mcc: 0.558
defaultdict(<class 'int'>, {'tn': 117, 'fp': 39, 'fn': 64, 'tp': 314})

Patient Stats:
test
loss: 0.000, tpr: 0.860, tnr: 0.750, fpr: 0.250, ppv: 0.891, npv: 0.692, acc: 0.827, acc_mu: 0.805, qi: 0.803, j: 0.610, f1p: 0.875, f1n: 0.720, mcc: 0.596
defaultdict(<class 'int'>, {'tn': 18, 'fp': 6, 'fn': 8, 'tp': 49})



In [17]:
from classifier.dataloaders import create_dataloaders
from classifier.testing import FineTunerFragmentTester, FineTunerPatientTester
import torch

models = load_all_ecg_ta_models("../../models/wav2vec-4s-ecg-training-a-ecg-oct23.log", "wav2vec", ["0", "1"])

for model in models:
    fragment_tester = FineTunerFragmentTester(model, dataloader)
    patient_tester = FineTunerPatientTester(model, dataloader)

    print()
    fragment_tester.test()
    print()
    patient_tester.test()
    print()

NameError: name 'load_all_ecg_ta_models' is not defined

## Create ROC Plots

In [10]:
import matplotlib.pyplot as plt
from typing import Optional
import numpy as np
from classifier.dataloaders import create_dataloaders
from classifier.testing import FineTunerFragmentTester, FineTunerPatientTester

def generate_roc_plots(models, datasets, output_path = None):
    dataloader = create_dataloaders(datasets, aux_type)

    fragment_tprs = []
    fragment_fprs = []
    patient_tprs = []
    patient_fprs = []

    for model in models:
        model.model_ft = model.model_ft.to("cuda:0" if torch.cuda.is_available() else "cpu")
        fragment_tester = FineTunerFragmentTester(model, dataloader)
        patient_tester = FineTunerPatientTester(model, dataloader)

        fragment_tpr, fragment_fpr, _ = fragment_tester.roc_curve()
        patient_tpr, patient_fpr, _ = patient_tester.roc_curve()

        fragment_fprs.append(fragment_fpr)
        fragment_tprs.append(fragment_tpr)
        patient_fprs.append(patient_fpr)
        patient_tprs.append(patient_tpr)
        model.model_ft = model.model_ft.to("cpu")


    common_tpr = np.linspace(0, 1, 100)
    interpolated_fragment_fprs = []
    interpolated_patient_fprs = []

    for fpr, tpr in zip(fragment_tprs, fragment_fprs):
        interp_fpr = np.interp(common_tpr, fpr[::-1], tpr[::-1])
        interpolated_fragment_fprs.append(interp_fpr)
    for fpr, tpr in zip(patient_tprs, patient_fprs):
        interp_fpr = np.interp(common_tpr, fpr[::-1], tpr[::-1])
        interpolated_patient_fprs.append(interp_fpr)

    interpolated_fragment_fprs = np.array(interpolated_fragment_fprs)
    interpolated_patient_fprs = np.array(interpolated_patient_fprs)

    mean_fpr = np.mean(interpolated_fragment_fprs, axis=0)
    fpr_2_5 = np.percentile(interpolated_fragment_fprs, 2.5, axis=0)
    fpr_97_5 = np.percentile(interpolated_fragment_fprs, 97.5, axis=0)
    print(f"{mean_fpr=}, {fpr_2_5=}, {fpr_97_5=}")
    plot_roc_curves([common_tpr, common_tpr, common_tpr], [mean_fpr, fpr_2_5, fpr_97_5], output_path=f"{output_path}_fragment")

    mean_fpr = np.mean(interpolated_patient_fprs, axis=0)
    fpr_2_5 = np.percentile(interpolated_patient_fprs, 2.5, axis=0)
    fpr_97_5 = np.percentile(interpolated_patient_fprs, 97.5, axis=0)
    print(f"{mean_fpr=}, {fpr_2_5=}, {fpr_97_5=}")
    plot_roc_curves([common_tpr, common_tpr, common_tpr], [mean_fpr, fpr_2_5, fpr_97_5], output_path=f"{output_path}_patient")



def plot_roc_curves(
        tprs: list[list[float]], 
        tnrs: list[list[float]], 
        output_path: Optional[str] = None,
        title: Optional[str] = None
):
    assert len(tprs) == len(tnrs), "Should contain the same length arrays."

    plt.figure()
    plt.grid()
    plt.xlabel("True positive rate")
    plt.ylabel("False positive rate")
    plt.title("ROC Curve")

    tpr = tprs[0]
    tnr = tnrs[0]
    plt.plot(tnr, tpr, label="Mean")
    
    tpr = tprs[1]
    tnr = tnrs[1]
    plt.plot(tnr, tpr, label="2.5% CI")

    tpr = tprs[2]
    tnr = tnrs[2]
    plt.plot(tnr, tpr, label="97.5% CI")

    plt.legend()

    if output_path is None:
        plt.show()
    else:
        print('saving')
        plt.savefig(output_path)

    plt.close()

In [None]:
import matplotlib.pyplot as plt

import matplotlib
#matplotlib.use('TkAgg')  # or 'QtAgg' or 'MacOSX' depending on your system
plt.ion

models_base_path_4125 = '../../data/models/cinc-wav2vecwav2vec-4s-cinc-cinc-jan06.log'
models_base_path_16 = '../../models/models_for_maybe_paper/wav2vec-4s-cinc-16k-cinc-16k-dec16.log'

datasets_4125 = get_pcg_cinc_dataset(fs=4125)
class_names = next(iter(datasets.values())).classes

models_4125 = load_all_pcg_cinc_models(models_base_path_4125, "wav2vec", class_names)

datasets_16 = get_pcg_cinc_dataset(fs=16000)
class_names = next(iter(datasets.values())).classes

models_16 = load_all_pcg_cinc_models(models_base_path_16, "wav2vec", class_names)

In [None]:
generate_roc_plots(models_4125, datasets_4125, "roc/cinc-4125")
generate_roc_plots(models_16, datasets_16, "roc/cinc-16k")

In [13]:
import matplotlib.pyplot as plt

import matplotlib
#matplotlib.use('TkAgg')  # or 'QtAgg' or 'MacOSX' depending on your system
plt.ion

models_base_path_4125 = '../../models/wav2vec-4s-training-a-split-1-jan17.log.pt'

models_base_path_16 = '../../models/models_for_maybe_paper/wav2vec-4s-training-a-16k-oct23.log'
models_base_path_pcg = '../../models/wav2vec-4s-pcg-training-a-pcg-oct22.log'
models_base_path_ecg = '../../models/wav2vec-4s-ecg-training-a-ecg-oct23.log'
models_base_path_no_augment = '../../models/models_for_maybe_paper/wav2vec-4s-training-a-16k-no-augment-jan03.log'

datasets_4125 = get_pecg_ta_dataset(fs=4125)
class_names = next(iter(datasets_4125.values())).classes

models_4125 = load_all_pecg_ta_models(models_base_path_4125, "wav2vec", class_names)

#datasets_16 = get_pecg_ta_dataset(fs=16000)
#class_names = next(iter(datasets_4125.values())).classes
#
##models_16 = load_all_pecg_ta_models(models_base_path_16, "wav2vec", class_names)
#
#datasets_pcg = get_pcg_ta_dataset(fs=4125)
#class_names = next(iter(datasets_pcg.values())).classes
#
#models_pcg = load_all_pcg_ta_models(models_base_path_pcg, "wav2vec", class_names)

#datasets_ecg = get_ecg_ta_dataset(fs=4125)
#models_ecg = load_all_ecg_ta_models(models_base_path_pcg, "wav2vec", class_names)
#
#datasets_no_augment = get_pecg_ta_dataset(fs=16000, augment=False)
#class_names = next(iter(datasets_4125.values())).classes
#
##models_no_augment = load_all_pecg_ta_models(models_base_path_pcg, "wav2vec", class_names)

100%|██████████| 81/81 [00:00<00:00, 1099.18it/s]
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-ba

In [14]:
#generate_roc_plots(models_pcg, datasets_pcg, "roc/ta-pcg-4125")
#generate_roc_plots(models_ecg, datasets_ecg, "roc/ta-ecg-4125")
generate_roc_plots(models_4125, datasets_4125, "roc/ta-pecg-4125")
#generate_roc_plots(models_16, datasets_16, "roc/ta-pecg-16k")
#generate_roc_plots(models_no_augment, datasets_no_augment, "roc/ta-pecg-16k-no-augment")

  0%|                                                                                            | 0/61 [00:00…

RuntimeError: Expected target size [64, 2], got [64]

In [5]:
#model_base_path = '../../data/models/cinc-wav2vecwav2vec-4s-cinc-wav2vec-cnn-cinc-wav2vec-cnn-dec19.log'
model_base_path = '../../data/models/cinc-wav2vecwav2vec-4s-cinc-wav2vec-cnn-cinc-wav2vec-cnn-dec19.log'

datasets = get_pcg_cinc_dataset(fs=4125)
class_names = next(iter(datasets.values())).classes
models = load_all_pcg_cinc_cnn_models(model_base_path, "wav2vec", class_names)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


16500


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


16500


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


16500


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


16500


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


16500


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


16500


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


16500


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


16500


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


16500


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


16500


In [9]:
generate_roc_plots(models, datasets, "roc/cinc-pcg-cnn-4125")

  0%|                                                                                           | 0/120 [00:00…

  0%|                                                                                           | 0/120 [00:00…

fragment_fpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9915074309978769, 0.4002123142250531, 0.1746284501061571, 0.09394904458598727, 0.037154989384288746, 0.0031847133757961785, 0.0010615711252653928, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], fragment_tpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.

  0%|                                                                                           | 0/120 [00:00…

  0%|                                                                                           | 0/120 [00:00…

fragment_fpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.39384288747346075, 0.11942675159235669, 0.08333333333333333, 0.038747346072186835, 0.0037154989384288748, 0.002653927813163482, 0.0005307855626326964, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], fragment_tpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0

  0%|                                                                                           | 0/120 [00:00…

  0%|                                                                                           | 0/120 [00:00…

fragment_fpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.994692144373673, 0.351380042462845, 0.160828025477707, 0.09607218683651805, 0.03927813163481953, 0.004246284501061571, 0.0005307855626326964, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], fragment_tpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.

  0%|                                                                                           | 0/120 [00:00…

  0%|                                                                                           | 0/120 [00:00…

fragment_fpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.37367303609341823, 0.17834394904458598, 0.09713375796178345, 0.034501061571125265, 0.0010615711252653928, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], fragment_tpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0

  0%|                                                                                           | 0/120 [00:00…

  0%|                                                                                           | 0/120 [00:00…

fragment_fpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.4166666666666667, 0.14012738853503184, 0.08545647558386411, 0.032908704883227176, 0.006369426751592357, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], fragment_tpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 

  0%|                                                                                           | 0/120 [00:00…

  0%|                                                                                           | 0/120 [00:00…

fragment_fpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9952229299363057, 0.3619957537154989, 0.1618895966029724, 0.09182590233545647, 0.029723991507430998, 0.0010615711252653928, 0.0, 0.0005307855626326964, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], fragment_tpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.

  0%|                                                                                           | 0/120 [00:00…

  0%|                                                                                           | 0/120 [00:00…

fragment_fpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.37154989384288745, 0.1751592356687898, 0.09925690021231423, 0.032908704883227176, 0.0031847133757961785, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], fragment_tpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,

  0%|                                                                                           | 0/120 [00:00…

  0%|                                                                                           | 0/120 [00:00…

fragment_fpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.3880042462845011, 0.13853503184713375, 0.09288747346072186, 0.04617834394904458, 0.011146496815286623, 0.0005307855626326964, 0.0, 0.0005307855626326964, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], fragment_tpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 

  0%|                                                                                           | 0/120 [00:00…

  0%|                                                                                           | 0/120 [00:00…

fragment_fpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9925690021231423, 0.46390658174097665, 0.18099787685774946, 0.09819532908704884, 0.025477707006369428, 0.0010615711252653928, 0.0005307855626326964, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], fragment_tpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 

  0%|                                                                                           | 0/120 [00:00…

  0%|                                                                                           | 0/120 [00:00…

fragment_fpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9920382165605095, 0.39118895966029726, 0.12154989384288747, 0.07908704883227176, 0.042993630573248405, 0.009023354564755838, 0.0010615711252653928, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], fragment_tpr=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1

## Create PaCMAP Plots

In [2]:
import numpy as np
from typing import Optional
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
import pacmap

import matplotlib.pyplot as plt

def plot_dimension_reduction(data: np.ndarray, labels: np.ndarray, transform, name: str, output_path: Optional[str] = None, normalise: bool = True, fit_embeddings: bool = True, original_data = None):
    scaler = MinMaxScaler((-1, 1))

    if fit_embeddings: 
        X_transformed = transform.fit_transform(data) # type: ignore
        if normalise:
            X_transformed = scaler.fit_transform(X_transformed)
    else:
        if original_data is None:
            raise ValueError("Must provide the original data with the fit_transform.")
        # If it's PaCMAP, we need the 'basis' argument. If it's PCA, no basis is required.
        if isinstance(transform, pacmap.PaCMAP):
            if original_data is None:
                raise ValueError("Must provide the original data with the PaCMAP transform.")
            X_transformed = transform.transform(data, basis=original_data)  # PaCMAP requires the basis argument
        else:
            X_transformed = transform.transform(data)

    plt.figure()
    normal_indices = np.where(labels == 0)[0]
    plt.scatter(
        X_transformed[normal_indices, 0],
        X_transformed[normal_indices, 1],
        color='blue', 
        label='Normal', 
        s=0.8,
    )
    # Plot "Abnormal" points
    abnormal_indices = np.where(labels == 1)[0]
    plt.scatter(
        X_transformed[abnormal_indices, 0],
        X_transformed[abnormal_indices, 1],
        color='red', 
        label='Abnormal', 
        s=0.8,
    ) # type: ignore
    plt.legend()
    plt.title(f"{name} Visualisation of Features")
    plt.xlabel("Component 1")
    plt.ylabel("Component 2")

    if output_path is None:
        plt.show()    
    else:
        plt.savefig(output_path)

    return transform, data

def plot_pacmap_features(data: np.ndarray, labels: np.ndarray, output_path: Optional[str] = None, normalise: bool = True, embedding_transform = None, original_data = None):
    embedding = pacmap.PaCMAP(n_components=2) if embedding_transform is None else embedding_transform  # type: ignore
    fit_embeddings = embedding_transform is None  # Only fit if embedding_transform is None
    fit_embedding, original_data = plot_dimension_reduction(data, labels, embedding, "PaCMAP", output_path=output_path, normalise=normalise, fit_embeddings=fit_embeddings, original_data=original_data)

    return fit_embedding, original_data

In [None]:
from classifier.dataloaders import create_dataloaders
from classifier.testing import FineTunerFragmentTester, FineTunerPatientTester
import matplotlib.pyplot as plt

import matplotlib
matplotlib.use('TkAgg')  # or 'QtAgg' or 'MacOSX' depending on your system
plt.ion

dataloader = create_dataloaders(datasets, aux_type)
features, labels = FineTunerFragmentTester(model.model_ft, dataloader).embeddings()

plot_pacmap_features(features, labels, output_path="pacmap/pcg-cinc-4125")