Author : _Sasika Amarasinghe_

## Clone Repository

In [1]:
!git clone https://github.com/IRS-UOM/vgg-16-cnn.git

fatal: destination path 'vgg-16-cnn' already exists and is not an empty directory.


In [2]:
!pip install -r vgg-16-cnn/requirements.txt



In [3]:
# !pip install wfdb==4.1.2

## Download Dataset


In [4]:
# prompt: Download this "https://drive.google.com/file/d/1ZBC4vUCM2znV7ngAsigWnMkNdxoH7bMi/view?usp=drive_link" zip as train_data.zip

!gdown --id 1ZBC4vUCM2znV7ngAsigWnMkNdxoH7bMi -O train_data.zip

Downloading...
From (original): https://drive.google.com/uc?id=1ZBC4vUCM2znV7ngAsigWnMkNdxoH7bMi
From (redirected): https://drive.google.com/uc?id=1ZBC4vUCM2znV7ngAsigWnMkNdxoH7bMi&confirm=t&uuid=7020b974-ff81-4347-91ef-d3118f18c32e
To: /content/train_data.zip
100% 367M/367M [00:03<00:00, 103MB/s]


In [5]:
# prompt: unzip "/content/train_data.zip" to a folder "train_data"

!unzip -j train_data.zip -d train_data_0

Archive:  train_data.zip
replace train_data_0/00928_hr.dat? [y]es, [n]o, [A]ll, [N]one, [r]ename: A
  inflating: train_data_0/00928_hr.dat  
  inflating: train_data_0/00182_hr-0.png  
  inflating: train_data_0/00343_hr.hea  
  inflating: train_data_0/00426_hr.hea  
  inflating: train_data_0/00541_hr-0.png  
  inflating: train_data_0/00213_hr.dat  
  inflating: train_data_0/00323_hr.hea  
  inflating: train_data_0/00205_hr.hea  
  inflating: train_data_0/00699_hr.dat  
  inflating: train_data_0/00296_hr.dat  
  inflating: train_data_0/00558_hr-0.png  
  inflating: train_data_0/00600_hr.dat  
  inflating: train_data_0/00888_hr.dat  
  inflating: train_data_0/00678_hr.dat  
  inflating: train_data_0/00779_hr-0.png  
  inflating: train_data_0/00873_hr-0.png  
  inflating: train_data_0/00683_hr.dat  
  inflating: train_data_0/00250_hr-0.png  
  inflating: train_data_0/00801_hr.hea  
  inflating: train_data_0/00868_hr.dat  
  inflating: train_data_0/00394_hr.hea  
  inflating: train_data_0/0

---
# Original code by competition organisers
---

---
### helper_code.py
---

In [6]:
#!/usr/bin/env python

# Do *not* edit this script.
# These are helper functions that you can use with your code.
# Check the example code to see how to use these functions in your code.

import numpy as np
import os
import sys
import wfdb

### Challenge variables
substring_labels = '# Labels:'
substring_images = '# Image:'

### Challenge data I/O functions

# Find the records in a folder and its subfolders.
def find_records(folder):
    records = set()
    for root, directories, files in os.walk(folder):
        for file in files:
            extension = os.path.splitext(file)[1]
            if extension == '.hea':
                record = os.path.relpath(os.path.join(root, file), folder)[:-4]
                records.add(record)
    records = sorted(records)
    return records

# Load the header for a record.
def load_header(record):
    header_file = get_header_file(record)
    header = load_text(header_file)
    return header

# Load the signals for a record.
def load_signals(record):
    signal_files = get_signal_files(record)
    path = os.path.split(record)[0]
    signal_files_exist = all(os.path.isfile(os.path.join(path, signal_file)) for signal_file in signal_files)
    if signal_files and signal_files_exist:
        signal, fields = wfdb.rdsamp(record)
    else:
        signal, fields = None, None
    return signal, fields

# Load the images for a record.
def load_images(record):
    from PIL import Image

    path = os.path.split(record)[0]
    image_files = get_image_files(record)

    images = list()
    for image_file in image_files:
        image_file_path = os.path.join(path, image_file)
        if os.path.isfile(image_file_path):
            image = Image.open(image_file_path)
            images.append(image)

    return images

# Load the labels for a record.
def load_labels(record):
    header = load_header(record)
    labels = get_labels_from_header(header)
    return labels

# Save the header for a record.
def save_header(record, header):
    header_file = get_header_file(record)
    save_text(header_file, header)

# Save the signals for a record.
def save_signals(record, signal, comments=list()):
    header = load_header(record)
    path, record = os.path.split(record)
    sampling_frequency = get_sampling_frequency(header)
    signal_formats = get_signal_formats(header)
    adc_gains = get_adc_gains(header)
    baselines = get_baselines(header)
    signal_units = get_signal_units(header)
    signal_names = get_signal_names(header)
    comments = [comment.replace('#', '').strip() for comment in comments]

    wfdb.wrsamp(record, fs=sampling_frequency, units=signal_units, sig_name=signal_names, \
                p_signal=signal, fmt=signal_formats, adc_gain=adc_gains, baseline=baselines, comments=comments,
                write_dir=path)

# Save the labels for a record.
def save_labels(record, labels):
    header_file = get_header_file(record)
    header = load_text(header_file)
    header += substring_labels + ' ' + ', '.join(labels) + '\n'
    save_text(header_file, header)
    return header

### Helper Challenge functions

# Load a text file as a string.
def load_text(filename):
    with open(filename, 'r') as f:
        string = f.read()
    return string

# Save a string as a text file.
def save_text(filename, string):
    with open(filename, 'w') as f:
        f.write(string)

# Get a variable from a string.
def get_variable(string, variable_name):
    variable = ''
    has_variable = False
    for l in string.split('\n'):
        if l.startswith(variable_name):
            variable = l[len(variable_name):].strip()
            has_variable = True
    return variable, has_variable

# Get variables from a string.
def get_variables(string, variable_name, sep=','):
    variables = list()
    has_variable = False
    for l in string.split('\n'):
        if l.startswith(variable_name):
            variables += [variable.strip() for variable in l[len(variable_name):].strip().split(sep)]
            has_variable = True
    return variables, has_variable

# Get the signal files from a header or a similar string.
def get_signal_files_from_header(string):
    signal_files = list()
    for i, l in enumerate(string.split('\n')):
        arrs = [arr.strip() for arr in l.split(' ')]
        if i==0 and not l.startswith('#'):
            num_channels = int(arrs[1])
        elif i<=num_channels and not l.startswith('#'):
            signal_file = arrs[0]
            if signal_file not in signal_files:
                signal_files.append(signal_file)
        else:
            break
    return signal_files

# Get the image files from a header or a similar string.
def get_image_files_from_header(string):
    images, has_image = get_variables(string, substring_images)
    if not has_image:
        raise Exception('No images available: did you forget to generate or include the images?')
    return images

# Get the labels from a header or a similar string.
def get_labels_from_header(string):
    labels, has_labels = get_variables(string, substring_labels)
    if not has_labels:
        raise Exception('No labels available: are you trying to load the labels from the held-out data, or did you forget to prepare the data to include the labels?')
    return labels

# Get the header file for a record.
def get_header_file(record):
    if not record.endswith('.hea'):
        header_file = record + '.hea'
    else:
        header_file = record
    return header_file

# Get the signal files for a record.
def get_signal_files(record):
    header_file = get_header_file(record)
    header = load_text(header_file)
    signal_files = get_signal_files_from_header(header)
    return signal_files

# Get the image files for a record.
def get_image_files(record):
    header_file = get_header_file(record)
    header = load_text(header_file)
    image_files = get_image_files_from_header(header)
    return image_files

### WFDB functions

# Get the record name from a header file.
def get_record_name(string):
    value = string.split('\n')[0].split(' ')[0].split('/')[0].strip()
    return value

# Get the number of signals from a header file.
def get_num_signals(string):
    value = string.split('\n')[0].split(' ')[1].strip()
    if is_integer(value):
        value = int(value)
    else:
        value = None
    return value

# Get the sampling frequency from a header file.
def get_sampling_frequency(string):
    value = string.split('\n')[0].split(' ')[2].split('/')[0].strip()
    if is_number(value):
        value = float(value)
    else:
        value = None
    return value

# Get the number of samples from a header file.
def get_num_samples(string):
    value = string.split('\n')[0].split(' ')[3].strip()
    if is_integer(value):
        value = int(value)
    else:
        value = None
    return value

