In [1]:
import os
import sys
sys.path.append("../models")
sys.path.append("../py_utils")

import torch
from torch.autograd import Variable

from ipywidgets import interact, fixed
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np

from models import EnhancementMultidecoder
from kaldi_data import KaldiEvalDataset

In [2]:
# First, the env variables needed from path.sh
os.environ["MODELS"] = "/data/sls/scratch/atitus5/meng/models"
os.environ["FEATS"] = "/data/sls/scratch/atitus5/meng/feats"

# Now, from models/base_config.sh
os.environ["FEAT_DIM"]="40"      # 40-dim Mel filter bank
os.environ["LEFT_CONTEXT"]="7"
os.environ["RIGHT_CONTEXT"]="7"
os.environ["OPTIMIZER"]="Adam"
os.environ["LEARNING_RATE"]="0.001"
os.environ["EPOCHS"]="35"
os.environ["BATCH_SIZE"]="128"

channels=[64,128,128]
kernels=[5,3,3]
downsamples=[2,0,2]
os.environ["CHANNELS_DELIM"]="_%s" % ("_".join(map(str, channels)))
os.environ["KERNELS_DELIM"]="_%s" % ("_".join(map(str, kernels)))
os.environ["DOWNSAMPLES_DELIM"]="_%s" % ("_".join(map(str, downsamples)))

os.environ["LATENT_DIM"]="1024"

os.environ["USE_BATCH_NORM"]="false"

phones_fc=[1024,1024]
os.environ["PHONE_FC_DELIM"]="_%s" % ("_".join(map(str, phones_fc)))
os.environ["NUM_PHONES"]="2020"

os.environ["CLEAN_DATASET"]="timit_clean"
os.environ["CLEAN_FEATS"]="%s/%s" % (os.environ["FEATS"], os.environ["CLEAN_DATASET"])

os.environ["DIRTY_DATASET"]="timit_dirty_100_rir"
os.environ["DIRTY_FEATS"]="%s/%s" % (os.environ["FEATS"], os.environ["DIRTY_DATASET"])

os.environ["EXPT_NAME"]="C%s_K%s_P%s_LATENT_%s_PHONE_FC_%s/BN_%s_OPT_%s_LR_%s_EPOCHS_%s_BATCH_%s" % (os.environ["CHANNELS_DELIM"],
                                                                                                     os.environ["KERNELS_DELIM"],
                                                                                                     os.environ["DOWNSAMPLES_DELIM"],
                                                                                                     os.environ["LATENT_DIM"],
                                                                                                     os.environ["PHONE_FC_DELIM"],
                                                                                                     os.environ["USE_BATCH_NORM"],
                                                                                                     os.environ["OPTIMIZER"],
                                                                                                     os.environ["LEARNING_RATE"],
                                                                                                     os.environ["EPOCHS"],
                                                                                                     os.environ["BATCH_SIZE"])

os.environ["MODEL_DIR"]="%s/%s/%s" % (os.environ["MODELS"], os.environ["DIRTY_DATASET"], os.environ["EXPT_NAME"])

In [3]:
# Check that the environment variables worked
print("Using experiment %s" % os.environ["EXPT_NAME"])

Using experiment C_64_128_128_K_5_3_3_P_2_0_2_LATENT_1024_PHONE_FC__1024_1024/BN_false_OPT_Adam_LR_0.001_EPOCHS_35_BATCH_128


In [4]:
clean_dataset = "timit_clean"
dirty_dataset = "timit_dirty_100_rir"

# Set up datasets for clean, dirty baselines (test set only)
clean_feat_dir = "%s/test" % os.environ["CLEAN_FEATS"]
# clean_baseline = KaldiEvalDataset(os.path.join(clean_feat_dir, "feats.scp"))
clean_baseline = KaldiEvalDataset(os.path.join(clean_feat_dir, "feats-norm.scp"))

dirty_feat_dir = "%s/test" % os.environ["DIRTY_FEATS"]
# dirty_baseline = KaldiEvalDataset(os.path.join(dirty_feat_dir, "feats.scp"))
dirty_baseline = KaldiEvalDataset(os.path.join(dirty_feat_dir, "feats-norm.scp"))

print("Set up baseline test datasets")

Set up baseline test datasets


In [5]:
model = EnhancementMultidecoder()
checkpoint = torch.load(model.ckpt_path(), map_location=lambda storage,loc: storage)
model.load_state_dict(checkpoint["state_dict"])
model.eval()

FileNotFoundError: [Errno 2] No such file or directory: '/data/sls/scratch/atitus5/meng/models/timit_dirty_100_rir/C_64_128_128_K_5_3_3_P_2_0_2_LATENT_1024_PHONE_FC__1024_1024/BN_false_OPT_Adam_LR_0.001_EPOCHS_35_BATCH_128/best_enhancement_md.pth.tar'

