In [1]:
import os
import itertools
import time
import random
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns


import torchvision
from torchsummary import summary
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import (CosineAnnealingLR,
                                      CosineAnnealingWarmRestarts,
                                      StepLR,
                                      ExponentialLR)

from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import accuracy_score, auc, f1_score, precision_score, recall_score

In [2]:
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Now you can access files in your Google Drive


Mounted at /content/drive


In [3]:
mitbih_test = pd.read_csv('/content/drive/MyDrive/ecg/mitbih_test.csv', header=None)
mitbih_train = pd.read_csv('/content/drive/MyDrive/ecg/mitbih_train.csv', header=None)

In [4]:
mitbih_test.rename(columns={187: 'class'}, inplace=True)
id_to_label = {
    0: "Normal",
    1: "Artial Premature",
    2: "Premature ventricular contraction",
    3: "Fusion of ventricular and normal",
    4: "Fusion of paced and normal"
}
mitbih_test['label'] = mitbih_test.iloc[:, -1].map(id_to_label)


In [5]:
mitbih_train.rename(columns={187: 'class'}, inplace=True)
id_to_label = {
    0: "Normal",
    1: "Artial Premature",
    2: "Premature ventricular contraction",
    3: "Fusion of ventricular and normal",
    4: "Fusion of paced and normal"
}
mitbih_train['label'] = mitbih_train.iloc[:, -1].map(id_to_label)


In [6]:
mitbih_test.to_csv('mitbih_test_new.csv', index=False)
mitbih_train.to_csv('mitbih_train_new.csv', index=False)

In [7]:
class Config:
    csv_path = ''
    seed = 2021
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    train_csv_path = 'mitbih_train_new.csv'
    test_csv_path = 'mitbih_test_new.csv'

In [8]:
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

In [9]:
config = Config()
seed_everything(config.seed)

In [10]:
class ECGDataset(Dataset):

    def __init__(self, df):
        self.df = df
        self.data_columns = self.df.columns[:-2].tolist()

    def __getitem__(self, idx):
        signal = self.df.loc[idx, self.data_columns].astype('float32')
        signal = torch.FloatTensor([signal.values])
        target = torch.LongTensor(np.array(self.df.loc[idx, 'class']))
        return signal, target

    def __len__(self):
        return len(self.df)

In [15]:
def get_dataloader(phase: str, batch_size: int = 96) -> DataLoader:
    '''
    Dataset and DataLoader.
    Parameters:
        pahse: training or validation phase.
        batch_size: data per iteration.
    Returns:
        data generator
    '''
    df = pd.read_csv(config.train_csv_path)
    train_df, val_df = train_test_split(
        df, test_size=0.15, random_state=config.seed, stratify=df['label']
    )
    train_df, val_df = train_df.reset_index(drop=True), val_df.reset_index(drop=True)
    df = train_df if phase == 'train' else val_df
    dataset = ECGDataset(df)
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=4)
    return dataloader

In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import models
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import pandas as pd
import numpy as np
from tqdm import tqdm
import time
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.io import loadmat
from sklearn.preprocessing import StandardScaler

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Swish activation function
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