# Get the signal formats from a header file.
def get_signal_formats(string):
    num_signals = get_num_signals(string)
    values = list()
    for i, l in enumerate(string.split('\n')):
        if 1 <= i <= num_signals:
            field = l.split(' ')[1]
            if 'x' in field:
                field = field.split('x')[0]
            if ':' in field:
                field = field.split(':')[0]
            if '+' in field:
                field = field.split('+')[0]
            value = field
            values.append(value)
    return values

# Get the ADC gains from a header file.
def get_adc_gains(string):
    num_signals = get_num_signals(string)
    values = list()
    for i, l in enumerate(string.split('\n')):
        if 1 <= i <= num_signals:
            field = l.split(' ')[2]
            if '/' in field:
                field = field.split('/')[0]
            if '(' in field and ')' in field:
                field = field.split('(')[0]
            value = float(field)
            values.append(value)
    return values

# Get the baselines from a header file.
def get_baselines(string):
    num_signals = get_num_signals(string)
    values = list()
    for i, l in enumerate(string.split('\n')):
        if 1 <= i <= num_signals:
            field = l.split(' ')[2]
            if '/' in field:
                field = field.split('/')[0]
            if '(' in field and ')' in field:
                field = field.split('(')[1].split(')')[0]
            else:
                field = get_adc_zeros(string)[i-1]
            value = int(field)
            values.append(value)
    return values

# Get the signal units from a header file.
def get_signal_units(string):
    num_signals = get_num_signals(string)
    values = list()
    for i, l in enumerate(string.split('\n')):
        if 1 <= i <= num_signals:
            field = l.split(' ')[2]
            if '/' in field:
                value = field.split('/')[1]
            else:
                value = 'mV'
            values.append(value)
    return values

# Get the ADC resolutions from a header file.
def get_adc_resolutions(string):
    num_signals = get_num_signals(string)
    values = list()
    for i, l in enumerate(string.split('\n')):
        if 1 <= i <= num_signals:
            field = l.split(' ')[3]
            value = int(field)
            values.append(value)
    return values

# Get the ADC zeros from a header file.
def get_adc_zeros(string):
    num_signals = get_num_signals(string)
    values = list()
    for i, l in enumerate(string.split('\n')):
        if 1 <= i <= num_signals:
            field = l.split(' ')[4]
            value = int(field)
            values.append(value)
    return values

# Get the initial values of a signal from a header file.
def get_initial_values(string):
    num_signals = get_num_signals(string)
    values = list()
    for i, l in enumerate(string.split('\n')):
        if 1 <= i <= num_signals:
            field = l.split(' ')[5]
            value = int(field)
            values.append(value)
    return values

# Get the checksums of a signal from a header file.
def get_checksums(string):
    num_signals = get_num_signals(string)
    values = list()
    for i, l in enumerate(string.split('\n')):
        if 1 <= i <= num_signals:
            field = l.split(' ')[6]
            value = int(field)
            values.append(value)
    return values

# Get the block sizes of a signal from a header file.
def get_block_sizes(string):
    num_signals = get_num_signals(string)
    values = list()
    for i, l in enumerate(string.split('\n')):
        if 1 <= i <= num_signals:
            field = l.split(' ')[7]
            value = int(field)
            values.append(value)
    return values

# Get the signal names from a header file.
def get_signal_names(string):
    num_signals = get_num_signals(string)
    values = list()
    for i, l in enumerate(string.split('\n')):
        if 1 <= i <= num_signals:
            value = l.split(' ')[8]
            values.append(value)
    return values

### Evaluation functions

# Construct the binary one-vs-rest confusion matrices, where the columns are the expert labels and the rows are the classifier
# for the given classes.
def compute_one_vs_rest_confusion_matrix(labels, outputs, classes):
    assert np.shape(labels) == np.shape(outputs)

    num_instances = len(labels)
    num_classes = len(classes)

    A = np.zeros((num_classes, 2, 2))
    for i in range(num_instances):
        for j in range(num_classes):
            if labels[i, j] == 1 and outputs[i, j] == 1: # TP
                A[j, 0, 0] += 1
            elif labels[i, j] == 0 and outputs[i, j] == 1: # FP
                A[j, 0, 1] += 1
            elif labels[i, j] == 1 and outputs[i, j] == 0: # FN
                A[j, 1, 0] += 1
            elif labels[i, j] == 0 and outputs[i, j] == 0: # TN
                A[j, 1, 1] += 1

    return A

# Compute macro F-measure.
def compute_f_measure(labels, outputs):
    # Compute confusion matrix.
    classes = sorted(set.union(*map(set, labels)))
    labels = compute_one_hot_encoding(labels, classes)
    outputs = compute_one_hot_encoding(outputs, classes)
    A = compute_one_vs_rest_confusion_matrix(labels, outputs, classes)

    num_classes = len(classes)
    per_class_f_measure = np.zeros(num_classes)
    for k in range(num_classes):
        tp, fp, fn, tn = A[k, 0, 0], A[k, 0, 1], A[k, 1, 0], A[k, 1, 1]
        if 2 * tp + fp + fn > 0:
            per_class_f_measure[k] = float(2 * tp) / float(2 * tp + fp + fn)
        else:
            per_class_f_measure[k] = float('nan')

    if np.any(np.isfinite(per_class_f_measure)):
        macro_f_measure = np.nanmean(per_class_f_measure)
    else:
        macro_f_measure = float('nan')

    return macro_f_measure, per_class_f_measure, classes

# Reorder channels in signal.
def reorder_signal(input_signal, input_channels, output_channels):
    # Do not allow repeated channels with potentially different values in a signal.
    assert(len(set(input_channels)) == len(input_channels))
    assert(len(set(output_channels)) == len(output_channels))

    if input_channels == output_channels:
        output_signal = input_signal
    else:
        input_channels = [channel.strip().casefold() for channel in input_channels]
        output_channels = [channel.strip().casefold() for channel in output_channels]

        input_signal = np.asarray(input_signal)
        num_samples = np.shape(input_signal)[0]
        num_channels = len(output_channels)
        data_type = input_signal.dtype
        output_signal = np.zeros((num_samples, num_channels), dtype=data_type)

        for i, output_channel in enumerate(output_channels):
            for j, input_channel in enumerate(input_channels):
                if input_channel == output_channel:
                    output_signal[:, i] = input_signal[:, j]

    return output_signal

# Pad or truncate signal.
def trim_signal(input_signal, num_samples_trimmed):
    input_signal = np.asarray(input_signal)
    num_samples, num_channels = np.shape(input_signal)
    data_type = input_signal.dtype

    if num_samples == num_samples_trimmed:
        output_signal = input_signal
    else:
        output_signal = np.zeros((num_samples_trimmed, num_channels), dtype=data_type)
        if num_samples < num_samples_trimmed: # Zero-pad the signals.
            output_signal[:num_samples, :] = input_signal
        else: # Truncate the signals.
            output_signal = input_signal[:num_samples_trimmed, :]

    return output_signal

# Compute SNR.
def compute_snr(label_signal, output_signal):
    label_signal = np.asarray(label_signal)
    output_signal = np.asarray(output_signal)

    assert(label_signal.ndim == output_signal.ndim == 1)
    assert(np.size(label_signal) == np.size(output_signal))

    idx_finite_signal = np.isfinite(label_signal)
    label_signal = label_signal[idx_finite_signal]
    output_signal = output_signal[idx_finite_signal]

    idx_nan_signal = np.isnan(output_signal)
    output_signal[idx_nan_signal] = 0

    noise_signal = output_signal - label_signal

    x = np.sum(label_signal**2)
    y = np.sum(noise_signal**2)

    if x > 0 and y > 0:
        snr = 10 * np.log10(x / y)
    elif x > 0 and y == 0:
        snr = float('inf')
    else:
        snr = float('nan')

    return snr

# Compute the mean signal power to median noise power metric.
def compute_snr_median(label_signal, output_signal):
    label_signal = np.asarray(label_signal)
    output_signal = np.asarray(output_signal)

    assert(label_signal.ndim == output_signal.ndim == 1)
    assert(np.size(label_signal) == np.size(output_signal))

    idx_finite_signal = np.isfinite(label_signal)
    label_signal = label_signal[idx_finite_signal]
    output_signal = output_signal[idx_finite_signal]

    idx_nan_signal = np.isnan(output_signal)
    output_signal[idx_nan_signal] = 0

    noise_signal = output_signal - label_signal

    x = np.mean(label_signal**2)
    y = np.median(noise_signal**2)

    if y > 0:
        snr = 10 * np.log10(x / y)
    else:
        snr = float('inf')

    return snr

