---
title: Train ResNet Models on ECG Image Data 
author: Frederik Bennhoff
date: 01/12/2023
---

In [None]:
folder_name = "spectrograms_128"
valid_pct = 0.1
seed = 42

**Import libraries and set path**

In [None]:
from fastcore.all import *
from shutil import rmtree
from fastai.vision.all import *
import os
import pandas as pd
import numpy as np
from sklearn.metrics import f1_score

# set seed
np.random.seed(seed)

# set up folder
os.chdir("src") # change to src folder
p = Path('../data/'+folder_name) # relative path to folder with images
p_out = Path('../out/'+folder_name) # relative path to output folder
p_out.mkdir(parents=True, exist_ok=True)

**Function definitions**

*For loading training data*

In [None]:
def get_files_df(folder_name, prefix):
    files_df = pd.DataFrame([], columns=["filename", "id", "sequence", "label"])
    for type in range(4):
        # Get a list of all filenames in the folder
        folder_path = f'../data/{folder_name}/{type}/'  # replace with your folder path
        filenames = os.listdir(folder_path)

        # Create a DataFrame from the list
        df = pd.DataFrame(filenames, columns=['filename'])

        # Split the 'filename' column on '_'
        df[['sg', 'id', 'sequence']] = df['filename'].str.split('_', expand=True)

        # Split the 'number2' column on '.' to remove the file extension
        df['sequence'] = df['sequence'].str.split('.', expand=True)[0]

        df.sort_values(by=['id', 'sequence'], inplace=True)
        df.reset_index(inplace=True, drop=True)
        df["label"] = type
        df = df[df.sg == prefix]
        df.drop(columns=[prefix], inplace=True)
        df['id'] = df['id'].to_numpy(dtype=int)
        df['sequence'] = df['sequence'].to_numpy(dtype=int)
        files_df = pd.concat([files_df, df], ignore_index=True)
    
    files_df.sort_values(by=['id', 'sequence'], inplace=True)
    files_df.reset_index(inplace=True, drop=True)
    files_df.reset_index(drop=True)
    
    return files_df

def get_df_for_model(folder_name, prefix, valid_pct):
    files_df = get_files_df(folder_name, prefix)
    # sample fraction 'valid_pct' of ids for validation sample
    ids = files_df.id.unique()
    valid_ids = np.random.choice(ids, int(len(ids)*valid_pct), replace=False)
    bool = [id in valid_ids for id in files_df.id]
    files_df["is_valid"] = bool
    files_df["filename"] = files_df.label.apply(lambda x: str(x)) + "/" + files_df["filename"]
    return files_df

*For predicting validation data*

In [None]:
def aggregate_predictions_on_validation_set(pred_prob, pred_class, files_df):
        
    predictions_df = pd.concat([
        files_df[files_df.is_valid == True].reset_index(drop=True),
        pd.DataFrame(pred_prob.numpy(), columns=["p_0", "p_1", "p_2", "p_3"]),
        pd.DataFrame(pred_class.numpy(), columns=["pred_label"])
    ], axis=1)
    max_cat = predictions_df.groupby("id").agg({"p_0": "mean", "p_1": "mean", "p_2": "mean", "p_3": "mean"}).apply(lambda x: x.argmax(), axis=1).reset_index(drop=True);
    mean_probs = predictions_df.groupby("id").agg({"p_0": "mean", "p_1": "mean", "p_2": "mean", "p_3": "mean"}).reset_index()
    predictions = pd.concat([mean_probs, max_cat], axis=1)
    predictions.columns = ["id", "p_0", "p_1", "p_2", "p_3", "pred_label"]
    predictions = pd.merge(files_df[files_df.is_valid == True],predictions, on="id")

    predictions = predictions[predictions["sequence"]==0] # only keep first sequence
    predictions.drop("sequence", axis=1, inplace=True) # drop sequence column
    
    return predictions

def calc_score(predictions_valset):
    return f1_score(predictions_valset["label"].to_numpy(dtype=np.int32), predictions_valset["pred_label"].to_numpy(dtype=np.int32), average='micro')

## Create Data Loaders

In [None]:
# check for failed images (not an issue, so not active)
#resize_images(p/"0", max_size=400, dest = p/"0") # resize images
# failed = verify_images(get_image_files(p/"0")) # verify images
# failed.map(Path.unlink) # delete failed images

Create a data loader and look at some of the images.

In [None]:
files_df = get_df_for_model(folder_name, "sg", valid_pct)
dls = ImageDataLoaders.from_df(
    files_df, 
    path=p,
    label_col="label",
    valid_col="is_valid", 
    item_tfms=Resize(224)
    );
dls.show_batch(max_n=6);

## Train and evaluate models

### Fit resnet50 model, full fit

#### Fitting

#### Validation sample predictions

### Fine-tune resnet50 model

#### Fitting

#### Validation sample predictions

### Fit resnet18 model, full fit

#### Fitting

In [None]:
learn = vision_learner(dls, 
                       resnet18,
                       metrics=error_rate
                       )

learn.fit_one_cycle(10)
learn.recorder.plot_loss()

# save model
learn.path = p_out
learn.save('resnet18_10_fit')
torch.save(dls, p_out/"models/resnet18_10_fit_dls.pkl") # save dataloaders

#### Validation sample predictions

In [None]:
# get predictions on validation set
pred_prob, pred_class = learn.get_preds() 
predictions_valset = aggregate_predictions_on_validation_set(pred_prob, pred_class, files_df)
predictions_valset.head(20)
print("\n\nF1 Score:", calc_score(predictions_valset))

### Fine-tune resnet18 model

#### Fitting

In [None]:
learn = vision_learner(dls, 
                       resnet18,
                       metrics=error_rate
                       )

learn.fine_tune(10)
learn.recorder.plot_loss()

# save model
learn.path = p_out
learn.save('resnet18_10_tuned') # save model
torch.save(dls, p_out/"models/resnet18_10_tuned_dls.pkl") # save dataloaders

#### Validation sample predictions

In [None]:
# get predictions on validation set
pred_prob, pred_class = learn.get_preds() 
predictions_valset = aggregate_predictions_on_validation_set(pred_prob, pred_class, files_df)
predictions_valset.head(20)
print("\n\nF1 Score:", calc_score(predictions_valset))

## Show results of last training run

## Prediction
We have multiple samples for each patient. We will aggregate the predictions for each patient using a simple average and predict the class with the highest resulting probability.

**Predictions on Test Set**

**Predictions on validation set**