# Convolutional, normalization, and pooling block
class ConvNormPool(nn.Module):
    def __init__(self, input_size, hidden_size, kernel_size, norm_type='batchnorm'):
        super(ConvNormPool, self).__init__()

        self.kernel_size = kernel_size
        self.conv_1 = nn.Conv1d(in_channels=input_size, out_channels=hidden_size, kernel_size=kernel_size)
        self.conv_2 = nn.Conv1d(in_channels=hidden_size, out_channels=hidden_size, kernel_size=kernel_size)
        self.conv_3 = nn.Conv1d(in_channels=hidden_size, out_channels=hidden_size, kernel_size=kernel_size)
        self.swish_1 = Swish()
        self.swish_2 = Swish()
        self.swish_3 = Swish()

        if norm_type == 'group':
            self.normalization_1 = nn.GroupNorm(num_groups=8, num_channels=hidden_size)
            self.normalization_2 = nn.GroupNorm(num_groups=8, num_channels=hidden_size)
            self.normalization_3 = nn.GroupNorm(num_groups=8, num_channels=hidden_size)
        else:
            self.normalization_1 = nn.BatchNorm1d(num_features=hidden_size)
            self.normalization_2 = nn.BatchNorm1d(num_features=hidden_size)
            self.normalization_3 = nn.BatchNorm1d(num_features=hidden_size)

        self.pool = nn.MaxPool1d(kernel_size=2)

    def forward(self, input):
        conv1 = self.conv_1(input)
        x = self.normalization_1(conv1)
        x = self.swish_1(x)
        x = nn.functional.pad(x, pad=(self.kernel_size - 1, 0))

        x = self.conv_2(x)
        x = self.normalization_2(x)
        x = self.swish_2(x)
        x = nn.functional.pad(x, pad=(self.kernel_size - 1, 0))

        conv3 = self.conv_3(x)
        x = self.normalization_3(conv1 + conv3)
        x = self.swish_3(x)
        x = nn.functional.pad(x, pad=(self.kernel_size - 1, 0))

        x = self.pool(x)
        return x

# RNN module (LSTM or GRU)
class RNN(nn.Module):
    def __init__(self, input_size, hid_size, num_rnn_layers=1, dropout_p=0.2, bidirectional=False, rnn_type='lstm'):
        super(RNN, self).__init__()

        if rnn_type == 'lstm':
            self.rnn_layer = nn.LSTM(
                input_size=input_size,
                hidden_size=hid_size,
                num_layers=num_rnn_layers,
                dropout=dropout_p if num_rnn_layers > 1 else 0,
                bidirectional=bidirectional,
                batch_first=True,
            )
        else:
            self.rnn_layer = nn.GRU(
                input_size=input_size,
                hidden_size=hid_size,
                num_layers=num_rnn_layers,
                dropout=dropout_p if num_rnn_layers > 1 else 0,
                bidirectional=bidirectional,
                batch_first=True,
            )

    def forward(self, input):
        outputs, hidden_states = self.rnn_layer(input)
        return outputs, hidden_states

# Combined CNN+LSTM model
class CNNLSTMModel(nn.Module):
    def __init__(self, cnn_input_size, cnn_hid_size, rnn_hid_size, rnn_type, bidirectional, n_classes=5, kernel_size=5):
        super(CNNLSTMModel, self).__init__()

        self.rnn_layer = RNN(
            input_size=cnn_hid_size * 2 if bidirectional else cnn_hid_size,
            hid_size=rnn_hid_size,
            rnn_type=rnn_type,
            bidirectional=bidirectional
        )
        self.conv1 = ConvNormPool(
            input_size=cnn_input_size,
            hidden_size=cnn_hid_size,
            kernel_size=kernel_size,
        )
        self.conv2 = ConvNormPool(
            input_size=cnn_hid_size,
            hidden_size=cnn_hid_size,
            kernel_size=kernel_size,
        )
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(in_features=cnn_hid_size, out_features=n_classes)

    def forward(self, input):
        x = self.conv1(input)
        x = self.conv2(x)
        x, _ = self.rnn_layer(x)
        x = self.avgpool(x)
        x = x.view(-1, x.size(1) * x.size(2))
        x = nn.functional.softmax(self.fc(x), dim=1)
        return x

