In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import random
from glob import glob
from tqdm import tqdm
from scipy.io import loadmat

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import os, sys
from typing import *
import torch
import random
import copy

In [2]:
def detach(
    batch_dict: Dict[int, List[torch.Tensor]] = None, 
    k_shot:int = None, 
    k_query:int = None
    ) -> Tuple[Dict[int, List[torch.Tensor]]]:
    sample_len = len(batch_dict[list(batch_dict.keys())[0]])
    
    if k_shot + k_query > sample_len:
        raise ValueError(f"Many data to unpack. Since #sample in support set: k_shot and #sample \
            in query set k_query must satisfy the condition: k_shot + k_query == #sample \
                in a batch per task.")
    elif k_shot + k_query < sample_len:
        raise UserWarning(f"the #sample in support set: k_shot and #sample in query set: k_query \
            totally are less than the #sample available in batch task dict. The redundant samples are \
                used in automatically used in query set.")
    
    support_dct = {
        _cls : batch_dict[_cls][:k_shot] for _cls in batch_dict
    }
    
    query_dct = {
        _cls : batch_dict[_cls][k_shot:] for _cls in batch_dict
    }
    
    return (support_dct, query_dct)

def maml_detach(
    batch_dict: Dict[int, List[torch.Tensor]] = None, 
    k_shot:int = None, 
    k_query:int = None,
    task:int = None
    ) -> Tuple[torch.Tensor]:
    
    support_dct, query_dct = detach(
        batch_dict=batch_dict,
        k_shot=k_shot,
        k_query=k_query
    )
    
    if not isinstance(task, int):
        raise ValueError(f"task arg must be integer type but found {type(task)} instead")
    elif task not in batch_dict.keys():
        raise Exception(f"Found no task {task} in batch dict")
    
    tasks = list(batch_dict.keys())
    
    support_x, support_y, query_x, query_y = [], [], [], []
    for _task in tasks:
        support_x.extend(support_dct[_task])
        query_x.extend(query_dct[_task])
        if _task == task:
            support_y.extend([1]*k_shot)
            query_y.extend([1]*k_query)
        else:
            support_y.extend([0]*k_shot)
            query_y.extend([0]*k_query)
    
    support_x = torch.stack(support_x)
    support_y = torch.FloatTensor(support_y)
    query_x = torch.stack(query_x)
    query_y = torch.FloatTensor(query_y)
    
    return (support_x, support_y, query_x, query_y)

def single_task_detach(
    batch_dict: Dict[int, List[torch.Tensor]] = None, 
    k_shot:int = None, 
    k_query:int = None,
    task:int = None
    ):
    
    support_dct, query_dct = detach(
        batch_dict=batch_dict,
        k_shot=k_shot,
        k_query=k_query
    )
    
    if not isinstance(task, int):
        raise ValueError(f"task arg must be integer type but found {type(task)} instead")
    elif task not in batch_dict.keys():
        raise Exception(f"Found no task {task} in batch dict")
    
    support_x, support_y, query_x, query_y = [], [], [], []
    
    support_x.extend(support_dct[task])
    support_y.extend([task]*len(support_dct[task]))
    query_x.extend(query_dct[task])
    query_y.extend([task]*len(query_dct[task]))
    
    support_x = torch.stack(support_x)
    support_y = torch.LongTensor(support_y)
    query_x = torch.stack(query_x)
    query_y = torch.LongTensor(query_y)
    
    return (support_x, support_y, query_x, query_y)


In [3]:
print(os.getcwd())

for i in range (3):
    os.chdir("..")
    
main_data_dir = os.getcwd() + "/Data set"

/home/thaobeo/git/HeartResearch/Experiment/Approach/Model_signal


In [4]:
data_dir = "/media/mountHDD2/khoibaocon"
print(os.listdir(data_dir))

['TrainingSet3', 'Label.csv', 'alldata', 'TrainingSet1', 'single_label.csv', 'TrainingSet2']


In [5]:
main_df = pd.read_csv(data_dir + "/Label.csv")
main_df.shape

(6877, 4)

In [6]:
main_df.head()

Unnamed: 0,Recording,First_label,Second_label,Third_label
0,A0001,5,,
1,A0002,1,,
2,A0003,2,,
3,A0004,2,,
4,A0005,7,,


In [7]:
main_df["First_label"].value_counts()

