In [1]:
import os
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
import sys

import matplotlib.pyplot as plt
import IPython.display as ipd

from tqdm import tqdm

from common import *
from dataset import ArrhythmiaDataset

from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd

from torch.utils.tensorboard import SummaryWriter


RECORD_DIR_PATH = '../data/mit-bih-arrhythmia-database-1.0.0'
WINDOW_SIZE = 540
MOVING_AVERAGE_RANGE = 17
USE_CLASSES_FROM_MANUAL_LABELS = True
SUBSET_FROM_MANUAL_LABELS = False
INCLUDE_MANUAL_LABELS = False
INCLUDE_RAW_SIGNAL = True

CLASSES = ['N', 'L', 'R', 'a', 'V', 'J', 'F'] if USE_CLASSES_FROM_MANUAL_LABELS else ['N', 'L', 'R', 'A', 'a', 'V', 'j', 'J', 'E', 'f', 'F', '[', '!', ']', '/', 'x', '|', 'Q']

batch_size = 256
n_epoch = 300

RUN_NAME = f'raw_signal_and_moving_average-{MOVING_AVERAGE_RANGE}_full_dataset'
CHECKPOINT_PATH = f'../models/{RUN_NAME} - checkpoint.pt'
ACCURACY_MOVING_AVERAGE_SIZE = 30  # moving average for accuracy to check if performance degraded


# TODO: S, e - need some preprocessing, dimensions seem to be wrong in one of these
# TODO: Q - of course, quite confusing, this is the most confused beat in confusion matrices

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
# Randomness seed
random_seed = 1 # or any of your favorite number
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)

In [4]:
dataset = ArrhythmiaDataset(RECORD_DIR_PATH, WINDOW_SIZE, only_include_labels = CLASSES, moving_average_range = MOVING_AVERAGE_RANGE, include_manual_labels = INCLUDE_MANUAL_LABELS, subset_from_manual_labels = SUBSET_FROM_MANUAL_LABELS, include_raw_signal =
INCLUDE_RAW_SIGNAL)

print(dataset.data.shape)
print(len(dataset.labels))