# Meter class for tracking metrics
class Meter:
    def __init__(self, n_classes=5):
        self.metrics = {}
        self.confusion = torch.zeros((n_classes, n_classes))

    def update(self, x, y, loss):
        x = np.argmax(x.detach().cpu().numpy(), axis=1)
        y = y.detach().cpu().numpy()
        self.metrics['loss'] += loss
        self.metrics['accuracy'] += accuracy_score(x, y)
        self.metrics['f1'] += f1_score(x, y, average='macro')
        self.metrics['precision'] += precision_score(x, y, average='macro', zero_division=1)
        self.metrics['recall'] += recall_score(x, y, average='macro', zero_division=1)

        self._compute_cm(x, y)

    def _compute_cm(self, x, y):
        for prob, target in zip(x, y):
            if prob == target:
                self.confusion[target][target] += 1
            else:
                self.confusion[target][prob] += 1

    def init_metrics(self):
        self.metrics['loss'] = 0
        self.metrics['accuracy'] = 0
        self.metrics['f1'] = 0
        self.metrics['precision'] = 0
        self.metrics['recall'] = 0

    def get_metrics(self):
        return self.metrics

    def get_confusion_matrix(self):
        return self.confusion

# Trainer class
class Trainer:
    def __init__(self, net, lr, batch_size, num_epochs):
        self.net = net.to(device)
        self.num_epochs = num_epochs
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.AdamW(self.net.parameters(), lr=lr)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=num_epochs, eta_min=5e-6)
        self.best_loss = float('inf')
        self.phases = ['train', 'val']
        self.dataloaders = {
            phase: get_dataloader(phase, batch_size) for phase in self.phases
        }
        self.train_df_logs = pd.DataFrame()
        self.val_df_logs = pd.DataFrame()

    def _train_epoch(self, phase):
        print(f"{phase} mode | time: {time.strftime('%H:%M:%S')}")

        self.net.train() if phase == 'train' else self.net.eval()
        meter = Meter()
        meter.init_metrics()

        for i, (data, target) in enumerate(tqdm(self.dataloaders[phase])):
            data = data.to(device)
            target = target.to(device)

            output = self.net(data)
            loss = self.criterion(output, target)

            if phase == 'train':
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            meter.update(output, target, loss.item())

        metrics = meter.get_metrics()
        metrics = {k: v / i for k, v in metrics.items()}
        df_logs = pd.DataFrame([metrics])
        confusion_matrix = meter.get_confusion_matrix()

        if phase == 'train':
            self.train_df_logs = pd.concat([self.train_df_logs, df_logs], axis=0)
        else:
            self.val_df_logs = pd.concat([self.val_df_logs, df_logs], axis=0)

        # show logs
        print('{}: {}, {}: {}, {}: {}, {}: {}, {}: {}'
              .format(*(x for kv in metrics.items() for x in kv)))

        return loss

    def run(self):
        for epoch in range(self.num_epochs):
            self._train_epoch(phase='train')
            with torch.no_grad():
                val_loss = self._train_epoch(phase='val')
                self.scheduler.step()

            if val_loss < self.best_loss:
                self.best_loss = val_loss
                print('\nNew checkpoint\n')
                torch.save(self.net.state_dict(), f"best_model_epoch{epoch}.pth")

model = CNNLSTMModel(1, 64, 64, 'lstm', True)
trainer = Trainer(net=model, lr=1e-3, batch_size=96, num_epochs=10)
trainer.run()




train mode | time: 15:59:07


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 776/776 [03:55<00:00,  3.30it/s]


loss: 1.0160900411298197, accuracy: 0.9109623655913967, f1: 0.512061820943652, precision: 0.5259997922520085, recall: 0.9124008559795304
val mode | time: 16:03:02


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 137/137 [00:16<00:00,  8.24it/s]


loss: 0.9829451082383885, accuracy: 0.9403810803167422, f1: 0.6168885061127378, precision: 0.6170944636464478, recall: 0.9361581928007615

New checkpoint

train mode | time: 16:03:19


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 776/776 [03:57<00:00,  3.27it/s]


loss: 0.9629422598500406, accuracy: 0.9453682795698937, f1: 0.6332502010160292, precision: 0.6314744043248198, recall: 0.9556641435572446
val mode | time: 16:07:17


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 137/137 [00:16<00:00,  8.17it/s]


loss: 0.9648529711891624, accuracy: 0.954874858597285, f1: 0.6492029617786615, precision: 0.6476272426860423, recall: 0.9667853804665567

New checkpoint

train mode | time: 16:07:33


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 776/776 [03:58<00:00,  3.25it/s]