First_label
5    1695
2    1098
1     918
8     826
3     704
7     653
6     574
4     207
9     202
Name: count, dtype: int64

In [8]:
main_df["Second_label"].value_counts()

Second_label
5.0    162
2.0    123
7.0     47
6.0     42
8.0     41
4.0     28
3.0     18
9.0     16
Name: count, dtype: int64

In [9]:
main_df["Third_label"].value_counts()

Third_label
9.0    2
8.0    2
6.0    1
4.0    1
Name: count, dtype: int64

In [10]:
single_main_df = main_df[main_df["Second_label"].isnull()]
single_main_df.shape

(6400, 4)

In [11]:
single_main_df.to_csv(main_data_dir + "/single_label.csv")

In [12]:
mat_files = glob(data_dir + "/alldata/*")
print(len(mat_files))

6877


In [13]:
single_fns = single_main_df["Recording"].values.tolist()
print(len(single_fns))

6400


In [14]:
single_mat_paths = [data_dir + f"/alldata/{x}.mat" for x in single_fns]
print(os.path.exists(single_mat_paths[0]))

True


In [15]:
sample_data = loadmat(single_mat_paths[0])
sample_data.keys()

dict_keys(['__header__', '__version__', '__globals__', 'ECG'])

In [16]:
sample_signal_data = sample_data['ECG'][0][0][2]
sample_signal_data.shape

(12, 7500)

In [17]:
print(len(single_mat_paths))

6400


In [18]:
# plt.plot(sample_signal_data[0])

In [19]:
sample_sig = torch.randn(1, 12, 32)
conv_test = nn.Conv1d(12, 12, 3, 1, 1)
print(conv_test(sample_sig).shape)

torch.Size([1, 12, 32])


In [20]:
len_lst = [loadmat(x)['ECG'][0][0][2].shape[1] for x in single_mat_paths]
print(f"MAX: {max(len_lst)}")
print(f"MIN: {min(len_lst)}")
print(f"AVG: {sum(len_lst)/len(len_lst)}")

MAX: 72000
MIN: 3000
AVG: 7946.03703125


In [21]:
class BasicBlock(nn.Module):
    def __init__(self, channel_num):
        super(BasicBlock, self).__init__()
        self.conv_block1 = nn.Sequential(
            nn.Conv1d(channel_num, channel_num, 3, padding=1),
            nn.BatchNorm1d(channel_num),
            nn.ReLU())
        
        self.conv_block2 = nn.Sequential(
            nn.Conv1d(channel_num, channel_num, 3, padding=1),
            nn.BatchNorm1d(channel_num),
        )
        self.relu = nn.ReLU()
        torch.nn.init.kaiming_normal_(self.conv_block1[0].weight)
        torch.nn.init.kaiming_normal_(self.conv_block2[0].weight)
        
    def forward(self, x):
        residual = x
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = x + residual
        out = self.relu(x)
        return out

In [22]:
test_basic_block = BasicBlock(2)
sample_sig = torch.randn(1, 2, 32)
print(test_basic_block(sample_sig).shape)

torch.Size([1, 2, 32])


In [23]:
class ResNet(nn.Module):
    def __init__(self, in_channels = 12, type = 18, num_classes = 9):
        super(ResNet, self).__init__()
        self.struc_dict = {
            18: {
                "num_channels" : [64, 128, 256, 512],
                "counts" : [2, 2, 2, 2]
            }
        }
        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=64, kernel_size=7, stride=2)
        torch.nn.init.kaiming_normal_(self.conv1.weight)
        self.max1 = nn.MaxPool1d(kernel_size=3, stride=2)
        self.main = nn.Sequential()
        for idx, struc in enumerate(
            zip(
                self.struc_dict[type]["num_channels"], 
                self.struc_dict[type]["counts"]
            )
        ):
            num_channel, cnt = struc
            for i in range(cnt):
                self.main.add_module(f"conv{idx+1}_{i}", BasicBlock(num_channel))
            if idx < len(self.struc_dict[type]["num_channels"]) - 1:
                self.main.add_module(f"ext_{idx}", nn.Conv1d(num_channel, self.struc_dict[type]["num_channels"][idx+1], 3, 1))
                self.main.add_module(f"extbn_{idx}", nn.BatchNorm1d(self.struc_dict[type]["num_channels"][idx+1]))
                                     
        self.avg = torch.nn.AdaptiveAvgPool1d((1))
        self.lin = nn.Linear(self.struc_dict[type]["num_channels"][-1], num_classes)
        torch.nn.init.kaiming_normal_(self.lin.weight)
    def forward(self, x):
        x = self.conv1(x)
        x = self.max1(x)
        x = self.main(x)
        x = self.avg(x)
        x = x.reshape(x.shape[0], -1)
        x = self.lin(x)
        return x