filename='100.atr' patient_record_number=100
beat_slice_array.shape=(2240, 2, 1080) beat_slices.shape=torch.Size([2240, 2, 1080])
filename='124.atr' patient_record_number=124
beat_slice_array.shape=(1612, 2, 1080) beat_slices.shape=torch.Size([1612, 2, 1080])
self.data.shape=torch.Size([2240, 2, 1080]) beat_slices.shape=torch.Size([1612, 2, 1080])
filename='219.atr' patient_record_number=219
beat_slice_array.shape=(2147, 2, 1080) beat_slices.shape=torch.Size([2147, 2, 1080])
self.data.shape=torch.Size([3852, 2, 1080]) beat_slices.shape=torch.Size([2147, 2, 1080])
filename='112.atr' patient_record_number=112
beat_slice_array.shape=(2537, 2, 1080) beat_slices.shape=torch.Size([2537, 2, 1080])
self.data.shape=torch.Size([5999, 2, 1080]) beat_slices.shape=torch.Size([2537, 2, 1080])
filename='119.atr' patient_record_number=119
beat_slice_array.shape=(1987, 2, 1080) beat_slices.shape=torch.Size([1987, 2, 1080])
self.data.shape=torch.Size([8536, 2, 1080]) beat_slices.shape=torch.Size([1987, 

In [5]:
labels, counts = torch.unique(dataset.labels_encoded, dim = 0, return_counts = True)

for label, count in zip(labels, counts):
    print(f'{dataset.get_label_from_tensor(label)}: {count}')


L: 8075
F: 803
J: 83
V: 7130
R: 7259
a: 150
N: 75052


In [6]:
# Drop some Normal beats to balance classes
normal_beat_mask = np.array(dataset.labels) == 'N'

new_labels = []
for idx, l in enumerate(normal_beat_mask):
    # Leave 10% samples in (currently theres 75k samples, while other popular classes are at about 8k)
    if l and random.uniform(0, 1) < 0.1:
        normal_beat_mask[idx] = False
    if not normal_beat_mask[idx]:
        new_labels.append(dataset.labels[idx])

new_data = dataset.data[normal_beat_mask == False]
dataset.data = new_data
dataset.labels = new_labels
dataset.encode_labels()

def show_class_count(dataset: ArrhythmiaDataset):
    print(dataset.data.shape)
    print(len(dataset.labels))
    labels, counts = torch.unique(dataset.labels_encoded, dim = 0, return_counts = True)

    for label, count in zip(labels, counts):
        print(f'{dataset.get_label_from_tensor(label)}: {count}')

show_class_count(dataset)

torch.Size([31058, 2, 1080])
31058
L: 8075
F: 803
J: 83
V: 7130
R: 7259
a: 150
N: 7558


In [7]:
def collate_fn(batch):

    # A data tuple has the form:
    # waveform, one-hot-encoded_label

    tensors, targets = [], []

    # Gather in lists, and encode labels as indices
    for waveform, label in batch:
        tensors += [waveform]
        targets += [label]

    # Group the list of tensors into a batched tensor
    tensors = torch.stack(tensors)
    tensors = tensors[:, :]
    targets = torch.stack(targets)

    return tensors, targets


if device == "cuda":
    num_workers = 1
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

train_dataset, test_dataset = dataset.train_test_split(0.2)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)

print('TRAIN DATASET:')
show_class_count(train_dataset)

print('TEST DATASET:')
show_class_count(test_dataset)

TRAIN DATASET:
torch.Size([24846, 2, 1080])
0
L: 6460
F: 642
J: 67
V: 5704
R: 5807
a: 120
N: 6046
TEST DATASET:
torch.Size([6212, 2, 1080])
0
L: 1615
F: 161
J: 16
V: 1426
R: 1452
a: 30
N: 1512


In [8]:
class M5(nn.Module):
    def __init__(self, n_input=1, n_output=35, stride=1, n_channel=32):
        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=3, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.pool1 = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.pool2 = nn.MaxPool1d(2)
        self.conv3 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(n_channel)
        self.pool3 = nn.MaxPool1d(3)
        self.conv4 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.pool4 = nn.MaxPool1d(3)
        self.conv5 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn5 = nn.BatchNorm1d(2 * n_channel)
        self.pool5 = nn.MaxPool1d(3)
        self.conv6 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn6 = nn.BatchNorm1d(2 * n_channel)
        self.pool6 = nn.MaxPool1d(3)
        self.fc1 = nn.Linear(2 * n_channel, n_channel)
        self.fc2 = nn.Linear(n_channel, n_output)

    def forward(self, x):
        # print(f'CONV1 INPUT SHAPE: {x.shape}')
        x = self.conv1(x)
        # print(f'CONV1 OUTPUT SHAPE: {x.shape}')
        x = F.relu(self.bn1(x))
        # print(f'POOL1 INPUT SHAPE: {x.shape}')
        x = self.pool1(x)
        # print(f'POOL1 OUTPUT SHAPE: {x.shape}')
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        # print(f'POOL2 INPUT SHAPE: {x.shape}')
        x = self.pool2(x)
        # print(f'POOL2 OUTPUT SHAPE: {x.shape}')
        x = self.conv3(x)
        x = F.relu(self.bn3(x))
        # print(f'POOL3 INPUT SHAPE: {x.shape}')
        x = self.pool3(x)
        # print(f'POOL3 OUTPUT SHAPE: {x.shape}')
        x = self.conv4(x)
        # print(f'BATCHNORM4 INPUT SHAPE: {x.shape}')
        x = F.relu(self.bn4(x))
        # print(f'POOL4 INPUT SHAPE: {x.shape}')
        x = self.pool4(x)
        # print(f'POOL4 OUTPUT SHAPE: {x.shape}')
        x = self.conv5(x)
        # print(f'BATCHNORM5 INPUT SHAPE: {x.shape}')
        x = F.relu(self.bn5(x))
        # print(f'POOL5 INPUT SHAPE: {x.shape}')
        x = self.pool5(x)
        # print(f'POOL5 OUTPUT SHAPE: {x.shape}')
        x = self.conv6(x)
        # print(f'BATCHNORM6 INPUT SHAPE: {x.shape}')
        x = F.relu(self.bn6(x))
        # print(f'POOL6 INPUT SHAPE: {x.shape}')
        x = self.pool6(x)
        # print(f'POOL6 OUTPUT SHAPE: {x.shape}')
        x = F.avg_pool1d(x, x.shape[-1])
        x = x.permute(0, 2, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=2)


model = M5(n_input = 2, n_output = len(set(dataset.labels)))
model.double().to(device)
print(model)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


n = count_parameters(model)
print("Number of parameters: %s" % n)

M5(
  (conv1): Conv1d(2, 32, kernel_size=(3,), stride=(1,))
  (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(32, 32, kernel_size=(3,), stride=(1,))
  (bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv1d(32, 32, kernel_size=(3,), stride=(1,))
  (bn3): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool3): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv1d(32, 64, kernel_size=(3,), stride=(1,))
  (bn4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool4): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (conv5): Conv1d(64, 64, kernel_size=(3,), stride=(1,))
  (bn5): Bat

In [9]:
optimizer = optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0001)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.7)  # reduce the learning after 20 epochs by a factor
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode = 'min', factor = 0.1, patience = 7, verbose = True)  # reduce learning after 7 epochs with no improvement

