In [1]:
import torch.nn as nn
import torch.optim as optim
import torch_geometric.utils
import torch, os, torch_geometric
import matplotlib.pyplot as plt
import datasets
from torch.utils.data import random_split, DataLoader
import utils, math

mimii_dataset_SNR = 0
N, K, query_size = 16, 3, 1
dataset_transform = 'spectrogram'
train_test_ratio = 0.1

args = utils.Args(seed=42, N=N, K=K, query_size=query_size, SNR=mimii_dataset_SNR, transform=dataset_transform)
if args.device != "cpu":
    num_workers = 1
    if N*K <= 32:
        pin_memory = True
    else:
        pin_memory = False
else:
    num_workers = 0
    pin_memory = False
utils.global_manual_seed(seed=args.seed, device=args.device)
split_generator = torch.Generator().manual_seed(args.seed)
train_generator = torch.Generator().manual_seed(args.seed)
test_generator = torch.Generator().manual_seed(args.seed)

# 加载数据集
mimii_dataset = datasets.MIMII(root_dir=f'./data/mimii/{mimii_dataset_SNR}'+'dB_SNR',
                                machine_classes=['fan', 'pump', 'valve', 'slider'],
                                model_ids=['id_00', 'id_02'], categories=['normal', 'abnormal'],
                                N=N, K=K, query_size=query_size, transform='spectrogram', seed=args.seed)
print(f"数据集大小: {len(mimii_dataset)}")
train_size = math.ceil(train_test_ratio * len(mimii_dataset))
test_size = len(mimii_dataset) - train_size
train_set, test_set = random_split(mimii_dataset, [train_size, test_size], generator=split_generator)
train_loader = DataLoader(train_set, batch_size=1, num_workers=num_workers, pin_memory=pin_memory, generator=train_generator)
test_loader = DataLoader(test_set, batch_size=1, num_workers=num_workers, pin_memory=pin_memory, generator=test_generator)
print(f"Train Size:{train_size}  |  Test Size:{test_size}")
print(mimii_dataset.N_classes)

  from .autonotebook import tqdm as notebook_tqdm


数据集大小: 27
Train Size:3  |  Test Size:24
['fan_id_02_abnormal', 'fan_id_00_normal', 'valve_id_02_abnormal', 'pump_id_00_normal', 'slider_id_02_abnormal', 'slider_id_00_abnormal', 'fan_id_02_normal', 'fan_id_00_abnormal', 'valve_id_00_normal', 'slider_id_00_normal', 'valve_id_02_normal', 'slider_id_02_normal', 'pump_id_02_normal', 'pump_id_00_abnormal', 'valve_id_00_abnormal', 'pump_id_02_abnormal']


In [2]:
import nets
import torch
from torch import nn as nn
from torch.nn import functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.utils import dense_to_sparse