In [24]:
model = ResNet()
sample_sig = torch.randn(1, 12, 3000)
model(sample_sig).shape

torch.Size([1, 9])

In [25]:
# class ECG(Dataset):
#     def __init__(self, data_paths, label_df):
#         self.data_paths = data_paths
#         random.shuffle(self.data_paths)
#         self.label_df = label_df

#     def __getitem__(self, idx):
#         data_path = self.data_paths[idx]        
#         data = loadmat(data_path)['ECG'][0][0][2]
#         clip_data = data[:, 300:3000]

#         filename = data_path.split("/")[-1].split(".")[0]
#         label = self.label_df[self.label_df["Recording"] == filename]["First_label"].values.item()

#         torch_data = torch.from_numpy(clip_data)

#         return torch_data.float(), label-1

#     def __len__(self):
#         return len(self.data_paths)

In [26]:
ratio = [0.8, 0.1, 0.1]

train_index = int(len(single_mat_paths)*ratio[0])
valid_index = int(len(single_mat_paths)*(ratio[0]+ratio[1]))

train_image_paths = single_mat_paths[:train_index]
valid_image_paths = single_mat_paths[train_index:valid_index]
test_image_paths = single_mat_paths[valid_index:]

In [27]:
ks = 1
kq = 5

def set_dataset(mat_path):
    sig_dict = {i : [] for i in range(9)}    

    for data_path in mat_path:
        filename = data_path.split("/")[-1].split(".")[0]
        _cls = single_main_df[single_main_df["Recording"] == filename]["First_label"].values.item()
        data = loadmat(data_path)['ECG'][0][0][2]
        clip_data = data[:, 300:3000]
        torch_data = torch.from_numpy(clip_data)

        sig_dict[_cls-1].append(clip_data)

    max_sample = max([len(sig_dict[i]) for i in range(9)])
    sample_cls_cnt = max_sample + ((ks + kq) - max_sample % (ks + kq))
    
    for i in range(9):

        a = sample_cls_cnt/len(sig_dict[i])

        if a < 1:
            sig_dict[i] = sig_dict[i] + random.sample(sig_dict[i], k = sample_cls_cnt - len(sig_dict[i]))
        else:
            sig_dict[i] = (int(np.floor(sample_cls_cnt/len(sig_dict[i]))))*sig_dict[i]
            sig_dict[i] = sig_dict[i] + random.sample(sig_dict[i], k = sample_cls_cnt - len(sig_dict[i]))
            
    return sig_dict

In [28]:
# class HeartData(Dataset):
#     def __init__(self, data_path):
#         self.data_path = data_path

#     def __len__(self):
#         return len(self.data_path)
        
#     def __getitem__(self, index):
#         filename = self.data_path[index].split("/")[-1].split(".")[0]
#         label = single_main_df[single_main_df["Recording"] == filename]["First_label"].values.item() - 1
#         data = loadmat(self.data_path[index])['ECG'][0][0][2]
#         signal = data[:, 300:3000]
        
#         return signal, label

In [29]:
class HeartData(Dataset):
    def __init__(self, data_paths):
        self.data_paths = data_paths
        random.shuffle(self.data_paths)

    def __getitem__(self, idx):
        data_path = self.data_paths[idx]        
        data = loadmat(data_path)['ECG'][0][0][2]
        clip_data = data[:, 300:3000]

        filename = data_path.split("/")[-1].split(".")[0]
        label = single_main_df[single_main_df["Recording"] == filename]["First_label"].values.item()

        torch_data = torch.from_numpy(clip_data)

        return torch_data.float(), label-1

    def __len__(self):
        return len(self.data_paths)

In [30]:
train_dataset = set_dataset(train_image_paths)
# valid_dataset = HearData(valid_image_paths)

In [31]:
valid_dataset = HeartData(valid_image_paths)