# Compute a metric inspired by the Kolmogorov-Smirnov test statistic.
def compute_ks_metric(label_signal, output_signal):
    label_signal = np.asarray(label_signal)
    output_signal = np.asarray(output_signal)

    assert(label_signal.ndim == output_signal.ndim == 1)
    assert(np.size(label_signal) == np.size(output_signal))

    idx_finite_signal = np.isfinite(label_signal)
    label_signal = label_signal[idx_finite_signal]
    output_signal = output_signal[idx_finite_signal]

    idx_nan_signal = np.isnan(output_signal)
    output_signal[idx_nan_signal] = 0

    label_signal_cdf = np.cumsum(np.abs(label_signal))
    output_signal_cdf = np.cumsum(np.abs(output_signal))

    if label_signal_cdf[-1] > 0:
        label_signal_cdf = label_signal_cdf / label_signal_cdf[-1]
    if output_signal_cdf[-1] > 0:
        output_signal_cdf = output_signal_cdf / output_signal_cdf[-1]

    goodness_of_fit = 1.0 - np.max(np.abs(label_signal_cdf - output_signal_cdf))

    return goodness_of_fit

# Compute the adaptive signed correlation index (ASCI) metric.
def compute_asci_metric(label_signal, output_signal, beta=0.05):
    label_signal = np.asarray(label_signal)
    output_signal = np.asarray(output_signal)

    assert(label_signal.ndim == output_signal.ndim == 1)
    assert(np.size(label_signal) == np.size(output_signal))

    idx_finite_signal = np.isfinite(label_signal)
    label_signal = label_signal[idx_finite_signal]
    output_signal = output_signal[idx_finite_signal]

    idx_nan_signal = np.isnan(output_signal)
    output_signal[idx_nan_signal] = 0

    if beta <= 0 or beta > 1:
        raise ValueError('The beta value should be greater than 0 and less than or equal to 1.')

    threshold = beta * np.std(label_signal)

    noise_signal = np.abs(label_signal - output_signal)

    discrete_noise = np.zeros_like(noise_signal)
    discrete_noise[noise_signal <= threshold] = 1
    discrete_noise[noise_signal > threshold] = -1

    asci = np.mean(discrete_noise)

    return asci

# Compute a weighted absolute difference metric.
def compute_weighted_absolute_difference(label_signal, output_signal, sampling_frequency):
    label_signal = np.asarray(label_signal)
    output_signal = np.asarray(output_signal)

    assert(label_signal.ndim == output_signal.ndim == 1)
    assert(np.size(label_signal) == np.size(output_signal))

    idx_finite_signal = np.isfinite(label_signal)
    label_signal = label_signal[idx_finite_signal]
    output_signal = output_signal[idx_finite_signal]

    idx_nan_signal = np.isnan(output_signal)
    output_signal[idx_nan_signal] = 0

    from scipy.signal import filtfilt

    m = round(0.1 * sampling_frequency)
    w = filtfilt(np.ones(m), m, label_signal, method='gust')
    w = 1 - 0.5/np.max(w) * w
    n = np.sum(w)

    weighted_absolute_difference_metric = np.sum(np.abs(label_signal-output_signal) * w)/n

    return weighted_absolute_difference_metric

### Other helper functions

# Check if a variable is a number or represents a number.
def is_number(x):
    try:
        float(x)
        return True
    except (ValueError, TypeError):
        return False

# Check if a variable is an integer or represents an integer.
def is_integer(x):
    if is_number(x):
        return float(x).is_integer()
    else:
        return False

# Check if a variable is a finite number or represents a finite number.
def is_finite_number(x):
    if is_number(x):
        return np.isfinite(float(x))
    else:
        return False

# Check if a variable is a NaN (not a number) or represents a NaN.
def is_nan(x):
    if is_number(x):
        return np.isnan(float(x))
    else:
        return False

# Cast a value to an integer if an integer, a float if a non-integer float, and an unknown value otherwise.
def cast_int_float_unknown(x):
    if is_integer(x):
        x = int(x)
    elif is_finite_number(x):
        x = float(x)
    elif is_number(x):
        x = 'Unknown'
    else:
        raise NotImplementedError(f'Unable to cast {x}.')
    return x

# Construct the one-hot encoding of data for the given classes.
def compute_one_hot_encoding(data, classes):
    num_instances = len(data)
    num_classes = len(classes)

    one_hot_encoding = np.zeros((num_instances, num_classes), dtype=np.bool_)
    unencoded_data = list()
    for i, x in enumerate(data):
        for y in x:
            for j, z in enumerate(classes):
                if (y == z) or (is_nan(y) and is_nan(z)):
                    one_hot_encoding[i, j] = 1

    return one_hot_encoding

---
### prepare_image_data.py
---

In [7]:
#!/usr/bin/env python

# Load libraries.
import argparse
import json
import os
import os.path
import shutil
import sys
from collections import defaultdict

# from helper_code import *

# Parse arguments.
def get_parser():
    description = 'Prepare the ECG image data from ECG-Image-Kit for the Challenge.'
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument('-i', '--input_folder', type=str, required=True)
    parser.add_argument('-o', '--output_folder', type=str, required=True)
    return parser

# Find files.
def find_files(folder, extensions, remove_extension=False, sort=False):
    selected_files = set()
    for root, directories, files in os.walk(folder):
        for file in files:
            extension = os.path.splitext(file)[1]
            if extension in extensions:
                file = os.path.relpath(os.path.join(root, file), folder)
                if remove_extension:
                    file = os.path.splitext(file)[0]
                selected_files.add(file)
    if sort:
        selected_files = sorted(selected_files)
    return selected_files

# Run script.
def run(args):
    # Define variables.
    image_file_types = ['.png', '.jpg', '.jpeg']

    # Find the header files.
    records = find_records(args.input_folder)

    # Find the image files.
    image_files = find_files(args.input_folder, image_file_types)
    record_to_image_files = defaultdict(set)
    for image_file in image_files:
        root, ext = os.path.splitext(image_file)
        record = '-'.join(root.split('-')[:-1])
        basename = os.path.basename(image_file)
        record_to_image_files[record].add(basename)

    # Update the header files and copy signal files.
    for record in records:
        record_path, record_basename = os.path.split(record)
        record_image_files = record_to_image_files[record]

        # Sort the images numerically if numerical and alphanumerically otherwise.
        record_suffixes = [os.path.splitext(image_file)[0].split('-')[-1] for image_file in record_image_files]
        if all(is_number(suffix) for suffix in record_suffixes):
            record_image_files = sorted(record_image_files, key=lambda image_file: float(os.path.splitext(image_file)[0].split('-')[-1]))
        else:
            record_image_files = sorted(record_image_files)

        # Update the header files.
        input_header_file = os.path.join(args.input_folder, record + '.hea')
        output_header_file = os.path.join(args.output_folder, record + '.hea')

        input_header = load_text(input_header_file)
        output_header = ''
        for l in input_header.split('\n'):
            if not l.startswith(substring_images) and l:
                output_header += l + '\n'

        record_image_string = ', '.join(record_image_files)
        output_header += f'{substring_images} {record_image_string}\n'

        input_path = os.path.join(args.input_folder, record_path)
        output_path = os.path.join(args.output_folder, record_path)

        os.makedirs(output_path, exist_ok=True)

        with open(output_header_file, 'w') as f:
            f.write(output_header)

        # Copy the signal and image files if available.
        if os.path.normpath(args.input_folder) != os.path.normpath(args.output_folder):
            relative_path = os.path.split(record)[0]

            signal_files = get_signal_files(output_header_file)
            relative_path = os.path.split(record)[0]
            for signal_file in signal_files:
                input_signal_file = os.path.join(args.input_folder, relative_path, signal_file)
                output_signal_file = os.path.join(args.output_folder, relative_path, signal_file)
                if os.path.isfile(input_signal_file):
                    shutil.copy2(input_signal_file, output_signal_file)

            image_files = get_image_files(output_header_file)
            for image_file in image_files:
                input_image_file = os.path.join(args.input_folder, relative_path, image_file)
                output_image_file = os.path.join(args.output_folder, relative_path, image_file)
                if os.path.isfile(input_image_file):
                    shutil.copy2(input_image_file, output_image_file)

# if __name__=='__main__':
#     run(get_parser().parse_args(sys.argv[1:]))

---
### prepare_ptbxl_data.py
---

In [8]:
#!/usr/bin/env python

# Load libraries.
import argparse
import ast
import numpy as np
import os
import os.path
import pandas as pd
import shutil
import sys