In [10]:
def train(model, epoch, log_interval, writer: SummaryWriter):
    train_losses = []
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):

        data = data.to(device)
        # print(f'DATA SHAPE: {data.shape}')
        target = target.to(device)

        # apply transform and model on whole batch directly on device
        output = model(data)

        # negative log-likelihood for a tensor of size (batch x 1 x n_output)
        squeezed_output = output.squeeze()
        loss = F.nll_loss(squeezed_output, target.argmax(dim = 1))

        writer.add_scalar('Train loss', loss.item(), epoch * len(train_loader.dataset) + batch_idx)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print training stats
        if batch_idx % log_interval == 0:
            print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")

        # update progress bar
        pbar.update(pbar_update)
        # record loss
        train_losses.append(loss.item())
    return train_losses

In [11]:
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score
)


def number_of_correct(pred, target):
    # count number of correct predictions
    return pred.squeeze().eq(target).sum().item()


def get_likely_index(tensor):
    # find most likely label index for each element in the batch
    return tensor.argmax(dim=-1)


def test(model, epoch, writer: SummaryWriter):
    model.eval()
    correct = 0
    y_true = []
    y_pred = []
    loss_sum = 0
    for data, target in test_loader:

        data = data.to(device)
        target = target.to(device)

        output = model(data)

        squeezed_output = output.squeeze()
        loss_sum += F.nll_loss(squeezed_output, target.argmax(dim = 1)).item()

        pred = get_likely_index(output)
        correct += number_of_correct(pred, target.argmax(dim = 1))

        y_true.extend(pred.squeeze().data.cpu().numpy())
        y_pred.extend(target.data.cpu().numpy().argmax(axis = 1))

        # update progress bar
        pbar.update(pbar_update)
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='micro')
    recall = recall_score(y_true, y_pred, average='micro')
    f1 = f1_score(y_true, y_pred, average='micro')

    writer.add_scalar('Test accuracy', accuracy, epoch)
    writer.add_scalar('Test precision', precision, epoch)
    writer.add_scalar('Test recall', recall, epoch)
    writer.add_scalar('Test f1', f1, epoch)
    writer.add_scalar('Test average loss', loss_sum / len(test_loader.dataset), epoch)

    # Build confusion matrix
    cf_matrix = confusion_matrix(y_true, y_pred)
    df_cm = pd.DataFrame(cf_matrix, index = [i for i in CLASSES],
                         columns = [i for i in CLASSES])
    plt.figure(figsize = (12,7))
    cf_matrix_figure = sn.heatmap(df_cm, annot=True).get_figure()
    writer.add_figure('Test confusion matrix', cf_matrix_figure, epoch)

    print(f"\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.4%})\n")
    return accuracy, precision, recall, f1, loss_sum

In [None]:
writer = SummaryWriter()

log_interval = 20

writer.add_hparams({f'data_shape_{i}': shape for i, shape in enumerate(dataset.data.shape)} | {'data_moving_average_range': MOVING_AVERAGE_RANGE, 'data_window_size': WINDOW_SIZE, 'batch_size': batch_size, 'n_epoch': n_epoch}, {'hparam/fake_accuracy_just_to_have_any_metric': 10}, run_name = RUN_NAME)