In [32]:
class ECG(Dataset):
    def __init__(self, dict_ds):
        self.dict_ds = dict_ds

    def __getitem__(self, idx):
        selected_dict = { _cls : self.dict_ds[_cls][idx] for _cls in range(9)}

        return selected_dict

    def __len__(self):
        
        return len(self.dict_ds[1])

In [33]:
# check_ds = ECG(data_paths=single_mat_paths, label_df=single_main_df)
# sample, lbl = check_ds[0]
# print(sample.shape, lbl)

In [34]:
# model(sample.unsqueeze(dim=0)).shape

In [35]:
# data_dict = {
#     idx : [] for idx in range(9)
# }

# for data_path in single_mat_paths:
#     filename = data_path.split("/")[-1].split(".")[0]
#     _cls = single_main_df[single_main_df["Recording"] == filename]["First_label"].values.item()

#     data_dict[_cls-1].append(data_path)

# for key in data_dict:
#     print(f"{key}->{len(data_dict[key])}")

In [36]:
# train_data_dict = {
#     _cls : data_dict[_cls][:int(0.9*len(data_dict[_cls]))] for _cls in data_dict
# }

# valid_data_dict = {
#     _cls : data_dict[_cls][int(0.9*len(data_dict[_cls])):] for _cls in data_dict
# }

# for key in train_data_dict:
#     print(f"{key}->{len(train_data_dict[key])}--{len(valid_data_dict[key])}")

In [37]:
# train_data_paths = []
# for key in train_data_dict:
#     train_data_paths.extend(train_data_dict[key])
# valid_data_paths = []
# for key in valid_data_dict:
#     valid_data_paths.extend(valid_data_dict[key])
# print(len(train_data_paths))
# print(len(valid_data_paths))

In [38]:
# sig_dict = {
#     idx : [] for idx in range(9)
# }    

# for data_path in single_mat_paths:
#     filename = data_path.split("/")[-1].split(".")[0]
#     _cls = single_main_df[single_main_df["Recording"] == filename]["First_label"].values.item()
#     data = loadmat(data_path)['ECG'][0][0][2]
#     sig_dict[_cls-1].append(data)

In [39]:
# print(sig_dict[0][0].shape)

In [40]:
train_ds = ECG(train_dataset)
# valid_ds = ECG(valid_dataset)

print(len(train_ds))
# print(len(valid_ds))

1248


In [41]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu", index = 1)
batch_size = ks + kq

traindl = DataLoader(
    train_ds,
    batch_size=batch_size, 
    shuffle=True, 
    pin_memory=True, 
    num_workers=os.cpu_count()//2
)

validdl = DataLoader(
    valid_dataset,
    batch_size=1, 
    shuffle=True, 
    pin_memory=True, 
    num_workers=os.cpu_count()//2
)

print(len(traindl))
print(len(validdl))

208
640


In [42]:
print(device)

cuda:1


In [43]:
epoch = 100
lr = 0.0001
inner_epoch = 1

model.to(device)
optimizer = Adam(model.parameters(), lr=lr)
scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=epoch)
loss_fn = nn.CrossEntropyLoss()

In [None]:
for e in range(epoch):
    num_task = 9
    model.train()
    print(f"Epoch: {e}")
    batch_cnt = 0
    total_loss = 0
    correct = 0
    for train_idx, train_data_dict in tqdm(enumerate(traindl)):
        metaloss = 0
        for task in train_data_dict:
            task_model = copy.deepcopy(model)
            task_optimizer = Adam(task_model.parameters(), lr=lr, weight_decay=1e-4)
            sp_x, sp_y, qr_x, qr_y = single_task_detach(batch_dict=train_data_dict,
                                                        k_shot=ks,
                                                        k_query=kq,
                                                        task=task)
        
            for in_e in range (inner_epoch):
                sp_x, sp_y = sp_x.to(device, dtype = torch.float), sp_y.to(device)
                sp_logits = task_model(sp_x)
                sp_loss = loss_fn(sp_logits, sp_y)
                task_optimizer.zero_grad()
                sp_loss.backward()
                task_optimizer.step() 
                
            qr_x, qr_y = qr_x.to(device, dtype = torch.float), qr_y.to(device)
            qr_logits = task_model(qr_x)
            qr_loss = loss_fn(qr_logits, qr_y)
            metaloss += qr_loss.item()
            qr_loss.backward()            

            for w_global, w_local in zip(model.parameters(), task_model.parameters()):
                if w_global.grad is None:
                    w_global.grad = w_local.grad
                else:
                    w_global.grad += w_local.grad  
        optimizer.step()
        optimizer.zero_grad()            

    model.eval()
    with torch.no_grad():
        test_loss = 0
        correct = 0
        total = 0
        batch_count = 0
        for test_idx, (test_sigs, test_labels) in tqdm(enumerate(validdl)):
            batch_count = test_idx
            test_sigs = test_sigs.to(device, dtype = torch.float)
            test_labels = test_labels.to(device)
            test_logits = model(test_sigs)                

            test_loss += loss_fn(test_logits, test_labels).item()
            _, predicted = test_logits.max(1)
            total += test_labels.size(0)
            correct += predicted.eq(test_labels).sum().item()

    print(f"Epoch: {e} - MetaLoss: {metaloss/num_task} - Test Loss: {test_loss/batch_count} - Test Acc: {100*correct/total}%") 