# from helper_code import *

# Parse arguments.
def get_parser():
    description = 'Prepare the PTB-XL database for use in the Challenge.'
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument('-i', '--input_folder', type=str, required=True)
    parser.add_argument('-pd', '--ptbxl_database_file', type=str, required=True) # ptbxl_database.csv
    parser.add_argument('-pm', '--ptbxl_mapping_file', type=str, required=True) # scp_statements.csv
    parser.add_argument('-sd', '--sl_database_file', type=str, required=True) # 12sl_statements.csv
    parser.add_argument('-sm', '--sl_mapping_file', type=str, required=True) # 12slv23ToSNOMED.csv
    parser.add_argument('-o', '--output_folder', type=str, required=True)
    return parser

# Run script.
def run(args):
    # Assign each class to a superclass; these commands were adapted from the PhysioNet project documentation.
    df_ptbxl_mapping = pd.read_csv(args.ptbxl_mapping_file, index_col=0)
    subclass_to_superclass = dict()
    for i, row in df_ptbxl_mapping.iterrows():
        if row['diagnostic'] == 1:
            subclass_to_superclass[i] = row['diagnostic_class']

    def assign_superclass(subclasses):
        superclasses = list()
        for subclass in subclasses:
            if subclass in subclass_to_superclass:
                superclass = subclass_to_superclass[subclass]
                if superclass not in superclasses:
                    superclasses.append(superclass)
        return superclasses

    # Load the PTB-XL labels.
    df_ptbxl_database = pd.read_csv(args.ptbxl_database_file, index_col='ecg_id')
    df_ptbxl_database.scp_codes = df_ptbxl_database.scp_codes.apply(lambda x: ast.literal_eval(x))

    # Map the PTB-XL classes to superclasses.
    df_ptbxl_database['diagnostic_superclass'] = df_ptbxl_database.scp_codes.apply(assign_superclass)

    # Load the 12SL labels.
    df_sl_database = pd.read_csv(args.sl_database_file, index_col='ecg_id')

    # Map the 12SL classes to the PTB-XL classes for the following acute myocardial infarction (MI) classes; PTB-XL does not include
    # a separate acute MI class.
    df_sl_mapping = pd.read_csv(args.sl_mapping_file, index_col='StatementNumber')

    acute_mi_statements = set([821, 822, 823, 827, 829, 902, 903, 904, 963, 964, 965, 966, 967, 968])
    acute_mi_classes = set()
    for statement in acute_mi_statements:
        if statement in df_sl_mapping.index:
            acute_mi_classes.add(df_sl_mapping.loc[statement]['Acronym'])

    # Identify the header files.
    records = find_records(args.input_folder)

    # Update the header files to include demographics data and labels and copy the signal files unchanged.
    for record in records:

        # Extract the demographics data.
        record_path, record_basename = os.path.split(record)
        ecg_id = int(record_basename.split('_')[0])
        row = df_ptbxl_database.loc[ecg_id]

        recording_date_string = row['recording_date']
        date_string, time_string = recording_date_string.split(' ')
        yyyy, mm, dd = date_string.split('-')
        date_string = f'{dd}/{mm}/{yyyy}'

        age = row['age']
        age = cast_int_float_unknown(age)

        sex = row['sex']
        if sex == 0:
            sex = 'Male'
        elif sex == 1:
            sex = 'Female'
        else:
            sex = 'Unknown'

        height = row['height']
        height = cast_int_float_unknown(height)

        weight = row['weight']
        weight = cast_int_float_unknown(weight)

        scp_code_dict = row['scp_codes']
        scp_codes = [scp_code for scp_code in scp_code_dict if scp_code_dict[scp_code] >= 0]
        superclasses = row['diagnostic_superclass']

        if ecg_id in df_sl_database.index:
            sl_codes = df_sl_database.loc[ecg_id]['statements']
        else:
            sl_codes = list()

        labels = list()
        if 'NORM' in superclasses:
            labels.append('NORM')
        if any(c in sl_codes for c in acute_mi_classes):
            labels.append('Acute MI')
        if 'MI' in superclasses and not any(c in sl_codes for c in acute_mi_classes):
            labels.append('Old MI')
        if 'STTC' in superclasses:
            labels.append('STTC')
        if 'CD' in superclasses:
            labels.append('CD')
        if 'HYP' in superclasses:
            labels.append('HYP')
        if 'PAC' in scp_codes:
            labels.append('PAC')
        if 'PVC' in scp_codes:
            labels.append('PVC')
        if 'AFIB' in scp_codes or 'AFLT' in scp_codes:
            labels.append('AFIB/AFL')
        if 'STACH' in scp_codes or 'SVTAC' in scp_codes or 'PSVT' in scp_codes:
            labels.append('TACHY')
        if 'SBRAD' in scp_codes:
            labels.append('BRADY')
        labels = ', '.join(labels)

        # Update the header file.
        input_header_file = os.path.join(args.input_folder, record + '.hea')
        output_header_file = os.path.join(args.output_folder, record + '.hea')

        input_path = os.path.join(args.input_folder, record_path)
        output_path = os.path.join(args.output_folder, record_path)

        os.makedirs(output_path, exist_ok=True)

        with open(input_header_file, 'r') as f:
            input_header = f.read()

        lines = input_header.split('\n')
        record_line = ' '.join(lines[0].strip().split(' ')[:4]) + '\n'
        signal_lines = '\n'.join(l.strip() for l in lines[1:] \
            if l.strip() and not l.startswith('#')) + '\n'
        comment_lines = '\n'.join(l.strip() for l in lines[1:] \
            if l.startswith('#') and not any((l.startswith(x) for x in ('# Age:', '# Sex:', '# Height:', '# Weight:', f'{substring_labels}')))) + '\n'

        record_line = record_line.strip() + f' {time_string} {date_string} ' + '\n'
        signal_lines = signal_lines.strip() + '\n'
        comment_lines = comment_lines.strip() + f'# Age: {age}\n# Sex: {sex}\n# Height: {height}\n# Weight: {weight}\n{substring_labels} {labels}\n'

        output_header = record_line + signal_lines + comment_lines

        with open(output_header_file, 'w') as f:
            f.write(output_header)

        # Copy the signal files if the input and output folders are different.
        if os.path.normpath(args.input_folder) != os.path.normpath(args.output_folder):
            relative_path = os.path.split(record)[0]

            signal_files = get_signal_files(input_header_file)
            for signal_file in signal_files:
                input_signal_file = os.path.join(args.input_folder, relative_path, signal_file)
                output_signal_file = os.path.join(args.output_folder, relative_path, signal_file)
                if os.path.isfile(input_signal_file):
                    shutil.copy2(input_signal_file, output_signal_file)

# if __name__=='__main__':
    # run(get_parser().parse_args(sys.argv[1:]))

---
### remove_hidden_data.py
---

In [9]:
#!/usr/bin/env python

# Load libraries.
import argparse
import os
import os.path
import shutil
import sys

# from helper_code import *

# Parse arguments.
def get_parser():
    description = 'Remove hidden data.'
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument('-i', '--input_folder', type=str, required=True)
    parser.add_argument('-w', '--include_waveforms', action='store_true')
    parser.add_argument('-l', '--include_labels', action='store_true')
    parser.add_argument('-m', '--include_images', action='store_true')
    parser.add_argument('-o', '--output_folder', type=str, required=True)
    return parser

