In [1]:
import os
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
import sys

import matplotlib.pyplot as plt
import IPython.display as ipd

from tqdm import tqdm

from common import *
from dataset import ArrhythmiaDataset

from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd

from torch.utils.tensorboard import SummaryWriter


RECORD_DIR_PATH = '../data/mit-bih-arrhythmia-database-1.0.0'
WINDOW_SIZE = 540
MOVING_AVERAGE_RANGE = 17
USE_CLASSES_FROM_MANUAL_LABELS = True
SUBSET_FROM_MANUAL_LABELS = False
INCLUDE_MANUAL_LABELS = False
INCLUDE_RAW_SIGNAL = False

CLASSES = ['N', 'L', 'R', 'a', 'V', 'J', 'F'] if USE_CLASSES_FROM_MANUAL_LABELS else ['N', 'L', 'R', 'A', 'a', 'V', 'j', 'J', 'E', 'f', 'F', '[', '!', ']', '/', 'x', '|', 'Q']

batch_size = 256
n_epoch = 300

RUN_NAME = 'moving_average_full_dataset'
CHECKPOINT_PATH = f'../models/{RUN_NAME} - checkpoint.pt'
ACCURACY_MOVING_AVERAGE_SIZE = 30  # moving average for accuracy to check if performance degraded


# TODO: S, e - need some preprocessing, dimensions seem to be wrong in one of these
# TODO: Q - of course, quite confusing, this is the most confused beat in confusion matrices

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
# Randomness seed
random_seed = 1 # or any of your favorite number
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)

In [4]:
dataset = ArrhythmiaDataset(RECORD_DIR_PATH, WINDOW_SIZE, only_include_labels = CLASSES, moving_average_range = MOVING_AVERAGE_RANGE, include_manual_labels = INCLUDE_MANUAL_LABELS, subset_from_manual_labels = SUBSET_FROM_MANUAL_LABELS, include_raw_signal =
INCLUDE_RAW_SIGNAL)

print(dataset.data.shape)
print(len(dataset.labels))

filename='100.atr' patient_record_number=100
beat_slice_array.shape=(2240, 1080) beat_slices.shape=torch.Size([2240, 1080])
filename='124.atr' patient_record_number=124
beat_slice_array.shape=(1612, 1080) beat_slices.shape=torch.Size([1612, 1080])
self.data.shape=torch.Size([2240, 1080]) beat_slices.shape=torch.Size([1612, 1080])
filename='219.atr' patient_record_number=219
beat_slice_array.shape=(2147, 1080) beat_slices.shape=torch.Size([2147, 1080])
self.data.shape=torch.Size([3852, 1080]) beat_slices.shape=torch.Size([2147, 1080])
filename='112.atr' patient_record_number=112
beat_slice_array.shape=(2537, 1080) beat_slices.shape=torch.Size([2537, 1080])
self.data.shape=torch.Size([5999, 1080]) beat_slices.shape=torch.Size([2537, 1080])
filename='119.atr' patient_record_number=119
beat_slice_array.shape=(1987, 1080) beat_slices.shape=torch.Size([1987, 1080])
self.data.shape=torch.Size([8536, 1080]) beat_slices.shape=torch.Size([1987, 1080])
filename='209.atr' patient_record_number=209

In [5]:
labels, counts = torch.unique(dataset.labels_encoded, dim = 0, return_counts = True)

for label, count in zip(labels, counts):
    print(f'{dataset.get_label_from_tensor(label)}: {count}')


L: 8075
R: 7259
N: 75052
J: 83
a: 150
V: 7130
F: 803


In [6]:
# Drop some Normal beats to balance classes
normal_beat_mask = np.array(dataset.labels) == 'N'

new_labels = []
for idx, l in enumerate(normal_beat_mask):
    # Leave 10% samples in (currently theres 75k samples, while other popular classes are at about 8k)
    if l and random.uniform(0, 1) < 0.1:
        normal_beat_mask[idx] = False
    if not normal_beat_mask[idx]:
        new_labels.append(dataset.labels[idx])

new_data = dataset.data[normal_beat_mask == False]
dataset.data = new_data
dataset.labels = new_labels
dataset.encode_labels()

def show_class_count(dataset: ArrhythmiaDataset):
    print(dataset.data.shape)
    print(len(dataset.labels))
    labels, counts = torch.unique(dataset.labels_encoded, dim = 0, return_counts = True)

    for label, count in zip(labels, counts):
        print(f'{dataset.get_label_from_tensor(label)}: {count}')

show_class_count(dataset)

torch.Size([31021, 1080])
31021
L: 8075
R: 7259
N: 7521
J: 83
a: 150
V: 7130
F: 803


In [7]:
def collate_fn(batch):

    # A data tuple has the form:
    # waveform, one-hot-encoded_label

    tensors, targets = [], []

    # Gather in lists, and encode labels as indices
    for waveform, label in batch:
        tensors += [waveform]
        targets += [label]

    # Group the list of tensors into a batched tensor
    tensors = torch.stack(tensors)
    tensors = tensors[:, None, :]
    targets = torch.stack(targets)

    return tensors, targets


if device == "cuda":
    num_workers = 1
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