Epoch: 0


640it [00:01, 332.97it/s]

Epoch: 0 - MetaLoss: 0.08939669219156106 - Test Loss: 13.638493604916986 - Test Acc: 17.03125%
Epoch: 1



640it [00:01, 331.36it/s]

Epoch: 1 - MetaLoss: 0.04670681359453334 - Test Loss: 13.261530585522896 - Test Acc: 17.03125%
Epoch: 2



640it [00:02, 314.61it/s]

Epoch: 2 - MetaLoss: 0.061318441604574524 - Test Loss: 8.58577856491411 - Test Acc: 17.03125%
Epoch: 3



640it [00:01, 322.24it/s]

Epoch: 3 - MetaLoss: 0.028517675715395145 - Test Loss: 14.200269855523892 - Test Acc: 13.4375%
Epoch: 4



640it [00:01, 326.71it/s]

Epoch: 4 - MetaLoss: 0.029498849995434284 - Test Loss: 14.497840355707574 - Test Acc: 17.03125%
Epoch: 5



640it [00:01, 327.35it/s]

Epoch: 5 - MetaLoss: 0.03803380560647282 - Test Loss: 11.984705242472636 - Test Acc: 13.4375%
Epoch: 6



640it [00:02, 319.57it/s]

Epoch: 6 - MetaLoss: 0.04611928229375432 - Test Loss: 9.2911780169394 - Test Acc: 13.4375%
Epoch: 7



640it [00:01, 320.18it/s]

Epoch: 7 - MetaLoss: 0.029211038671847846 - Test Loss: 8.783220217634629 - Test Acc: 15.15625%
Epoch: 8



640it [00:02, 316.31it/s]

Epoch: 8 - MetaLoss: 0.01921009396513303 - Test Loss: 10.055898709371336 - Test Acc: 13.4375%
Epoch: 9



640it [00:01, 321.20it/s]

Epoch: 9 - MetaLoss: 0.028489939894320235 - Test Loss: 7.549764797861028 - Test Acc: 13.4375%
Epoch: 10



640it [00:02, 314.93it/s]

Epoch: 10 - MetaLoss: 0.009574218420311809 - Test Loss: 7.074091498869732 - Test Acc: 13.4375%
Epoch: 11



640it [00:01, 323.92it/s]

Epoch: 11 - MetaLoss: 0.0374773307186034 - Test Loss: 10.643665140749807 - Test Acc: 13.4375%
Epoch: 12



640it [00:01, 329.78it/s]

Epoch: 12 - MetaLoss: 0.026148395276524954 - Test Loss: 7.500314531098696 - Test Acc: 17.03125%
Epoch: 13



640it [00:02, 316.39it/s]

Epoch: 13 - MetaLoss: 0.0356681920044745 - Test Loss: 9.385736420434439 - Test Acc: 17.03125%
Epoch: 14



640it [00:01, 331.49it/s]

Epoch: 14 - MetaLoss: 0.006270249878677229 - Test Loss: 7.412292716205586 - Test Acc: 15.9375%
Epoch: 15



640it [00:01, 321.99it/s]

Epoch: 15 - MetaLoss: 0.01302759428249879 - Test Loss: 7.574366318414655 - Test Acc: 16.875%
Epoch: 16



640it [00:01, 322.63it/s]

Epoch: 16 - MetaLoss: 0.020841238865007956 - Test Loss: 9.004754348035686 - Test Acc: 13.4375%
Epoch: 17



