In [28]:
import torch
import os
import json
from tqdm import tqdm 

from Lime.cams import read_hdf5
from Define_Model.ParallelBlocks import Parallel, gumbel_softmax
from Define_Model.TDNN.ECAPA_brain import ECAPA_TDNN, Classifier

In [9]:
root_dir = '/home/yangwenhao/project/SpeakerVerification-pytorch'
grad_dir = root_dir + '/Data/gradient'

In [4]:
embedding_model = ECAPA_TDNN(filter="fbank",
  sr=16000,
  feat_dim=80,
  input_dim=80,
  input_norm='Mean', num_classes=5994)

In [5]:
embedding_model.classifier = Classifier(input_size=192,
  lin_neurons=192,
  out_neurons=5994)

In [7]:
model = Parallel(embedding_model, layers=4, agent_model=None)

In [8]:
resume = root_dir + '/Data/checkpoint/ECAPA_brain/Mean_batch48_SASP2_em192_official_2sesmix8/arcsoft_adam_cyclic/vox2/wave_fb80_dist_fine16pspot_onelayer/123456/checkpoint_12.pth'

checkpoint = torch.load(resume)

checkpoint_state_dict = checkpoint['state_dict']
if isinstance(checkpoint_state_dict, tuple):
    checkpoint_state_dict = checkpoint_state_dict[0]

filtered = {k: v for k, v in checkpoint_state_dict.items() if 'num_batches_tracked' not in k}

if list(filtered.keys())[0].startswith('module'):
    new_state_dict = OrderedDict()
    for k, v in filtered.items():
        new_state_dict[k[7:]] = v  # 新字典的key值对应的value为一一对应的值。

    model.load_state_dict(new_state_dict)
    del new_state_dict
else:
    model_dict = model.state_dict()
    model_dict.update(filtered)
    model.load_state_dict(model_dict)
    del model_dict

In [14]:
input_path = grad_dir + '/ThinResNet34_ser07/Mean_batch128_cbam_downk5_avg0_SAP2_em256_dp01_alpha0_none1_chn32_wde4_varesmix8/arcsoft_sgd_rop/vox2/wave_sp161_dist/123456/vox2_dev4'

In [17]:
# load selected input uids
data_reader = input_path + '/data.h5py'
assert os.path.exists(data_reader), print(data_reader)

uid_reader = input_path + '/uid_idx.json'
assert os.path.exists(uid_reader)
with open(uid_reader, 'r') as f:
    uididx = json.load(f)
    
some_data = set([uid for uid,idx in uididx])
print("Length of data: ", len(some_data))

Length of data:  11988


In [None]:
model.eval()

In [None]:
policys = []

for uid in tqdm(some_data, ncols=50):
    data = read_hdf5(data_reader, uid)
    data = torch.tensor(data).float().unsqueeze(0).unsqueeze(0)
    
    x = model.model.input_mask(data)
    policy, _ = model.agent_model(x)
    
    policy = gumbel_softmax(policy.cuda())
    policys.append(policy[0].cpu())
    # break

In [34]:
a = torch.randn(12, 192,23)

In [37]:
b = torch.randn(12,1,1)

In [38]:
a*b

tensor([[[ 4.7388e-01,  1.1668e+00, -6.9824e-01,  ...,  1.5910e+00,
           1.1053e+00, -1.0082e+00],
         [-1.1910e-01, -3.9314e-01, -9.9971e-01,  ..., -3.9227e-01,
           8.7295e-01,  6.7731e-02],
         [-2.8258e-01, -1.4829e+00,  9.5587e-01,  ..., -4.5215e-01,
           1.4841e+00, -9.4898e-03],
         ...,
         [-4.0878e-01,  1.0492e-01,  5.6134e-01,  ..., -1.3302e+00,
           3.5066e-01, -4.2690e-01],
         [ 4.6019e-01, -1.3650e+00, -6.3043e-01,  ...,  1.9371e-01,
           3.4028e-01, -1.1978e+00],
         [-3.3661e-01,  6.6808e-01, -1.8488e-01,  ..., -2.7147e-01,
          -8.8073e-02, -1.4807e-03]],

        [[-5.9109e-02, -3.7060e-02, -1.9911e-01,  ...,  4.3472e-03,
           5.4537e-02, -9.8661e-02],
         [-1.4320e-02, -1.6565e-01, -1.0048e-01,  ..., -1.3820e-02,
           8.4347e-02, -1.6191e-01],
         [ 4.6202e-02, -1.2334e-01,  3.2102e-02,  ..., -2.7667e-02,
          -9.1732e-02, -6.6295e-02],
         ...,
         [-7.0280e-02, -4