train_dataset, test_dataset = dataset.train_test_split(0.2)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)

print('TRAIN DATASET:')
show_class_count(train_dataset)

print('TEST DATASET:')
show_class_count(test_dataset)

TRAIN DATASET:
torch.Size([24816, 1080])
0
L: 6460
R: 5807
N: 6017
J: 66
a: 120
V: 5704
F: 642
TEST DATASET:
torch.Size([6205, 1080])
0
L: 1615
R: 1452
N: 1504
J: 17
a: 30
V: 1426
F: 161


In [8]:
class M5(nn.Module):
    def __init__(self, n_input=1, n_output=35, stride=1, n_channel=32):
        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=3, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.pool1 = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.pool2 = nn.MaxPool1d(2)
        self.conv3 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(n_channel)
        self.pool3 = nn.MaxPool1d(3)
        self.conv4 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.pool4 = nn.MaxPool1d(3)
        self.conv5 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn5 = nn.BatchNorm1d(2 * n_channel)
        self.pool5 = nn.MaxPool1d(3)
        self.conv6 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn6 = nn.BatchNorm1d(2 * n_channel)
        self.pool6 = nn.MaxPool1d(3)
        self.fc1 = nn.Linear(2 * n_channel, n_channel)
        self.fc2 = nn.Linear(n_channel, n_output)

    def forward(self, x):
        # print(f'CONV1 INPUT SHAPE: {x.shape}')
        x = self.conv1(x)
        # print(f'CONV1 OUTPUT SHAPE: {x.shape}')
        x = F.relu(self.bn1(x))
        # print(f'POOL1 INPUT SHAPE: {x.shape}')
        x = self.pool1(x)
        # print(f'POOL1 OUTPUT SHAPE: {x.shape}')
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        # print(f'POOL2 INPUT SHAPE: {x.shape}')
        x = self.pool2(x)
        # print(f'POOL2 OUTPUT SHAPE: {x.shape}')
        x = self.conv3(x)
        x = F.relu(self.bn3(x))
        # print(f'POOL3 INPUT SHAPE: {x.shape}')
        x = self.pool3(x)
        # print(f'POOL3 OUTPUT SHAPE: {x.shape}')
        x = self.conv4(x)
        # print(f'BATCHNORM4 INPUT SHAPE: {x.shape}')
        x = F.relu(self.bn4(x))
        # print(f'POOL4 INPUT SHAPE: {x.shape}')
        x = self.pool4(x)
        # print(f'POOL4 OUTPUT SHAPE: {x.shape}')
        x = self.conv5(x)
        # print(f'BATCHNORM5 INPUT SHAPE: {x.shape}')
        x = F.relu(self.bn5(x))
        # print(f'POOL5 INPUT SHAPE: {x.shape}')
        x = self.pool5(x)
        # print(f'POOL5 OUTPUT SHAPE: {x.shape}')
        x = self.conv6(x)
        # print(f'BATCHNORM6 INPUT SHAPE: {x.shape}')
        x = F.relu(self.bn6(x))
        # print(f'POOL6 INPUT SHAPE: {x.shape}')
        x = self.pool6(x)
        # print(f'POOL6 OUTPUT SHAPE: {x.shape}')
        x = F.avg_pool1d(x, x.shape[-1])
        x = x.permute(0, 2, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=2)


model = M5(n_output=len(set(dataset.labels)))
model.double().to(device)
print(model)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


n = count_parameters(model)
print("Number of parameters: %s" % n)

M5(
  (conv1): Conv1d(1, 32, kernel_size=(3,), stride=(1,))
  (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(32, 32, kernel_size=(3,), stride=(1,))
  (bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv1d(32, 32, kernel_size=(3,), stride=(1,))
  (bn3): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool3): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv1d(32, 64, kernel_size=(3,), stride=(1,))
  (bn4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool4): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (conv5): Conv1d(64, 64, kernel_size=(3,), stride=(1,))
  (bn5): Bat

In [9]:
optimizer = optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.7)  # reduce the learning after 20 epochs by a factor

In [10]:
def train(model, epoch, log_interval, writer: SummaryWriter):
    train_losses = []
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):

        data = data.to(device)
        # print(f'DATA SHAPE: {data.shape}')
        target = target.to(device)

        # apply transform and model on whole batch directly on device
        output = model(data)

        # negative log-likelihood for a tensor of size (batch x 1 x n_output)
        squeezed_output = output.squeeze()
        loss = F.nll_loss(squeezed_output, target.argmax(dim = 1))

        writer.add_scalar('Train loss', loss.item(), epoch * len(train_loader.dataset) + batch_idx)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print training stats
        if batch_idx % log_interval == 0:
            print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")

        # update progress bar
        pbar.update(pbar_update)
        # record loss
        train_losses.append(loss.item())
    return train_losses

In [11]:
def number_of_correct(pred, target):
    # count number of correct predictions
    return pred.squeeze().eq(target).sum().item()


def get_likely_index(tensor):
    # find most likely label index for each element in the batch
    return tensor.argmax(dim=-1)