640it [00:01, 320.43it/s]

Epoch: 17 - MetaLoss: 0.010658516289873255 - Test Loss: 10.532537056992162 - Test Acc: 17.03125%
Epoch: 18



640it [00:02, 319.92it/s]

Epoch: 18 - MetaLoss: 0.01353944473925771 - Test Loss: 4.9274738039442445 - Test Acc: 13.28125%
Epoch: 19



640it [00:01, 322.94it/s]

Epoch: 19 - MetaLoss: 0.04195714798859424 - Test Loss: 6.66797097552865 - Test Acc: 16.875%
Epoch: 20



640it [00:01, 321.54it/s]

Epoch: 20 - MetaLoss: 0.007034535468038585 - Test Loss: 9.018517810714473 - Test Acc: 17.03125%
Epoch: 21



640it [00:02, 319.67it/s]

Epoch: 21 - MetaLoss: 0.005246155959967937 - Test Loss: 6.088466988296389 - Test Acc: 16.25%
Epoch: 22



640it [00:04, 131.69it/s]

Epoch: 22 - MetaLoss: 0.011135696135978732 - Test Loss: 7.284234023989646 - Test Acc: 11.25%
Epoch: 23



640it [00:02, 276.56it/s]

Epoch: 23 - MetaLoss: 0.0024269526701472285 - Test Loss: 7.536974912331511 - Test Acc: 8.59375%
Epoch: 24



640it [00:02, 273.75it/s]

Epoch: 24 - MetaLoss: 0.0047368327246254515 - Test Loss: 6.143697755367543 - Test Acc: 9.53125%
Epoch: 25



640it [00:03, 208.73it/s]

Epoch: 25 - MetaLoss: 0.021033506009391405 - Test Loss: 5.708271020820076 - Test Acc: 17.03125%
Epoch: 26



640it [00:05, 123.03it/s]

Epoch: 26 - MetaLoss: 0.0036897114284026125 - Test Loss: 6.223067693479347 - Test Acc: 16.5625%
Epoch: 27



640it [00:02, 256.84it/s]

Epoch: 27 - MetaLoss: 0.0030363864619478895 - Test Loss: 7.348534596017788 - Test Acc: 13.4375%
Epoch: 28



640it [00:02, 274.53it/s]

Epoch: 28 - MetaLoss: 0.043553299314226024 - Test Loss: 6.415330081385253 - Test Acc: 9.84375%
Epoch: 29



640it [00:03, 208.84it/s]

Epoch: 29 - MetaLoss: 0.012843783202697523 - Test Loss: 13.646325879429197 - Test Acc: 13.4375%
Epoch: 30



640it [00:05, 123.07it/s]

Epoch: 30 - MetaLoss: 0.010643162738738788 - Test Loss: 7.199783072443985 - Test Acc: 13.4375%
Epoch: 31



640it [00:02, 297.82it/s]

Epoch: 31 - MetaLoss: 0.007447969420657803 - Test Loss: 10.38109100020581 - Test Acc: 13.4375%
Epoch: 32





In [None]:
#         batch_cnt = batch
#         train_sig = train_sig.to(device)
#         train_label = train_label.to(device)
        
#         pred = model(train_sig)
#         loss = loss_fn(pred, train_label)
        
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
        
#         scheduler.step()
        
#         total_loss += loss.item()
#         correct += (pred.argmax(1) == train_label).type(torch.float).sum().item()
    
#     total_loss /= batch_cnt
#     correct /= len(traindl.dataset)
    
#     print(f"train loss: {total_loss} - train acc: {100*correct}")
    
#     batch_cnt = 0
#     val_total_loss = 0
#     val_correct = 0
#     model.eval()
#     with torch.no_grad():
#         for batch, (valid_sig, valid_label) in tqdm(enumerate(validdl)):
#             batch_cnt = batch
#             valid_sig = valid_sig.to(device)
#             valid_label = valid_label.to(device)
            
#             pred = model(valid_sig)
#             loss = loss_fn(pred, valid_label)
            
#             val_total_loss += loss.item()
#             val_correct += (pred.argmax(1) == valid_label).type(torch.float).sum().item()
    
#         val_total_loss /= batch_cnt
#         val_correct /= len(validdl.dataset)
        
#         print(f"valid loss: {val_total_loss} - valid acc: {100*val_correct}")