class EnsembleGCN(nn.Module):
    """
    GCN架构:
    将time_features送入一个time_GCN, 将freq_features送入一个frea_GCN,
    得到图卷积后的两种特征形状为 [n_samples, time_gcn_embed_sz] 和 [n_samples, freq_gcn_embed_sz]
    连接两种特征 -> features.shape = [n_samples ,time_gcn_embed_sz+freq_gcn_embed_sz]
    连接后的特征送入一个GCN, 再次卷积 -> [n_samples, gcn_embed_sz] -> [Linear] -> n_classes
    """
    def __init__(self, num_time_features_inputs, num_freq_features_inputs, 
                 time_embed_size, freq_embed_size, embed_size, num_classes_output):
        super().__init__()
        self.time_gconv = GCNConv(num_time_features_inputs, time_embed_size)
        self.freq_gconv = GCNConv(num_freq_features_inputs, freq_embed_size)
        # +num_classes_output
        self.concat_gconv = GCNConv(time_embed_size+freq_embed_size, embed_size)
        self.out = nn.Linear(embed_size, num_classes_output)

    def to_one_hot(self, labels, num_classes, query_size):
        num_samples = labels.shape[0]
        one_hot = torch.zeros((num_samples, num_classes), device=labels.device)
        one_hot[:num_samples-query_size, labels] = 1

        return one_hot

    def graph_construction(self, feature_vecs, num_classes=0):
        """内嵌于EnsembleGCN的更新连接权重的函数"""
        num_nodes, vec_len = feature_vecs.shape[0], feature_vecs.shape[1]
        adj_matrix = torch.zeros((num_nodes, num_nodes))
        for i in range(num_nodes):
            for j in range(num_nodes):
                if i == j:
                    continue
                adj_matrix[i][j] = 1 / (torch.sum(torch.pow(feature_vecs[i, :vec_len-num_classes] - feature_vecs[j, :vec_len-num_classes], 2), 0) + torch.tensor(1e-4))
        _, edge_weight = dense_to_sparse(adj_matrix)
        return edge_weight

    def forward(self, time_features, edge_index, time_edge_weight,
                freq_features, freq_edge_weight):
        # , labels, num_classes, query_size
        time_embedding = F.leaky_relu(self.time_gconv(time_features, edge_index, edge_weight=time_edge_weight))
        freq_embedding = F.leaky_relu(self.freq_gconv(freq_features, edge_index, edge_weight=freq_edge_weight))
        # , self.to_one_hot(labels, num_classes, query_size)
        concated_features = torch.concat([time_embedding, freq_embedding], dim=-1)
        with torch.no_grad():
            # , num_classes
            concated_edge_weight = self.graph_construction(concated_features).to(time_embedding.device)
        embedding = F.leaky_relu(self.concat_gconv(concated_features, edge_index, edge_weight=concated_edge_weight))
        out = self.out(embedding)
        return out
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
class EnsembleNet(nn.Module):
    def __init__(self, time_feature_extractor, freq_feature_extractor, gcn):
        super().__init__()
        self.time_feature_extractor = time_feature_extractor
        self.freq_feature_extractor = freq_feature_extractor
        self.gcn = gcn

    def graph_construction(self, feature_vecs):
        """返回edge_index和edge_weight"""
        num_nodes = feature_vecs.shape[0]
        adj_matrix = torch.zeros((num_nodes, num_nodes))
        for i in range(num_nodes):
            for j in range(num_nodes):
                if i == j:
                    continue
                adj_matrix[i][j] = 1 / (torch.sum(torch.pow(feature_vecs[i] - feature_vecs[j], 2), 0) + torch.tensor(1e-4))
        edge_index, edge_weight = dense_to_sparse(adj_matrix)
        edge_index = edge_index.to(feature_vecs.device)
        edge_weight = edge_weight.to(feature_vecs.device)
        return edge_index, edge_weight

    def forward(self, waveforms, spectrograms):
        time_features = self.time_feature_extractor(waveforms)
        freq_features = self.freq_feature_extractor(spectrograms)
        edge_index, time_edge_weight = self.graph_construction(time_features)
        _, freq_edge_weight = self.graph_construction(freq_features)
        out = self.gcn(time_features, edge_index, time_edge_weight,
                       freq_features, freq_edge_weight)
        return out


time_cnn_num_hidden_channels = 64
freq_cnn_num_hidden_channels = 32

time_cnn = nets.AudioCNN(num_channels_input=8, 
                        num_channels_hidden=time_cnn_num_hidden_channels, 
                        num_classes_output=args.N)
freq_cnn = nets.AudioCNN2D(num_channels_input=8, 
                        num_channels_hidden=freq_cnn_num_hidden_channels, 
                        num_classes_output=args.N)

time_embed_size, freq_embed_size = 384, 256
gcn_embed_size = 128
ensemble_gcn = EnsembleGCN(num_time_features_inputs=time_cnn_num_hidden_channels*2,
                    num_freq_features_inputs=freq_cnn_num_hidden_channels*2,
                    time_embed_size=time_embed_size, freq_embed_size=freq_embed_size,
                    embed_size=gcn_embed_size, num_classes_output=args.N)

