In [1]:
import sys
sys.path.append(r'C:/Program Files (zk)/PythonFiles/AClassification/SoundDL-CoughVID')
import os
import yaml
import time
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchaudio
from pretrained.wav2vec import Wav2Vec

from models.conv_vae import ConvVAE, vae_loss
from models.classifiers import LSTM_Classifier, LSTM_Attn_Classifier

from modules.loss import FocalLoss
from readers.coughvid_reader import CoughVID_Class, CoughVID_Dataset
from readers.featurizer import Wave2Mel
from readers.collate_fn import collate_fn
from tools.plotter import calc_accuracy, plot_heatmap

In [2]:
class ConvEncoder(nn.Module):
    def __init__(self, inp_shape=(1, 298, 512), n_class=3):
        super().__init__()
        c, h, w = inp_shape
        hh, ww = h, w
        self.shapes = [(hh, ww)]
        self.encoder = nn.Sequential()
        cl = [16, 32, -1, 64, 128, -1, 512]
        ksp = [(4, 2, 1), ((5,4),2,1), (2, 2, -1), (4, (1, 2), 1), (4, 2, 1), (2, 2, -1), (3, 2, 0)]
        pre_c = 1
        for i, (k, s, p) in enumerate(ksp):
            if cl[i] == -1:            
                self.encoder.append(nn.MaxPool2d(kernel_size=k, stride=s, return_indices=False))
                hh /= 2
                ww /= 2
            else:
                self.encoder.append(nn.Conv2d(pre_c, cl[i], kernel_size=k, stride=s, padding=p))
                self.encoder.append(nn.BatchNorm2d(cl[i]))
                self.encoder.append(nn.ReLU(inplace=True))
                pre_c = cl[i]
                if isinstance(k, tuple):
                    hh = (hh-k[0]+2*p) // s + 1
                    ww = (ww-k[1]+2*p) // s + 1
                elif isinstance(s, tuple):
                    hh = (hh-k+2*p) // s[0] + 1
                    ww = (ww-k+2*p) // s[1] + 1
                else:
                    hh = (hh-k+2*p) // s + 1
                    ww = (ww-k+2*p) // s + 1
            self.shapes.append((hh, ww))
        print(self.shapes)

        self.flatten = nn.Flatten(start_dim=1)
        print("zero later:", cl[-1]*hh*ww)
        hidden_size = [int(cl[-1]*hh*ww), 256, 64, n_class]
        self.cls = nn.Sequential()
        for i in range(len(hidden_size) - 1):
            in_dim = hidden_size[i]
            out_dim = hidden_size[i + 1]
            self.cls.append(nn.Linear(in_dim, out_dim))
            if (i < len(hidden_size) - 2):
                self.cls.append(nn.BatchNorm1d(out_dim))
                self.cls.append(nn.ReLU(inplace=True))
            elif (i == len(hidden_size) - 2):
                self.cls.append(nn.BatchNorm1d(out_dim))
        
        self.softmax = nn.Softmax(dim=1)
    def forward(self, x_input):
        # feat = self.mp1(self.encoder_conv1(x_input))
        # print(feat.shape)
        # feat = self.mp2(self.encoder_conv2(feat))
        # feat = self.encoder_conv3(feat)
        feat = self.encoder(x_input)
        # print("after encoder:", feat.shape)
        feat = self.flatten(feat)
        # print("after flatten:", feat.shape)
        feat = self.cls(feat)
        # print("after cls:", feat.shape)
        pred = self.softmax(feat)
        # print("after softmax:", pred.shape)
        return pred

# x_mel = torch.randn(size=(16, 1, 298, 512))
# # print(x_mel.shape)
# model = ConvEncoder()
# out = model(x_mel)
# out.shape

In [3]:
import pandas as pd
src_data = pd.read_csv("./datasets/waveinfo_labedfine_forcls.csv", header=0, index_col=0, delimiter=',')
print("原始数据：", src_data.shape)
print(src_data.iloc[:, [0, 6]].groupby("status_full").count())

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

