In [1]:
import torch.nn as nn
import torch.optim as optim
import torch_geometric.utils
import torch, os, torch_geometric
import matplotlib.pyplot as plt
import datasets
from torch.utils.data import random_split, DataLoader
import utils, math

N, K, query_size = 8, 3, 1
dataset_transform = 'spectrogram'
train_test_ratio = 0.1

args = utils.Args(seed=42, N=N, K=K, query_size=query_size, transform=dataset_transform, dis_calcu='euci')
if args.device != "cpu":
    num_workers = 1
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False
utils.global_manual_seed(seed=args.seed, device=args.device)
split_generator = torch.Generator().manual_seed(args.seed)
train_generator = torch.Generator().manual_seed(args.seed)
test_generator = torch.Generator().manual_seed(args.seed)

# 加载数据集
te_trainset = datasets.TennesseeEastman(root_dir='./data/te', N=N, K=K, query_size=query_size, is_training=True,
                                        selected_faults=[1, 4, 5, 6, 7, 8, 9, 10], seed=args.seed)
print(f"数据集大小: {len(te_trainset)}")
train_size = math.ceil(train_test_ratio * len(te_trainset))
test_size = len(te_trainset) - train_size
train_set, test_set = random_split(te_trainset, [train_size, test_size], generator=split_generator)
train_loader = DataLoader(train_set, batch_size=1, num_workers=num_workers, pin_memory=pin_memory, generator=train_generator)
test_loader = DataLoader(test_set, batch_size=1, num_workers=num_workers, pin_memory=pin_memory, generator=test_generator)
print(f"Train Size:{train_size}  |  Test Size:{test_size}")

  from .autonotebook import tqdm as notebook_tqdm


数据集大小: 125
Train Size:13  |  Test Size:112


In [4]:
import nets
import torch
from torch import nn as nn
from torch.nn import functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.utils import dense_to_sparse

class EnsembleGCN(nn.Module):
    """
    GCN架构:
    将time_features送入一个time_GCN, 将freq_features送入一个frea_GCN,
    得到图卷积后的两种特征形状为 [n_samples, time_gcn_embed_sz] 和 [n_samples, freq_gcn_embed_sz]
    连接两种特征 -> features.shape = [n_samples ,time_gcn_embed_sz+freq_gcn_embed_sz]
    连接后的特征送入一个GCN, 再次卷积 -> [n_samples, gcn_embed_sz] -> [Linear] -> n_classes
    """
    def __init__(self, num_time_features_inputs, num_freq_features_inputs, 
                 time_embed_size, freq_embed_size, embed_size, num_classes_output, dis_calcu):
        super().__init__()
        self.time_gconv = GCNConv(num_time_features_inputs, time_embed_size)
        self.freq_gconv = GCNConv(num_freq_features_inputs, freq_embed_size)
        self.concat_gconv = GCNConv(time_embed_size+freq_embed_size+num_classes_output, embed_size)
        self.out = nn.Linear(embed_size, num_classes_output)
        self.dis_calcu = dis_calcu

    def to_one_hot(self, labels, num_classes, query_size):
        num_samples = labels.shape[0]
        one_hot = torch.zeros((num_samples, num_classes), device=labels.device)
        one_hot[:num_samples-query_size, labels] = 1

        return one_hot

    def graph_construction(self, feature_vecs, num_classes=args.N, dis_calcu='euci'):
        """内嵌于EnsembleGCN的更新连接权重的函数"""
        num_nodes, vec_len = feature_vecs.shape[0], feature_vecs.shape[1]
        adj_matrix = torch.zeros((num_nodes, num_nodes))
        distance_functions = {
                            'euci': lambda x, y: 1 / (utils.euclidean_distance(x, y, squared=False) + torch.tensor(1e-4)),
                            'mah': lambda x, y: 1 / (utils.mahalanobis_distance(x, y, feature_vecs[:, :vec_len-num_classes]) + torch.tensor(1e-4)),
                            'cos': lambda x, y: 1 / (utils.cosine_distance(x, y) + torch.tensor(1e-4)),
                             }
        distance_function = distance_functions[dis_calcu]
        for i in range(num_nodes):
            for j in range(num_nodes):
                if i == j:
                    continue
                adj_matrix[i][j] = distance_function(feature_vecs[i, :vec_len-num_classes], feature_vecs[j, :vec_len-num_classes])
        _, edge_weight = dense_to_sparse(adj_matrix)
        return edge_weight

    def forward(self, time_features, edge_index, time_edge_weight,
                freq_features, freq_edge_weight, labels, num_classes, query_size):
        time_embedding = F.leaky_relu(self.time_gconv(time_features, edge_index, edge_weight=time_edge_weight))
        freq_embedding = F.leaky_relu(self.freq_gconv(freq_features, edge_index, edge_weight=freq_edge_weight))
        concated_features = torch.concat([time_embedding, freq_embedding, self.to_one_hot(labels, num_classes, query_size)], dim=-1)
        with torch.no_grad():
            concated_edge_weight = self.graph_construction(concated_features,  num_classes, dis_calcu=self.dis_calcu).to(time_embedding.device)
        embedding = F.leaky_relu(self.concat_gconv(concated_features, edge_index, edge_weight=concated_edge_weight))
        out = self.out(embedding)
        return out
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
class EnsembleNet(nn.Module):
    def __init__(self, time_feature_extractor, freq_feature_extractor, gcn, dis_calcu):
        super().__init__()
        self.time_feature_extractor = time_feature_extractor
        self.freq_feature_extractor = freq_feature_extractor
        self.gcn = gcn
        self.dis_calcu = dis_calcu

    def graph_construction(self, feature_vecs, dis_calcu='euci'):
        """
        返回edge_index和edge_weight
        dis_calcu参数为向量距离的计算方式, 可选['euci', 'mah', 'cos'] 分别为欧氏, 马氏, 余弦距离
        """
        num_nodes = feature_vecs.shape[0]
        adj_matrix = torch.zeros((num_nodes, num_nodes))
        distance_functions = {
                            'euci': lambda x, y: 1 / (utils.euclidean_distance(x, y, squared=False) + torch.tensor(1e-4)),
                            'mah': lambda x, y: 1 / (utils.mahalanobis_distance(x, y, feature_vecs) + torch.tensor(1e-4)),
                            'cos': lambda x, y: 1 / (utils.cosine_distance(x, y) + torch.tensor(1e-4)),
                             }
        distance_function = distance_functions[dis_calcu]
        for i in range(num_nodes):
            for j in range(num_nodes):
                if i == j:
                    continue
                adj_matrix[i][j] = distance_function(feature_vecs[i], feature_vecs[j])
                    
        edge_index, edge_weight = dense_to_sparse(adj_matrix)
        edge_index = edge_index.to(feature_vecs.device)
        edge_weight = edge_weight.to(feature_vecs.device)
        return edge_index, edge_weight

    def forward(self, waveforms, spectrograms, labels, num_classes, query_size):
        time_features = self.time_feature_extractor(waveforms)
        freq_features = self.freq_feature_extractor(spectrograms)
        edge_index, time_edge_weight = self.graph_construction(time_features, dis_calcu=self.dis_calcu)
        _, freq_edge_weight = self.graph_construction(freq_features, dis_calcu=self.dis_calcu)
        out = self.gcn(time_features, edge_index, time_edge_weight,
                       freq_features, freq_edge_weight, labels, num_classes, query_size)
        return out

