In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pickle

In [2]:
class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.conv1d = nn.Conv1d(in_channels=768, out_channels=5, kernel_size=1)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(5, 4) 

    def forward(self, x):
        x = x.permute(0, 2, 1)  
        x = self.conv1d(x) 
        x = x.mean(dim=2) 
        x = self.relu(x)
        x = self.fc(x) 
        return x

In [3]:
from sklearn.metrics import recall_score

# 训练函数
def train(model, train_data, train_labels, criterion, optimizer, num_epochs, batch_size):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        total_correct_per_class = np.zeros(len(np.unique(train_labels)))  
        total_samples_per_class = np.zeros(len(np.unique(train_labels)))  

        indices = np.arange(len(train_data))
        np.random.shuffle(indices)  # 随机打乱索引

        for i in range(0, len(train_data), batch_size):
            batch_indices = indices[i:i+batch_size]
            inputs = torch.tensor(train_data[batch_indices], dtype=torch.float32)
            labels = torch.tensor(train_labels[batch_indices], dtype=torch.long)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)

            # 计算每个类的正确样本数和总样本数
            for label in range(len(np.unique(train_labels))):
                total_correct_per_class[label] += ((predicted == labels) & (labels == label)).sum().item()
                total_samples_per_class[label] += (labels == label).sum().item()

        epoch_loss = running_loss / (len(train_data) / batch_size)

        # 计算每个类的准确率，并取平均得到 UA
        class_accuracies = total_correct_per_class / total_samples_per_class
        epoch_ua = np.mean(class_accuracies)

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, UA: {epoch_ua:.4f}')

        print('########################')

        test(model, wav2vec_last3, label_last3)


# 测试函数
def test(model, test_data, test_labels):
    model.eval()
    with torch.no_grad():
        inputs = torch.tensor(test_data, dtype=torch.float32)
        labels = torch.tensor(test_labels, dtype=torch.long)

        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)

        # 计算每个类的准确率并取平均得到 UA
        total_correct_per_class = np.zeros(len(np.unique(test_labels)))  # 用于记录每个类的正确样本数
        total_samples_per_class = np.zeros(len(np.unique(test_labels)))  # 用于记录每个类的总样本数

        for label in range(len(np.unique(test_labels))):
            total_correct_per_class[label] += ((predicted == labels) & (labels == label)).sum().item()
            total_samples_per_class[label] += (labels == label).sum().item()

        class_accuracies = total_correct_per_class / total_samples_per_class
        ua = np.mean(class_accuracies)

        print(f'Unweighted Accuracy (UA): {ua:.4f}')

In [4]:
import pickle
#读取数据集
with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session3_pt_特征/data_Session1_w2v2.pkl', 'rb') as f:
    wav2vec_last1 = pickle.load(f)
    print('wav2vec_last1',wav2vec_last1.shape)

with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session3_pt_特征/data_Session1_label.pkl', 'rb') as f:
    label_last1 = pickle.load(f)
    print('label_last1',label_last1.shape)

with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session3_pt_特征/data_Session2_w2v2.pkl', 'rb') as f:
    wav2vec_last2 = pickle.load(f)
    print('wav2vec_last2',wav2vec_last2.shape)

with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session3_pt_特征/data_Session2_label.pkl', 'rb') as f:
    label_last2 = pickle.load(f)
    print('label_last2',label_last2.shape)

with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session3_pt_特征/data_Session3_w2v2.pkl', 'rb') as f:
    wav2vec_last3 = pickle.load(f)
    print('wav2vec_last3',wav2vec_last3.shape)

with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session3_pt_特征/data_Session3_label.pkl', 'rb') as f:
    label_last3 = pickle.load(f)
    print('label_last3',label_last3.shape)

with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session3_pt_特征/data_Session4_w2v2.pkl', 'rb') as f:
    wav2vec_last4 = pickle.load(f)
    print('wav2vec_last4',wav2vec_last4.shape)

with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session3_pt_特征/data_Session4_label.pkl', 'rb') as f:
    label_last4 = pickle.load(f)
    print('label_last4',label_last4.shape)

with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session3_pt_特征/data_Session5_w2v2.pkl', 'rb') as f:
    wav2vec_last5 = pickle.load(f)
    print('wav2vec_last5',wav2vec_last5.shape)

with open('/home/ni/step1-提取数据特征/整合-按条提取语音_Session3_pt_特征/data_Session5_label.pkl', 'rb') as f:
    label_last5 = pickle.load(f)
    print('label_last5',label_last5.shape)

wav2vec_last1 (1085, 256, 768)
label_last1 (1085,)
wav2vec_last2 (1023, 256, 768)
label_last2 (1023,)
wav2vec_last3 (1151, 256, 768)
label_last3 (1151,)
wav2vec_last4 (1031, 256, 768)
label_last4 (1031,)
wav2vec_last5 (1241, 256, 768)
label_last5 (1241,)


In [5]:
import numpy as np
wav2vec_last = np.concatenate((wav2vec_last1, wav2vec_last2, wav2vec_last4, wav2vec_last5),axis=0)
label_last = np.concatenate((label_last1,label_last2,label_last4,label_last5))
print(wav2vec_last.shape,label_last.shape)

(4380, 256, 768) (4380,)


In [6]:
num_epochs = 15
batch_size = 256

# 初始化模型、损失函数和优化器
model = CustomModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
train(model, wav2vec_last, label_last, criterion, optimizer, num_epochs, batch_size)

# 测试模型
test(model, wav2vec_last3, label_last3)

Epoch [1/15], Loss: 1.1747, UA: 0.4768
########################
Unweighted Accuracy (UA): 0.4229
Epoch [2/15], Loss: 0.9066, UA: 0.4977
########################
Unweighted Accuracy (UA): 0.4253
Epoch [3/15], Loss: 0.8085, UA: 0.4988
########################
Unweighted Accuracy (UA): 0.4217
Epoch [4/15], Loss: 0.7237, UA: 0.4988
########################
Unweighted Accuracy (UA): 0.4819
Epoch [5/15], Loss: 0.5852, UA: 0.7252
########################
Unweighted Accuracy (UA): 0.6762
Epoch [6/15], Loss: 0.3731, UA: 0.9794
########################
Unweighted Accuracy (UA): 0.7080
Epoch [7/15], Loss: 0.2102, UA: 0.9907
########################
Unweighted Accuracy (UA): 0.7011
Epoch [8/15], Loss: 0.1401, UA: 0.9909
########################
Unweighted Accuracy (UA): 0.6997
Epoch [9/15], Loss: 0.1064, UA: 0.9909
########################
Unweighted Accuracy (UA): 0.7041
Epoch [10/15], Loss: 0.0890, UA: 0.9909
########################
Unweighted Accuracy (UA): 0.7058
Epoch [11/15], Loss: 0.0749, 