# Deep Learning

In [1]:
import os
import itertools
import time
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import (CosineAnnealingLR,
                                      CosineAnnealingWarmRestarts,
                                      StepLR,
                                      ExponentialLR)

from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import accuracy_score, auc, f1_score, precision_score, recall_score

In [2]:
%matplotlib inline

## Loading Data

In [3]:
# ecg_path = "C://Users//HP//Documents//Varun//DHAI//Capstone//checkpoint//processed_data//ecg_processed_data.csv"
ecg_path = "C://Users//HP//Documents//Varun//DHAI//Capstone//checkpoint//processed_data//ecg_processed_data_2.csv"

In [4]:
ecg_data = pd.read_csv(ecg_path, index_col=0)

In [5]:
ecg_data.head()

Unnamed: 0_level_0,Segment Start,Segment End,xs0,xs1,xs2,xs3,xs4,xs5,xs6,xs7,...,xs193,xs194,xs195,xs196,xs197,xs198,xs199,Annotation Class,Annotation Class Numeric,acn
Record ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
100,13,213,0.059449,0.055308,0.052035,0.049666,0.047961,0.046922,0.046801,0.04777,...,0.033233,0.033981,0.034276,0.034166,0.033702,0.033134,0.032648,N,1,0
100,307,507,0.084239,0.087566,0.088956,0.088621,0.086919,0.084459,0.081823,0.079506,...,0.020049,0.020272,0.020719,0.020971,0.020866,0.020622,0.020434,N,1,0
100,563,763,-0.006937,-0.00569,-0.005184,-0.005172,-0.005405,-0.005802,-0.006385,-0.007163,...,-0.084798,-0.084841,-0.084325,-0.082953,-0.080644,-0.077424,-0.073181,N,1,0
100,883,1083,0.072086,0.073128,0.073026,0.071802,0.069731,0.067336,0.065156,0.063624,...,0.040605,0.040177,0.039328,0.038051,0.036671,0.035565,0.034817,N,1,0
100,1168,1368,0.084762,0.083995,0.082327,0.080127,0.077701,0.075216,0.072963,0.071264,...,0.025977,0.027249,0.029211,0.031715,0.034364,0.036462,0.037443,N,1,0


In [6]:
ecg_data["Annotation Class"].unique()

array(['N', 'V', '/', 'L', 'R'], dtype=object)

In [7]:
ecg_data["Annotation Class"].nunique()

5

In [8]:
ecg_data["Annotation Class"].value_counts()

N    73439
L     8068
R     7255
V     6793
/     3619
Name: Annotation Class, dtype: int64

## Preprocessing

### Converting Non Numeric Column to Numeric discontinuous Columns

In [15]:
ecg_data.head()

Unnamed: 0_level_0,Segment Start,Segment End,xs0,xs1,xs2,xs3,xs4,xs5,xs6,xs7,...,xs193,xs194,xs195,xs196,xs197,xs198,xs199,Annotation Class,Annotation Class Numeric,acn
Record ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
118,10,210,0.070648,0.056115,0.036897,0.013635,-0.012495,-0.039801,-0.066433,-0.090459,...,0.091568,0.08972,0.091422,0.095186,0.099423,0.103057,0.105667,R,5,4
118,284,484,0.181398,0.193873,0.208499,0.225425,0.24424,0.263613,0.280935,0.292966,...,-0.022158,-0.008332,0.000986,0.007224,0.012365,0.018527,0.027373,R,5,4
118,589,789,0.153018,0.163216,0.175094,0.188575,0.20313,0.218118,0.232817,0.246466,...,-0.055312,-0.04935,-0.048095,-0.04838,-0.046108,-0.037715,-0.021261,R,5,4
118,895,1095,0.178514,0.188845,0.200294,0.213053,0.226713,0.240282,0.252544,0.262816,...,0.01627,0.026095,0.032914,0.038232,0.043546,0.049988,0.058251,R,5,4
118,1221,1421,0.226174,0.224501,0.21723,0.205161,0.189663,0.172243,0.154287,0.137009,...,0.095548,0.092868,0.09053,0.089298,0.089521,0.091152,0.093633,R,5,4