dis_calcu = 'euci'
time_cnn_num_hidden_channels = 64
freq_cnn_num_hidden_channels = 32

time_cnn = nets.AudioCNN(num_channels_input=52, 
                        num_channels_hidden=time_cnn_num_hidden_channels,
                        first_kernel_size=8, other_kernel_size=2, stride=4, max_pool_size=2,
                        num_classes_output=args.N)
freq_cnn = nets.AudioCNN2D(num_channels_input=52,
                        num_channels_hidden=freq_cnn_num_hidden_channels,
                        first_kernel_size=(3, 8), other_kernel_size=2, stride=1, avg_pool_size=1,
                        num_classes_output=args.N)

time_embed_size, freq_embed_size = 256, 192
gcn_embed_size = 128
ensemble_gcn = EnsembleGCN(num_time_features_inputs=time_cnn_num_hidden_channels*2,
                    num_freq_features_inputs=freq_cnn_num_hidden_channels*2,
                    time_embed_size=time_embed_size, freq_embed_size=freq_embed_size,
                    embed_size=gcn_embed_size, num_classes_output=args.N,
                    dis_calcu=dis_calcu)

ensemble_net = EnsembleNet(time_feature_extractor=time_cnn,
                          freq_feature_extractor=freq_cnn,
                          gcn=ensemble_gcn,
                          dis_calcu=dis_calcu)