pbar_update = 1 / (len(train_loader) + len(test_loader))
losses = []
accuracies = []

with tqdm(total=n_epoch) as pbar:
    for epoch in range(1, n_epoch + 1):
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, CHECKPOINT_PATH)

        train_losses = train(model, epoch, log_interval, writer)
        losses.extend(train_losses)

        accuracy, precision, recall, f1, loss_sum = test(model, epoch, writer)
        accuracies.append(accuracy)
        scheduler.step(loss_sum)

        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        # Early stopping
        if len(accuracies) >= ACCURACY_MOVING_AVERAGE_SIZE + 1:
            is_performance_degraded = np.mean(accuracies[-ACCURACY_MOVING_AVERAGE_SIZE - 1:-1]) > np.mean(accuracies[-ACCURACY_MOVING_AVERAGE_SIZE:])
            if is_performance_degraded:
                # Reload the last non-degraded checkpoint
                checkpoint = torch.load(CHECKPOINT_PATH)
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                break


# Let's plot the training loss versus the number of iteration.
# plt.plot(losses);
# plt.title("training loss");

  0%|          | 0.008130081300813009/300 [00:00<6:40:35, 80.12s/it]



  0%|          | 0.17073170731707313/300 [00:03<1:44:39, 20.94s/it] 



  0%|          | 0.3333333333333331/300 [00:07<2:07:05, 25.45s/it] 



  0%|          | 0.5121951219512191/300 [00:10<1:25:48, 17.19s/it] 



  0%|          | 0.658536585365854/300 [00:13<1:45:15, 21.10s/it] 



  0%|          | 0.9918699186991892/300 [00:16<29:56,  6.01s/it]  


Test Epoch: 1	Accuracy: 5555/6212 (89.4237%)



  0%|          | 1.170731707317074/300 [00:20<1:47:02, 21.49s/it] 



  0%|          | 1.3333333333333328/300 [00:23<1:57:58, 23.70s/it]



  1%|          | 1.5121951219512175/300 [00:26<1:23:41, 16.82s/it]



  1%|          | 1.6585365853658505/300 [00:29<1:42:51, 20.68s/it]



  1%|          | 1.991869918699181/300 [00:32<30:05,  6.06s/it]   


Test Epoch: 2	Accuracy: 5877/6212 (94.6072%)



  1%|          | 2.1869918699186965/300 [00:36<1:24:04, 16.94s/it]



  1%|          | 2.34959349593496/300 [00:39<1:24:09, 16.96s/it]  



  1%|          | 2.4959349593495967/300 [00:41<1:43:12, 20.82s/it]



  1%|          | 2.65853658536586/300 [00:44<1:44:26, 21.08s/it]  



  1%|          | 3.008130081300826/300 [00:48<45:44,  9.24s/it]   


Test Epoch: 3	Accuracy: 5994/6212 (96.4907%)



  1%|          | 3.1869918699187156/300 [00:51<1:24:14, 17.03s/it]



  1%|          | 3.349593495934979/300 [00:54<1:24:34, 17.11s/it] 



  1%|          | 3.495934959349616/300 [00:57<1:48:09, 21.89s/it] 



  1%|          | 3.658536585365879/300 [01:00<1:59:18, 24.16s/it] 



  1%|▏         | 4.000000000000032/300 [01:04<27:04,  5.49s/it]   


Test Epoch: 4	Accuracy: 5996/6212 (96.5229%)



  1%|▏         | 4.170731707317099/300 [01:07<1:40:32, 20.39s/it] 



  1%|▏         | 4.3333333333333535/300 [01:10<1:41:56, 20.69s/it]



  2%|▏         | 4.512195121951233/300 [01:13<1:22:33, 16.77s/it] 



  2%|▏         | 4.674796747967488/300 [01:16<1:22:31, 16.77s/it] 



  2%|▏         | 4.9999999999999964/300 [01:19<26:51,  5.46s/it]  


Test Epoch: 5	Accuracy: 6061/6212 (97.5692%)



  2%|▏         | 5.186991869918689/300 [01:22<1:21:53, 16.67s/it] 



  2%|▏         | 5.349593495934943/300 [01:25<1:22:24, 16.78s/it] 



  2%|▏         | 5.512195121951198/300 [01:28<1:22:01, 16.71s/it] 



  2%|▏         | 5.674796747967452/300 [01:31<1:22:15, 16.77s/it] 



  2%|▏         | 5.999999999999961/300 [01:34<26:50,  5.48s/it]   