def test(model, epoch, writer: SummaryWriter):
    model.eval()
    correct = 0
    y_true = []
    y_pred = []
    for data, target in test_loader:

        data = data.to(device)
        target = target.to(device)

        output = model(data)

        pred = get_likely_index(output)
        correct += number_of_correct(pred, target.argmax(dim = 1))

        y_true.extend(pred.squeeze().data.cpu().numpy())
        y_pred.extend(target.data.cpu().numpy().argmax(axis = 1))

        # update progress bar
        pbar.update(pbar_update)
    accuracy = 100. * correct / len(test_loader.dataset)
    writer.add_scalar('Test accuracy', accuracy, epoch)

    # Build confusion matrix
    cf_matrix = confusion_matrix(y_true, y_pred)
    df_cm = pd.DataFrame(cf_matrix, index = [i for i in CLASSES],
                         columns = [i for i in CLASSES])
    plt.figure(figsize = (12,7))
    cf_matrix_figure = sn.heatmap(df_cm, annot=True).get_figure()
    writer.add_figure('Test confusion matrix', cf_matrix_figure, epoch)

    print(f"\nTest Epoch: {epoch}\tAccuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.0f}%)\n")
    return accuracy

In [12]:
writer = SummaryWriter()

log_interval = 20

writer.add_hparams({f'data_shape_{i}': shape for i, shape in enumerate(dataset.data.shape)} | {'data_moving_average_range': MOVING_AVERAGE_RANGE, 'data_window_size': WINDOW_SIZE, 'batch_size': batch_size, 'n_epoch': n_epoch}, {'hparam/fake_accuracy_just_to_have_any_metric': 10}, run_name = RUN_NAME)

pbar_update = 1 / (len(train_loader) + len(test_loader))
losses = []
accuracies = []

with tqdm(total=n_epoch) as pbar:
    for epoch in range(1, n_epoch + 1):
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, CHECKPOINT_PATH)

        train_losses = train(model, epoch, log_interval, writer)
        losses.extend(train_losses)

        accuracy = test(model, epoch, writer)
        accuracies.append(accuracy)
        scheduler.step()

        # Early stopping
        if len(accuracies) >= ACCURACY_MOVING_AVERAGE_SIZE + 1:
            is_performance_degraded = np.mean(accuracies[-ACCURACY_MOVING_AVERAGE_SIZE - 1:-1]) > np.mean(accuracies[-ACCURACY_MOVING_AVERAGE_SIZE:])
            if is_performance_degraded:
                # Reload the last non-degraded checkpoint
                checkpoint = torch.load(CHECKPOINT_PATH)
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                break


# Let's plot the training loss versus the number of iteration.
# plt.plot(losses);
# plt.title("training loss");

  0%|          | 0.00819672131147541/300 [00:01<11:14:26, 134.89s/it]



  0%|          | 0.17213114754098363/300 [00:04<1:42:47, 20.57s/it]  



  0%|          | 0.336065573770492/300 [00:07<1:45:19, 21.09s/it]  



  0%|          | 0.5163934426229512/300 [00:10<1:21:04, 16.24s/it] 



  0%|          | 0.6803278688524595/300 [00:13<1:20:49, 16.20s/it]



  0%|          | 1.0081967213114762/300 [00:17<52:12, 10.48s/it]  


Test Epoch: 1	Accuracy: 5310/6205 (86%)



  0%|          | 1.188524590163933/300 [00:20<1:24:38, 17.00s/it] 



  0%|          | 1.3524590163934391/300 [00:23<1:21:04, 16.29s/it]



  1%|          | 1.5163934426229453/300 [00:26<1:20:54, 16.26s/it]



  1%|          | 1.6803278688524514/300 [00:28<1:20:26, 16.18s/it]



  1%|          | 2.0081967213114638/300 [00:32<37:37,  7.58s/it]  


Test Epoch: 2	Accuracy: 5825/6205 (94%)



  1%|          | 2.1885245901639205/300 [00:35<1:20:10, 16.15s/it]



  1%|          | 2.3524590163934267/300 [00:38<1:20:09, 16.16s/it]



  1%|          | 2.516393442622933/300 [00:40<1:20:50, 16.31s/it] 



  1%|          | 2.680327868852439/300 [00:43<1:21:04, 16.36s/it] 



  1%|          | 3.0081967213114513/300 [00:47<37:53,  7.66s/it]  


Test Epoch: 3	Accuracy: 5901/6205 (95%)



  1%|          | 3.188524590163908/300 [00:50<1:20:12, 16.21s/it] 



  1%|          | 3.3524590163934143/300 [00:52<1:20:36, 16.30s/it]



  1%|          | 3.5163934426229204/300 [00:55<1:20:41, 16.33s/it]



  1%|          | 3.663934426229476/300 [00:58<2:00:37, 24.42s/it] 



  1%|▏         | 4.008196721311439/300 [01:02<49:09,  9.97s/it]   


Test Epoch: 4	Accuracy: 5993/6205 (97%)



  1%|▏         | 4.188524590163896/300 [01:05<1:24:28, 17.13s/it] 



  1%|▏         | 4.352459016393402/300 [01:08<1:19:58, 16.23s/it] 



  2%|▏         | 4.516393442622908/300 [01:11<1:19:45, 16.19s/it] 



  2%|▏         | 4.680327868852414/300 [01:14<1:19:49, 16.22s/it] 



  2%|▏         | 5.0081967213114265/300 [01:17<38:30,  7.83s/it]  


