---
title: Train ResNet Models on ECG Image Data 
author: Frederik Bennhoff
date: 01/12/2023
---

In [None]:
folder_name = "spectrograms_256"
valid_pct = 0.1
seed = 42

**Import libraries and set path**

In [None]:
from fastcore.all import *
from shutil import rmtree
from fastai.vision.all import *
import os
import pandas as pd
import numpy as np

# set seed
np.random.seed(seed)

# set up folder
p = Path('../data/'+folder_name) # relative path to folder with images
p_out = Path('../out/'+folder_name) # relative path to output folder
p_out.mkdir(parents=True, exist_ok=True)

**Function definitions**

In [None]:
def get_files_df(folder_name):
    files_df = pd.DataFrame([], columns=["filename", "id", "sequence", "label"])
    for type in range(4):
        # Get a list of all filenames in the folder
        folder_path = f'../data/{folder_name}/{type}/'  # replace with your folder path
        filenames = os.listdir(folder_path)

        # Create a DataFrame from the list
        df = pd.DataFrame(filenames, columns=['filename'])

        # Split the 'filename' column on '_'
        df[['sg', 'id', 'sequence']] = df['filename'].str.split('_', expand=True)

        # Split the 'number2' column on '.' to remove the file extension
        df['sequence'] = df['sequence'].str.split('.', expand=True)[0]

        df.sort_values(by=['id', 'sequence'], inplace=True)
        df.reset_index(inplace=True, drop=True)
        df["label"] = type
        df = df[df.sg == "sg"]
        df.drop(columns=["sg"], inplace=True)
        df['id'] = df['id'].to_numpy(dtype=int)
        df['sequence'] = df['sequence'].to_numpy(dtype=int)
        files_df = pd.concat([files_df, df], ignore_index=True)
    
    files_df.sort_values(by=['id', 'sequence'], inplace=True)
    files_df.reset_index(inplace=True, drop=True)
    files_df.reset_index(drop=True)
    
    return files_df

def get_df_for_model(folder_name, valid_pct):
    files_df = get_files_df(folder_name)
    # sample fraction 'valid_pct' of ids for validation sample
    ids = files_df.id.unique()
    valid_ids = np.random.choice(ids, int(len(ids)*valid_pct), replace=False)
    bool = [id in valid_ids for id in files_df.id]
    files_df["is_valid"] = bool
    files_df["filename"] = files_df.label.apply(lambda x: str(x)) + "/" + files_df["filename"]
    return files_df

## Create Data Loaders

In [None]:
# check for failed images (not an issue, so not active)
#resize_images(p/"0", max_size=400, dest = p/"0") # resize images
# failed = verify_images(get_image_files(p/"0")) # verify images
# failed.map(Path.unlink) # delete failed images

Create a data loader and look at some of the images.

In [None]:
files_df = get_df_for_model(folder_name, valid_pct)
dls = ImageDataLoaders.from_df(
    files_df, 
    path=p,
    label_col="label",
    valid_col="is_valid", 
    item_tfms=Resize(224)
    );
dls.show_batch(max_n=6);

## Train models

### Fit resnet50 model, full fit

In [None]:
learn = vision_learner(dls, 
                       resnet50,
                       metrics=error_rate
                       )

learn.fit_one_cycle(10)
learn.recorder.plot_loss()

# save model
learn.path = p_out
learn.save('resnet50-10_fit')

### Fine-tune resnet50 model

In [None]:
learn = vision_learner(dls, 
                       resnet50,
                       metrics=error_rate
                       )

learn.fine_tune(10)
learn.recorder.plot_loss()

# save model
learn.path = p_out
learn.save('resnet50-10_tuned')

### Fit resnet18 model, full fit

In [None]:
learn = vision_learner(dls, 
                       resnet18,
                       metrics=error_rate
                       )

learn.fit_one_cycle(20)
learn.recorder.plot_loss()

# save model
learn.path = p_out
learn.save('resnet18-20_fit')

### Fine-tune resnet18 model

In [None]:
learn = vision_learner(dls, 
                       resnet18,
                       metrics=error_rate
                       )

learn.fine_tune(6)
learn.recorder.plot_loss()

# save model
learn.path = p_out
learn.save('resnet18-6_tuned')

## Show results of last training run

In [None]:
learn.show_results()

## Prediction
We have multiple samples for each patient. We will aggregate the predictions for each patient using Bayes' rule and predict the class with the highest posterior probability.

In [None]:
p

In [None]:
inter = Interpretation.from_learner(learn, dl = dls[1])

In [None]:
files_df[files_df.is_valid == True].reset_index(drop=True)

In [None]:
pred_prob, pred_class = learn.get_preds()

In [None]:
predictions_df = pd.concat([
    files_df[files_df.is_valid == True].reset_index(drop=True),
    pd.DataFrame(pred_prob.numpy(), columns=["p_0", "p_1", "p_2", "p_3"]),
    pd.DataFrame(pred_class.numpy(), columns=["pred_class"])
], axis=1)
max_cat = predictions_df.groupby("id").agg({"p_0": "mean", "p_1": "mean", "p_2": "mean", "p_3": "mean"}).apply(lambda x: x.argmax(), axis=1).reset_index(drop=True);
mean_probs = predictions_df.groupby("id").agg({"p_0": "mean", "p_1": "mean", "p_2": "mean", "p_3": "mean"}).reset_index()
predictions = pd.concat([mean_probs, max_cat], axis=1)
predictions.columns = ["id", "p_0", "p_1", "p_2", "p_3", "pred_class"]
predictions = pd.merge(files_df[files_df.is_valid == True],predictions, on="id")

In [None]:
predictions.pred_class

In [None]:
help(f1_score)

In [None]:
len(predictions["label"].to_numpy(dtype=np.int32)), len(predictions["pred_class"].to_numpy())

In [None]:
from sklearn.metrics import f1_score
f1_score(predictions["label"].to_numpy(dtype=np.int32), predictions["pred_class"].to_numpy(dtype=np.int32), average='micro')

In [None]:
learn.show_results()

In [None]:
files_df[files_df["label"]==3]

In [None]:
# Load the saved model
learn.load('resnet18-20_fit')

In [None]:
import PIL

def open_image(fname, size=224):
    img = PIL.Image.open(fname).convert('RGB')
    img = img.resize((size, size))
    t = torch.Tensor(np.array(img))
    return t.permute(2,0,1).float()/255.0

# Make a prediction
img = PILImage.create(p/files_df.filename[0])  # replace with the path to your image
pred_class, pred_idx, probs = learn.predict(img)

print(pred_idx)

In [None]:
im.show(title=f'This food is {label} with probability {probs[0]:.4f}.')

In [None]:
learn.predict(img)