# Run script.
def run(args):
    # Identify header files.
    records = find_records(args.input_folder)

    # Update header files and copy signal files.
    for record in records:
        record_path, record_basename = os.path.split(record)

        input_header_file = os.path.join(args.input_folder, record + '.hea')
        output_header_file = os.path.join(args.output_folder, record + '.hea')

        input_header = load_text(input_header_file)
        output_header = ''

        num_signals = get_num_signals(input_header)
        for i, l in enumerate(input_header.split('\n')):
            arrs = l.split(' ')
            if i == 0:
                output_header += ' '.join(arrs[:4]) + '\n'
            elif 1 <= i <= num_signals:
                output_header += ' '.join(arrs[:5] + ['', '', ''] + [arrs[8]]) + '\n'
            elif l.startswith(substring_labels):
                if args.include_labels:
                    output_header += l + '\n'
            elif l.startswith(substring_images):
                if args.include_images:
                    output_header += l + '\n'

        input_path = os.path.join(args.input_folder, record_path)
        output_path = os.path.join(args.output_folder, record_path)

        os.makedirs(output_path, exist_ok=True)

        with open(output_header_file, 'w') as f:
            f.write(output_header)

        relative_path = os.path.split(record)[0]

        if args.include_waveforms and os.path.normpath(args.input_folder) != os.path.normpath(args.output_folder):
            signal_files = get_signal_files(input_header_file)
            for signal_file in signal_files:
                input_signal_file = os.path.join(args.input_folder, relative_path, signal_file)
                output_signal_file = os.path.join(args.output_folder, relative_path, signal_file)
                if os.path.isfile(input_signal_file):
                    shutil.copy2(input_signal_file, output_signal_file)
        elif not args.include_waveforms and os.path.normpath(args.input_folder) == os.path.normpath(args.output_folder):
            signal_files = get_signal_files(input_header_file)
            for signal_file in signal_files:
                input_signal_file = os.path.join(args.input_folder, relative_path, signal_file)
                output_signal_file = os.path.join(args.output_folder, relative_path, signal_file)
                if os.path.isfile(output_signal_file):
                    os.remove(output_signal_file)

        if args.include_images:
            image_files = get_image_files(input_header_file)
            for image_file in image_files:
                input_image_file = os.path.join(args.input_folder, relative_path, image_file)
                output_image_file = os.path.join(args.output_folder, relative_path, image_file)
                if os.path.isfile(input_image_file):
                    shutil.copy2(input_image_file, output_image_file)

# if __name__=='__main__':
#     run(get_parser().parse_args(sys.argv[1:]))

---
### run_model.py
---

In [10]:
#!/usr/bin/env python

# Please do *not* edit this script. Changes will be discarded so that we can run the trained models consistently.

# This file contains functions for running models for the Challenge. You can run it as follows:
#
#   python run_model.py -d data -m model -o outputs -v
#
# where 'data' is a folder containing the Challenge data, 'models' is a folder containing the your trained models, 'outputs' is a
# folder for saving your models' outputs, and -v is an optional verbosity flag.

import argparse
import os
import sys

# from helper_code import *
# from team_code import load_models, run_models

# Parse arguments.
def get_parser():
    description = 'Run the trained Challenge models.'
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument('-d', '--data_folder', type=str, required=True)
    parser.add_argument('-m', '--model_folder', type=str, required=True)
    parser.add_argument('-o', '--output_folder', type=str, required=True)
    parser.add_argument('-v', '--verbose', action='store_true')
    parser.add_argument('-f', '--allow_failures', action='store_true')
    return parser

# Run the code.
def run(args):
    # Load the models.
    if args.verbose:
        print('Loading the Challenge model...')

    # You can use these functions to perform tasks, such as loading your models, that you only need to perform once.
    digitization_model, classification_model = load_models(args.model_folder, args.verbose) ### Teams: Implement this function!!!

    # Find the Challenge data.
    if args.verbose:
        print('Finding the Challenge data...')

    records = find_records(args.data_folder)
    num_records = len(records)

    if num_records==0:
        raise Exception('No data were provided.')

    # Create a folder for the Challenge outputs if it does not already exist.
    os.makedirs(args.output_folder, exist_ok=True)

    # Run the team's models on the Challenge data.
    if args.verbose:
        print('Running the Challenge model(s) on the Challenge data...')

    # Iterate over the records.
    for i in range(num_records):
        if args.verbose:
            width = len(str(num_records))
            print(f'- {i+1:>{width}}/{num_records}: {records[i]}...')

        data_record = os.path.join(args.data_folder, records[i])
        output_record = os.path.join(args.output_folder, records[i])

        # Run the models. Allow or disallow the models to fail on some of the data, which can be helpful for debugging.
        try:
            signals, labels = run_models(data_record, digitization_model, classification_model, args.verbose) ### Teams: Implement this function!!!
        except:
            if args.allow_failures:
                if args.verbose:
                    print('... failed.')
                signal = None
                labels = None
            else:
                raise

        # Save Challenge outputs.
        output_path = os.path.split(output_record)[0]
        os.makedirs(output_path, exist_ok=True)

        data_header = load_header(data_record)
        save_header(output_record, data_header)

        if signals is not None:
            comments = [l for l in data_header.split('\n') if l.startswith('#')]
            save_signals(output_record, signals, comments)
        if labels is not None:
            save_labels(output_record, labels)

    if args.verbose:
        print('Done.')

# if __name__ == '__main__':
#     run(get_parser().parse_args(sys.argv[1:]))

---
### simple_cnn.py
---

In [11]:
"""Simple convolutional neural network"""

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

# Define a simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self, list_of_classes):
        super(SimpleCNN, self).__init__()
        self.list_of_classes = list_of_classes
        self.num_classes = len(self.list_of_classes)
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64*53*68, 512)          # Adjusted  to input size (425, 550)
        self.fc2 = nn.Linear(512, self.num_classes)  # 11 output classes for multilabel classification

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.pool(torch.relu(self.conv3(x)))
        x = x.view(-1, x.shape[1] * x.shape[2]* x.shape[3])
        x = torch.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))  # Sigmoid activation for multilabel classification
        return x


---
### train_model.py
---

In [12]:
#!/usr/bin/env python

# Please do *not* edit this script. Changes will be discarded so that we can train the models consistently.

# This file contains functions for training models for the Challenge. You can run it as follows:
#
#   python train_model.py -d data -m model -v
#
# where 'data' is a folder containing the Challenge data, 'model' is a folder for saving your models, and , and -v is an optional
# verbosity flag.

import argparse
import sys

# from helper_code import *
# from team_code import train_models

# Parse arguments.
def get_parser():
    description = 'Train the Challenge models.'
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument('-d', '--data_folder', type=str, required=True)
    parser.add_argument('-m', '--model_folder', type=str, required=True)
    parser.add_argument('-v', '--verbose', action='store_true')
    return parser

# Run the code.
def run(args):
    train_models(args.data_folder, args.model_folder, args.verbose) ### Teams: Implement this function!!!

# if __name__ == '__main__':
#     run(get_parser().parse_args(sys.argv[1:]))


---
### team_code.py
---

In [13]:
#!/usr/bin/env python

# Edit this script to add your team's code. Some functions are *required*, but you can edit most parts of the required functions,
# change or remove non-required functions, and add your own functions.

################################################################################
#
# Optional libraries, functions, and variables. You can change or remove them.
#
################################################################################

import joblib
import numpy as np
import os
import random
import shutil
import sys
import torch
import torch.nn as nn
import torch.optim as optim

from PIL import Image
from collections import OrderedDict
# from data_loader import get_training_and_validation_loaders
from functools import partial
# from helper_code import *
from matplotlib import pyplot as plt
# from simple_cnn import SimpleCNN
from sklearn.metrics import average_precision_score,precision_recall_curve,roc_curve, roc_auc_score
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm
from typing import Callable, Optional

# DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# DEVICE = "cuda"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS = 10
CLASSIFICATION_THRESHOLD=0.5
CLASSIFICATION_DISTANCE_TO_MAX_THRESHOLD=0.1
LIST_OF_ALL_LABELS=['NORM', 'Acute MI', 'Old MI', 'STTC', 'CD', 'HYP', 'PAC', 'PVC', 'AFIB/AFL', 'TACHY', 'BRADY']
RESIZE_TEST_IMAGES=(425, 550)
OPTIM_LR=1e-3
OPTIM_WEIGHT_DECAY=1e-4
SCHEDULER_STEP_SIZE=7
SCHEDULER_GAMMA=0.1

################################################################################
#
# Required functions. Edit these functions to add your code, but do not change the arguments of the functions.
#
################################################################################

# Train your models. This function is *required*. You should edit this function to add your code, but do *not* change the arguments
# of this function. If you do not train one of the models, then you can return None for the model.

