# Finetuning classifier output

## Notebook setup

In [1]:
#| code-fold: true
#| code-summary: "Click to see packages imported"
import os
import configparser
import random
import shutil
from pathlib import Path

import torch
import wfdb
import numpy as np
import dsail
from dsail.model.model_utils import get_model
from dsail.train import Trainer
from dsail.data import get_loss_weights_and_flags

In [2]:
#|include: false
# If the current working directory is the nbs/ folder, change to the project 
# root directory instead.

if Path.cwd().stem == "nbs":
    os.chdir(Path.cwd().parent)
print(f"The current working directory is {Path.cwd()}")

The current working directory is c:\Users\Shaun\source\Thesis\MisdiagnosisOfAthleteECG


In [3]:
#| code-fold: true
#| code-summary: "Click to see local packages imported"
from src.run_12ECG_classifier import load_12ECG_model, run_12ECG_classifier
from src.data.util import get_all_records, get_predicted_findings, diagnosis_codes, codes_to_label_vector
import src.data.norwegian as norwegian

In [4]:
#|include: false
# Import configuration settings, like location of data directory.
config = configparser.ConfigParser()
if not Path("config.ini").exists():
    print("WARNING: Please generate a config.ini file by running scripts/get_datasets.py")
else:
    config.read("config.ini")
    data_dir = Path((config["datasets"]["path"])).expanduser()
    print(f"Datasets are located at {data_dir.resolve()}")

Datasets are located at C:\Users\Shaun\source\Thesis\MisdiagnosisOfAthleteECG\data


## Scoring settings

We only care about labels that correspond to "borderline" athlete findings.

In [5]:
sinus_labels = [426177001, 426783006, 427084000, 427393009]
# sinus_labels = [426177001, 426783006]       # Bradycardia or Normal
rbbb_labels = [713427006, 713426002]        # Incomplete RBBB, Complete RBBB
# won't do t-wave inversion, because no output for lead number provided.
athlete_labels = sinus_labels + rbbb_labels

## Model setup

In [6]:
original_weights_dir = Path.cwd() / "checkpoints" / "original"
finetune_dir = Path.cwd() / "checkpoints" / "finetune_2"

config_dir = Path.cwd() / "config"
training_data_dir = data_dir / "challenge-2020" / "1.0.2" / "training"
target_data_dir = data_dir / "norwegian-athlete-ecg" / "1.0.0"

# Ensure output directory exists
if finetune_dir.exists():
    print(f"{finetune_dir} already exists. Are we overwriting an existing finetune?")
else:
    finetune_dir.mkdir()

In [7]:
# Load model config from disk
data_cfg = dsail.config.DataConfig(config_dir / "data.json")
preprocess_cfg = dsail.config.PreprocessConfig(config_dir / "preprocess.json")
model_cfg = dsail.config.ModelConfig(config_dir / "model.json")
run_cfg = dsail.config.RunConfig(config_dir / "run.json")

# Check if CUDA device available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
device

device(type='cuda')

In [9]:
# Borrowed from DSAIL_SNU

def set_seeds(seed):
    """ set random seeds """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [10]:
print('Loading 12ECG model...')
model = load_12ECG_model(original_weights_dir, config_dir)

Loading 12ECG model...


In [11]:
print(f"This model is an ensemble of {len(model[3])} networks")

This model is an ensemble of 10 networks


In [12]:
# For each classification output, there are weights for 256 inputs to adjust.
# Can we use partial backpropogation for this layer only?
model[3][0].linear.weight.shape

torch.Size([24, 256])

In [13]:
# Alternatively, is just adding bias enough?
model[3][0].linear.bias.shape

torch.Size([24])

We only want the linear layer at the output of each network to be trainable.

In [14]:
%%script echo skipping cell
# Only classifier output stem should be trainable
# https://stackoverflow.com/questions/62523912/freeze-certain-layers-of-an-existing-model-in-pytorch