In [10]:
# rename_col_dict = {str(i): f"xs{i}" for i in range(0, 200)}
# ecg_data = ecg_data.rename(columns=rename_col_dict)
# ecg_data.to_csv("C://Users//HP//Documents//Varun//DHAI//Capstone//checkpoint//processed_data//ecg_processed_data_2.csv", index=False)

In [12]:
id_to_label = {val: idx for idx, val in enumerate(list(ecg_data["Annotation Class"].unique()))}
print(id_to_label)

{'N': 0, 'V': 1, '/': 2, 'L': 3, 'R': 4}


In [None]:
x_columns = [f"xs{i}" for i in range(0, 200)]

In [None]:
import torch

class Config:
    csv_path = ''
    seed = 2021
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    attn_state_path = 'C://Users//HP//Documents//Varun//DHAI//Capstone//checkpoint//processed_data//attn.pth'
    lstm_state_path = 'C://Users//HP//Documents//Varun//DHAI//Capstone//checkpoint//processed_data//lstm.pth'
    cnn_state_path = 'C://Users//HP//Documents//Varun//DHAI//Capstone//checkpoint//processed_data//cnn.pth'
    
    attn_logs = 'C://Users//HP//Documents//Varun//DHAI//Capstone//checkpoint//processed_data//attn.csv'
    lstm_logs = 'C://Users//HP//Documents//Varun//DHAI//Capstone//checkpoint//processed_data//lstm.csv'
    cnn_logs = 'C://Users//HP//Documents//Varun//DHAI//Capstone//checkpoint//processed_data//cnn.csv'
    
    ecg_csv_path = 'C://Users//HP//Documents//Varun//DHAI//Capstone//checkpoint//processed_data//ecg_processed_data_2.csv'


def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
config = Config()
seed_everything(config.seed)

## What each Annotation looks like?

In [None]:
for symbol in ecg_data["Annotation Class"].unique():
    readings = (
        ecg_data[ecg_data["Annotation Class"] == symbol].head(1)[x_columns].values[0]
    )
    plt.figure(figsize=(10, 8))
    plt.title(label=symbol)
    plt.plot(readings)
    plt.show()

## Data Loader

In [None]:
class ECGDataset(Dataset):

    def __init__(self, df):
        self.df = df
        self.X = torch.tensor(self.df.iloc[:,3:-3].values, dtype=torch.float32)
        self.y = torch.tensor(self.df.iloc[:,-1].values, dtype=torch.long) 

    def __getitem__(self, idx):
        signal = torch.reshape(self.X[idx],(1,200))
        target = self.y[idx]
        return signal, target

    def __len__(self):
        return len(self.df)

In [None]:
def get_dataloader(train_csv_path: str, phase: str, batch_size: int = 96) -> DataLoader:
    '''
    Dataset and DataLoader.
    Parameters:
        train_csv_path: processed_data path.
        phase: training, testing or validation phase.
        target: Annotation Class Numeric
        signal_cols: columns belonging to signal
        batch_size: data per iteration.
    Returns:
        data generator
    '''
    df = pd.read_csv(train_csv_path)
   
    t_df, test_df = train_test_split( df, test_size=0.2, random_state=42, stratify=df['acn'])
    
    t_df, test_df = t_df.reset_index(drop=True), test_df.reset_index(drop=True)
    
    train_df, val_df = train_test_split(t_df, test_size=0.25, random_state=42, stratify=t_df['acn'])

    train_df, val_df = train_df.reset_index(drop=True), val_df.reset_index(drop=True)

    if phase == 'train':
        df = train_df
    elif phase == 'val':
        df = val_df
    else:
        df = test_df

    dataset = ECGDataset(df)
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=4)
    
    return dataloader

## Trainer and Metrics