def train_models(data_folder, model_folder, verbose):

    print(DEVICE)
    # Find the data files.
    if verbose:
        print('Finding the Challenge data...')

    records = find_records(data_folder)
    num_records = len(records)

    if num_records == 0:
        raise FileNotFoundError('No data were provided.')

    # Extract the data...
    if verbose:
        print('Loading the data...')

    classification_images = list() # list of image paths
    classification_labels = list() # list of lists of strings

    # Iterate over the records.
    for i in range(num_records):
        if verbose:
            width = len(str(num_records))
            print(f'- {i+1:>{width}}/{num_records}: {records[i]}...')

        record = os.path.join(data_folder, records[i])
        record_parent_folder=os.path.dirname(record)

        # Some images may not be labeled, so we'll exclude those
        labels = load_labels(record)
        if labels:

            # I'm imposing a further condition: the label strings should be nonempty
            nonempty_labels=[l for l in labels if l != '']
            if nonempty_labels != []:

                # Add the first image to the list
                images = get_image_files(record)
                classification_images.append(os.path.join(record_parent_folder, images[0]) )
                classification_labels.append(nonempty_labels)

    # We expect some images to be labeled for classification.
    if not classification_labels:
        raise Exception('There are no labels for the data.')

    # Fix an ordering of the labels
    num_classes=len(LIST_OF_ALL_LABELS)

    # Train the models.
    if verbose:
        print('Training the models on the data...')

    #=====================
    # Classification task
    #=====================

    # Split the training set into "training" and "validation" subsets, returning them as DataLoaders
    training_loader, validation_loader \
        = get_training_and_validation_loaders(LIST_OF_ALL_LABELS, classification_images, classification_labels)

    # Initialize a model
    classification_model = SimpleCNN(LIST_OF_ALL_LABELS).to(DEVICE)
    for param in classification_model.parameters(): # fine tune all the layers
        param.requires_grad = True

    loss = nn.BCELoss() # binary cross entropy loss for multilabel classification
    opt = optim.Adam(classification_model.parameters(), lr=OPTIM_LR, weight_decay=OPTIM_WEIGHT_DECAY)
    scheduler = StepLR(opt, step_size=SCHEDULER_STEP_SIZE, gamma=SCHEDULER_GAMMA)

    N_loss = []
    N_loss_valid = []
    train_auprc = []
    valid_auprc = []
    train_auroc= []
    valid_auroc = []
    f1_train = []
    f1_valid = []

    plot_folder=os.path.join(model_folder, "training_figures")
    os.makedirs(plot_folder, exist_ok=True)

    # Filename to save the final weights to
    final_weights=None

    # Now let's train!
    for epoch in range(EPOCHS):

        # Initialization of variables for plotting the progress
         N_item_sum = 0
         N_item_sum_valid = 0
         targets_train = []
         outputs_train = []
         targets_valid = []
         outputs_valid = []

         ### Training part
         if verbose:
            print(f"============================[{epoch}]============================")
         classification_model.train()
         for i, (image, label) in enumerate(training_loader):
             opt.zero_grad()

             image = image.float().to(DEVICE)
             label = label.to(torch.float).to(DEVICE)
             prediction = classification_model(image)

             # loss
             N = loss(prediction,label)
             N.backward()
             N_item = N.item()
             N_item_sum += N_item

             # gradient clipping plus optimizer
             torch.nn.utils.clip_grad_norm_(classification_model.parameters(), max_norm=10)
             opt.step()
             if verbose:
                print(f"Epoch: {epoch}, Iteration: {i}, Loss: {N_item}")

             targets_train.append(label.data.cpu().numpy()) #target[:,0]
             outputs_train.append(prediction.data.cpu().numpy())

         ### Validation part
         classification_model.eval()
         with torch.no_grad():
          for j, (image, label) in enumerate(validation_loader):
                 image = image.float().to(DEVICE)
                 label = label.to(torch.float).to(DEVICE)
                 prediction = classification_model(image)

                 N = loss(prediction,label)
                 N_item = N.item()
                 N_item_sum_valid += N.item()

                 targets_valid.append(label.data.cpu().numpy()) #target[:,0]
                 outputs_valid.append(prediction.data.cpu().numpy())
                 print(f"Epoch: {epoch}, Valid Iteration: {j}, Loss: {N_item}")

         scheduler.step()

         # Logging the outputs and targets to caluclate auprc and auroc
         targets_train = np.concatenate(targets_train, axis=0).T
         outputs_train = np.concatenate(outputs_train, axis=0).T
         targets_valid = np.concatenate(targets_valid, axis=0).T
         outputs_valid = np.concatenate(outputs_valid, axis=0).T

         auprc_t = average_precision_score(y_true=targets_train, y_score=outputs_train)
         auroc_t = roc_auc_score(y_true=targets_train, y_score=outputs_train)
         auprc_v = average_precision_score(y_true=targets_valid, y_score=outputs_valid)
         auroc_v = roc_auc_score(y_true=targets_valid, y_score=outputs_valid)

         train_auprc.append(auprc_t)
         train_auroc.append(auroc_t)
         valid_auprc.append(auprc_v)
         valid_auroc.append(auroc_v)

         N_loss.append(N_item_sum/i)
         N_loss_valid.append(N_item_sum_valid/j)

         # saving loss function after each epoch so you can look on progress
         fig = plt.figure()
         plt.plot(N_loss, label="train")
         plt.plot(N_loss_valid, label="valid")
         plt.title("Loss function")
         plt.xlabel('epoch')
         plt.ylabel('loss')
         plt.grid()
         plt.legend()
         plt.savefig(os.path.join(plot_folder, "loss.png"))
         plt.close()

         fig = plt.figure()
         plt.plot(train_auprc, label="train auprc")
         plt.plot(valid_auprc, label="valid auprc")
         plt.plot(train_auroc, label="train auroc")
         plt.plot(valid_auroc, label="valid auroc")

         plt.title("AUPRC and AUROC")
         plt.xlabel('epoch')
         plt.ylabel('Performace')
         plt.grid()
         plt.legend()
         plt.savefig(os.path.join(plot_folder, "auroc_auprc.png"))
         plt.close()

         ### save model after each epoch
         file_path = os.path.join(model_folder, "model_weights_" + str(epoch) + ".pth")
         torch.save(classification_model.state_dict(), file_path)

         # If this is the last epoch, then the weights of the model will be saved to this file
         final_weights = file_path

    # Create a folder for the models if it does not already exist.
    os.makedirs(model_folder, exist_ok=True)

    # Save the models.
    save_classification_model(model_folder, LIST_OF_ALL_LABELS, final_weights)

    if verbose:
        print('Done.')
        print()

# Load your trained models. This function is *required*. You should edit this
# function to add your code, but do *not* change the arguments of this
# function. If you do not train one of the models, then you can return None for
# the model.
def load_models(model_folder, verbose):
    digitization_model = None

    classes_filename = os.path.join(model_folder, 'classes.txt')
    classes = joblib.load(classes_filename)

    classification_model = SimpleCNN(classes).to(DEVICE) # instantiate a new copy of the model
    classification_filename = os.path.join(model_folder, "classification_model.pth")
    classification_model.load_state_dict(torch.load(classification_filename))

    return digitization_model, classification_model

# Run your trained digitization model. This function is *required*. You should edit this function to add your code, but do *not*
# change the arguments of this function. If you did not train one of the models, then you can return None for the model.
def run_models(record, digitization_model, classification_model, verbose):

    # Run the digitization model; if you did not train this model, then you can set signal = None.
    signal = None

    # Run the classification model.
    classes = classification_model.list_of_classes

    # Open the image:
    record_parent_folder=os.path.dirname(record)
    image_files=get_image_files(record)
    image_path=os.path.join(record_parent_folder, image_files[0])
    img = Image.open(image_path)
    # FIXME: repeated code---maybe factor out opening the image from a record
    if img.mode != 'RGB':
        img = img.convert('RGB')

    # transform the image and make it suitable as input
    img = transforms.Resize(RESIZE_TEST_IMAGES)(img)
    img = transforms.ToTensor()(img)
    img = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])(img)
    img = img.unsqueeze(0)

    # send it to the GPU if necessary
    img = img.float().to(DEVICE)

    classification_model.eval()
    with torch.no_grad():
        probabilities = torch.squeeze(classification_model(img), 0).tolist()
        predictions=list()
        for i in range(len(classes)):
            if probabilities[i] >= CLASSIFICATION_THRESHOLD:
                predictions.append(classes[i])

    # backup if none is over the threshold: use the max
    if predictions==[]:
        highest_probability=max(probabilities)
        for i in range(len(classes)):
            if abs(highest_probability - probabilities[i]) <= CLASSIFICATION_DISTANCE_TO_MAX_THRESHOLD:
                predictions.append(classes[i])

    return signal, predictions

#########################################################################################
#
# Optional functions. You can change or remove these functions and/or add new functions.
#
#########################################################################################

# Extract features.
def extract_features(record):
    images = load_images(record)
    mean = 0.0
    std = 0.0
    for image in images:
        image = np.asarray(image)
        mean += np.mean(image)
        std += np.std(image)
    return np.array([mean, std])

# Save your trained models.
def save_classification_model(model_folder,
                list_of_classes=None,
                final_weights=None):

    if final_weights is not None:
        classes=filename = os.path.join(model_folder, 'classes.txt')
        joblib.dump(list_of_classes, filename, protocol=0)

        # copy the file with the final weights to the model path
        model_filename=os.path.join(model_folder, "classification_model.pth")
        shutil.copyfile(final_weights, model_filename)