# Apply to all 10 networks in ensemble model
for net in model[3]:
    # Freeze all parameters in model
    for param in net.parameters():
        param.requires_grad = False

    # Unfreeze linear layer from output stem (classifier output)
    for param in net.linear.parameters():
        param.requires_grad = True

skipping cell


## Finetuning

In [15]:
%%script echo skipping cell
#|include: false
# Deprecated, keeping the carcass in case I need to use Trainer class again.

from src.train import evaluate, train
from dsail.data import collate_into_block

loss_fn = torch.nn.L1Loss()

for entry in get_all_records(target_data_dir):
    # 12-lead ECG signals (input data), and header info (e.g. sampling frequency)
    record = wfdb.rdrecord(target_data_dir / entry)
    signals = record.p_signal.transpose()
    with open((target_data_dir / entry).with_suffix(".hea"), 'r') as f:
        header_data=f.readlines()

    # Actual labels from cardiologist
    comments_c = record.comments[1]
    findings_c = norwegian.extract_findings(comments_c)
    actual_findings = norwegian.classify_relevant_findings(findings_c)
    actual_labels = codes_to_label_vector(actual_findings, athlete_labels)
    actual_scores = np.array(actual_labels, dtype=float)

    # Run full model, get predictions
    current_label, current_score, classes = run_12ECG_classifier(signals, header_data, model)
    predicted_scores = np.zeros(len(athlete_labels))
    for i, code in enumerate(classes):
        if int(code) in athlete_labels:
            index = athlete_labels.index(int(code))
            predicted_scores[index] = current_score[i]
    
    # Run each ensemble model separately
    losses = []
    for net in model[3]:
        # Setup for using DSAIL_SNU Trainer object for eval
        data_cfg.data = signals
        data_cfg.header = header_data
        dataset_val = dsail.data.get_dataset_from_configs(data_cfg, preprocess_cfg)
        iterator_val = torch.utils.data.DataLoader(dataset_val, 1, collate_fn=dsail.data.collate_into_list)
        dataset_train = dsail.data.get_dataset_from_configs(data_cfg, preprocess_cfg)
        iterator_train = torch.utils.data.DataLoader(dataset_train, 1, collate_fn=dsail.data.collate_into_list)

        loss_weights_and_flags = get_loss_weights_and_flags(data_cfg, run_cfg)
        # trainer = Trainer(net, data_cfg, run_cfg.multilabel, loss_weights_and_flags)
        # trainer.set_device(device, data_parallel=False)
        # trainer.set_optim_scheduler(run_cfg, list(net.parameters()))
        optimizer = torch.optim.SGD(params=net.parameters(), lr=0.01)

        # Evaluate model (forward pass)
        # If signal is too long, it'll get cut into multiple batches
        # for batch in iterator_val:
        #     trainer.evaluate(batch)
        # predicted_scores_net = np.zeros(len(athlete_labels))
        # for i, code in enumerate(classes):
        #     if int(code) in athlete_labels:
        #         index = athlete_labels.index(int(code))
        #         predicted_scores_net[index] = trainer.logger_eval.scalar_outputs[0][0][i]
        
        scalar_outputs = None
        for batch in iterator_train:
            # scalar_outputs = evaluate(net, batch, device, loss_weights_and_flags)
            scalar_outputs, ce_loss = train(net, batch, device, optimizer, classes, athlete_labels, actual_scores)
        # trainer.scheduler_step()
        predicted_scores_net = np.zeros(len(athlete_labels))
        for i, code in enumerate(classes):
            if int(code) in athlete_labels:
                index = athlete_labels.index(int(code))
                predicted_scores_net[index] = scalar_outputs[0][i]
                # predicted_scores_net[index] = trainer.logger_train.scalar_outputs[0][0][i]

        # Find error
        loss = loss_fn(torch.Tensor(predicted_scores_net), torch.Tensor(actual_scores))
        losses.append(loss)

        # # 
        # optimizer.zero_grad()
        # loss.backward()

            
    # Find error (for full model)
    loss = loss_fn(torch.Tensor(predicted_scores), torch.Tensor(actual_scores))
    print(f"loss (full model) = {loss}")
    print(f"loss (avg of 10) = {sum(losses) / len(losses)}")

    # Adjust output stem

