In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from tqdm.notebook import tqdm
import os

from model import VQVAE
from dataset import ECGDataset
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from plot import extract_features,plot_embedding

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def train(epoch, loader, model, optimizer, device):
    model.train()
    loader = tqdm(loader)

    criterion = nn.MSELoss()
    latent_loss_weight = 0.25

    mse_sum = 0
    mse_n = 0

    for i, (data, label) in enumerate(loader):
        optimizer.zero_grad()

        data = data.to(device)
        
        # 裁剪输入数据的长度，使其与模型处理后的序列长度一致
        max_len = 134
        data = data[:, :, :max_len]

        out, latent_loss = model(data)
        recon_loss = criterion(out, data)
        latent_loss = latent_loss.mean()
        loss = recon_loss + latent_loss_weight * latent_loss
        loss.backward()

        optimizer.step()

        part_mse_sum = recon_loss.item() * data.shape[0]
        part_mse_n = data.shape[0]
        mse_sum += part_mse_sum
        mse_n += part_mse_n

        lr = optimizer.param_groups[0]["lr"]

        loader.set_description(
            (
                f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; "
                f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; "
                f"lr: {lr:.5f}"
            )
        )

def evaluate(loader, model, device):
    model.eval()
    criterion = nn.MSELoss()
    mse_sum = 0
    mae_sum = 0
    n = 0

    with torch.no_grad():
        for data, label in loader:
            data = data.to(device)
            max_len = 134
            data = data[:, :, :max_len]

            out, _ = model(data)
            mse = criterion(out, data).item()
            mae = torch.mean(torch.abs(out - data)).item()

            mse_sum += mse * data.shape[0]
            mae_sum += mae * data.shape[0]
            n += data.shape[0]

    avg_mse = mse_sum / n
    avg_mae = mae_sum / n
    print(f"Avg MSE: {avg_mse:.5f}, Avg MAE: {avg_mae:.5f}")
    return avg_mse, avg_mae

In [None]:
train_dataset = ECGDataset('./data/ECG5000_TRAIN.txt')
test_dataset = ECGDataset('./data/ECG5000_TEST.txt')

# 划分训练集和验证集
# train_size = int(0.8 * len(dataset))
# val_size = len(dataset) - train_size
# train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=232, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=232, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=232, shuffle=False)

for data, label in train_loader:
    print(data.shape, label.shape)
    break  # 打印第一批次的数据形状调试


In [None]:
model = VQVAE(in_channel=1, channel=64, n_res_block=1, n_res_channel=16, embed_dim=32, n_embed=256).to(device)

optimizer = optim.Adam(model.parameters(), lr=3e-5)


In [None]:
num_epochs = 5000
os.makedirs("checkpoint", exist_ok=True)
os.makedirs("pca_sample", exist_ok=True)
os.makedirs("tsne_sample", exist_ok=True)

for epoch in range(num_epochs):
    train(epoch, train_loader, model, optimizer, device)
    if (epoch + 1) % 100 == 0:  # 每100个epoch保存一次模型
        torch.save(model.state_dict(), f"checkpoint/vqvae_{str(epoch + 1).zfill(3)}.pt")
        
        # 在测试集上评估
        print(f"Evaluating on validation data at epoch {epoch + 1}")
        avg_mse_val, avg_mae_val = evaluate(test_loader, model, device)
        
        # 提取测试集特征
        test_features, test_labels = extract_features(test_loader, model, device)

        # 使用PCA进行降维
        pca = PCA(n_components=2)
        test_features_pca = pca.fit_transform(test_features.reshape(test_features.shape[0], -1))

        # 绘制PCA结果并保存
        filename = f"pca_sample/pca_epoch_{epoch + 1}.png"
        plot_embedding(test_features_pca, test_labels, f'PCA of ECG5000 Embeddings at Epoch {epoch + 1}', filename)

        # 使用TSNE进行降维
        tsne = TSNE(n_components=2, perplexity=30, n_iter=300)
        test_features_tsne = tsne.fit_transform(test_features.reshape(test_features.shape[0], -1))

        # 绘制TSNE结果并保存
        filename = f"tsne_sample/tsne_epoch_{epoch + 1}.png"
        plot_embedding(test_features_tsne, test_labels, f'TSNE of ECG5000 Embeddings at Epoch {epoch + 1}', filename)

In [None]:
import imageio
import os
import re

def create_gif_from_folder(folder_path, output_path, duration=1.0):
    # 获取文件夹中的所有图像文件
    image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]
    
    # 提取文件名中的数字，并按照数字排序
    image_files.sort(key=lambda x: int(re.search(r'_(\d+)', x).group(1)))
    
    # 读取图像文件并创建GIF
    images = [imageio.imread(os.path.join(folder_path, file)) for file in image_files]
    imageio.mimsave(output_path, images, duration=duration)



In [None]:
create_gif_from_folder('pca_sample', './pca_sample/pca_progress.gif', duration=3.0)

In [None]:
create_gif_from_folder('tsne_sample', './tsne_sample/tsne_progress.gif', duration=3.0)