Test Epoch: 5	Accuracy: 6031/6205 (97%)



  2%|▏         | 5.188524590163883/300 [01:20<1:23:42, 17.04s/it] 



  2%|▏         | 5.336065573770439/300 [01:23<1:47:03, 21.80s/it] 



  2%|▏         | 5.499999999999945/300 [01:26<1:39:32, 20.28s/it] 



  2%|▏         | 5.663934426229451/300 [01:29<1:42:18, 20.86s/it] 



  2%|▏         | 6.008196721311414/300 [01:33<43:48,  8.94s/it]   


Test Epoch: 6	Accuracy: 6032/6205 (97%)



  2%|▏         | 6.17213114754092/300 [01:36<1:39:58, 20.41s/it]  



  2%|▏         | 6.336065573770426/300 [01:39<1:40:56, 20.62s/it] 



  2%|▏         | 6.4999999999999325/300 [01:42<1:43:26, 21.15s/it]



  2%|▏         | 6.680327868852389/300 [01:45<1:22:15, 16.83s/it] 



  2%|▏         | 6.991803278688451/300 [01:48<28:55,  5.92s/it]   


Test Epoch: 7	Accuracy: 6060/6205 (98%)



  2%|▏         | 7.172131147540908/300 [01:51<1:40:54, 20.67s/it] 



  2%|▏         | 7.3524590163933645/300 [01:54<1:22:10, 16.85s/it]



  3%|▎         | 7.516393442622871/300 [01:57<1:22:15, 16.87s/it] 



  3%|▎         | 7.663934426229426/300 [02:00<1:47:43, 22.11s/it] 



  3%|▎         | 8.00819672131139/300 [02:04<37:04,  7.62s/it]    


Test Epoch: 8	Accuracy: 6045/6205 (97%)



  3%|▎         | 8.188524590163865/300 [02:07<1:23:09, 17.10s/it]



  3%|▎         | 8.336065573770437/300 [02:10<1:46:04, 21.82s/it]



  3%|▎         | 8.516393442622913/300 [02:13<1:22:05, 16.90s/it]



  3%|▎         | 8.680327868852437/300 [02:16<1:24:09, 17.33s/it]



  3%|▎         | 8.991803278688533/300 [02:19<27:36,  5.69s/it]  


Test Epoch: 9	Accuracy: 6070/6205 (98%)



  3%|▎         | 9.188524590163961/300 [02:23<1:19:34, 16.42s/it]



  3%|▎         | 9.336065573770533/300 [02:25<1:41:41, 20.99s/it]



  3%|▎         | 9.51639344262301/300 [02:29<1:27:41, 18.11s/it] 



  3%|▎         | 9.680327868852533/300 [02:31<1:18:43, 16.27s/it]



  3%|▎         | 10.000000000000105/300 [02:35<25:37,  5.30s/it] 


Test Epoch: 10	Accuracy: 6061/6205 (98%)



  3%|▎         | 10.188524590164057/300 [02:38<1:24:12, 17.43s/it]



  3%|▎         | 10.352459016393581/300 [02:41<1:19:16, 16.42s/it]



  4%|▎         | 10.516393442623105/300 [02:44<1:19:48, 16.54s/it]



  4%|▎         | 10.680327868852629/300 [02:47<1:19:33, 16.50s/it]



  4%|▎         | 11.008196721311677/300 [02:50<37:15,  7.74s/it]  


Test Epoch: 11	Accuracy: 6074/6205 (98%)



  4%|▎         | 11.1721311475412/300 [02:53<1:43:25, 21.49s/it]  



  4%|▍         | 11.336065573770725/300 [02:56<1:40:26, 20.88s/it]



  4%|▍         | 11.516393442623201/300 [02:59<1:17:59, 16.22s/it]



  4%|▍         | 11.680327868852725/300 [03:02<1:18:10, 16.27s/it]



  4%|▍         | 12.008196721311773/300 [03:05<36:35,  7.62s/it]  


Test Epoch: 12	Accuracy: 6049/6205 (97%)



  4%|▍         | 12.18852459016425/300 [03:08<1:21:15, 16.94s/it] 



  4%|▍         | 12.352459016393773/300 [03:11<1:20:42, 16.84s/it]



  4%|▍         | 12.516393442623297/300 [03:14<1:18:13, 16.32s/it]



  4%|▍         | 12.680327868852821/300 [03:17<1:17:58, 16.28s/it]



  4%|▍         | 12.991803278688916/300 [03:20<27:55,  5.84s/it]  


Test Epoch: 13	Accuracy: 6085/6205 (98%)



  4%|▍         | 13.172131147541393/300 [03:23<1:49:41, 22.95s/it]



  4%|▍         | 13.336065573770917/300 [03:26<1:42:32, 21.46s/it]



  5%|▍         | 13.50000000000044/300 [03:29<1:42:53, 21.55s/it] 



  5%|▍         | 13.680327868852917/300 [03:32<1:19:29, 16.66s/it]



  5%|▍         | 14.008196721311965/300 [03:36<36:58,  7.76s/it]  