skipping cell


In [16]:
# List of class labels returned by classifier
# current_label, current_score, classes = run_12ECG_classifier(signals, header_data, model)
classes = [
    '10370003',
    '111975006',
    '164889003',
    '164890007',
    '164909002',
    '164917005',
    '164934002',
    '164947007',
    '251146004',
    '270492004',
    '284470004',
    '39732003',
    '426177001',
    '426627000',
    '426783006',
    '427084000',
    '427172004',
    '427393009',
    '445118002',
    '47665007',
    '59931005',
    '698252002',
    '713426002',
    '713427006'
]

In [17]:
%%script echo skipping cell
# Evaluate original model

loss_fn = torch.nn.L1Loss()

for entry in get_all_records(target_data_dir):
    # 12-lead ECG signals (input data), and header info (e.g. sampling frequency)
    record = wfdb.rdrecord(target_data_dir / entry)
    signals = record.p_signal.transpose()
    with open((target_data_dir / entry).with_suffix(".hea"), 'r') as f:
        header_data=f.readlines()

    # Actual labels from cardiologist
    comments_c = record.comments[1]
    findings_c = norwegian.extract_findings(comments_c)
    actual_findings = norwegian.classify_relevant_findings(findings_c)
    actual_labels = codes_to_label_vector(actual_findings, athlete_labels)
    actual_scores = np.array(actual_labels, dtype=float)

    # Run full model, get predictions
    current_label, current_score, classes = run_12ECG_classifier(signals, header_data, model)
    predicted_scores = np.zeros(len(athlete_labels))
    for i, code in enumerate(classes):
        if int(code) in athlete_labels:
            index = athlete_labels.index(int(code))
            predicted_scores[index] = current_score[i]
    
    # Find error (for full model)
    loss = loss_fn(torch.Tensor(predicted_scores), torch.Tensor(actual_scores))
    print(f"loss (full model) = {loss}")

skipping cell


In [18]:
from src.train import evaluate, train

N_epochs = 100

# Train each of the 10 networks separately
# What does "fold" mean in this context?
for fold, net in enumerate(model[3]):
    print(f"Finetuning network {fold+1} / 10 over {N_epochs} epochs")
    set_seeds(2020)
    optimizer = torch.optim.SGD(params=net.parameters(), lr=0.0001)
    avg_ce_losses = []
    for epoch in range(N_epochs):
        ce_losses = []
        for entry in get_all_records(target_data_dir):      # TODO: Split into train and validation subsets
            # Inputs
            # 12-lead ECG signals (input data), and header info (e.g. sampling frequency)
            record = wfdb.rdrecord(target_data_dir / entry)
            signals = record.p_signal.transpose()
            with open((target_data_dir / entry).with_suffix(".hea"), 'r') as f:
                header_data=f.readlines()

            # True labels from cardiologist
            comments_c = record.comments[1]
            findings_c = norwegian.extract_findings(comments_c)
            actual_findings = norwegian.classify_relevant_findings(findings_c)
            actual_labels = codes_to_label_vector(actual_findings, athlete_labels)
            actual_scores = np.array(actual_labels, dtype=float)

            # Load data for training
            data_cfg.data = signals
            data_cfg.header = header_data
            dataset_train = dsail.data.get_dataset_from_configs(data_cfg, preprocess_cfg)
            iterator_train = torch.utils.data.DataLoader(dataset_train, 1, collate_fn=dsail.data.collate_into_list)

            # Training pass
            # 1. Forward pass
            # 2. Cross-entropy loss
            # 3. Zero gradients
            # 4. Backpropagation on loss
            # 5. Gradient descent
            scalar_outputs = None
            for batch in iterator_train:
                # scalar_outputs = evaluate(net, batch, device, loss_weights_and_flags)
                scalar_outputs, ce_loss = train(net, batch, device, optimizer, classes, athlete_labels, actual_scores)
                ce_losses.append(ce_loss)

            # Select athletic subset of predicted scores for later analysis
            predicted_scores_net = np.zeros(len(athlete_labels))
            for i, code in enumerate(classes):
                if int(code) in athlete_labels:
                    index = athlete_labels.index(int(code))
                    predicted_scores_net[index] = scalar_outputs[0][i]
        
        avg_ce_losses.append(sum(ce_losses) / len(ce_losses))

        # Save checkpoint
        torch.save(net.state_dict(), finetune_dir / f"model_{fold}_epoch_{epoch}.sav")
    
    print(f"\tinit (epoch 0): \t{avg_ce_losses[0]}")
    print(f"\tbest (epoch {avg_ce_losses.index(min(avg_ce_losses))}): \t{min(avg_ce_losses)}")
    print(f"\timprovement: {avg_ce_losses[0] - min(avg_ce_losses)}")
    
    # Save best network weights
    shutil.copyfile(
        finetune_dir / f"model_{fold}_epoch_{avg_ce_losses.index(min(avg_ce_losses))}.sav", 
        finetune_dir / f"finalized_model_{fold}.sav"
    )
    
    # Copy original class thresholds
    shutil.copyfile(
        original_weights_dir / f"finalized_model_thresholds_{fold}.npy",
        finetune_dir / f"finalized_model_thresholds_{fold}.npy"
    )