In [None]:
class Meter:
    def __init__(self, n_classes=5):
        self.metrics = {}
        self.confusion = torch.zeros((n_classes, n_classes))

    def update(self, x, y, loss):
        x = np.argmax(x.detach().cpu().numpy(), axis=1)
        y = y.detach().cpu().numpy()
        self.metrics['loss'] += loss
        self.metrics['accuracy'] += accuracy_score(x, y)
        self.metrics['f1'] += f1_score(x, y, average='macro')
        self.metrics['precision'] += precision_score(x, y, average='macro', zero_division=1)
        self.metrics['recall'] += recall_score(x, y, average='macro', zero_division=1)

        self._compute_cm(x, y)

    def _compute_cm(self, x, y):
        for prob, target in zip(x, y):
            if prob == target:
                self.confusion[target][target] += 1
            else:
                self.confusion[target][prob] += 1

    def init_metrics(self):
        self.metrics['loss'] = 0
        self.metrics['accuracy'] = 0
        self.metrics['f1'] = 0
        self.metrics['precision'] = 0
        self.metrics['recall'] = 0

    def get_metrics(self):
        return self.metrics

    def get_confusion_matrix(self):
        return self.confusion


class Trainer:
    def __init__(self, train_csv_path, net, lr, batch_size, num_epochs):
        self.net = net.to(config.device)
        self.num_epochs = num_epochs
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = AdamW(self.net.parameters(), lr=lr)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=num_epochs, eta_min=5e-6)
        self.best_loss = float('inf')
        self.phases = ['train', 'val']
        self.dataloaders = {
            phase: get_dataloader(train_csv_path, phase, batch_size) for phase in self.phases
        }
        self.train_df_logs = pd.DataFrame()
        self.val_df_logs = pd.DataFrame()

    def _train_epoch(self, phase):
        print(f"{phase} mode | time: {time.strftime('%H:%M:%S')}")

        self.net.train() if phase == 'train' else self.net.eval()
        meter = Meter()
        meter.init_metrics()
        

        for i, (data, target) in enumerate(self.dataloaders[phase]):
            data = data.to(config.device)
            target = target.to(config.device)

            output = self.net(data)
            loss = self.criterion(output, target)

            if phase == 'train':
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            meter.update(output, target, loss.item())
 
            
        metrics = meter.get_metrics()
        metrics = {k: v / i for k, v in metrics.items()}
        df_logs = pd.DataFrame([metrics])
        confusion_matrix = meter.get_confusion_matrix()

        if phase == 'train':
            self.train_df_logs = pd.concat([self.train_df_logs, df_logs], axis=0)
        else:
            self.val_df_logs = pd.concat([self.val_df_logs, df_logs], axis=0)

        # show logs
        print('{}: {}, {}: {}, {}: {}, {}: {}, {}: {}'
              .format(*(x for kv in metrics.items() for x in kv))
              )
        fig, ax = plt.subplots(figsize=(5, 5))
        cm_ = ax.imshow(confusion_matrix, cmap='hot')
        ax.set_title('Confusion matrix', fontsize=15)
        ax.set_xlabel('Actual', fontsize=13)
        ax.set_ylabel('Predicted', fontsize=13)
        plt.colorbar(cm_)
        plt.show()

        return loss

    def run(self):
        for epoch in range(self.num_epochs):
            self._train_epoch(phase='train')
            with torch.no_grad():
                val_loss = self._train_epoch(phase='val')
                self.scheduler.step()

            if val_loss < self.best_loss:
                self.best_loss = val_loss
                print('\nNew checkpoint\n')
                self.best_loss = val_loss
                torch.save(self.net.state_dict(), f"best_model_epoc{epoch}.pth")
            # clear_output()


## Models

In [None]:
from utils.models import CNN, RNNModel, RNNAttentionModel

In [None]:
model = CNN(num_classes=5, hid_size=128)

## Experimentation

In [None]:
trainer = Trainer(train_csv_path=ecg_path , net=model, lr=1e-3, batch_size=96, num_epochs=10)#100)
trainer.run()

In [None]:
train_logs = trainer.train_df_logs
train_logs.columns = ["train_"+ colname for colname in train_logs.columns]
val_logs = trainer.val_df_logs
val_logs.columns = ["val_"+ colname for colname in val_logs.columns]