ensemble_net = EnsembleNet(time_feature_extractor=time_cnn,
                                freq_feature_extractor=freq_cnn,
                                gcn=ensemble_gcn)

In [3]:
def train(ensemble_net, train_loader, num_epochs, lr, args):

    def init_weights(module): 
        if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Linear)):
            nn.init.xavier_uniform_(module.weight)
        elif isinstance(module, torch_geometric.nn.GCNConv):
            nn.init.xavier_uniform_(module.lin.weight)

    def get_query_acc(y_hat, y, query_size):
        return torch.sum(y_hat[-query_size*args.N:].argmax(dim=1) == y[-query_size*args.N:]) / (query_size*args.N)

    def get_acc(y_hat, y):
        return torch.sum(y_hat.argmax(dim=1) == y) / len(y)
    
    ensemble_net.to(args.device)
    ensemble_net.apply(init_weights)
    ensemble_net.train()
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(ensemble_net.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.95)

    for epoch in range(1, num_epochs+1):
        metric = utils.Accumulator(4) # 一个epoch单样本损失平均值 一个epoch的query正确率 一个epoch的总正确率 一个epoch的批量数
        for support_set, query_set in train_loader:
            waveforms, spectrograms, labels = [], [], []
            for waveform, spectrogram, label in support_set:
                waveforms.append(waveform.squeeze(0))
                spectrograms.append(spectrogram.squeeze(0))
                labels.append(label)
            for waveform, spectrogram, label in query_set:
                waveforms.append(waveform.squeeze(0))
                spectrograms.append(spectrogram.squeeze(0))
                labels.append(label)
            labels = torch.tensor(labels, device=args.device)
            waveforms = torch.stack(waveforms).to(args.device)
            spectrograms = torch.stack(spectrograms).to(args.device)

            optimizer.zero_grad()
            y_hat = ensemble_net(waveforms, spectrograms)
            loss = loss_function(y_hat, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            metric.add(loss.item()/len(labels), 
                       get_query_acc(y_hat, labels, args.query_size), get_acc(y_hat, labels), 1)
            del spectrograms, waveforms, labels
            torch.cuda.empty_cache()
        print(f"Epoch:{epoch}  |  Loss:{metric[0]:.7f}  |  Query正确率:{metric[1]/metric[3]*100:.2f}%  |  整体正确率:{metric[2]/metric[3]*100:.2f}%")

train(ensemble_net, train_loader, num_epochs=100, lr=0.0025, args=args)

Epoch:1  |  Loss:0.1257870  |  Query正确率:14.58%  |  整体正确率:15.10%
Epoch:2  |  Loss:0.1015206  |  Query正确率:35.42%  |  整体正确率:30.73%
Epoch:3  |  Loss:0.0808207  |  Query正确率:50.00%  |  整体正确率:46.35%
Epoch:4  |  Loss:0.0632541  |  Query正确率:54.17%  |  整体正确率:53.12%
Epoch:5  |  Loss:0.0514279  |  Query正确率:68.75%  |  整体正确率:64.58%
Epoch:6  |  Loss:0.0417319  |  Query正确率:79.17%  |  整体正确率:68.75%
Epoch:7  |  Loss:0.0352030  |  Query正确率:72.92%  |  整体正确率:70.31%
Epoch:8  |  Loss:0.0281385  |  Query正确率:85.42%  |  整体正确率:80.21%
Epoch:9  |  Loss:0.0225483  |  Query正确率:87.50%  |  整体正确率:84.90%
Epoch:10  |  Loss:0.0188425  |  Query正确率:85.42%  |  整体正确率:86.46%
Epoch:11  |  Loss:0.0154616  |  Query正确率:87.50%  |  整体正确率:87.50%
Epoch:12  |  Loss:0.0130146  |  Query正确率:95.83%  |  整体正确率:92.71%
Epoch:13  |  Loss:0.0114985  |  Query正确率:91.67%  |  整体正确率:90.10%
Epoch:14  |  Loss:0.0093058  |  Query正确率:91.67%  |  整体正确率:90.62%
Epoch:15  |  Loss:0.0071823  |  Query正确率:97.92%  |  整体正确率:96.35%
Epoch:16  |  Loss:0.0052435  |  Qu

In [4]:
torch.cuda.empty_cache()
def test(ensemble_net, test_loader, args:utils.Args):
    def get_query_acc(y_hat, y, query_size):
            return torch.sum(y_hat[-query_size*args.N:].argmax(dim=1) == y[-query_size*args.N:]) / (query_size*args.N)
    def get_acc(y_hat, y):
        return torch.sum(y_hat.argmax(dim=1) == y) / len(y)
    
    ensemble_net.to(args.device)
    ensemble_net.eval()
    accs = []
    query_accs = []

    with torch.no_grad():
        for test_idx, (support_set, query_set) in enumerate(test_loader):
            waveforms, spectrograms, labels = [], [], []
            for waveform, spectrogram, label in support_set:
                waveforms.append(waveform.squeeze(0))
                spectrograms.append(spectrogram.squeeze(0))
                labels.append(label)
            for waveform, spectrogram, label in query_set:
                waveforms.append(waveform.squeeze(0))
                spectrograms.append(spectrogram.squeeze(0))
                labels.append(label)
            labels = torch.tensor(labels, device=args.device)
            waveforms = torch.stack(waveforms).to(args.device)
            spectrograms = torch.stack(spectrograms).to(args.device)

            y_hat = ensemble_net(waveforms, spectrograms)
            accs.append(get_acc(y_hat, labels))
            query_accs.append(get_query_acc(y_hat, labels, args.query_size))
            del spectrograms, waveforms, labels
            torch.cuda.empty_cache()
            print(f"第{test_idx}批测试  |  Query正确率:{query_accs[-1]*100:.2f}%  |  整体正确率:{accs[-1]*100:.2f}%")
    print(f"平均Query正确率:{sum(query_accs)/len(query_accs)*100:.2f}%  |  平均整体正确率:{sum(accs)/len(accs)*100:.2f}%")
test(ensemble_net, test_loader, args)

第0批测试  |  Query正确率:81.25%  |  整体正确率:84.38%
第1批测试  |  Query正确率:87.50%  |  整体正确率:82.81%
第2批测试  |  Query正确率:87.50%  |  整体正确率:87.50%
第3批测试  |  Query正确率:87.50%  |  整体正确率:84.38%
第4批测试  |  Query正确率:87.50%  |  整体正确率:85.94%
第5批测试  |  Query正确率:81.25%  |  整体正确率:79.69%
第6批测试  |  Query正确率:87.50%  |  整体正确率:84.38%
第7批测试  |  Query正确率:81.25%  |  整体正确率:87.50%
第8批测试  |  Query正确率:93.75%  |  整体正确率:93.75%
第9批测试  |  Query正确率:93.75%  |  整体正确率:93.75%
第10批测试  |  Query正确率:75.00%  |  整体正确率:79.69%
第11批测试  |  Query正确率:93.75%  |  整体正确率:82.81%
第12批测试  |  Query正确率:81.25%  |  整体正确率:78.12%
第13批测试  |  Query正确率:93.75%  |  整体正确率:90.62%
第14批测试  |  Query正确率:87.50%  |  整体正确率:79.69%
第15批测试  |  Query正确率:87.50%  |  整体正确率:78.12%
第16批测试  |  Query正确率:75.00%  |  整体正确率:84.38%
第17批测试  |  Query正确率:87.50%  |  整体正确率:84.38%
第18批测试  |  Query正确率:81.25%  |  整体正确率:87.50%
第19批测试  |  Query正确率:81.25%  |  整体正确率:87.50%
第20批测试  |  Query正确率:87.50%  |  整体正确率:85.94%
第21批测试  |  Query正确率:68.75%  |  整体正确率:79.69%
第22批测试  |  Query正确率:93.75%  |  整体正确率:82.81