encoder = Wav2Vec(pretrained=True).to(device)
print("Load Pretrained model Wav2Vec...")

criterion = FocalLoss(class_num=3)
print("Create CrossEntropyLoss...")

print("All model and loss are on device:", device)

shapes, class_num = [298, 512], 3

model = ConvEncoder().to(device)

# model loss_function optimizer scheduler
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-4, max_lr=1e-1, step_size_up=10)
print("Create TDNN, Adam with lr=1e-3, CosineAnnealingLR Shceduler")

原始数据： (6341, 7)
             filename
status_full          
0                2114
1                3288
2                 939
Load Pretrained model Wav2Vec...
Create CrossEntropyLoss...
All model and loss are on device: cuda
[(298, 512), (149, 256), (74, 128), (37.0, 64.0), (36.0, 32.0), (18.0, 16.0), (9.0, 8.0), (4.0, 3.0)]
zero later: 6144.0
Create TDNN, Adam with lr=1e-3, CosineAnnealingLR Shceduler


In [4]:
from torch.utils.data import DataLoader
train_x, train_y, test_x, test_y = CoughVID_Class(isdemo=False)

tic = time.time()
cough_dataset = CoughVID_Dataset(path_list=train_x, label_list=train_y)
toc = time.time()
print("Train Dataset Creat Completely, cost time:", toc-tic)

tic = time.time()
valid_dataset = CoughVID_Dataset(path_list=test_x, label_list=test_y)
toc = time.time()
print("Valid Dataset Creat Completely, cost time:", toc-tic)

num of trainingset:  6044 6044
num of testingset: 297 297


  samples, sample_rate = librosa.load(file)  # , dtype='float32')
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)
  samples, sample_rate = librosa.load(file)  # , dtype='float32')
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)
  samples, sample_rate = librosa.load(file)  # , dtype='float32')
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)
Loading: 100%|█████████████████████████████████████████████████████████████████████| 6044/6044 [06:52<00:00, 14.66it/s]


Train Dataset Creat Completely, cost time: 412.3014051914215


Loading: 100%|███████████████████████████████████████████████████████████████████████| 297/297 [00:20<00:00, 14.37it/s]

Valid Dataset Creat Completely, cost time: 20.67452073097229





In [5]:
configs = {
    "run_save_dir": "./runs/wav2vec_coughvid/",
    "model":{
        "num_class": 3,
        "input_length": 94,
        "wav_length": 48000,
        "input_dim": 512,
        "n_mels": 128,
        },
    "fit":{
        "batch_size": 64,
        "epochs" : 23,
        "start_scheduler_epoch": 6
        },
}

num_epoch = configs["fit"]["epochs"]

In [6]:
train_loader = DataLoader(cough_dataset, batch_size=configs["fit"]["batch_size"], shuffle=True,
                          collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=configs["fit"]["batch_size"], shuffle=True,
                          collate_fn=collate_fn)
print("Create Training Loader and Valid Loader.")

Create Training Loader and Valid Loader.


In [7]:
for i, (x_wav, y_label, max_len_rate) in enumerate(train_loader):
    # print(x_wav.shape)
    print(y_label)
    # print(max_len_rate)
    x_wav = x_wav.to(device)
    x_mel = encoder(x_wav).transpose(1,2).unsqueeze(1)
    print(x_mel.shape)
    print(model(x_mel).shape)
    if i>1:
        break

tensor([1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 0, 1, 1, 2, 0, 1, 2, 1, 1, 1, 1,
        1, 2, 1, 0, 1, 2, 1, 0, 0, 1, 0, 1, 2, 0, 0, 1, 0, 1, 1, 0, 2, 0, 0, 0,
        0, 1, 0, 0, 2, 0, 0, 0, 2, 1, 1, 1, 1, 1, 2, 2])
torch.Size([64, 1, 298, 512])
torch.Size([64, 3])
tensor([1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0,
        2, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 2, 0, 1, 1,
        1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 2, 1, 0, 0, 1])