logs = pd.concat([train_logs,val_logs], axis=1)
logs.reset_index(drop=True, inplace=True)
logs = logs.loc[:, [
    'train_loss', 'val_loss', 
    'train_accuracy', 'val_accuracy', 
    'train_f1', 'val_f1',
    'train_precision', 'val_precision',
    'train_recall', 'val_recall']
                                 ]
logs.head()
logs.to_csv('cnn.csv', index=False)

In [None]:
cnn_model = CNN(num_classes=5, hid_size=128).to(config.device)
cnn_model.load_state_dict(
    torch.load(config.cnn_state_path,
               map_location=config.device)
);
cnn_model.eval();
logs = pd.read_csv(config.cnn_logs)

In [None]:
colors = ['#C042FF', '#03C576FF', '#FF355A', '#03C5BF', '#96C503', '#C5035B']
palettes = [sns.color_palette(colors, 2),
            sns.color_palette(colors, 4), 
            sns.color_palette(colors[:2]+colors[-2:] + colors[2:-2], 6)]
            
fig, ax = plt.subplots(1, 2, figsize=(12, 4))

sns.lineplot(data=logs.iloc[:, :2], palette=palettes[0], markers=True, ax=ax[0], linewidth=2.5,)
ax[0].set_title("Loss Function during Model Training", fontsize=14)
ax[0].set_xlabel("Epoch", fontsize=14)

sns.lineplot(data=logs.iloc[:, 2:6], palette=palettes[1], markers=True, ax=ax[1], linewidth=2.5, legend="full")
ax[1].set_title("Metrics during Model Training", fontsize=15)
ax[1].set_xlabel("Epoch", fontsize=14)

plt.suptitle('CNN Model', fontsize=18)

plt.tight_layout()
fig.savefig("cnn.png", format="png",  pad_inches=0.2, transparent=False, bbox_inches='tight')
fig.savefig("cnn.svg", format="svg",  pad_inches=0.2, transparent=False, bbox_inches='tight')

In [None]:
lstm_model = RNNModel(1, 64, 'lstm', True).to(config.device)
lstm_model.load_state_dict(
    torch.load(config.lstm_state_path,
               map_location=config.device)
);
lstm_model.eval();
logs = pd.read_csv(config.lstm_logs)

In [None]:
colors = ['#C042FF', '#03C576FF', '#FF355A', '#03C5BF', '#96C503', '#C5035B']
palettes = [sns.color_palette(colors, 2),
            sns.color_palette(colors, 4), 
            sns.color_palette(colors[:2]+colors[-2:] + colors[2:-2], 6)]
            
fig, ax = plt.subplots(1, 2, figsize=(12, 4))

sns.lineplot(data=logs.iloc[:, :2], palette=palettes[0], markers=True, ax=ax[0], linewidth=2.5,)
ax[0].set_title("Loss Function during Model Training", fontsize=14)
ax[0].set_xlabel("Epoch", fontsize=14)

sns.lineplot(data=logs.iloc[:, 2:6], palette=palettes[1], markers=True, ax=ax[1], linewidth=2.5, legend="full")
ax[1].set_title("Metrics during Model Training", fontsize=15)
ax[1].set_xlabel("Epoch", fontsize=14)

plt.suptitle('CNN+LSTM Model', fontsize=18)

plt.tight_layout()
fig.savefig("lstm.png", format="png",  pad_inches=0.2, transparent=False, bbox_inches='tight')
fig.savefig("lstm.svg", format="svg",  pad_inches=0.2, transparent=False, bbox_inches='tight')

In [None]:
attn_model = RNNAttentionModel(1, 64, 'lstm', False).to(config.device)
attn_model.load_state_dict(
    torch.load(config.attn_state_path,
               map_location=config.device)
);
attn_model.eval();
logs = pd.read_csv(config.attn_logs)

In [None]:
colors = ['#C042FF', '#03C576FF', '#FF355A', '#03C5BF', '#96C503', '#C5035B']
palettes = [sns.color_palette(colors, 2),
            sns.color_palette(colors, 4), 
            sns.color_palette(colors[:2]+colors[-2:] + colors[2:-2], 6)]
            