loss: 0.9563253931076295, accuracy: 0.9515107526881722, f1: 0.6472223476553586, precision: 0.6461836905861901, recall: 0.9653100901546475
val mode | time: 16:11:32


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 137/137 [00:16<00:00,  8.11it/s]


loss: 0.9603227854651564, accuracy: 0.9586279223227754, f1: 0.6564205122999177, precision: 0.6482972011092157, recall: 0.9809517756395036
train mode | time: 16:11:49


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 776/776 [04:02<00:00,  3.20it/s]


loss: 0.9538436328980231, accuracy: 0.9536209677419357, f1: 0.652200571609294, precision: 0.6527467924704757, recall: 0.9672232228858396
val mode | time: 16:15:52


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 137/137 [00:16<00:00,  8.16it/s]


loss: 0.9618870081270442, accuracy: 0.9567896870286577, f1: 0.654835938735793, precision: 0.6465616061038603, recall: 0.9788103196499778
train mode | time: 16:16:09


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 776/776 [03:58<00:00,  3.25it/s]


loss: 0.953130817720967, accuracy: 0.9544005376344088, f1: 0.6532010975431918, precision: 0.6549010044305441, recall: 0.9667955724268835
val mode | time: 16:20:07


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 137/137 [00:17<00:00,  8.01it/s]


loss: 0.9578885628896601, accuracy: 0.960772530165913, f1: 0.6590026144352835, precision: 0.6579377698486781, recall: 0.9737994448319173
train mode | time: 16:20:25


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 776/776 [04:01<00:00,  3.21it/s]


loss: 0.9522833622655561, accuracy: 0.9551397849462357, f1: 0.6549007000959479, precision: 0.6561609944779161, recall: 0.9685714693512405
val mode | time: 16:24:27


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 137/137 [00:17<00:00,  7.66it/s]


loss: 0.9569620056187406, accuracy: 0.9617093231523384, f1: 0.661717549121507, precision: 0.6636368050186325, recall: 0.9737185045113845

New checkpoint

train mode | time: 16:24:44


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 776/776 [03:59<00:00,  3.24it/s]


loss: 0.9511857199668884, accuracy: 0.9559731182795701, f1: 0.6568436745298731, precision: 0.6591056959662202, recall: 0.969123141128493
val mode | time: 16:28:44


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 137/137 [00:17<00:00,  7.83it/s]


loss: 0.9575060364954612, accuracy: 0.961326357466064, f1: 0.6598809191758357, precision: 0.6624832009341467, recall: 0.9709154619655473

New checkpoint

train mode | time: 16:29:02


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 776/776 [04:00<00:00,  3.23it/s]


loss: 0.9502709817117261, accuracy: 0.9570483870967739, f1: 0.6592560457131474, precision: 0.6604158111021642, recall: 0.9720224752196631
val mode | time: 16:33:02


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 137/137 [00:17<00:00,  8.04it/s]


loss: 0.9571015080984902, accuracy: 0.9613852752639519, f1: 0.6599744615642427, precision: 0.6621680536303107, recall: 0.9722163816165948
train mode | time: 16:33:19


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 776/776 [04:00<00:00,  3.22it/s]


loss: 0.9496000436044508, accuracy: 0.9576935483870963, f1: 0.659696073556946, precision: 0.662290262578431, recall: 0.971181587889166
val mode | time: 16:37:20


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 137/137 [00:17<00:00,  7.80it/s]


loss: 0.9565450119621614, accuracy: 0.962015695701358, f1: 0.6618577136960712, precision: 0.661464051949482, recall: 0.9761211229831893
train mode | time: 16:37:37


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 776/776 [03:59<00:00,  3.23it/s]


loss: 0.9492270828062488, accuracy: 0.9580564516129026, f1: 0.6599964724499399, precision: 0.662775345548293, recall: 0.9713139667685643
val mode | time: 16:41:37


  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
  signal = torch.FloatTensor([signal.values])
100%|██████████| 137/137 [00:16<00:00,  8.08it/s]

