# Reweighting source datasets

## Notebook setup

In [1]:
#| code-fold: true
#| code-summary: "Click to see packages imported"
import os
import configparser
import random
import shutil
from pathlib import Path

import torch
import wfdb
import numpy as np
import dsail
from dsail.model.model_utils import get_model
from dsail.train import Trainer
from dsail.data import get_loss_weights_and_flags

In [2]:
#|include: false
# If the current working directory is the nbs/ folder, change to the project 
# root directory instead.
if Path.cwd().stem == "nbs":
    os.chdir(Path.cwd().parent)
print(f"The current working directory is {Path.cwd()}")

The current working directory is c:\Users\Shaun\source\Thesis\MisdiagnosisOfAthleteECG


In [22]:
#| code-fold: true
#| code-summary: "Click to see local packages imported"
from src.run_12ECG_classifier import load_12ECG_model, run_12ECG_classifier
from src.data.util import get_all_records, get_predicted_findings, diagnosis_codes, codes_to_label_vector
from src.data.challenge2020 import extract_snomed_ct_codes_from_comment

In [4]:
#|include: false
# Import configuration settings, like location of data directory.
config = configparser.ConfigParser()
if not Path("config.ini").exists():
    print("WARNING: Please generate a config.ini file by running scripts/get_datasets.py")
else:
    config.read("config.ini")
    data_dir = Path((config["datasets"]["path"])).expanduser()
    print(f"Datasets are located at {data_dir.resolve()}")

Datasets are located at C:\Users\Shaun\source\Thesis\MisdiagnosisOfAthleteECG\data


## Young subset of training datasets

In [5]:
physionet_data_dir = data_dir / "challenge-2020" / "1.0.2" / "training"
georgia_dataset_dir = physionet_data_dir / "georgia"
ptbxl_dataset_dir = physionet_data_dir / "ptb-xl"
cpsc_dataset_dir = physionet_data_dir / "cpsc_2018"
cpscextra_dataset_dir = physionet_data_dir / "cpsc_2018_extra"

In [None]:
source_records = \
    get_all_records(georgia_dataset_dir) + \
    get_all_records(ptbxl_dataset_dir) + \
    get_all_records(cpsc_dataset_dir) + \
    get_all_records(cpscextra_dataset_dir)

# This will take 3-4 minutes
young_subset = []   # Every record under a certain age
vampires = 0        # Vampires don't age
for entry in source_records:
    header = wfdb.rdheader(entry)
    age_str = header.comments[0].split(': ')[1]
    age = int( age_str ) if age_str.isnumeric() else None
    if age == None:
        vampires += 1
    elif age < 40:
        young_subset.append(entry)

In [7]:
len(source_records)

41414

In [8]:
len(young_subset)

5252

In [12]:
vampires

177

## Model setup

In [9]:
original_weights_dir = Path.cwd() / "checkpoints" / "original"
reweight_dir = Path.cwd() / "checkpoints" / "reweight_1"

config_dir = Path.cwd() / "config"
training_data_dir = data_dir / "challenge-2020" / "1.0.2" / "training"
target_data_dir = data_dir / "norwegian-athlete-ecg" / "1.0.0"

# Ensure output directory exists
if reweight_dir.exists():
    print(f"{reweight_dir} already exists. Are we overwriting an existing model?")
else:
    reweight_dir.mkdir()

c:\Users\Shaun\source\Thesis\MisdiagnosisOfAthleteECG\checkpoints\reweight_1 already exists. Are we overwriting an existing model?


In [10]:
# Load model config from disk
data_cfg = dsail.config.DataConfig(config_dir / "data.json")
preprocess_cfg = dsail.config.PreprocessConfig(config_dir / "preprocess.json")
model_cfg = dsail.config.ModelConfig(config_dir / "model.json")
run_cfg = dsail.config.RunConfig(config_dir / "run.json")

# Check if CUDA device available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
device

device(type='cuda')

In [13]:
# Borrowed from DSAIL_SNU

def set_seeds(seed):
    """ set random seeds """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [14]:
print('Loading 12ECG model...')
model = load_12ECG_model(original_weights_dir, config_dir)

Loading 12ECG model...


In [15]:
print(f"This model is an ensemble of {len(model[3])} networks")

This model is an ensemble of 10 networks


## Scoring settings

We only care about sinus rhythm and RBBB findings.

In [19]:
sinus_labels = [426177001, 426783006, 427084000, 427393009]
# sinus_labels = [426177001, 426783006]       # Bradycardia or Normal
rbbb_labels = [713427006, 713426002]        # Incomplete RBBB, Complete RBBB
# won't do t-wave inversion, because no output for lead number provided.
athlete_labels = sinus_labels + rbbb_labels

## Training loop

In [20]:
# List of class labels returned by classifier
# current_label, current_score, classes = run_12ECG_classifier(signals, header_data, model)
classes = [
    '10370003',
    '111975006',
    '164889003',
    '164890007',
    '164909002',
    '164917005',
    '164934002',
    '164947007',
    '251146004',
    '270492004',
    '284470004',
    '39732003',
    '426177001',
    '426627000',
    '426783006',
    '427084000',
    '427172004',
    '427393009',
    '445118002',
    '47665007',
    '59931005',
    '698252002',
    '713426002',
    '713427006'
]

In [35]:
from src.train import evaluate, train

N_epochs = 10
lr = 0.001