Test Epoch: 6	Accuracy: 6072/6212 (97.7463%)



  2%|▏         | 6.1869918699186535/300 [01:37<1:21:56, 16.73s/it]



  2%|▏         | 6.349593495934908/300 [01:40<1:21:36, 16.67s/it] 



  2%|▏         | 6.512195121951162/300 [01:43<1:21:53, 16.74s/it] 



  2%|▏         | 6.674796747967417/300 [01:46<1:21:44, 16.72s/it] 



  2%|▏         | 6.999999999999925/300 [01:49<26:45,  5.48s/it]   


Test Epoch: 7	Accuracy: 6074/6212 (97.7785%)



  2%|▏         | 7.186991869918618/300 [01:52<1:21:14, 16.65s/it] 



  2%|▏         | 7.349593495934872/300 [01:55<1:21:48, 16.77s/it] 



  2%|▏         | 7.495934959349501/300 [01:58<1:40:27, 20.61s/it] 



  3%|▎         | 7.674796747967381/300 [02:01<1:21:47, 16.79s/it] 



  3%|▎         | 7.99999999999989/300 [02:04<26:37,  5.47s/it]    


Test Epoch: 8	Accuracy: 6044/6212 (97.2956%)



  3%|▎         | 8.186991869918582/300 [02:08<1:20:57, 16.64s/it]



  3%|▎         | 8.349593495934837/300 [02:10<1:21:29, 16.76s/it]



  3%|▎         | 8.512195121951091/300 [02:13<1:21:12, 16.72s/it]



  3%|▎         | 8.674796747967346/300 [02:16<1:21:27, 16.78s/it]



  3%|▎         | 8.999999999999854/300 [02:19<26:31,  5.47s/it]  


Test Epoch: 9	Accuracy: 6071/6212 (97.7302%)



  3%|▎         | 9.186991869918547/300 [02:23<1:20:40, 16.64s/it]



  3%|▎         | 9.349593495934801/300 [02:26<1:21:21, 16.80s/it]



  3%|▎         | 9.512195121951056/300 [02:28<1:21:16, 16.79s/it]



  3%|▎         | 9.67479674796731/300 [02:31<1:21:18, 16.80s/it] 



  3%|▎         | 9.999999999999819/300 [02:35<26:26,  5.47s/it]  


Test Epoch: 10	Accuracy: 6093/6212 (98.0844%)



  3%|▎         | 10.170731707316886/300 [02:38<1:38:23, 20.37s/it]



  3%|▎         | 10.349593495934766/300 [02:41<1:20:56, 16.77s/it]



  4%|▎         | 10.51219512195102/300 [02:44<1:20:56, 16.78s/it] 



  4%|▎         | 10.674796747967275/300 [02:46<1:20:38, 16.72s/it]



  4%|▎         | 10.999999999999783/300 [02:50<26:21,  5.47s/it]  


Test Epoch: 11	Accuracy: 6062/6212 (97.5853%)



  4%|▎         | 11.186991869918476/300 [02:53<1:20:10, 16.65s/it]



  4%|▍         | 11.34959349593473/300 [02:56<1:20:34, 16.75s/it] 



  4%|▍         | 11.512195121950985/300 [02:59<1:20:23, 16.72s/it]



  4%|▍         | 11.674796747967239/300 [03:02<1:20:24, 16.73s/it]



  4%|▍         | 11.999999999999748/300 [03:05<26:20,  5.49s/it]  


Test Epoch: 12	Accuracy: 6092/6212 (98.0683%)



  4%|▍         | 12.18699186991844/300 [03:08<1:19:58, 16.67s/it] 



  4%|▍         | 12.349593495934695/300 [03:11<1:20:19, 16.75s/it]



  4%|▍         | 12.512195121950949/300 [03:14<1:20:10, 16.73s/it]



  4%|▍         | 12.674796747967203/300 [03:17<1:20:04, 16.72s/it]



  4%|▍         | 12.999999999999712/300 [03:20<26:15,  5.49s/it]  