torch.Size([64, 1, 298, 512])
torch.Size([64, 3])
tensor([2, 1, 0, 1, 0, 0, 2, 1, 1, 1, 1, 2, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 2, 0,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0,
        2, 1, 2, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1])
torch.Size([64, 1, 298, 512])
torch.Size([64, 3])


In [8]:
timestr = time.strftime("%Y%m%d%H%M", time.localtime())
run_save_dir = configs["run_save_dir"] + timestr + f'_cnn_focalloss_bs64/'
os.makedirs(run_save_dir, exist_ok=True)
print("创建运行保存文件", run_save_dir)
with open("setting.txt", 'w', encoding="utf_8") as fout:
    fout.write("这次def __getitem__(self, ind):里面写几行：\n")
    fout.write("tmpseg = copy(self.wav_list[ind])\n")
    fout.write("tmpseg.crop(duration=3.0, mode=\"train\")\n")
    fout.write("tmpseg.wav_padding()\n")
    fout.write("assert len(tmpseg) == 48000, \"Error Length\"\n")
    fout.write("return tmpseg.samples, self.label_list[ind]\n")
    fout.write("，虽然慢，但是随机切片是必要的，或者先读，后续再切似乎更快吧？")
    fout.write("batch_size调到64，希望能缓解样本标签不均衡的问题？")

创建运行保存文件 ./runs/wav2vec_coughvid/202404301457_cnn_focalloss_bs64/


In [9]:
history1 = []
for epoch_id in range(configs["fit"]["epochs"]):
    # ---------------------------
    # -----------TRAIN-----------
    # ---------------------------
    model.train()
    for x_idx, (x_wav, y_label, _) in enumerate(tqdm(train_loader, desc="Training")):
        x_wav = x_wav.to(device)
        x_mel = encoder(x_wav).transpose(1,2).unsqueeze(1)
        y_label = torch.tensor(y_label, device=device)
        # print("shape of x_mel:", x_mel.shape)
        
        optimizer.zero_grad()
        y_hat = model(x_mel)
        pred_loss = criterion(y_hat, y_label)
        pred_loss.backward()
        optimizer.step()

        if x_idx > 2:
            history1.append(pred_loss.item())
        if x_idx % 60 == 0:
            print(f"Epoch[{epoch_id}], mtid pred loss:{pred_loss.item():.4f}")
    if epoch_id >= configs["fit"]["start_scheduler_epoch"]:
        scheduler.step()

    # ---------------------------
    # -----------SAVE------------
    # ---------------------------
    plt.figure(0)
    plt.plot(range(len(history1)), history1, c="green", alpha=0.7)
    plt.savefig(run_save_dir + f'cls_loss_iter_{epoch_id}.png')
    plt.close()
    # if epoch > 6 and epoch % 2 == 0:
    os.makedirs(run_save_dir + f"model_epoch_{epoch_id}/", exist_ok=True)
    tmp_model_path = "{model}model_{epoch}.pth".format(
        model=run_save_dir + f"model_epoch_{epoch_id}/",
        epoch=epoch_id)
    torch.save(model.state_dict(), tmp_model_path)
    # ---------------------------
    # -----------TEST------------
    # ---------------------------
    model.eval()
    heatmap_input = None
    labels = None
    for x_idx, (x_wav, y_label, _) in enumerate(tqdm(valid_loader, desc="Validate")):
        x_wav = x_wav.to(device)
        x_mel = encoder(x_wav).transpose(1,2).unsqueeze(1)
        print(x_mel.shape)
        y_label = torch.tensor(y_label, device=device)
        
        y_pred = model(x_mel)
        pred_loss = criterion(y_pred, y_label)
        
        if x_idx == 0:
            heatmap_input, labels = y_pred, y_label
        else:
            heatmap_input = torch.concat((heatmap_input, y_pred), dim=0)
            labels = torch.concat((labels, y_label), dim=0)
        # if x_idx * configs["fit"]["batch_size"] > 800:
        #     break
    print("heatmap_input shape:", heatmap_input.shape)
    print("lables shape:", labels.shape)
    # if epoch > 3:
    #     self.plot_reduction(resume_path="", load_epoch=epoch, reducers=["heatmap"])
    heatmap_input = heatmap_input.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()
    calc_accuracy(pred_matrix=heatmap_input, label_vec=labels,
                  save_path=run_save_dir + f"/accuracy_epoch_{epoch_id}.txt")
    plot_heatmap(pred_matrix=heatmap_input, label_vec=labels,
                 ticks=["healthy", "symptomatic", "COVID-19"],
                 save_path=run_save_dir + f"/heatmap_epoch_{epoch_id}.png")