Test Epoch: 14	Accuracy: 6052/6205 (98%)



  5%|▍         | 14.188524590164441/300 [03:39<1:18:20, 16.45s/it]



  5%|▍         | 14.352459016393965/300 [03:42<1:17:32, 16.29s/it]



  5%|▍         | 14.516393442623489/300 [03:44<1:18:08, 16.42s/it]



  5%|▍         | 14.680327868853013/300 [03:47<1:18:11, 16.44s/it]



  5%|▍         | 14.975409836066156/300 [03:50<26:55,  5.67s/it]  


Test Epoch: 15	Accuracy: 6072/6205 (98%)



  5%|▌         | 15.024590163935013/300 [03:51<45:11,  9.51s/it]



  5%|▌         | 15.188524590164537/300 [03:54<1:17:06, 16.24s/it]



  5%|▌         | 15.35245901639406/300 [03:57<1:18:13, 16.49s/it] 



  5%|▌         | 15.516393442623585/300 [03:59<1:18:03, 16.46s/it]



  5%|▌         | 15.680327868853109/300 [04:02<1:17:13, 16.30s/it]



  5%|▌         | 16.008196721312153/300 [04:06<36:20,  7.68s/it]  


Test Epoch: 16	Accuracy: 6102/6205 (98%)



  5%|▌         | 16.18852459016459/300 [04:09<1:16:32, 16.18s/it] 



  5%|▌         | 16.35245901639408/300 [04:11<1:17:09, 16.32s/it] 



  6%|▌         | 16.516393442623567/300 [04:14<1:17:39, 16.44s/it]



  6%|▌         | 16.680327868853055/300 [04:17<1:16:46, 16.26s/it]



  6%|▌         | 17.008196721312032/300 [04:20<36:13,  7.68s/it]  


Test Epoch: 17	Accuracy: 6102/6205 (98%)



  6%|▌         | 17.18852459016447/300 [04:23<1:16:30, 16.23s/it] 



  6%|▌         | 17.352459016393958/300 [04:26<1:17:08, 16.38s/it]



  6%|▌         | 17.516393442623446/300 [04:29<1:16:31, 16.25s/it]



  6%|▌         | 17.663934426229986/300 [04:32<1:36:37, 20.53s/it]



  6%|▌         | 18.00819672131191/300 [04:35<36:06,  7.68s/it]   


Test Epoch: 18	Accuracy: 6106/6205 (98%)



  6%|▌         | 18.18852459016435/300 [04:38<1:16:09, 16.21s/it] 



  6%|▌         | 18.352459016393837/300 [04:41<1:16:18, 16.26s/it]



  6%|▌         | 18.516393442623325/300 [04:44<1:16:15, 16.25s/it]



  6%|▌         | 18.680327868852814/300 [04:46<1:16:02, 16.22s/it]



  6%|▋         | 19.00819672131179/300 [04:50<35:36,  7.60s/it]   


Test Epoch: 19	Accuracy: 6104/6205 (98%)



  6%|▋         | 19.188524590164228/300 [04:53<1:15:35, 16.15s/it]



  6%|▋         | 19.352459016393716/300 [04:56<1:15:30, 16.14s/it]



  7%|▋         | 19.516393442623205/300 [04:58<1:16:01, 16.26s/it]



  7%|▋         | 19.680327868852693/300 [05:01<1:15:46, 16.22s/it]



  7%|▋         | 20.00819672131167/300 [05:04<36:02,  7.72s/it]   


Test Epoch: 20	Accuracy: 6118/6205 (99%)



  7%|▋         | 20.188524590164107/300 [05:07<1:15:30, 16.19s/it]



  7%|▋         | 20.352459016393595/300 [05:10<1:15:32, 16.21s/it]



  7%|▋         | 20.516393442623084/300 [05:13<1:15:39, 16.24s/it]



  7%|▋         | 20.680327868852572/300 [05:16<1:15:34, 16.23s/it]



  7%|▋         | 21.00819672131155/300 [05:19<36:01,  7.75s/it]   


Test Epoch: 21	Accuracy: 6101/6205 (98%)



  7%|▋         | 21.188524590163986/300 [05:22<1:17:11, 16.61s/it]



  7%|▋         | 21.352459016393475/300 [05:25<1:17:30, 16.69s/it]



  7%|▋         | 21.516393442622963/300 [05:28<1:15:49, 16.34s/it]



  7%|▋         | 21.68032786885245/300 [05:31<1:15:13, 16.22s/it] 



  7%|▋         | 21.99180327868848/300 [05:34<27:35,  5.95s/it]   


Test Epoch: 22	Accuracy: 6121/6205 (99%)



  7%|▋         | 22.188524590163865/300 [05:37<1:15:04, 16.21s/it]



  7%|▋         | 22.352459016393354/300 [05:40<1:15:15, 16.26s/it]



  8%|▊         | 22.516393442622842/300 [05:43<1:15:07, 16.25s/it]



  8%|▊         | 22.68032786885233/300 [05:46<1:14:52, 16.20s/it] 



  8%|▊         | 22.999999999999833/300 [05:49<24:40,  5.35s/it]  