Test Epoch: 13	Accuracy: 6123/6212 (98.5673%)



  4%|▍         | 13.186991869918405/300 [03:23<1:19:28, 16.63s/it]



  4%|▍         | 13.34959349593466/300 [03:26<1:19:56, 16.73s/it] 



  5%|▍         | 13.512195121950914/300 [03:29<1:20:05, 16.77s/it]



  5%|▍         | 13.674796747967168/300 [03:32<1:20:07, 16.79s/it]



  5%|▍         | 13.991869918698864/300 [03:35<28:04,  5.89s/it]  


Test Epoch: 14	Accuracy: 6106/6212 (98.2936%)



  5%|▍         | 14.18699186991837/300 [03:38<1:19:20, 16.66s/it] 



  5%|▍         | 14.349593495934624/300 [03:41<1:19:45, 16.75s/it]



  5%|▍         | 14.512195121950878/300 [03:44<1:19:29, 16.71s/it]



  5%|▍         | 14.674796747967132/300 [03:47<1:19:45, 16.77s/it]



  5%|▍         | 14.999999999999641/300 [03:50<26:00,  5.48s/it]  


Test Epoch: 15	Accuracy: 6107/6212 (98.3097%)



  5%|▌         | 15.170731707316708/300 [03:53<1:38:56, 20.84s/it]



  5%|▌         | 15.349593495934588/300 [03:56<1:19:34, 16.77s/it]



  5%|▌         | 15.512195121950842/300 [03:59<1:19:26, 16.75s/it]



  5%|▌         | 15.674796747967097/300 [04:02<1:19:14, 16.72s/it]



  5%|▌         | 15.999999999999606/300 [04:05<26:07,  5.52s/it]  


Test Epoch: 16	Accuracy: 6115/6212 (98.4385%)



  5%|▌         | 16.186991869918298/300 [04:09<1:18:42, 16.64s/it]



  5%|▌         | 16.333333333332927/300 [04:11<1:37:13, 20.57s/it]



  6%|▌         | 16.512195121950807/300 [04:14<1:19:07, 16.75s/it]



  6%|▌         | 16.67479674796706/300 [04:17<1:19:03, 16.74s/it] 



  6%|▌         | 16.99999999999957/300 [04:21<25:51,  5.48s/it]   


Test Epoch: 17	Accuracy: 6122/6212 (98.5512%)



  6%|▌         | 17.186991869918263/300 [04:24<1:21:01, 17.19s/it]



  6%|▌         | 17.349593495934517/300 [04:27<1:18:59, 16.77s/it]



  6%|▌         | 17.51219512195077/300 [04:30<1:18:51, 16.75s/it] 



  6%|▌         | 17.674796747967026/300 [04:32<1:18:42, 16.73s/it]



  6%|▌         | 17.991869918698722/300 [04:36<27:46,  5.91s/it]  


Test Epoch: 18	Accuracy: 6120/6212 (98.5190%)



  6%|▌         | 18.186991869918227/300 [04:39<1:18:20, 16.68s/it]



  6%|▌         | 18.34959349593448/300 [04:42<1:18:45, 16.78s/it] 



  6%|▌         | 18.512195121950736/300 [04:45<1:18:35, 16.75s/it]



  6%|▌         | 18.67479674796699/300 [04:48<1:18:30, 16.74s/it] 



  6%|▋         | 18.9999999999995/300 [04:51<25:44,  5.50s/it]    


Test Epoch: 19	Accuracy: 6115/6212 (98.4385%)



  6%|▋         | 19.18699186991819/300 [04:54<1:19:27, 16.98s/it] 



  6%|▋         | 19.349593495934446/300 [04:57<1:18:31, 16.79s/it]



  7%|▋         | 19.5121951219507/300 [05:00<1:18:14, 16.74s/it]  



  7%|▋         | 19.674796747966955/300 [05:03<1:18:00, 16.70s/it]



  7%|▋         | 19.999999999999464/300 [05:06<25:44,  5.52s/it]  


Test Epoch: 20	Accuracy: 6112/6212 (98.3902%)



  7%|▋         | 20.186991869918156/300 [05:09<1:17:49, 16.69s/it]



  7%|▋         | 20.34959349593441/300 [05:12<1:18:08, 16.76s/it] 



  7%|▋         | 20.512195121950665/300 [05:15<1:18:04, 16.76s/it]



  7%|▋         | 20.67479674796692/300 [05:18<1:17:55, 16.74s/it] 



  7%|▋         | 20.999999999999428/300 [05:21<25:29,  5.48s/it]  


