In [None]:
import sys
import os
import time
import math
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn.metrics as metrics
import torch
import torch.nn as nn
import torch.optim as optim
sys.path.append('../')
sys.path.append('C:/Program Files (zk)/PythonFiles/AClassification/Heart-Sound-Diagnosis/')
from models.classifiers import LSTM_Attn_Classifier
from models.mobilefacenet import MobileFaceNet
from datautils.PhysioNet2016Dataset import get_loaders

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
train_loader, test_loader = get_loaders()
class_loss = nn.CrossEntropyLoss().to(device)

In [None]:
# cl_model = MobileNetV2(dc=1, n_class=2, input_size=87, width_mult=1).to(device)
cl_model = MobileFaceNet(inp_c=1, input_dim=87, latent_size=(6, 8), num_class=2, inp=1).to(device)
x = torch.randn(16, 1, 87, 128, device=device)  # (bs, length, dim)
label = torch.randint(low=0, high=2, size=(16,), device=device)
tmp_pred, tmp_feat = cl_model(x, label)
print(tmp_pred.shape, tmp_feat.shape)

In [None]:
iter_max = 1000
warm_up_iter, T_max, lr_max, lr_min = 30, iter_max // 3, 5e-4, 5e-5
# reference: https://blog.csdn.net/qq_36560894/article/details/114004799
# 为param_groups[0] (即model.layer2) 设置学习率调整规则 - Warm up + Cosine Anneal
lambda0 = lambda cur_iter: 0.005 * cur_iter / warm_up_iter if cur_iter < warm_up_iter else \
    (lr_min + 0.5 * (lr_max - lr_min) * (
            1.0 + math.cos((cur_iter - warm_up_iter) / (T_max - warm_up_iter) * math.pi))) / 0.1
optimizer = optim.Adam(cl_model.parameters(), lr=5e-4)
# scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=5e-5)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda0)
datetimestr = time.strftime("%Y%m%d%H%M", time.localtime())
setting_content = "LSTM+DotAttn, adam LambdaLR 0 ~ 5e-4 ~ 5e-5."
run_save_dir = "./ckpt/physionet/" + datetimestr + f'_/'

In [None]:
old = 0
STD_acc = []
STD_loss = []
loss_line = []
lr_list = []
for epoch_id in tqdm(range(iter_max), desc="Train"):
    cl_model.train()
    loss_list = []
    lr_list.append(optimizer.param_groups[0]['lr'])
    for idx, (X_mel, y_mel) in enumerate(train_loader):
        # print(X_mel.shape, y_mel.shape)
        # return
        optimizer.zero_grad()
        X_mel = X_mel.to(device)
        if idx == 0:
            print(X_mel.shape)
        if X_mel.ndim == 3:
            X_mel = X_mel.transpose(1, 2)
            X_mel = X_mel.unsqueeze(1)
        y_mel = y_mel.to(device)
        print(y_mel)
        if idx == 0:
            print(X_mel.shape)
        pred, _ = cl_model(x=X_mel, label=y_mel)
        if idx == 0:
            # torch.Size([32, 1, 87, 128]) torch.Size([32]) torch.Size([32, 5])
            print(X_mel.shape, y_mel.shape, pred.shape)
        loss_v = class_loss(pred, y_mel)
        loss_v.backward()
        loss_list.append(loss_v.item())
        optimizer.step()
    loss_line.append(np.array(loss_list).mean())
    cl_model.eval()
    with torch.no_grad():
        acc_list = []
        loss_list = []
        for idx, (X_mel, y_mel) in enumerate(test_loader):
            X_mel = X_mel.to(device)
            if X_mel.ndim == 3:
                X_mel = X_mel.transpose(1, 2)
                X_mel = X_mel.unsqueeze(1)
            y_mel = y_mel.to(device)
            # print(X_mel.shape)
            pred, _ = cl_model(x=X_mel, label=y_mel)
            loss_eval = class_loss(pred, y_mel)
            # print(y_mel.argmax(-1))
            # print(pred.argmax(-1))
            acc_batch = metrics.accuracy_score(y_mel.data.cpu().numpy(),
                                               pred.argmax(-1).data.cpu().numpy())
            acc_list.append(acc_batch)
            loss_list.append(loss_eval.item())
        acc_per = np.array(acc_list).mean()
        # print("new acc:", acc_per)
        STD_acc.append(acc_per)
        STD_loss.append(np.array(loss_list).mean())
        if acc_per > old:
            old = acc_per
            print("new acc:", acc_per)
            if acc_per > 0.85:
                print(f"Epoch[{epoch_id}]: {acc_per}")
                if not os.path.exists(run_save_dir):
                    os.makedirs(run_save_dir, exist_ok=True)
                    with open(run_save_dir + f"setting.txt", 'w') as fin:
                        # fin.write("MobileNetV2, adam cosine anneal 5e-4 ~ 5e-5, data augmentation, feature map max reduction.")
                        fin.write(setting_content)
                torch.save(cl_model.state_dict(), run_save_dir + f"cls_model_{epoch_id}.pt")
                torch.save(optimizer.state_dict(), run_save_dir + f"optimizer_{epoch_id}.pt")
    scheduler.step()

In [None]:
plt.figure(0)
plt.subplot(1, 2, 1)
plt.plot(range(len(loss_line)), loss_line, c="red", label="train_loss")
plt.plot(range(len(STD_loss)), STD_loss, c="blue", label="valid_loss")
plt.plot(range(len(STD_acc)), STD_acc, c="green", label="valid_accuracy")
plt.xlabel("iteration")
plt.ylabel("metrics")
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(lr_list)
if not os.path.exists(run_save_dir):
    os.makedirs(run_save_dir, exist_ok=True)
    plt.savefig(run_save_dir + "train_result.png", format="png", dpi=300)
plt.show()