fig, ax = plt.subplots(1, 2, figsize=(12, 4))

sns.lineplot(data=logs.iloc[:, :2], palette=palettes[0], markers=True, ax=ax[0], linewidth=2.5,)
ax[0].set_title("Loss Function during Model Training", fontsize=14)
ax[0].set_xlabel("Epoch", fontsize=14)

sns.lineplot(data=logs.iloc[:, 2:6], palette=palettes[1], markers=True, ax=ax[1], linewidth=2.5, legend="full")
ax[1].set_title("Metrics during Model Training", fontsize=15)
ax[1].set_xlabel("Epoch", fontsize=14)

plt.suptitle('CNN+LSTM+Attention Model', fontsize=18)

plt.tight_layout()
fig.savefig("attn.png", format="png",  pad_inches=0.2, transparent=False, bbox_inches='tight')
fig.savefig("attn.svg", format="svg",  pad_inches=0.2, transparent=False, bbox_inches='tight')

## Testing

In [None]:
test_dataloader = get_dataloader(train_csv_path=config.ecg_csv_path, phase='test', batch_size =96)

In [None]:
def make_test_stage(dataloader, model, probs=False):
    cls_predictions = []
    cls_ground_truths = []

    for i, (data, cls_target) in enumerate(dataloader):
        with torch.no_grad():

            data = data.to(config.device)
            cls_target = cls_target.cpu()
            cls_prediction = model(data)
            
            if not probs:
                cls_prediction = torch.argmax(cls_prediction, dim=1)
    
            cls_predictions.append(cls_prediction.detach().cpu())
            cls_ground_truths.append(cls_target)

    predictions_cls = torch.cat(cls_predictions).numpy()
    ground_truths_cls = torch.cat(cls_ground_truths).numpy()
    return predictions_cls, ground_truths_cls

In [None]:
models = [cnn_model, lstm_model, attn_model]


In [None]:
y_pred, y_true = make_test_stage(test_dataloader, models[0])
y_pred.shape, y_true.shape

In [None]:
report = pd.DataFrame(
    classification_report(
        y_pred,
        y_true,
        output_dict=True
    )
).transpose()

In [None]:
colors = ['#00FA9A', '#D2B48C', '#FF69B4']#random.choices(list(mcolors.CSS4_COLORS.values()), k = 3)
report_plot = report.apply(lambda x: x*100)
ax = report_plot[["precision", "recall", "f1-score"]].plot(kind='bar',
                                                      figsize=(13, 4), legend=True, fontsize=15, color=colors)

ax.set_xlabel("Estimators", fontsize=15)
ax.set_xticklabels(
    list(id_to_label.values())+["accuracy avg", "marco avg", "weighted avg"],
    rotation=15, fontsize=11)
ax.set_ylabel("Percentage", fontsize=15)
plt.title("CNN Model Classification Report", fontsize=20)

for percentage, p in zip(
    report[['precision', 'recall', 'f1-score']].values,
    ax.patches):
    
    percentage = " ".join([str(round(i*100, 2))+"%" for i in percentage])
    x = p.get_x() + p.get_width() - 0.4
    y = p.get_y() + p.get_height() / 4
    ax.annotate(percentage, (x, y), fontsize=8, rotation=15, fontweight='bold')
fig.savefig("cnn_report.png", format="png",  pad_inches=0.2, transparent=False, bbox_inches='tight')
fig.savefig("cnn_report.svg", format="svg",  pad_inches=0.2, transparent=False, bbox_inches='tight')
plt.show()

In [None]:
# cnn +lstm

y_pred, y_true = make_test_stage(test_dataloader, models[1])
y_pred.shape, y_true.shape

In [None]:
report = pd.DataFrame(
    classification_report(
        y_pred,
        y_true,
        output_dict=True
    )
).transpose()

In [None]:
colors = ['#00FA9A', '#D2B48C', '#FF69B4']#random.choices(list(mcolors.CSS4_COLORS.values()), k = 3)
report_plot = report.apply(lambda x: x*100)
ax = report_plot[["precision", "recall", "f1-score"]].plot(kind='bar',
                                                      figsize=(13, 4), legend=True, fontsize=15, color=colors)