---
### data_loader.py
---

In [14]:
from PIL import Image
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from torchvision import transforms
from sklearn.model_selection import train_test_split

# from helper_code import *

#===================================
# Parameters (configure stuff here)
#===================================

RESIZE_TO_DIMENSIONS=(425, 550)                             # resize all images to these dimensions during training
TRANSFORM_TRAINING_IMAGES=transforms.RandomHorizontalFlip() # apply this to each training image
IMAGE_MODE='RGB'                                            # set all images to this mode
RANDOM_STATE=42                                             # number for repeatable pseudorandomness

#=========
# Classes
#=========


class ECGImageDataset(Dataset):
    """Map-style dataset that yields (image, label) pairs. In this context,
       each "label" in the (image, label) tuple will really be a vector encoding
       multiple dx labels, while each "image" will be an image path.

       Initialize this with
        - a list of all possible class labels in a fixed ordering,
        - a boolean value saying whether this is a training set or not,
        - a list of paths (as strings), and
        - a list of lists of labels (as strings).
       """

    def __init__(self, list_of_all_classes:list, is_training:bool,
                    image_paths:list, image_labels:list):

        self.list_of_all_classes=list_of_all_classes
        self.num_classes = len(self.list_of_all_classes)
        # Inverse of list_of_all_classes: look up index by name
        self.class_to_index=dict()
        for i in range(self.num_classes):
            self.class_to_index[list_of_all_classes[i]]=i

        self.is_training=is_training

        self.image_paths=image_paths
        self.image_labels=image_labels

        # How to transform images
        inner_transformations = [TRANSFORM_TRAINING_IMAGES] if self.is_training else []
        transformations = [transforms.Resize(RESIZE_TO_DIMENSIONS)] \
                          + inner_transformations \
                          + [transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
        self.transform_images = transforms.Compose(transformations)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        """index --> (image, label_vector)."""

        # First the image
        our_image = Image.open(self.image_paths[idx])
        if our_image.mode != IMAGE_MODE: # e.g. RGB
            our_image = our_image.convert(IMAGE_MODE)
        our_image = self.transform_images(our_image)

        # Next the labels

        #     Our dataset will encode each label not as a list of strings but as a vector,
        #     one index for each class. Thus we have to convert a list like
        #       ["THING1", "THING2", "THING3"]
        #     into a vector like
        #       [0,1,0,1,1,0].

        label_strings = self.image_labels[idx]
        our_label_vector = [0] * self.num_classes
        for l in label_strings: # labels assigned to this index
            our_label_vector[self.class_to_index[l]] = 1

        return our_image, our_label_vector

    @staticmethod
    def collate_fn(batch):
        """Batch of pairs -> pair of tensors representing the batch"""
        images, labels = tuple(zip(*batch))
        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels


def get_training_and_validation_loaders(list_of_all_classes, image_path_list, label_names_list):
    """
    Given  a list `list_of_all_classes` of all class labels in the dataset,
    `image_path_list` (a list of image paths), and `label_names_list` (a list
    of _lists_ of label names corresponding to the images), return the pair
    `(training_loader, validation_loader)` which can be used to train/validate
    a model.

    """

    # Divide the dataset into training and validation sets
    training_images, validation_images, \
    training_classes, validation_classes, \
        = train_test_split(image_path_list,
                           label_names_list,
                           test_size=0.2,
                           random_state=RANDOM_STATE,
                           shuffle=True)

    # Dataset for training
    train_dataset=ECGImageDataset(list_of_all_classes=list_of_all_classes,
                                  is_training=True,
                                  image_paths=training_images,
                                  image_labels=training_classes)

    # Dataset for training
    validation_dataset=ECGImageDataset(list_of_all_classes=list_of_all_classes,
                                       is_training=False,
                                       image_paths=validation_images,
                                       image_labels=validation_classes)

    # Dataloader for training
    training_loader = torch.utils.data.DataLoader(train_dataset,
                                                  batch_size=4,
                                                  shuffle=True,
                                                  pin_memory=True,
                                                  num_workers=8,
                                                  drop_last=True,
                                                  collate_fn=train_dataset.collate_fn)

    # DataLoader for validation
    validation_loader = torch.utils.data.DataLoader(validation_dataset,
                                                    batch_size=4,
                                                    shuffle=False, # note this differs
                                                    pin_memory=True,
                                                    num_workers=8,
                                                    drop_last=True,
                                                    collate_fn=validation_dataset.collate_fn)

    return training_loader, validation_loader



---
# Utilities
---

In [15]:
# a function to print the size of the files in MB in all the subdirectories when a file path is given

In [16]:
# prompt: # a function to print the size of the files in MB in all the subdirectories when a file path is given

def print_folder_sizes(root_path):
  """
  Prints the size of the files in MB in all the subdirectories when a file path is given.

  Args:
    root_path (str): The path to the root directory.
  """
  for dirpath, dirnames, filenames in os.walk(root_path):
    folder_size = 0
    for filename in filenames:
      filepath = os.path.join(dirpath, filename)
      folder_size += os.path.getsize(filepath)
    if folder_size > 0:
      print(f"{dirpath}: {(folder_size / (1024 * 1024)):.2f} MB")


---
# Training
---

This is the function call stack for the following command.
`python train_model.py -d training_data -m model`

1. train_model.py
  - run()
  
  2. team_code.py
    - train_models()
         
         3.
         - get_training_and_validation_loaders()

         -


I ran the below line with device=cuda:0

In [17]:
#!/usr/bin/env python

# Edit this script to add your team's code. Some functions are *required*, but you can edit most parts of the required functions,
# change or remove non-required functions, and add your own functions.

################################################################################
#
# Optional libraries, functions, and variables. You can change or remove them.
#
################################################################################

import joblib
import numpy as np
import os
import random
import shutil
import sys
import torch
import torch.nn as nn
import torch.optim as optim

from PIL import Image
from collections import OrderedDict
# from data_loader import get_training_and_validation_loaders
from functools import partial
# from helper_code import *
from matplotlib import pyplot as plt
# from simple_cnn import SimpleCNN
from sklearn.metrics import average_precision_score,precision_recall_curve,roc_curve, roc_auc_score
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm
from typing import Callable, Optional

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# DEVICE = "cuda"

---
### configuration
---

In [18]:
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS = 1
CLASSIFICATION_THRESHOLD=0.5
CLASSIFICATION_DISTANCE_TO_MAX_THRESHOLD=0.1
LIST_OF_ALL_LABELS=['NORM', 'Acute MI', 'Old MI', 'STTC', 'CD', 'HYP', 'PAC', 'PVC', 'AFIB/AFL', 'TACHY', 'BRADY']
RESIZE_TEST_IMAGES=(425, 550)
OPTIM_LR=1e-3
OPTIM_WEIGHT_DECAY=1e-4
SCHEDULER_STEP_SIZE=7
SCHEDULER_GAMMA=0.1

In [19]:
verbose = True

# folder paths
data_folder = "/content/train_data_0"
model_folder = "/content/model"

os.makedirs(model_folder, exist_ok=True)

---
### training code
---

In [20]:
print(DEVICE)

cpu


In [21]:
records = find_records(data_folder)
num_records = len(records)

if num_records == 0:
    raise FileNotFoundError('No data were provided.')
else:
    print(f'Found {num_records} records.')

Found 987 records.


In [22]:
classification_images = list()  # list of image paths
classification_labels = list()  # list of lists of strings

# Iterate over the records.
for i in range(num_records):
    if verbose:
        width = len(str(num_records))
        print(f'- {i + 1:>{width}}/{num_records}: {records[i]}...')

    record = os.path.join(data_folder, records[i])
    record_parent_folder = os.path.dirname(record)

    # Some images may not be labeled, so we'll exclude those
    labels = load_labels(record)
    if labels:

        # I'm imposing a further condition: the label strings should be nonempty
        nonempty_labels = [l for l in labels if l != '']
        if nonempty_labels != []:

            # Add the first image to the list
            images = get_image_files(record)
            classification_images.append(os.path.join(record_parent_folder, images[0]))
            classification_labels.append(nonempty_labels)

-   1/987: 00001_hr...
-   2/987: 00002_hr...
-   3/987: 00003_hr...
-   4/987: 00004_hr...
-   5/987: 00005_hr...
-   6/987: 00006_hr...
-   7/987: 00007_hr...
-   8/987: 00008_hr...
-   9/987: 00009_hr...
-  10/987: 00010_hr...
-  11/987: 00011_hr...
-  12/987: 00012_hr...
-  13/987: 00013_hr...
-  14/987: 00014_hr...
-  15/987: 00015_hr...
-  16/987: 00016_hr...
-  17/987: 00017_hr...
-  18/987: 00018_hr...
-  19/987: 00019_hr...
-  20/987: 00020_hr...
-  21/987: 00021_hr...
-  22/987: 00022_hr...
-  23/987: 00023_hr...
-  24/987: 00024_hr...
-  25/987: 00025_hr...
-  26/987: 00026_hr...
-  27/987: 00027_hr...
-  28/987: 00028_hr...
-  29/987: 00029_hr...
-  30/987: 00030_hr...
-  31/987: 00031_hr...
-  32/987: 00032_hr...
-  33/987: 00033_hr...
-  34/987: 00034_hr...
-  35/987: 00035_hr...
-  36/987: 00036_hr...
-  37/987: 00037_hr...
-  38/987: 00038_hr...
-  39/987: 00039_hr...
-  40/987: 00040_hr...
-  41/987: 00041_hr...
-  42/987: 00042_hr...
-  43/987: 00043_hr...
-  44/987: 

In [23]:
# We expect some images to be labeled for classification.
if not classification_labels:
  raise Exception('There are no labels for the data.')

# Fix an ordering of the labels
num_classes=len(LIST_OF_ALL_LABELS)


---
## Data loading
---

In [24]:
# Split the training set into "training" and "validation" subsets, returning them as DataLoaders
training_loader, validation_loader \
  = get_training_and_validation_loaders(LIST_OF_ALL_LABELS, classification_images, classification_labels)



In [25]:
for i,v in enumerate(validation_loader):
  # print(i, v)
  pass

  if i == 0:
    break

In [26]:
v[0].shape

torch.Size([4, 3, 425, 550])

In [27]:
v[1].shape

torch.Size([4, 11])

In [28]:
plot_folder=os.path.join(model_folder, "training_figures")
os.makedirs(plot_folder, exist_ok=True)

# Filename to save the final weights to
final_weights=None

-------
### Classification task
-----

In [29]:
# Train the models.
if verbose:
    print('Training the models on the data...')

Training the models on the data...


In [30]:
# We expect some images to be labeled for classification.
if not classification_labels:
    raise Exception('There are no labels for the data.')

# Fix an ordering of the labels
num_classes=len(LIST_OF_ALL_LABELS)

In [31]:
LIST_OF_ALL_LABELS

['NORM',
 'Acute MI',
 'Old MI',
 'STTC',
 'CD',
 'HYP',
 'PAC',
 'PVC',
 'AFIB/AFL',
 'TACHY',
 'BRADY']

In [32]:
from torchvision import models

# Initialize a pretrained VGG16 model
classification_model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:07<00:00, 78.4MB/s]


