导入库

In [22]:
import os
import imp
import numpy as np
import json

from dataset import EEGAudioDataset
from model.CLUB import CLUBSample_group
from model.CPC import Cross_CPC
from model.VQVAE import VQVAEEncoder,VQVAEDecoder,SemanticDecoder
from dataset import EEGAudioDataset
from torch_dct import dct

公共变量

In [23]:
words_path = r'./feat/words'
pt = 'sub-06'
test_word = 95
config_path = r'./config'
model_name = 'cmg_noclip'

读取数据

In [24]:
folder_path = os.path.join(words_path,f'{pt}')
filename = os.listdir(folder_path)[test_word]
word_info = np.load(os.path.join(folder_path,filename),allow_pickle=True)
word=word_info.item()['label']
eeg=word_info.item()['eeg']
audio=word_info.item()['audio']

print(word,eeg.shape,audio.shape)

helft (829, 127) (12960,)


数据预处理

In [25]:
with open(os.path.join(config_path,f'{model_name}.json'),'r') as f:
    cfg = json.load(f)
    model_cfg = cfg['model_config']
    data_cfg = cfg['data_config']

# load config 
seg_size = model_cfg['seg_size']
pred_size = model_cfg['pred_size']
# batch_size = model_cfg['batch_size']
# end_epoch = model_cfg['epochs'] if argu.epoch is None else argu.epoch
lr = model_cfg['lr']
b1 = model_cfg['b1']
b2 = model_cfg['b2']
clip_grad = model_cfg['clip_grad']
hidden_dim = model_cfg['hidden_dim']
d_model = model_cfg['d_model']
nhead = model_cfg['nhead']
n_layer = model_cfg['n_layer']
n_embedding = model_cfg['n_embedding']
mi_iter = model_cfg['mi_iter']

data_path = data_cfg['data_path']
win_len = data_cfg['win_len']
frame_shift = data_cfg['frame_shift']
eeg_sr = data_cfg['eeg_sr']
audio_sr = data_cfg['audio_sr']
pad_mode = data_cfg['pad_mode']

提取高频eeg信号和音频信号的梅尔频谱

In [26]:
eeg = EEGAudioDataset.extractHG(eeg,eeg_sr,windowLength=win_len,frameshift=frame_shift)
melspec = EEGAudioDataset.extractMelSpecs(audio,audio_sr,windowLength=win_len,frameshift=frame_shift)
# print(eeg.shape,melspec.shape)
if melspec.shape[0]!=eeg.shape[0]:
    minlen = min(melspec.shape[0],eeg.shape[0])
    melspec = melspec[:minlen,:]
    eeg = eeg[:minlen,:]
print(eeg.shape,melspec.shape)

(162, 127) (162, 40)


预处理

In [27]:
eeg_mean = np.mean(eeg)
eeg_std = np.std(eeg)
eeg = (eeg-eeg_mean)/eeg_std
eeg_list = []
mel_list = []
hop_size = 1
pad_width = ((int(np.floor((seg_size-hop_size)/2.0)),int(np.ceil((seg_size-hop_size)/2.0))),(0,0))
pad_eeg = np.pad(eeg,pad_width,mode=pad_mode)
pad_mel = np.pad(melspec,pad_width,mode=pad_mode)
num_win = int(len(eeg)/float(hop_size))
for i in range(num_win):
    start = i*hop_size
    end = start + seg_size
    eeg_list.append(pad_eeg[start:end,:])
    mel_list.append(pad_mel[start:end,:])

eeg_data = np.stack(eeg_list,axis=0)
mel_data = np.stack(mel_list,axis=0)
print(eeg_data.shape,mel_data.shape)

(162, 16, 127) (162, 16, 40)


加载模型

In [160]:
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from itertools import chain

In [161]:
input_dim = eeg_data.shape[-1]
output_dim = mel_data.shape[-1]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vqvae_encoder = VQVAEEncoder(mel_dim=output_dim,eeg_dim=input_dim,mel_output_dim=d_model,eeg_output_dim=d_model,n_embedding=n_embedding,embedding_dim=d_model).to(device)
cpc = Cross_CPC(embedding_dim=d_model,hidden_dim=hidden_dim,context_dim=hidden_dim,num_layers=n_layer,predict_step=pred_size).to(device)
mel_mi_net = CLUBSample_group(x_dim=d_model,y_dim=d_model,hidden_size=hidden_dim).to(device)
eeg_mi_net = CLUBSample_group(x_dim=d_model,y_dim=d_model,hidden_size=hidden_dim).to(device)
vqvae_decoder = VQVAEDecoder(mel_dim=output_dim,eeg_dim=input_dim,mel_output_dim=d_model,eeg_output_dim=d_model,embedding_dim=d_model).to(device)
mel_vq_decoder = SemanticDecoder(input_dim=d_model,output_dim=output_dim).to(device)

