In [1]:
import sys
sys.path.append(r'C:/Program Files (zk)/PythonFiles/AClassification/AudioClassification-Pytorch-KZhao/')

In [2]:
import os
import yaml
import time
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from ackit.data_utils.collate_fn import collate_fn
from ackit.trainer_setting import get_model
from ackit.utils.utils import load_ckpt
from ackit.data_utils.coughvid_reader import CoughVID_Lists, CoughVID_Dataset
from ackit.data_utils.featurizer import Wave2Mel
from ackit.utils.plotter import calc_accuracy, plot_heatmap

### 基本参数设置

In [15]:
configs = {
    "run_save_dir": "../runs/tdnn_coughvid/",
    "model":{
        "num_class": 3,
        "input_length": 94,
        "wav_length": 48000,
        "input_dim": 512,
        "n_mels": 80,
        },
    "fit":{
        "batch_size": 64,
        "epochs" : 23,
        "start_scheduler_epoch": 6
        },
}

# istrain: 如果是评估环节，设为False，读取测试集，并且不创建optimizer
# isdemo: 如果只是测试一下，设为True，仅读取32条数据方便快速测试是否有bug
istrain, isdemo = True, False

configs = "../configs/tdnn_coughvid.yaml"
with open(configs) as stream:
    configs = yaml.safe_load(stream)
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
num_epoch = configs["fit"]["epochs"]
timestr = time.strftime("%Y%m%d%H%M", time.localtime())
if istrain:
    run_save_dir = configs["run_save_dir"] + timestr + f'_tdnn/'
if not isdemo:
    os.makedirs(run_save_dir, exist_ok=True)
train_dataset, valid_dataset = None, None
train_loader, valid_loader = None, None

In [None]:
trp, trl, vap, val = CoughVID_Lists(filename="./datasets/waveinfo_annotation.csv", istrain=True, isdemo=False)

## 极为耗时的一句，读取数据

In [5]:
train_dataset = CoughVID_Dataset(path_list=trp, label_list=trl)
valid_dataset = CoughVID_Dataset(path_list=vap, label_list=val)

  samples, sample_rate = librosa.core.load(file)  # , dtype='float32')
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)
  samples, sample_rate = librosa.core.load(file)  # , dtype='float32')
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)
  samples, sample_rate = librosa.core.load(file)  # , dtype='float32')
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)
Loading: 100%|███████████████████████████████████████████████████████████████████| 15085/15085 [17:19<00:00, 14.52it/s]


In [7]:
# dataloader
train_loader = DataLoader(train_dataset, batch_size=self.configs["fit"]["batch_size"], shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=self.configs["fit"]["batch_size"], shuffle=True, collate_fn=collate_fn)
# data_transform
w2m = Wave2Mel(sr=16000, n_mels=80)

In [17]:
# model loss_function optimizer scheduler
model = get_model("tdnn", configs, istrain=True).to(device)
cls_loss = nn.CrossEntropyLoss().to(device)
print("All model and loss are on device:", device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=5e-5)

All model and loss are on device: cuda


# train and valid

In [None]:
from ackit.utils.plotter import calc_accuracy, plot_heatmap