print("============== END TRAINING ==============")

  y_label = torch.tensor(y_label, device=device)
Training:   1%|▊                                                                        | 1/95 [00:00<01:01,  1.52it/s]

Epoch[0], mtid pred loss:0.5371


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:04,  6.67it/s]

Epoch[0], mtid pred loss:0.5383


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.49it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 17.29it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 10.81it/s]

torch.Size([64, 1, 298, 512])
torch.Size([41, 1, 298, 512])





heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.32996632996632996
precision: ['0.3491', '0.3200', '0.3190']
recall: ['0.3737', '0.2424', '0.3737']


  y_label = torch.tensor(y_label, device=device)
Training:   3%|██▎                                                                      | 3/95 [00:00<00:10,  8.76it/s]

Epoch[1], mtid pred loss:0.4825


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:04,  6.74it/s]

Epoch[1], mtid pred loss:0.4817


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.72it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 18.08it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 18.44it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.36363636363636365
precision: ['0.3920', '0.3678', '0.3176']
recall: ['0.4949', '0.3232', '0.2727']


  y_label = torch.tensor(y_label, device=device)
Training:   3%|██▎                                                                      | 3/95 [00:00<00:10,  8.92it/s]

Epoch[2], mtid pred loss:0.4184


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:04,  6.66it/s]

Epoch[2], mtid pred loss:0.4286


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.80it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 18.02it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 19.16it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.367003367003367
precision: ['0.3852', '0.3875', '0.3263']
recall: ['0.4747', '0.3131', '0.3131']


  y_label = torch.tensor(y_label, device=device)
Training:   1%|▊                                                                        | 1/95 [00:00<00:14,  6.43it/s]

Epoch[3], mtid pred loss:0.4295


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:04,  6.73it/s]

Epoch[3], mtid pred loss:0.4205


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.69it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 19.05it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 18.61it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.38047138047138046
precision: ['0.3879', '0.4062', '0.3412']
recall: ['0.4545', '0.3939', '0.2929']


  y_label = torch.tensor(y_label, device=device)
Training:   3%|██▎                                                                      | 3/95 [00:00<00:10,  8.65it/s]

Epoch[4], mtid pred loss:0.3601


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:05,  6.59it/s]

Epoch[4], mtid pred loss:0.3534


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.69it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 17.85it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 18.82it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.39057239057239057
precision: ['0.4153', '0.3854', '0.3614']
recall: ['0.4949', '0.3737', '0.3030']


  y_label = torch.tensor(y_label, device=device)
Training:   3%|██▎                                                                      | 3/95 [00:00<00:10,  8.82it/s]

Epoch[5], mtid pred loss:0.3622


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:04,  6.70it/s]

Epoch[5], mtid pred loss:0.3654


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.82it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 17.24it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 18.49it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.3569023569023569
precision: ['0.3667', '0.3933', '0.3068']
recall: ['0.4444', '0.3535', '0.2727']


  y_label = torch.tensor(y_label, device=device)
Training:   1%|▊                                                                        | 1/95 [00:00<00:14,  6.69it/s]

Epoch[6], mtid pred loss:0.3393


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:05,  6.39it/s]

Epoch[6], mtid pred loss:0.3054


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.71it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 18.12it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 19.58it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.38047138047138046
precision: ['0.3879', '0.4286', '0.3222']
recall: ['0.4545', '0.3939', '0.2929']


  y_label = torch.tensor(y_label, device=device)