In [13]:
freq_dim = int(os.environ["FEAT_DIM"])
left_context = int(os.environ["LEFT_CONTEXT"])
right_context = int(os.environ["RIGHT_CONTEXT"])
time_dim = (left_context + right_context + 1)
def augmentFeats(model, feats, decoder_class):
    feats_numpy = feats.reshape((-1, freq_dim))
    num_frames = feats_numpy.shape[0]
    decoded_feats = np.empty((num_frames, freq_dim))
    for i in range(num_frames):
        frame_spliced = np.zeros((time_dim, freq_dim))
        frame_spliced[left_context - min(i, left_context):left_context, :] = feats_numpy[i - min(i, left_context):i, :]
        frame_spliced[left_context, :] = feats_numpy[i, :]
        frame_spliced[left_context + 1:left_context + 1 + min(num_frames - i - 1, right_context), :] = feats_numpy[i + 1:i + 1 + min(num_frames - i - 1, right_context), :]
        frame_tensor = Variable(torch.FloatTensor(frame_spliced))

        recon_frames = model.forward_decoder(frame_tensor, decoder_class)
        recon_frames_numpy = recon_frames.cpu().data.numpy().reshape((-1, freq_dim))
        decoded_feats[i, :] = recon_frames_numpy[left_context:left_context + 1, :]
    return decoded_feats

In [14]:
# color_map = "coolwarm"
color_map = "viridis"
def plotParallelUtts(utt_id_idx):
    fig, axarr = plt.subplots(3, 2, sharex=True)
    
    fig.set_size_inches(12, 8)
    
    # CLEAN baseline
    clean_baseline_utt_id = clean_baseline.utt_ids[utt_id_idx]
    clean_baseline_feats = clean_baseline.feats_for_uttid(clean_baseline_utt_id)
    axarr[0, 0].axis('off')    # Pretty-up the resulting output by removing gridlines
    axarr[0, 0].imshow(np.transpose(clean_baseline_feats), origin='lower', cmap=color_map, aspect='auto', interpolation='none')
    axarr[0, 0].set_title("CLEAN")
    
    # DIRTY baseline
    dirty_baseline_utt_id = dirty_baseline.utt_ids[utt_id_idx]
    dirty_baseline_feats = dirty_baseline.feats_for_uttid(dirty_baseline_utt_id)
    axarr[0, 1].axis('off')    # Pretty-up the resulting output by removing gridlines
    axarr[0, 1].imshow(np.transpose(dirty_baseline_feats), origin='lower', cmap=color_map, aspect='auto', interpolation='none')
    axarr[0, 1].set_title("DIRTY")
    
    
    # CLEAN->CLEAN
    clean_clean_feats = augmentFeats(model, clean_baseline_feats, "clean")
    axarr[1, 0].axis('off')    # Pretty-up the resulting output by removing gridlines
    axarr[1, 0].imshow(np.transpose(clean_clean_feats), origin='lower', cmap=color_map, aspect='auto', interpolation='none')
    axarr[1, 0].set_title("CLEAN>CLEAN")
    
    # CLEAN->DIRTY
    clean_dirty_feats = augmentFeats(model, clean_baseline_feats, "dirty")
    axarr[2, 0].axis('off')    # Pretty-up the resulting output by removing gridlines
    axarr[2, 0].imshow(np.transpose(clean_dirty_feats), origin='lower', cmap=color_map, aspect='auto', interpolation='none')
    axarr[2, 0].set_title("CLEAN>DIRTY")

    
    # DIRTY->CLEAN
    dirty_clean_feats = augmentFeats(model, dirty_baseline_feats, "clean")
    axarr[1, 1].axis('off')    # Pretty-up the resulting output by removing gridlines
    axarr[1, 1].imshow(np.transpose(dirty_clean_feats), origin='lower', cmap=color_map, aspect='auto', interpolation='none')
    axarr[1, 1].set_title("DIRTY>CLEAN")
    
    # DIRTY->DIRTY
    dirty_dirty_feats = augmentFeats(model, dirty_baseline_feats, "dirty")
    axarr[2, 1].axis('off')    # Pretty-up the resulting output by removing gridlines
    axarr[2, 1].imshow(np.transpose(dirty_dirty_feats), origin='lower', cmap=color_map, aspect='auto', interpolation='none')
    axarr[2, 1].set_title("DIRTY>DIRTY")
    
    plt.tight_layout()
    
    # fig.savefig("%s_idx%d.eps" % (clean_baseline_utt_id, utt_id_idx))
    
    plt.show()

interact(plotParallelUtts, utt_id_idx=range(len(clean_baseline)))

A Jupyter Widget

<function __main__.plotParallelUtts>