loss: 0.9561561499448383, accuracy: 0.9623986613876324, f1: 0.6617581217848936, precision: 0.6635118807771425, recall: 0.9740671920629839





In [23]:
test_df = pd.read_csv(config.test_csv_path)
print(test_df.shape)
test_dataset = ECGDataset(test_df)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=96, num_workers=0, shuffle=False)

(21892, 189)


In [69]:
import torch
from sklearn.metrics import accuracy_score
model.load_state_dict(torch.load("best_model_epoch7.pth"))
model.eval()
model.to(device)
predictions = []
true_labels = []
with torch.no_grad():
    for data, labels in test_dataloader:
        data = data.to(device)
        labels = labels.to(device)
        outputs = model(data)
        _, predicted = torch.max(outputs, 1)
        predictions.extend(predicted.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

accuracy = accuracy_score(true_labels, predictions)

print(f"Accuracy on the test set: {accuracy:.4f}")


Loaded pretrained weights for efficientnet-b0
train mode | time: 19:33:42


0it [00:00, ?it/s]


TypeError: ignored

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import pandas as pd
import numpy as np
from tqdm import tqdm
import time
from sklearn.metrics import confusion_matrix
from timm import create_model

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class EfficientNetModel(nn.Module):
    def __init__(self, num_classes=5):
        super(EfficientNetModel, self).__init__()
        efficientnet_model = create_model('efficientnet_b7', pretrained=False)
        self.features = efficientnet_model
        in_features = efficientnet_model.num_features
        self.classifier = nn.Linear(in_features, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.mean([2, 3])  # Global average pooling
        x = self.classifier(x)
        return x

# EfficientNet Trainer class
class EfficientNetTrainer:
    def __init__(self, net, lr, batch_size, num_epochs):
        self.net = net.to(device)
        self.num_epochs = num_epochs
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.AdamW(self.net.parameters(), lr=lr)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=num_epochs, eta_min=5e-6)
        self.best_loss = float('inf')
        self.phases = ['train', 'val']
        self.dataloaders = {
            phase: get_dataloader(phase, batch_size) for phase in self.phases
        }
        self.train_df_logs = pd.DataFrame()
        self.val_df_logs = pd.DataFrame()

    def _train_epoch(self, phase):
        print(f"{phase} mode | time: {time.strftime('%H:%M:%S')}")

        self.net.train() if phase == 'train' else self.net.eval()
        meter = Meter()
        meter.init_metrics()

        for i, (data, target) in enumerate(tqdm(self.dataloaders[phase])):
            data = data.to(device)
            target = target.to(device)

            output = self.net(data)
            loss = self.criterion(output, target)

            if phase == 'train':
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            meter.update(output, target, loss.item())

        metrics = meter.get_metrics()
        metrics = {k: v / i for k, v in metrics.items()}
        df_logs = pd.DataFrame([metrics])
        confusion_matrix = meter.get_confusion_matrix()

        if phase == 'train':
            self.train_df_logs = pd.concat([self.train_df_logs, df_logs], axis=0)
        else:
            self.val_df_logs = pd.concat([self.val_df_logs, df_logs], axis=0)

        # show logs
        print('{}: {}, {}: {}, {}: {}, {}: {}, {}: {}'
              .format(*(x for kv in metrics.items() for x in kv)))

        return loss

    def run(self):
        for epoch in range(self.num_epochs):
            self._train_epoch(phase='train')
            with torch.no_grad():
                val_loss = self._train_epoch(phase='val')
                self.scheduler.step()

            if val_loss < self.best_loss:
                self.best_loss = val_loss
                print('\nNew checkpoint\n')
                torch.save(self.net.state_dict(), f"best_model_epoch{epoch}.pth")


# Usage
efficientnet_model = EfficientNetModel(num_classes=5)
efficientnet_trainer = EfficientNetTrainer(net=efficientnet_model, lr=1e-3, batch_size=96, num_epochs=10)
efficientnet_trainer.run()