Test Epoch: 23	Accuracy: 6112/6205 (99%)



  8%|▊         | 23.188524590163745/300 [05:52<1:18:11, 16.95s/it]



  8%|▊         | 23.336065573770284/300 [05:55<1:40:30, 21.80s/it]



  8%|▊         | 23.51639344262272/300 [05:58<1:25:00, 18.45s/it] 



  8%|▊         | 23.66393442622926/300 [06:01<1:33:44, 20.35s/it] 



  8%|▊         | 23.999999999999712/300 [06:05<24:37,  5.35s/it]  


Test Epoch: 24	Accuracy: 6113/6205 (99%)



  8%|▊         | 24.172131147540675/300 [06:08<1:48:46, 23.66s/it]



  8%|▊         | 24.352459016393112/300 [06:12<1:23:55, 18.27s/it]



  8%|▊         | 24.5163934426226/300 [06:15<1:17:26, 16.87s/it]  



  8%|▊         | 24.68032786885209/300 [06:18<1:17:10, 16.82s/it] 



  8%|▊         | 24.975409836065168/300 [06:21<26:03,  5.69s/it]  


Test Epoch: 25	Accuracy: 6118/6205 (99%)



  8%|▊         | 25.008196721311066/300 [06:22<41:56,  9.15s/it]



  8%|▊         | 25.172131147540554/300 [06:25<1:54:03, 24.90s/it]



  8%|▊         | 25.336065573770043/300 [06:28<1:48:41, 23.74s/it]



  9%|▊         | 25.51639344262248/300 [06:31<1:14:46, 16.34s/it] 



  9%|▊         | 25.68032786885197/300 [06:34<1:15:34, 16.53s/it] 



  9%|▊         | 25.99999999999947/300 [06:37<26:17,  5.76s/it]   


Test Epoch: 26	Accuracy: 6091/6205 (98%)



  9%|▊         | 26.172131147540433/300 [06:41<1:42:58, 22.56s/it]



  9%|▉         | 26.336065573769922/300 [06:44<1:53:59, 24.99s/it]



  9%|▉         | 26.51639344262236/300 [06:48<1:16:40, 16.82s/it] 



  9%|▉         | 26.6639344262289/300 [06:50<1:36:52, 21.27s/it]  



  9%|▉         | 26.991803278687875/300 [06:54<27:19,  6.00s/it]  


Test Epoch: 27	Accuracy: 6106/6205 (98%)



  9%|▉         | 27.18852459016326/300 [06:57<1:21:48, 17.99s/it] 



  9%|▉         | 27.35245901639275/300 [07:01<1:18:03, 17.18s/it] 



  9%|▉         | 27.51639344262224/300 [07:03<1:14:43, 16.45s/it] 



  9%|▉         | 27.680327868851727/300 [07:06<1:14:29, 16.41s/it]



  9%|▉         | 28.008196721310703/300 [07:10<34:32,  7.62s/it]  


Test Epoch: 28	Accuracy: 6119/6205 (99%)



  9%|▉         | 28.18852459016314/300 [07:13<1:13:34, 16.24s/it] 



  9%|▉         | 28.33606557376968/300 [07:15<1:42:02, 22.54s/it] 



  9%|▉         | 28.49999999999917/300 [07:19<1:47:58, 23.86s/it] 



 10%|▉         | 28.680327868851606/300 [07:22<1:24:43, 18.74s/it]



 10%|▉         | 29.008196721310583/300 [07:26<35:00,  7.75s/it]  


Test Epoch: 29	Accuracy: 6084/6205 (98%)



 10%|▉         | 29.18852459016302/300 [07:29<1:13:54, 16.38s/it] 



 10%|▉         | 29.35245901639251/300 [07:32<1:13:39, 16.33s/it] 



 10%|▉         | 29.516393442621997/300 [07:34<1:13:54, 16.39s/it]



 10%|▉         | 29.680327868851485/300 [07:37<1:13:55, 16.41s/it]



 10%|█         | 30.008196721310462/300 [07:41<34:37,  7.70s/it]  


Test Epoch: 30	Accuracy: 6119/6205 (99%)



 10%|█         | 30.1885245901629/300 [07:44<1:13:19, 16.31s/it]  



 10%|█         | 30.352459016392388/300 [07:46<1:13:31, 16.36s/it]



 10%|█         | 30.516393442621876/300 [07:49<1:13:05, 16.27s/it]



 10%|█         | 30.680327868851364/300 [07:52<1:13:16, 16.32s/it]



 10%|█         | 31.00819672131034/300 [07:55<34:30,  7.70s/it]   


Test Epoch: 31	Accuracy: 6121/6205 (99%)



 10%|█         | 31.18852459016278/300 [07:58<1:12:34, 16.20s/it] 



 10%|█         | 31.352459016392267/300 [08:01<1:13:11, 16.35s/it]



 11%|█         | 31.516393442621755/300 [08:04<1:12:51, 16.28s/it]



 11%|█         | 31.680327868851244/300 [08:07<1:12:57, 16.32s/it]



 11%|█         | 32.00819672131022/300 [08:10<34:10,  7.65s/it]   