Finetuning network 1 / 10 over 100 epochs
	init (epoch 0): 	1.6440274715423584
	best (epoch 1): 	1.6220239400863647
	improvement: 0.022003531455993652
Finetuning network 2 / 10 over 100 epochs
	init (epoch 0): 	1.6619373559951782
	best (epoch 60): 	1.6334235668182373
	improvement: 0.028513789176940918
Finetuning network 3 / 10 over 100 epochs
	init (epoch 0): 	1.6975668668746948
	best (epoch 81): 	1.6526380777359009
	improvement: 0.044928789138793945
Finetuning network 4 / 10 over 100 epochs
	init (epoch 0): 	1.6788705587387085
	best (epoch 88): 	1.6343796253204346
	improvement: 0.044490933418273926
Finetuning network 5 / 10 over 100 epochs
	init (epoch 0): 	1.6671278476715088
	best (epoch 9): 	1.623307466506958
	improvement: 0.04382038116455078
Finetuning network 6 / 10 over 100 epochs
	init (epoch 0): 	1.6786848306655884
	best (epoch 36): 	1.6391350030899048
	improvement: 0.039549827575683594
Finetuning network 7 / 10 over 100 epochs
	init (epoch 0): 	1.6817437410354614
	best (epoch 

## Evaluation

In [19]:
print("Loading finetuned model")
model_finetune = load_12ECG_model(finetune_dir, config_dir)

Loading finetuned model


In [20]:
# Evaluate finetuned model

# loss_fn = torch.nn.L1Loss()     # TODO: Should use cross-entropy loss for multiclass classifier
loss_fn = torch.nn.CrossEntropyLoss()

for entry in get_all_records(target_data_dir):
    # 12-lead ECG signals (input data), and header info (e.g. sampling frequency)
    record = wfdb.rdrecord(target_data_dir / entry)
    signals = record.p_signal.transpose()
    with open((target_data_dir / entry).with_suffix(".hea"), 'r') as f:
        header_data=f.readlines()

    # Actual labels from cardiologist
    comments_c = record.comments[1]
    findings_c = norwegian.extract_findings(comments_c)
    actual_findings = norwegian.classify_relevant_findings(findings_c)
    actual_labels = codes_to_label_vector(actual_findings, athlete_labels)
    actual_scores = np.array(actual_labels, dtype=float)

    # Run original model, get predictions
    current_label, current_score, classes = run_12ECG_classifier(signals, header_data, model)
    predicted_scores = np.zeros(len(athlete_labels))
    for i, code in enumerate(classes):
        if int(code) in athlete_labels:
            index = athlete_labels.index(int(code))
            predicted_scores[index] = current_score[i]
    
    # Run finetuned model, get predictions
    current_label, current_score, classes = run_12ECG_classifier(signals, header_data, model_finetune)
    predicted_scores_finetune = np.zeros(len(athlete_labels))
    for i, code in enumerate(classes):
        if int(code) in athlete_labels:
            index = athlete_labels.index(int(code))
            predicted_scores_finetune[index] = current_score[i]
    
    # Find error (for full model)
    loss_original = loss_fn(torch.Tensor(predicted_scores), torch.Tensor(actual_scores))
    loss_finetune = loss_fn(torch.Tensor(predicted_scores_finetune), torch.Tensor(actual_scores))
    print(f"loss (original) = {loss_original}\tloss (finetune) = {loss_finetune}")
    print(actual_scores)
    print(predicted_scores)
    print(predicted_scores_finetune)

loss (original) = 1.6488487720489502	loss (finetune) = 1.6584820747375488
[0. 0. 0. 1. 0. 0.]
[0.01571385 0.57213237 0.0145252  0.35925963 0.15850817 0.04178687]
[0.01418421 0.54637196 0.01346938 0.34263794 0.17458124 0.04220469]
loss (original) = 1.57094144821167	loss (finetune) = 1.5903799533843994
[0. 0. 0. 1. 0. 0.]
[0.04767237 0.74101483 0.01472449 0.48700721 0.01970278 0.02061614]
[0.04414863 0.73619656 0.01411685 0.46071363 0.02058608 0.02280122]
loss (original) = 1.8113987445831299	loss (finetune) = 1.8125780820846558
[0. 1. 0. 0. 0. 0.]
[0.01351097 0.15493916 0.02472729 0.07357612 0.55262998 0.11695744]
[0.01310635 0.1527068  0.02221194 0.06393385 0.55728036 0.11722222]
loss (original) = 1.268202781677246	loss (finetune) = 1.2776403427124023
[0. 1. 0. 0. 0. 0.]
[0.0723825  0.7378892  0.04833657 0.04445688 0.07290736 0.09258826]
[0.06857638 0.72607492 0.04551834 0.04085563 0.07762357 0.10403387]
loss (original) = 3.1200459003448486	loss (finetune) = 3.1249308586120605
[1. 1. 0.

In [21]:
torch.optim.SGD?

[1;31mInit signature:[0m
[0mtorch[0m[1;33m.[0m[0moptim[0m[1;33m.[0m[0mSGD[0m[1;33m([0m[1;33m
[0m    [0mparams[0m[1;33m:[0m [0mUnion[0m[1;33m[[0m[0mIterable[0m[1;33m[[0m[0mtorch[0m[1;33m.[0m[0mTensor[0m[1;33m][0m[1;33m,[0m [0mIterable[0m[1;33m[[0m[0mDict[0m[1;33m[[0m[0mstr[0m[1;33m,[0m [0mAny[0m[1;33m][0m[1;33m][0m[1;33m,[0m [0mIterable[0m[1;33m[[0m[0mTuple[0m[1;33m[[0m[0mstr[0m[1;33m,[0m [0mtorch[0m[1;33m.[0m[0mTensor[0m[1;33m][0m[1;33m][0m[1;33m][0m[1;33m,[0m[1;33m
[0m    [0mlr[0m[1;33m:[0m [0mUnion[0m[1;33m[[0m[0mfloat[0m[1;33m,[0m [0mtorch[0m[1;33m.[0m[0mTensor[0m[1;33m][0m [1;33m=[0m [1;36m0.001[0m[1;33m,[0m[1;33m
[0m    [0mmomentum[0m[1;33m:[0m [0mfloat[0m [1;33m=[0m [1;36m0[0m[1;33m,[0m[1;33m
[0m    [0mdampening[0m[1;33m:[0m [0mfloat[0m [1;33m=[0m [1;36m0[0m[1;33m,[0m[1;33m
[0m    [0mweight_decay[0m[1;33m:[0m [0mfloat[0m [1;33m=[0m 

In [22]:
print(np.finfo(float).eps)

2.220446049250313e-16