In [5]:
def train(ensemble_net, train_loader, num_epochs, lr, args):

    def init_weights(module): 
        if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Linear)):
            nn.init.xavier_uniform_(module.weight)
        elif isinstance(module, torch_geometric.nn.GCNConv):
            nn.init.xavier_uniform_(module.lin.weight)

    def get_query_acc(y_hat, y, query_size):
        return torch.sum(y_hat[-query_size*args.N:].argmax(dim=1) == y[-query_size*args.N:]) / (query_size*args.N)

    def get_acc(y_hat, y):
        return torch.sum(y_hat.argmax(dim=1) == y) / len(y)
    
    ensemble_net.to(args.device)
    ensemble_net.apply(init_weights)
    ensemble_net.train()
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(ensemble_net.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.95)

    for epoch in range(1, num_epochs+1):
        metric = utils.Accumulator(4) # 一个epoch单样本损失平均值 一个epoch的query正确率 一个epoch的总正确率 一个epoch的批量数
        for waveforms, spectrograms, labels in train_loader:
            labels = labels.squeeze(0).to(args.device)
            waveforms = waveforms.squeeze(0).to(args.device)
            spectrograms = spectrograms.squeeze(0).to(args.device)
            
            optimizer.zero_grad()
            y_hat = ensemble_net(waveforms, spectrograms, labels, args.N, args.query_size)
            loss = loss_function(y_hat, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            metric.add(loss.item()/len(labels), 
                       get_query_acc(y_hat, labels, args.query_size), get_acc(y_hat, labels), 1)
            del spectrograms, waveforms, labels
            torch.cuda.empty_cache()
        print(f"Epoch:{epoch}  \t  Loss:{metric[0]:.7f}  \t  Query正确率:{metric[1]/metric[3]*100:.2f}%  \t  整体正确率:{metric[2]/metric[3]*100:.2f}%")

train(ensemble_net, train_loader, num_epochs=100, lr=0.001, args=args)

Epoch:1  	  Loss:0.5376588  	  Query正确率:38.46%  	  整体正确率:47.84%
Epoch:2  	  Loss:0.2319626  	  Query正确率:54.81%  	  整体正确率:78.85%
Epoch:3  	  Loss:0.1140551  	  Query正确率:80.77%  	  整体正确率:86.30%
Epoch:4  	  Loss:0.0783247  	  Query正确率:80.77%  	  整体正确率:87.74%
Epoch:5  	  Loss:0.0566993  	  Query正确率:97.12%  	  整体正确率:97.84%
Epoch:6  	  Loss:0.0654364  	  Query正确率:89.42%  	  整体正确率:92.55%
Epoch:7  	  Loss:0.0618531  	  Query正确率:80.77%  	  整体正确率:90.87%
Epoch:8  	  Loss:0.0267133  	  Query正确率:100.00%  	  整体正确率:99.04%
Epoch:9  	  Loss:0.0139702  	  Query正确率:99.04%  	  整体正确率:99.76%
Epoch:10  	  Loss:0.0036176  	  Query正确率:100.00%  	  整体正确率:100.00%
Epoch:11  	  Loss:0.0014318  	  Query正确率:100.00%  	  整体正确率:100.00%
Epoch:12  	  Loss:0.0009253  	  Query正确率:100.00%  	  整体正确率:100.00%
Epoch:13  	  Loss:0.0006832  	  Query正确率:100.00%  	  整体正确率:100.00%
Epoch:14  	  Loss:0.0005582  	  Query正确率:100.00%  	  整体正确率:100.00%
Epoch:15  	  Loss:0.0004947  	  Query正确率:100.00%  	  整体正确率:100.00%
Epoch:16  	  Loss:0.0

In [6]:
torch.cuda.empty_cache()
def test(ensemble_net, test_loader, args:utils.Args):
    def get_query_acc(y_hat, y, query_size):
            return torch.sum(y_hat[-query_size*args.N:].argmax(dim=1) == y[-query_size*args.N:]) / (query_size*args.N)
    def get_acc(y_hat, y):
        return torch.sum(y_hat.argmax(dim=1) == y) / len(y)
    
    ensemble_net.to(args.device)
    ensemble_net.eval()
    accs = []
    query_accs = []

    with torch.no_grad():
        for test_idx, (waveforms, spectrograms, labels) in enumerate(test_loader):
            labels = labels.squeeze(0).to(args.device)
            waveforms = waveforms.squeeze(0).to(args.device)
            spectrograms = spectrograms.squeeze(0).to(args.device)
            y_hat = ensemble_net(waveforms, spectrograms, labels, args.N, args.query_size)
            accs.append(get_acc(y_hat, labels))
            query_accs.append(get_query_acc(y_hat, labels, args.query_size))
            del spectrograms, waveforms, labels
            torch.cuda.empty_cache()
            print(f"第{test_idx}批测试  |  Query正确率:{query_accs[-1]*100:.2f}%  |  整体正确率:{accs[-1]*100:.2f}%")
    print(f"平均Query正确率:{sum(query_accs)/len(query_accs)*100:.2f}%  |  平均整体正确率:{sum(accs)/len(accs)*100:.2f}%")
    
test(ensemble_net, test_loader, args)

第0批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第1批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第2批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第3批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第4批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第5批测试  |  Query正确率:87.50%  |  整体正确率:96.88%
第6批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第7批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第8批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第9批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第10批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第11批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第12批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第13批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第14批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第15批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第16批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第17批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第18批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第19批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第20批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
第21批测试  |  Query正确率:100.00%  |  整体正确率:100.00%