# Train each of the 10 networks separately
# What does "fold" mean in this context?
for fold, net in enumerate(model[3]):
    print(f"Finetuning network {fold+1} / 10 over {N_epochs} epochs")
    set_seeds(2020)
    optimizer = torch.optim.SGD(params=net.parameters(), lr=lr)
    avg_ce_losses = []
    for epoch in range(N_epochs):
        ce_losses = []
        for entry in young_subset:      # TODO: Split into train and validation subsets
            # Inputs
            # 12-lead ECG signals (input data), and header info (e.g. sampling frequency)
            record = wfdb.rdrecord(entry)
            signals = record.p_signal.transpose()
            with open(entry.with_suffix(".hea"), 'r') as f:
                header_data=f.readlines()

            # True labels from cardiologist
            if record.comments[2] == 'Dx:':
                finding_codes = []
            else:
                finding_codes = extract_snomed_ct_codes_from_comment(record.comments[2])
            actual_labels = codes_to_label_vector(finding_codes, athlete_labels)

            # # Hack: If no sinus rhythm findings, assume normal sinus rhythm (426783006)
            # if sum(actual_labels) == 0:
            #     actual_labels[1] = 1
            # if sum(predicted_labels) == 0:
            #     predicted_labels[1] = 1

            actual_scores = np.array(actual_labels, dtype=float)

            # Load data for training
            data_cfg.data = signals
            data_cfg.header = header_data
            dataset_train = dsail.data.get_dataset_from_configs(data_cfg, preprocess_cfg)
            iterator_train = torch.utils.data.DataLoader(dataset_train, 1, collate_fn=dsail.data.collate_into_list)

            # Training pass
            # 1. Forward pass
            # 2. Cross-entropy loss
            # 3. Zero gradients
            # 4. Backpropagation on loss
            # 5. Gradient descent
            scalar_outputs = None
            for batch in iterator_train:
                # scalar_outputs = evaluate(net, batch, device, loss_weights_and_flags)
                scalar_outputs, ce_loss = train(net, batch, device, optimizer, classes, athlete_labels, actual_scores)
                # hack to prevent nan poison
                if not ce_loss.isnan():
                    ce_losses.append(ce_loss)

            # Select athletic subset of predicted scores for later analysis
            predicted_scores_net = np.zeros(len(athlete_labels))
            for i, code in enumerate(classes):
                if int(code) in athlete_labels:
                    index = athlete_labels.index(int(code))
                    predicted_scores_net[index] = scalar_outputs[0][i]
        
        avg_ce_losses.append(sum(ce_losses) / len(ce_losses))

        # Save checkpoint
        torch.save(net.state_dict(), reweight_dir / f"model_{fold}_epoch_{epoch}.sav")
    
    print(f"\tinit (epoch 0): \t{avg_ce_losses[0]}")
    print(f"\tbest (epoch {avg_ce_losses.index(min(avg_ce_losses))}): \t{min(avg_ce_losses)}")
    print(f"\timprovement: {avg_ce_losses[0] - min(avg_ce_losses)}")
    
    # Save best network weights
    shutil.copyfile(
        reweight_dir / f"model_{fold}_epoch_{avg_ce_losses.index(min(avg_ce_losses))}.sav", 
        reweight_dir / f"finalized_model_{fold}.sav"
    )
    
    # Copy original class thresholds
    shutil.copyfile(
        original_weights_dir / f"finalized_model_thresholds_{fold}.npy",
        reweight_dir / f"finalized_model_thresholds_{fold}.npy"
    )

Finetuning network 1 / 10 over 10 epochs


  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_samp

	init (epoch 0): 	1.2214528322219849
	best (epoch 3): 	1.2202023267745972
	improvement: 0.0012505054473876953
Finetuning network 2 / 10 over 10 epochs


  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_samp

	init (epoch 0): 	1.2327011823654175
	best (epoch 8): 	1.231736660003662
	improvement: 0.0009645223617553711
Finetuning network 3 / 10 over 10 epochs


  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_samp

	init (epoch 0): 	1.2653496265411377
	best (epoch 5): 	1.261620044708252
	improvement: 0.003729581832885742
Finetuning network 4 / 10 over 10 epochs


  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_samp

	init (epoch 0): 	1.227596402168274
	best (epoch 0): 	1.227596402168274
	improvement: 0.0
Finetuning network 5 / 10 over 10 epochs


  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_samp

	init (epoch 0): 	1.219138264656067
	best (epoch 9): 	1.217279076576233
	improvement: 0.0018591880798339844
Finetuning network 6 / 10 over 10 epochs


  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_samp

	init (epoch 0): 	1.2039555311203003
	best (epoch 4): 	1.2029688358306885
	improvement: 0.0009866952896118164
Finetuning network 7 / 10 over 10 epochs


  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_samp

	init (epoch 0): 	1.2477622032165527
	best (epoch 9): 	1.2452261447906494
	improvement: 0.0025360584259033203
Finetuning network 8 / 10 over 10 epochs


  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_samp

	init (epoch 0): 	1.221293330192566
	best (epoch 3): 	1.2192747592926025
	improvement: 0.002018570899963379
Finetuning network 9 / 10 over 10 epochs


  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_samp

	init (epoch 0): 	1.237667441368103
	best (epoch 5): 	1.2363046407699585
	improvement: 0.0013628005981445312
Finetuning network 10 / 10 over 10 epochs


  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  updated_mean = (last_sum + new_sum) / updated_samp

	init (epoch 0): 	1.208930492401123
	best (epoch 1): 	1.2071137428283691
	improvement: 0.0018167495727539062