In [33]:
# Modify the final layer
classification_model.classifier

Sequential(
  (0): Linear(in_features=25088, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Dropout(p=0.5, inplace=False)
  (3): Linear(in_features=4096, out_features=4096, bias=True)
  (4): ReLU(inplace=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=4096, out_features=1000, bias=True)
)

In [34]:
# Freeze the feature extractor layers
for param in classification_model.features.parameters():
    param.requires_grad = False

In [35]:
# Modify the final layer
classification_model.classifier[6] = nn.Linear(classification_model.classifier[6].in_features, num_classes)

In [36]:
# Move the model to the specified device
classification_model = classification_model.to(DEVICE)

In [37]:
# TODO : Select the best loss function
loss = nn.BCELoss() # binary cross entropy loss for multilabel classification

# TODO: Choose the optimal hyperparameters
opt = optim.Adam(classification_model.parameters(), lr=OPTIM_LR, weight_decay=OPTIM_WEIGHT_DECAY)

# TODO: Find out what does this do!
scheduler = StepLR(opt, step_size=SCHEDULER_STEP_SIZE, gamma=SCHEDULER_GAMMA)

In [38]:
N_loss = []
N_loss_valid = []
train_auprc = []
valid_auprc = []
train_auroc= []
valid_auroc = []
f1_train = []
f1_valid = []

In [39]:
do_validation = False

In [41]:
# Now let's train!
for epoch in range(EPOCHS):

    # Initialization of variables for plotting the progress
    N_item_sum = 0
    N_item_sum_valid = 0
    targets_train = []
    outputs_train = []
    targets_valid = []
    outputs_valid = []

    ### Training part
    if verbose:
        print(f"============================[{epoch}]============================")
    classification_model.train()
    for i, (image, label) in enumerate(training_loader): # train_dataloader
        opt.zero_grad()

        image = image.float().to(DEVICE)
        label = label.to(torch.float).to(DEVICE)
        prediction = torch.sigmoid(classification_model(image))

        # loss
        N = loss(prediction, label)
        N.backward()
        N_item = N.item()
        N_item_sum += N_item

        # gradient clipping plus optimizer
        torch.nn.utils.clip_grad_norm_(classification_model.parameters(), max_norm=10)
        opt.step()
        if verbose:
            print(f"Epoch: {epoch}, Iteration: {i}, Loss: {N_item}")

        targets_train.append(label.data.cpu().numpy()) #target[:,0]
        outputs_train.append(prediction.data.cpu().numpy())

    if do_validation:
        ### Validation part
        classification_model.eval()
        with torch.no_grad():
            for j, (image, label) in enumerate(validation_loader):
                image = image.float().to(DEVICE)
                label = label.to(torch.float).to(DEVICE)
                prediction = classification_model(image)

                N = loss(prediction, label)
                N_item = N.item()
                N_item_sum_valid += N.item()

                targets_valid.append(label.data.cpu().numpy()) #target[:,0]
                outputs_valid.append(prediction.data.cpu().numpy())
                print(f"Epoch: {epoch}, Valid Iteration: {j}, Loss: {N_item}")

        scheduler.step()

        # Logging the outputs and targets to calculate auprc and auroc
        targets_train = np.concatenate(targets_train, axis=0).T
        outputs_train = np.concatenate(outputs_train, axis=0).T
        targets_valid = np.concatenate(targets_valid, axis=0).T
        outputs_valid = np.concatenate(outputs_valid, axis=0).T

        auprc_t = average_precision_score(y_true=targets_train, y_score=outputs_train)
        auroc_t = roc_auc_score(y_true=targets_train, y_score=outputs_train)
        auprc_v = average_precision_score(y_true=targets_valid, y_score=outputs_valid)
        auroc_v = roc_auc_score(y_true=targets_valid, y_score=outputs_valid)

        train_auprc.append(auprc_t)
        train_auroc.append(auroc_t)
        valid_auprc.append(auprc_v)
        valid_auroc.append(auroc_v)

        N_loss.append(N_item_sum / i)
        N_loss_valid.append(N_item_sum_valid / j)

        # Saving loss function after each epoch so you can look on progress
        fig = plt.figure()
        plt.plot(N_loss, label="train")
        plt.plot(N_loss_valid, label="valid")
        plt.title("Loss function")
        plt.xlabel('epoch')
        plt.ylabel('loss')
        plt.grid()
        plt.legend()
        plt.savefig(os.path.join(plot_folder, "loss.png"))
        plt.close()

        fig = plt.figure()
        plt.plot(train_auprc, label="train auprc")
        plt.plot(valid_auprc, label="valid auprc")
        plt.plot(train_auroc, label="train auroc")
        plt.plot(valid_auroc, label="valid auroc")

        plt.title("AUPRC and AUROC")
        plt.xlabel('epoch')
        plt.ylabel('Performance')
        plt.grid()
        plt.legend()
        plt.savefig(os.path.join(plot_folder, "auroc_auprc.png"))
        plt.close()

    ### Save model after each epoch
    file_path = os.path.join(model_folder, "vgg16_weights_epochs_" + str(epoch) + ".pth")
    torch.save(classification_model.state_dict(), file_path)

    # If this is the last epoch, then the weights of the model will be saved to this file
    final_weights = file_path





Epoch: 0, Iteration: 0, Loss: 0.6496692299842834
Epoch: 0, Iteration: 1, Loss: 0.5974317789077759


KeyboardInterrupt: 

In [None]:
# Save the models.
save_classification_model(model_folder, LIST_OF_ALL_LABELS, final_weights)

if verbose:
    print('Done.')
    print()

In [None]:
print_folder_sizes(model_folder)