Test Epoch: 32	Accuracy: 6124/6205 (99%)



 11%|█         | 32.188524590162736/300 [08:13<1:12:42, 16.29s/it]



 11%|█         | 32.352459016392295/300 [08:16<1:13:03, 16.38s/it]



 11%|█         | 32.516393442621855/300 [08:19<1:12:35, 16.28s/it]



 11%|█         | 32.680327868851414/300 [08:22<1:12:50, 16.35s/it]



 11%|█         | 33.00819672131053/300 [08:25<34:27,  7.74s/it]   


Test Epoch: 33	Accuracy: 6119/6205 (99%)



 11%|█         | 33.18852459016305/300 [08:28<1:12:13, 16.24s/it] 



 11%|█         | 33.35245901639261/300 [08:31<1:12:30, 16.32s/it] 



 11%|█         | 33.51639344262217/300 [08:34<1:12:35, 16.34s/it] 



 11%|█         | 33.68032786885173/300 [08:36<1:12:56, 16.43s/it] 



 11%|█▏        | 34.008196721310846/300 [08:40<33:47,  7.62s/it]  


Test Epoch: 34	Accuracy: 6120/6205 (99%)



 11%|█▏        | 34.18852459016336/300 [08:43<1:11:51, 16.22s/it] 



 11%|█▏        | 34.35245901639292/300 [08:46<1:12:03, 16.27s/it] 



 12%|█▏        | 34.51639344262248/300 [08:48<1:12:11, 16.32s/it] 



 12%|█▏        | 34.68032786885204/300 [08:51<1:12:08, 16.31s/it] 



 12%|█▏        | 35.00819672131116/300 [08:55<33:52,  7.67s/it]   


Test Epoch: 35	Accuracy: 6125/6205 (99%)



 12%|█▏        | 35.188524590163674/300 [08:58<1:11:33, 16.21s/it]



 12%|█▏        | 35.35245901639323/300 [09:00<1:12:32, 16.45s/it] 



 12%|█▏        | 35.51639344262279/300 [09:03<1:12:08, 16.37s/it] 



 12%|█▏        | 35.68032786885235/300 [09:06<1:12:27, 16.45s/it] 



 12%|█▏        | 36.00819672131147/300 [09:09<35:08,  7.99s/it]   


Test Epoch: 36	Accuracy: 6124/6205 (99%)



 12%|█▏        | 36.188524590163986/300 [09:12<1:11:56, 16.36s/it]



 12%|█▏        | 36.352459016393546/300 [09:15<1:12:26, 16.49s/it]



 12%|█▏        | 36.516393442623105/300 [09:18<1:12:18, 16.47s/it]



 12%|█▏        | 36.680327868852665/300 [09:21<1:12:01, 16.41s/it]



 12%|█▏        | 37.000000000000306/300 [09:24<22:42,  5.18s/it]  


Test Epoch: 37	Accuracy: 6112/6205 (99%)



 12%|█▏        | 37.1885245901643/300 [09:27<1:11:07, 16.24s/it]  



 12%|█▏        | 37.35245901639386/300 [09:30<1:12:14, 16.50s/it] 



 13%|█▎        | 37.51639344262342/300 [09:33<1:11:38, 16.38s/it] 



 13%|█▎        | 37.68032786885298/300 [09:36<1:12:09, 16.50s/it] 



 13%|█▎        | 38.008196721312096/300 [09:39<42:34,  9.75s/it]  


Test Epoch: 38	Accuracy: 6124/6205 (99%)



 13%|█▎        | 38.18852459016461/300 [09:42<1:11:49, 16.46s/it] 



 13%|█▎        | 38.35245901639417/300 [09:45<1:11:39, 16.43s/it] 



 13%|█▎        | 38.51639344262373/300 [09:48<1:11:54, 16.50s/it] 



 13%|█▎        | 38.68032786885329/300 [09:51<1:11:31, 16.42s/it] 



 13%|█▎        | 39.00000000000093/300 [09:54<22:34,  5.19s/it]   


Test Epoch: 39	Accuracy: 6124/6205 (99%)



 13%|█▎        | 39.188524590164924/300 [09:57<1:11:08, 16.37s/it]



 13%|█▎        | 39.352459016394484/300 [10:00<1:11:22, 16.43s/it]



 13%|█▎        | 39.51639344262404/300 [10:03<1:11:17, 16.42s/it] 



 13%|█▎        | 39.6803278688536/300 [10:06<1:11:22, 16.45s/it]  



 13%|█▎        | 40.00819672131272/300 [10:09<35:46,  8.26s/it]   


Test Epoch: 40	Accuracy: 6122/6205 (99%)



 13%|█▎        | 40.17213114754228/300 [10:13<1:36:14, 22.23s/it] 



 13%|█▎        | 40.352459016394796/300 [10:16<1:11:52, 16.61s/it]



 14%|█▎        | 40.516393442624356/300 [10:18<1:11:15, 16.48s/it]



 14%|█▎        | 40.680327868853915/300 [10:21<1:11:37, 16.57s/it]



 14%|█▎        | 41.000000000001556/300 [10:24<22:25,  5.20s/it]  