Training:   3%|██▎                                                                      | 3/95 [00:00<00:10,  8.66it/s]

Epoch[7], mtid pred loss:0.3203


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:04,  6.81it/s]

Epoch[7], mtid pred loss:0.4361


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.72it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 17.43it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 18.65it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.42424242424242425
precision: ['0.3952', '0.4508', '0.6250']
recall: ['0.6667', '0.5556', '0.0505']


  y_label = torch.tensor(y_label, device=device)
Training:   3%|██▎                                                                      | 3/95 [00:00<00:10,  8.97it/s]

Epoch[8], mtid pred loss:0.3582


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:04,  6.73it/s]

Epoch[8], mtid pred loss:0.4629


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.81it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 18.08it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 19.19it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.38047138047138046
precision: ['0.5200', '0.3541', '0.6000']
recall: ['0.1313', '0.9192', '0.0909']


  y_label = torch.tensor(y_label, device=device)
Training:   3%|██▎                                                                      | 3/95 [00:00<00:10,  8.81it/s]

Epoch[9], mtid pred loss:0.3903


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:04,  6.77it/s]

Epoch[9], mtid pred loss:0.3844


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.78it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 18.24it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 19.18it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.3872053872053872
precision: ['0.3571', '0.4245', '0.5556']
recall: ['0.6566', '0.4545', '0.0505']


  y_label = torch.tensor(y_label, device=device)
Training:   3%|██▎                                                                      | 3/95 [00:00<00:10,  8.77it/s]

Epoch[10], mtid pred loss:0.3503


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:04,  6.72it/s]

Epoch[10], mtid pred loss:0.3885


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.77it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 17.86it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 19.11it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.4276094276094276
precision: ['0.4118', '0.4354', '0.5000']
recall: ['0.5657', '0.6465', '0.0707']


  y_label = torch.tensor(y_label, device=device)
Training:   1%|▊                                                                        | 1/95 [00:00<00:14,  6.54it/s]

Epoch[11], mtid pred loss:0.2855


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:04,  6.78it/s]

Epoch[11], mtid pred loss:0.2866


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.78it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 17.86it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 18.90it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.4444444444444444
precision: ['0.4359', '0.4151', '0.7143']
recall: ['0.5152', '0.6667', '0.1515']


  y_label = torch.tensor(y_label, device=device)
Training:   3%|██▎                                                                      | 3/95 [00:00<00:10,  9.06it/s]

Epoch[12], mtid pred loss:0.2520


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:04,  6.76it/s]

Epoch[12], mtid pred loss:0.2327


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.79it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 17.24it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 19.18it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.39730639730639733
precision: ['0.3879', '0.3983', '0.5000']
recall: ['0.6465', '0.4747', '0.0707']


  y_label = torch.tensor(y_label, device=device)
Training:   3%|██▎                                                                      | 3/95 [00:00<00:10,  8.83it/s]

Epoch[13], mtid pred loss:0.2759


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:04,  6.75it/s]

Epoch[13], mtid pred loss:0.2830


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.77it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 17.86it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 19.11it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.3838383838383838
precision: ['0.4328', '0.3636', '0.4062']
recall: ['0.2929', '0.7273', '0.1313']


  y_label = torch.tensor(y_label, device=device)
Training:   3%|██▎                                                                      | 3/95 [00:00<00:10,  9.10it/s]

Epoch[14], mtid pred loss:0.1793


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:04,  6.74it/s]

Epoch[14], mtid pred loss:0.2817


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.77it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 17.73it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 19.47it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.4074074074074074
precision: ['0.4495', '0.3704', '0.4615']
recall: ['0.4949', '0.6061', '0.1212']


  y_label = torch.tensor(y_label, device=device)
Training:   3%|██▎                                                                      | 3/95 [00:00<00:10,  9.00it/s]

Epoch[15], mtid pred loss:0.1573


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:05,  6.58it/s]