Test Epoch: 21	Accuracy: 6128/6212 (98.6478%)



  7%|▋         | 21.18699186991812/300 [05:24<1:17:58, 16.78s/it] 



  7%|▋         | 21.349593495934375/300 [05:27<1:17:42, 16.73s/it]



  7%|▋         | 21.51219512195063/300 [05:30<1:17:26, 16.68s/it] 



  7%|▋         | 21.674796747966884/300 [05:33<1:17:43, 16.75s/it]



  7%|▋         | 21.999999999999392/300 [05:36<25:20,  5.47s/it]  


Test Epoch: 22	Accuracy: 6129/6212 (98.6639%)



  7%|▋         | 22.186991869918085/300 [05:40<1:17:09, 16.66s/it]



  7%|▋         | 22.34959349593434/300 [05:42<1:17:23, 16.72s/it] 



  7%|▋         | 22.49593495934897/300 [05:45<1:35:21, 20.62s/it] 



  8%|▊         | 22.658536585365223/300 [05:48<1:42:39, 22.21s/it]



  8%|▊         | 23.00813008130017/300 [05:52<41:32,  9.00s/it]   


Test Epoch: 23	Accuracy: 6116/6212 (98.4546%)



  8%|▊         | 23.18699186991805/300 [05:56<1:27:41, 19.01s/it] 



  8%|▊         | 23.33333333333268/300 [05:59<1:38:51, 21.44s/it] 



  8%|▊         | 23.51219512195056/300 [06:02<1:17:12, 16.75s/it] 



  8%|▊         | 23.674796747966813/300 [06:05<1:16:45, 16.67s/it]



  8%|▊         | 24.008130081300134/300 [06:08<44:18,  9.63s/it]  


Test Epoch: 24	Accuracy: 6133/6212 (98.7283%)



  8%|▊         | 24.186991869918014/300 [06:12<1:17:54, 16.95s/it]



  8%|▊         | 24.333333333332643/300 [06:14<1:39:50, 21.73s/it]



  8%|▊         | 24.495934959348897/300 [06:17<1:39:05, 21.58s/it]



  8%|▊         | 24.674796747966777/300 [06:21<1:16:39, 16.71s/it]



  8%|▊         | 24.98373983739766/300 [06:24<28:27,  6.21s/it]   


Test Epoch: 25	Accuracy: 6123/6212 (98.5673%)



  8%|▊         | 25.024390243901724/300 [06:24<49:51, 10.88s/it]



  8%|▊         | 25.170731707316353/300 [06:27<1:39:06, 21.64s/it]



  8%|▊         | 25.333333333332607/300 [06:30<1:37:57, 21.40s/it]



  8%|▊         | 25.495934959348862/300 [06:33<1:36:15, 21.04s/it]



  9%|▊         | 25.658536585365116/300 [06:36<1:34:45, 20.73s/it]



  9%|▊         | 26.008130081300063/300 [06:40<44:10,  9.67s/it]  


Test Epoch: 26	Accuracy: 6131/6212 (98.6961%)



  9%|▊         | 26.170731707316317/300 [06:43<1:34:12, 20.64s/it]



  9%|▉         | 26.349593495934197/300 [06:46<1:17:34, 17.01s/it]



  9%|▉         | 26.495934959348826/300 [06:49<1:33:08, 20.43s/it]



  9%|▉         | 26.674796747966706/300 [06:52<1:18:15, 17.18s/it]



  9%|▉         | 26.999999999999215/300 [06:56<26:13,  5.76s/it]  


Test Epoch: 27	Accuracy: 6124/6212 (98.5834%)



  9%|▉         | 27.186991869917907/300 [06:59<1:19:06, 17.40s/it]



  9%|▉         | 27.333333333332536/300 [07:02<1:38:13, 21.61s/it]



  9%|▉         | 27.49593495934879/300 [07:05<1:37:59, 21.58s/it] 



  9%|▉         | 27.60975609756017/300 [07:07<1:23:26, 18.38s/it] 