Test Epoch: 41	Accuracy: 6121/6205 (99%)



 14%|█▎        | 41.18852459016555/300 [10:28<1:10:57, 16.45s/it] 



 14%|█▍        | 41.35245901639511/300 [10:31<1:11:05, 16.49s/it] 



 14%|█▍        | 41.51639344262467/300 [10:33<1:10:58, 16.47s/it] 



 14%|█▍        | 41.68032786885423/300 [10:36<1:11:17, 16.56s/it] 



 14%|█▍        | 42.00819672131335/300 [10:40<42:55,  9.98s/it]   


Test Epoch: 42	Accuracy: 6120/6205 (99%)



 14%|█▍        | 42.172131147542906/300 [10:44<1:50:20, 25.68s/it]



 14%|█▍        | 42.35245901639542/300 [10:47<1:19:12, 18.45s/it] 



 14%|█▍        | 42.51639344262498/300 [10:50<1:10:34, 16.45s/it] 



 14%|█▍        | 42.68032786885454/300 [10:53<1:15:00, 17.49s/it] 



 14%|█▍        | 43.00819672131366/300 [10:57<32:51,  7.67s/it]   


Test Epoch: 43	Accuracy: 6122/6205 (99%)



 14%|█▍        | 43.17213114754322/300 [11:00<1:40:31, 23.49s/it] 



 14%|█▍        | 43.352459016395734/300 [11:03<1:10:42, 16.53s/it]



 15%|█▍        | 43.516393442625294/300 [11:06<1:11:27, 16.72s/it]



 15%|█▍        | 43.6639344262319/300 [11:09<1:27:18, 20.44s/it]  



 15%|█▍        | 43.991803278691016/300 [11:13<24:38,  5.77s/it]  


Test Epoch: 44	Accuracy: 6120/6205 (99%)



 15%|█▍        | 44.18852459016649/300 [11:16<1:11:15, 16.71s/it] 



 15%|█▍        | 44.35245901639605/300 [11:19<1:12:06, 16.92s/it] 



 15%|█▍        | 44.516393442625606/300 [11:22<1:14:23, 17.47s/it]



 15%|█▍        | 44.680327868855166/300 [11:25<1:13:34, 17.29s/it]



 15%|█▌        | 45.00000000000281/300 [11:28<22:32,  5.30s/it]   


Test Epoch: 45	Accuracy: 6125/6205 (99%)



 15%|█▌        | 45.1885245901668/300 [11:32<1:11:48, 16.91s/it]  



 15%|█▌        | 45.35245901639636/300 [11:35<1:11:33, 16.86s/it] 



 15%|█▌        | 45.51639344262592/300 [11:38<1:11:07, 16.77s/it] 



 15%|█▌        | 45.68032786885548/300 [11:41<1:09:24, 16.37s/it] 



 15%|█▌        | 46.00000000000312/300 [11:44<21:58,  5.19s/it]   


Test Epoch: 46	Accuracy: 6127/6205 (99%)



 15%|█▌        | 46.17213114754416/300 [11:47<1:28:37, 20.95s/it] 



 15%|█▌        | 46.336065573773716/300 [11:50<1:46:04, 25.09s/it]



 16%|█▌        | 46.500000000003276/300 [11:53<1:49:13, 25.85s/it]



 16%|█▌        | 46.663934426232835/300 [11:57<1:37:18, 23.05s/it]



 16%|█▌        | 47.00819672131491/300 [12:01<32:42,  7.76s/it]   


Test Epoch: 47	Accuracy: 6121/6205 (99%)



 16%|█▌        | 47.188524590167425/300 [12:04<1:11:46, 17.03s/it]



 16%|█▌        | 47.352459016396985/300 [12:06<1:09:49, 16.58s/it]



 16%|█▌        | 47.516393442626544/300 [12:09<1:09:39, 16.55s/it]



 16%|█▌        | 47.680327868856104/300 [12:12<1:09:27, 16.52s/it]



 16%|█▌        | 48.00819672131522/300 [12:16<32:34,  7.75s/it]   


Test Epoch: 48	Accuracy: 6125/6205 (99%)



 16%|█▌        | 48.18852459016774/300 [12:19<1:08:41, 16.37s/it] 



 16%|█▌        | 48.3524590163973/300 [12:21<1:10:45, 16.87s/it]  



 16%|█▌        | 48.51639344262686/300 [12:24<1:08:53, 16.43s/it] 



 16%|█▌        | 48.680327868856416/300 [12:27<1:08:52, 16.44s/it]



 16%|█▋        | 49.008196721315535/300 [12:31<32:14,  7.71s/it]  


Test Epoch: 49	Accuracy: 6123/6205 (99%)



 16%|█▋        | 49.18852459016805/300 [12:34<1:08:39, 16.42s/it] 



 16%|█▋        | 49.35245901639761/300 [12:36<1:08:51, 16.48s/it] 



 17%|█▋        | 49.51639344262717/300 [12:39<1:08:33, 16.42s/it] 



 17%|█▋        | 49.68032786885673/300 [12:42<1:08:32, 16.43s/it] 



 17%|█▋        | 50.00000000000437/300 [12:45<1:03:49, 15.32s/it] 


Test Epoch: 50	Accuracy: 6117/6205 (99%)