In [19]:
history1 = []
for epoch_id in range(configs["fit"]["epochs"]):
    # ---------------------------
    # -----------TRAIN-----------
    # ---------------------------
    model.train()
    for x_idx, (x_wav, y_label, _) in enumerate(tqdm(train_loader, desc="Training")):
        x_mel = w2m(x_wav).to(device)
        y_label = torch.tensor(y_label, device=device)
        # print("shape of x_mel:", x_mel.shape)
        optimizer.zero_grad()
        y_pred, _ = model(x=x_mel)
        # recon_loss = self.recon_loss(recon_spec, x_mel)
        pred_loss = cls_loss(y_pred, y_label)
        pred_loss.backward()
        optimizer.step()

        if x_idx > 2:
            history1.append(pred_loss.item())
        if x_idx % 60 == 0:
            print(f"Epoch[{epoch_id}], mtid pred loss:{pred_loss.item():.4f}")
    if epoch_id >= configs["fit"]["start_scheduler_epoch"]:
        scheduler.step()

    # ---------------------------
    # -----------SAVE------------
    # ---------------------------
    plt.figure(0)
    plt.plot(range(len(history1)), history1, c="green", alpha=0.7)
    plt.savefig(run_save_dir + f'cls_loss_iter_{epoch_id}.png')
    plt.close()
    # if epoch > 6 and epoch % 2 == 0:
    os.makedirs(run_save_dir + f"model_epoch_{epoch_id}/", exist_ok=True)
    tmp_model_path = "{model}model_{epoch}.pth".format(
        model=run_save_dir + f"model_epoch_{epoch_id}/",
        epoch=epoch_id)
    torch.save(model.state_dict(), tmp_model_path)
    # ---------------------------
    # -----------TEST------------
    # ---------------------------
    model.eval()
    heatmap_input = None
    labels = None
    for x_idx, (x_wav, y_label, _) in enumerate(tqdm(train_loader, desc="Test")):
        x_mel = w2m(x_wav).to(device)
        y_label = torch.tensor(y_label, device=device)
        y_pred, _ = model(x=x_mel)
        if x_idx == 0:
            heatmap_input, labels = y_pred, y_label
        else:
            heatmap_input = torch.concat((heatmap_input, y_pred), dim=0)
            labels = torch.concat((labels, y_label), dim=0)
        if x_idx * configs["fit"]["batch_size"] > 800:
            break
    print("heatmap_input shape:", heatmap_input.shape)
    print("lables shape:", labels.shape)
    # if epoch > 3:
    #     self.plot_reduction(resume_path="", load_epoch=epoch, reducers=["heatmap"])
    heatmap_input = heatmap_input.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()
    calc_accuracy(pred_matrix=heatmap_input, label_vec=labels,
                  save_path=run_save_dir + f"/accuracy_epoch_{epoch_id}.png")
    plot_heatmap(pred_matrix=heatmap_input, label_vec=labels,
                 ticks=["healthy", "symptomatic", "COVID-19"],
                 save_path=run_save_dir + f"/heatmap_epoch_{epoch_id}.png")
print("============== END TRAINING ==============")

  y_label = torch.tensor(y_label, device=device)
Training:   1%|▉                                                                       | 3/236 [00:00<00:14, 15.62it/s]

Epoch[0], mtid pred loss:1.4905


Training:  26%|██████████████████▋                                                    | 62/236 [00:02<00:08, 21.06it/s]

Epoch[0], mtid pred loss:1.0240


Training:  52%|████████████████████████████████████▏                                 | 122/236 [00:06<00:06, 16.90it/s]

Epoch[0], mtid pred loss:1.0846


Training:  76%|█████████████████████████████████████████████████████▍                | 180/236 [00:10<00:04, 12.01it/s]

Epoch[0], mtid pred loss:0.9461


Training: 100%|██████████████████████████████████████████████████████████████████████| 236/236 [00:13<00:00, 17.97it/s]
  y_label = torch.tensor(y_label, device=device)
Test:   0%|                                                                                    | 0/236 [00:00<?, ?it/s]


ValueError: too many values to unpack (expected 2)

In [13]:
model.eval()
tsne_input = None
heatmap_input = None
labels = None
for x_idx, (x_wav, y_label, _) in enumerate(tqdm(train_loader, desc="Test")):
    x_mel = w2m(x_wav).to(device)
    y_label = torch.tensor(y_label, device=device)
    y_pred, featmap = model(x=x_mel)
    if x_idx == 0:
        heatmap_input, labels, tsne_input = y_pred, y_label, featmap
    else:
        heatmap_input = torch.concat((heatmap_input, y_pred), dim=0)
        labels = torch.concat((labels, y_label), dim=0)
        tsne_input = torch.concat((tsne_input, featmap), dim=0)
    if x_idx * configs["fit"]["batch_size"] > 800:
        break
print("heatmap_input shape:", heatmap_input.shape)
print("lables shape:", labels.shape)
# plot_reduction(resume_path="", load_epoch=epoch, reducers=["heatmap"])
heatmap_input = heatmap_input.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
tsne_input = tsne_input.detach().cpu().numpy()

torch.Size([64, 80, 94])