main_optimizer = torch.optim.Adam(chain(vqvae_encoder.parameters(),cpc.parameters(),vqvae_decoder.parameters()),lr=lr,betas=(b1,b2))
mel_mi_net_optimizer = torch.optim.Adam(mel_mi_net.parameters(),lr=lr,betas=(b1,b2))
eeg_mi_net_optimizer = torch.optim.Adam(eeg_mi_net.parameters(),lr=lr,betas=(b1,b2))
mel_vq_decoder_optimizer = torch.optim.Adam(mel_vq_decoder.parameters(),lr=lr,betas=(b1,b2))
# scheduler = MultiStepLR(main_optimizer,milestones=[10,20,30],gamma=0.5)

criterion = nn.MSELoss().to(device)
loss_fn = lambda x,y:(criterion(x, y)+criterion(torch.exp(x),torch.exp(y))+criterion(dct(x,norm='ortho'),dct(y,norm='ortho')))


122 40


输入预处理

In [162]:
data_padding = np.zeros((1,eeg.shape[1]))
eeg_list = []
for idx in range(eeg.shape[0]):
    if idx-prv_frame+1<0:
        tmp = eeg[0:idx+1]
        for _ in range(prv_frame-idx-1):
            tmp=np.insert(tmp,0,data_padding,axis=0)
        eeg_list.append(tmp)
    else:
        eeg_list.append(eeg[idx-prv_frame+1:idx+1])
eeg = np.stack(eeg_list,axis=0)

In [163]:
eeg.shape

(96, 3, 122)

In [164]:
# pbar = tqdm.trange(epochs, desc=f"Epochs")
model.load_state_dict(torch.load(f'./res/{pt}/{model_name}.pt')['model_state_dict'])
model.eval()

Model(
  (l1): Sequential(
    (0): Linear(in_features=122, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=512, out_features=256, bias=True)
  )
  (transformer): TransformerModel(
    (encoder_layer): TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (linear1): Linear(in_features=256, out_features=1024, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=1024, out_features=256, bias=True)
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
    (transformer): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): Multihe

转换为MFCC

In [165]:
model_output = model(torch.from_numpy(eeg).to(device).type(tensor_type)).detach().cpu().numpy()

In [166]:
import matplotlib.pyplot as plt
import librosa
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(f'./logs/{pt}/{model_name}')
origin_melspec_fig = plt.figure()
librosa.display.specshow(melspec.T,sr=16000,hop_length=80,win_length=400,x_axis='time', y_axis='mel')
plt.colorbar(format='%+2.0f dB')        
plt.title(f'{pt}-{word}-origin')
writer.add_figure(tag=f"{pt}-{word}-origin log Mel spectrogram",figure=origin_melspec_fig)

model_melspec_fig = plt.figure()
librosa.display.specshow(model_output.T,sr=16000,hop_length=80,win_length=400,x_axis='time', y_axis='mel')
plt.colorbar(format='%+2.0f dB')        
plt.title(f'{pt}-{word}-model')
writer.add_figure(tag=f"{pt}-{word}-model log Mel spectrogram",figure=model_melspec_fig)
plt.show()
# librosa_melspec_fig = plt.figure()
# # numWindows = int(np.floor((audio.shape[0]-window_length*audio_sameple_rate)/(frameshift*audio_sameple_rate)))
# librosa_melspec = librosa.feature.melspectrogram(y=audio.astype(np.float32),sr=audio_sameple_rate,n_fft=400,hop_length=80,n_mels=80,center=False)
# librosa_melspec = librosa.power_to_db(librosa_melspec, ref=np.max)
# librosa.display.specshow(librosa_melspec,sr=16000,hop_length=80,win_length=400,x_axis='time', y_axis='mel')
# plt.colorbar(format='%+2.0f dB')        
# plt.title(f'{pt}-{word}-librosa')
# plt.show()
# # print(numWindows)
# writer.add_figure(tag=f"{pt}-{word}-librosa log Mel spectrogram",figure=librosa_melspec_fig)
writer.close()

In [167]:
model_mfcc = utils.toMFCC(model_output)
mfcc = utils.toMFCC(melspec)
eu_dis = 0
for i in range(mfcc.shape[0]):
    eu_dis += np.linalg.norm(model_mfcc[i] - mfcc[i])
mcd = eu_dis/mfcc.shape[0]
print(model_output.shape,model_mfcc.shape)
print(melspec.shape,mfcc.shape)
print(mcd)

(96, 40) (96, 13)
(96, 40) (96, 13)
19.382871546148376


In [168]:
np.save(os.path.join('mel_files',f'{pt}_{test_word}_model.npy'),model_output.T)