ax.set_xlabel("Estimators", fontsize=15)
ax.set_xticklabels(
    list(id_to_label.values())+["accuracy avg", "marco avg", "weighted avg"],
    rotation=15, fontsize=11)
ax.set_ylabel("Percentage", fontsize=15)
plt.title("CNN+LSTM Model Classification Report", fontsize=20)

for percentage, p in zip(
    report[['precision', 'recall', 'f1-score']].values,
    ax.patches):
    
    percentage = " ".join([str(round(i*100, 2))+"%" for i in percentage])
    x = p.get_x() + p.get_width() - 0.4
    y = p.get_y() + p.get_height() / 4
    ax.annotate(percentage, (x, y), fontsize=8, rotation=15, fontweight='bold')
fig.savefig("lstm_report.png", format="png",  pad_inches=0.2, transparent=False, bbox_inches='tight')
fig.savefig("lstm_report.svg", format="svg",  pad_inches=0.2, transparent=False, bbox_inches='tight')
plt.show()

In [None]:
# cnn + lstm + attn 
y_pred, y_true = make_test_stage(test_dataloader, models[2])
y_pred.shape, y_true.shape

In [None]:
report = pd.DataFrame(
    classification_report(
        y_pred,
        y_true,
        output_dict=True
    )
).transpose()

In [None]:
colors = ['#00FA9A', '#D2B48C', '#FF69B4']#random.choices(list(mcolors.CSS4_COLORS.values()), k = 3)
report_plot = report.apply(lambda x: x*100)
ax = report_plot[["precision", "recall", "f1-score"]].plot(kind='bar',
                                                      figsize=(13, 4), legend=True, fontsize=15, color=colors)

ax.set_xlabel("Estimators", fontsize=15)
ax.set_xticklabels(
    list(id_to_label.values())+["accuracy avg", "marco avg", "weighted avg"],
    rotation=15, fontsize=11)
ax.set_ylabel("Percentage", fontsize=15)
plt.title("CNN+LSTM+Attention Model Classification Report", fontsize=20)

for percentage, p in zip(
    report[['precision', 'recall', 'f1-score']].values,
    ax.patches):
    
    percentage = " ".join([str(round(i*100, 2))+"%" for i in percentage])
    x = p.get_x() + p.get_width() - 0.4
    y = p.get_y() + p.get_height() / 4
    ax.annotate(percentage, (x, y), fontsize=8, rotation=15, fontweight='bold')
fig.savefig("attn_report.png", format="png",  pad_inches=0.2, transparent=False, bbox_inches='tight')
fig.savefig("attn_report.svg", format="svg",  pad_inches=0.2, transparent=False, bbox_inches='tight')
plt.show()

In [None]:
## Ensemble

y_pred = np.zeros((y_pred.shape[0], 5), dtype=np.float32)
for i, model in enumerate(models, 1):
    y_pred_, y_true = make_test_stage(test_dataloader, model, True)
    y_pred += y_pred_
y_pred /= i
y_pred = np.argmax(y_pred, axis=1)

In [None]:
clf_report = classification_report(y_pred, 
                                   y_true,
                                   labels=[0,1,2,3,4],
                                   target_names=list(id_to_label.values()),#['N', 'S', 'V', 'F', 'Q'],
                                   output_dict=True)


plt.figure(figsize=(10, 8))
ax = sns.heatmap(pd.DataFrame(clf_report).iloc[:-1, :].T, annot=True)
ax.set_xticklabels(ax.get_xticklabels(),fontsize=15)
ax.set_yticklabels(ax.get_yticklabels(),fontsize=12, rotation=0)
plt.title("Ensemble Classification Report", fontsize=20)
plt.savefig(f"ensemble result.svg",format="svg",bbox_inches='tight', pad_inches=0.2)
plt.savefig(f"ensemble result.png", format="png",bbox_inches='tight', pad_inches=0.2)

In [None]:
clf_report