Epoch[15], mtid pred loss:0.2485


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.70it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 17.70it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 17.88it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.36363636363636365
precision: ['0.3605', '0.3537', '0.3953']
recall: ['0.6263', '0.2929', '0.1717']


  y_label = torch.tensor(y_label, device=device)
Training:   3%|██▎                                                                      | 3/95 [00:00<00:10,  8.74it/s]

Epoch[16], mtid pred loss:0.1850


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:04,  6.90it/s]

Epoch[16], mtid pred loss:0.1830


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.69it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 18.12it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 19.04it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.43097643097643096
precision: ['0.4500', '0.4037', '0.5625']
recall: ['0.5455', '0.6566', '0.0909']


  y_label = torch.tensor(y_label, device=device)
Training:   1%|▊                                                                        | 1/95 [00:00<00:14,  6.53it/s]

Epoch[17], mtid pred loss:0.1537


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:04,  6.72it/s]

Epoch[17], mtid pred loss:0.1782


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.79it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 18.09it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 18.83it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.44107744107744107
precision: ['0.5385', '0.3960', '0.5294']
recall: ['0.4242', '0.8081', '0.0909']


  y_label = torch.tensor(y_label, device=device)
Training:   3%|██▎                                                                      | 3/95 [00:00<00:10,  8.88it/s]

Epoch[18], mtid pred loss:0.1324


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:04,  6.73it/s]

Epoch[18], mtid pred loss:0.1482


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.76it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 16.81it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 18.17it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.4107744107744108
precision: ['0.4714', '0.3738', '0.5714']
recall: ['0.3333', '0.7778', '0.1212']


  y_label = torch.tensor(y_label, device=device)
Training:   1%|▊                                                                        | 1/95 [00:00<00:14,  6.30it/s]

Epoch[19], mtid pred loss:0.1399


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:05,  6.44it/s]

Epoch[19], mtid pred loss:0.1520


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.48it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 18.18it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 19.30it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.4276094276094276
precision: ['0.4884', '0.3927', '0.5000']
recall: ['0.4242', '0.7576', '0.1010']


  y_label = torch.tensor(y_label, device=device)
Training:   1%|▊                                                                        | 1/95 [00:00<00:14,  6.35it/s]

Epoch[20], mtid pred loss:0.1513


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:04,  6.75it/s]

Epoch[20], mtid pred loss:0.1332


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.75it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 17.83it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 19.22it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.43434343434343436
precision: ['0.4787', '0.4033', '0.5000']
recall: ['0.4545', '0.7374', '0.1111']


  y_label = torch.tensor(y_label, device=device)
Training:   3%|██▎                                                                      | 3/95 [00:00<00:10,  8.95it/s]

Epoch[21], mtid pred loss:0.1210


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:04,  6.63it/s]

Epoch[21], mtid pred loss:0.1276


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.61it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 17.14it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 18.76it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.43434343434343436
precision: ['0.4386', '0.4337', '0.4118']
recall: ['0.5051', '0.7273', '0.0707']


  y_label = torch.tensor(y_label, device=device)
Training:   3%|██▎                                                                      | 3/95 [00:00<00:10,  8.83it/s]

Epoch[22], mtid pred loss:0.1055


Training:  65%|██████████████████████████████████████████████▉                         | 62/95 [00:09<00:05,  6.38it/s]

Epoch[22], mtid pred loss:0.1565


Training: 100%|████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.65it/s]
  y_label = torch.tensor(y_label, device=device)
Validate:  40%|█████████████████████████████▌                                            | 2/5 [00:00<00:00, 18.35it/s]

torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])
torch.Size([64, 1, 298, 512])


Validate: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 19.25it/s]


torch.Size([41, 1, 298, 512])
heatmap_input shape: torch.Size([297, 3])
lables shape: torch.Size([297])
(297, 3)
acc: 0.468013468013468
precision: ['0.5326', '0.4407', '0.4286']
recall: ['0.4949', '0.7879', '0.1212']


# End