# GRUX Speech Commands


In [1]:
skip_training = False #flag to set True before validation

import os
import random
import math
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if skip_training:
    # The models are always evaluated on CPU
    device = torch.device("cpu")
else:
    #device = torch.device('cuda:0')
    device = torch.device("cpu")

In [3]:
n_tgt_class = 0
train_weights = None
csv_path = '../speech_commands/commands.pkl'
torch.autograd.set_detect_anomaly(True)

def squeeze_tgts(numerical_labels, tgt_classes, class_dict):
    filtered_labels = []
    for i, label in enumerate(tgt_classes):
        if class_dict[label] in numerical_labels:
            filtered_labels.append(i)
    return filtered_labels

class AudioDataset(Dataset):
    def __init__(self, folder, csv_path, file_prefix, train = True):
        global n_tgt_class
        global train_weights
        self.folder = folder
        self.file_prefix = file_prefix
        (self.file_data, self.class_dict) = pickle.load(open(csv_path, "rb"))
        self.file_list = []
        #self.tgt_classes = ['Guitar', 'Human_voice', 'Glass', 'Vehicle', 'Water', 'Dog', 'Explosion', 'Bird']
        #self.tgt_classes = ['Guitar', 'Human_voice', 'Glass', 'Water', 'Dog', 'Explosion']
        #self.tgt_classes = ['Human_voice', 'Glass']
        self.tgt_classes = ['yes', 'no']
        n_tgt_class = len(self.tgt_classes)
        for row in self.file_data:
            if (row[-1] == 'train' and train) or (row[-1] != 'train' and not train):
                squeezed_tgts = squeeze_tgts(row[1], self.tgt_classes, self.class_dict)
                if len(squeezed_tgts) != 0:
                    comps = row[0].split('/')
                    pickle_name = comps[0] + comps[1].split('.')[0]
                    path = os.path.join(self.folder, self.file_prefix + pickle_name + '.pkl')
                    self.file_list.append(path)
        self.n_samples = len(self.file_list)

        if train:
            class_counts = np.zeros(n_tgt_class)
            for row in self.file_data:
                squeezed_labels = squeeze_tgts(row[1], self.tgt_classes, self.class_dict)
                for label in squeezed_labels:
                    class_counts[label] += 1
            mode = max(class_counts)
            train_weights = mode*(1./class_counts)
        print('n_tgt_class: ' + str(n_tgt_class))
        print('class_dict: ' + str(self.class_dict))
            
    def __len__(self):
        return self.n_samples
        
    def __getitem__(self, idx):
        path = self.file_list[idx]
        gt = pickle.load(open(path, "rb"))
        item_0 = torch.tensor(gt[0]).permute(1,0).float().to(device)
        item_1 = squeeze_tgts(gt[1], self.tgt_classes, self.class_dict)
        item = (item_0, item_1, path)
        #print("fetching tensor of size " + str(item[0].size()) + " and class " + str(gt[1]))
        
        return item
        

## Collate Function

In [4]:
from torch.nn.utils.rnn import pad_sequence

padding_value = 0

def collate(list_of_samples):
    """Merges a list of samples to form a mini-batch.

    Args:
      list_of_samples is a list of tuples (src_seq, tgt_cat):
          src_seq is of shape (src_seq_length,)
          tgt_cat is list of class indices

    Returns:
      src_pad of shape (max_src_seq_length, batch_size, encoding_size): Tensor of padded src sequences.
          The sequences should be sorted by length in a decreasing order, that is src_pad[:,0,:] should be
          the longest sequence, and src_pad[:,-1,:] should be the shortest.
      src_seq_lengths: List of lengths of src sequences.
      tgt_cats of shape (batch_size, n_class): Tensor of target categories' probabilities.
    """
    
    #print('sample size: ' + str(list_of_samples[0][0].size()))
    sortSamples = sorted(list_of_samples, key = lambda x: x[0].size(0), reverse = True)
    
    #figure out max lengths and src_seq_lengths list
    src_seq_lengths = []
    for seq in sortSamples:
        src_seq_lengths.append(seq[0].size(0))
    max_src_seq_length = src_seq_lengths[0]
    #print(src_seq_lengths)
    
    #padding loop
    src_pad = None
    tgt_cats = None
    file_paths = []
    
    for seq in sortSamples:
        src_pad_len = max_src_seq_length - seq[0].size(0)
        pad_src_seq = F.pad(seq[0], (0, 0, 0, src_pad_len), value = padding_value).unsqueeze(0).to(device)
        if src_pad is None:
            src_pad = pad_src_seq
        else:
            src_pad = torch.cat([src_pad, pad_src_seq], 0).to(device)
        
        tgt_cat_list = seq[1]
        if len(tgt_cat_list) == 0:
            raise Exception("Found sample with no target class")
        new_tgt_vec = torch.zeros((1, n_tgt_class)).float().to(device)
        for i in tgt_cat_list:
            new_tgt_vec[0,i] = 1
                
        if tgt_cats is None:
            tgt_cats = new_tgt_vec
        else:
            tgt_cats = torch.cat([tgt_cats, new_tgt_vec], 0).to(device)
        
        file_paths.append(seq[2])
            
    return src_pad.permute(1, 0, 2), src_seq_lengths, tgt_cats, file_paths

In [5]:
# Create custom DataLoader using the implemented collate function
from torch.utils.data import DataLoader
dataset_direc = '../speech_commands/gammatone'
trainset = AudioDataset(dataset_direc + '/train_redo', csv_path, 'gammatone_')
print(n_tgt_class)

n_tgt_class: 2
class_dict: {'backward': 0, 'bed': 1, 'bird': 2, 'cat': 3, 'dog': 4, 'down': 5, 'eight': 6, 'five': 7, 'follow': 8, 'forward': 9, 'four': 10, 'go': 11, 'happy': 12, 'house': 13, 'learn': 14, 'left': 15, 'marvin': 16, 'nine': 17, 'no': 18, 'off': 19, 'on': 20, 'one': 21, 'right': 22, 'seven': 23, 'sheila': 24, 'six': 25, 'stop': 26, 'three': 27, 'tree': 28, 'two': 29, 'up': 30, 'visual': 32, 'wow': 33, 'yes': 34, 'zero': 35}
2


In [6]:
trainloader = DataLoader(dataset=trainset, batch_size=4, shuffle=True, collate_fn=collate, pin_memory=False)

## Experimental Model

In [7]:
class Experimental(nn.Module):
    def __init__(self, encode_size, hidden_size, n_tgt_class):
        """
        Args:
          encode_size: The size of the (encoded) spectral input
          hidden_size: The number of features in the hidden state of GRU.
        """
        super(Experimental, self).__init__()
        self.hidden_size = hidden_size
        self.n_tgt_class = n_tgt_class
        self.update_gate = nn.Linear(hidden_size + encode_size, hidden_size)
        self.reset_gate = nn.Linear(hidden_size + encode_size, hidden_size)
        self.candidate_activation = nn.Linear(hidden_size + encode_size, hidden_size)
        self.activation = nn.Sigmoid()
        
    def emulate_gru(self, src_pad, src_len):
        h_in = torch.zeros(self.hidden_size)
        for i in range(src_len):
            z_in = torch.cat((src_pad[i,:], h_in))
            z = self.activation(self.update_gate(z_in))
            r = self.activation(self.reset_gate(z_in))
            h_hat_in = torch.cat((src_pad[i,:], r*h_in))
            h_hat = self.activation(self.candidate_activation(h_hat_in))
            h_out = z*h_hat + (1 - z) * h_hat
            h_in = h_out
        return h_out
    
    def emulate_batch(self, src_pad, src_lengths):
        batch_sz = len(src_lengths)
        hidden = torch.zeros(batch_sz, self.hidden_size)
        for i in range(batch_sz):
            hidden[i,:] = self.emulate_gru(src_pad[:,i,:], src_lengths[i])
        return hidden

    def forward(self, src_pad, src_lengths):
        """
        Args:
          src_pad of shape (max_src_seq_length, batch_size, encoding_size): Tensor of padded src sequences.
          src_lengths: List of source sequence lengths.

        Returns:
          outputs of shape (max_src_seq_length, batch_size, hidden_size): Padded outputs of GRU at every step.
          hidden of shape (1, batch_size, hidden_size): Updated states of the GRU.
        """

        hidden = self.emulate_batch(src_pad, src_lengths)
        hidden = hidden.unsqueeze(0)
        probs = hidden[:,:,-self.n_tgt_class:].to(device)
        return probs, hidden, 0

## Training

In [8]:
hidden_size = 128
encode_size = 256
input_bottleneck_size = 8
experimental = Experimental(encode_size, hidden_size, n_tgt_class).to(device)

In [9]:
def testShapes():
    test_tensor = torch.zeros(5, 4, encode_size).to(device)
    src_lengths = [5, 5, 5, 5]
    softMax_probs, hidden, _ = experimental.forward(test_tensor, src_lengths)
    print(softMax_probs.size())
    print(hidden.size())
    
testShapes()

torch.Size([1, 4, 2])
torch.Size([1, 4, 128])


In [10]:
def maskInputs(module, key, low_idx, high_idx):
    state_dict = module.state_dict()[key]
    for i in range(3):
        start = i*hidden_size
        state_dict[start+low_idx:start+high_idx,:] = torch.zeros(high_idx-low_idx, encode_size)
    
#maskInputs(experimental, 'gru.weight_ih_l0', input_bottleneck_size, hidden_size)
#print(experimental.state_dict()['gru.weight_ih_l0'])

# a function that serves a testing purpose in finding out units with masked inputs
def printMaskedIntervals(module, key):
    state = module.state_dict()[key]
    intervals = []
    inter_started = False
    for i in range(state.size(0)):
        row = state[i,:]
        if torch.sum(row) == 0.0:
            if not inter_started:
                inter_started = True
                intervals.append(i)
        else:
            if inter_started:
                inter_started = False
                intervals[-1] = (intervals[-1], i)
    if inter_started:
        intervals[-1] = (intervals[-1], state.size(0))
    print(intervals)

#printMaskedIntervals(experimental, 'gru.weight_ih_l0')

In [11]:
def trainBatch(src_batch, seq_lengths, target_batch, experimental, experimental_opt, criterion, verbose):
    """
    Args:
      ae_batch of shape (max_src_seq_length, batch_size, encode_size): Padded ae sequences.
      seq_lengths: List of sequence lengths.
      target_batch of shape (1, batch_size): Tensor of target classes.

    Returns:
      loss
    """

    batch_size = src_batch.size(1)
    max_src_seq_length = src_batch.size(0)

    experimental_opt.zero_grad()

    #print('Src batch min: ' + str(torch.min(src_batch)))
    #print('Src batch max: ' + str(torch.max(src_batch)))
    
    #if verbose:
        #print('Current allocated memory: ' + str(torch.cuda.memory_stats()['allocated_bytes.all.current']))
        #print('Peak allocated memory: ' + str(torch.cuda.memory_stats()['allocated_bytes.all.peak']))
    probs, hidden, _ = experimental.forward(src_batch, seq_lengths)
    pred = probs.squeeze(0)
    #print('Pred min: ' + str(torch.min(pred)))
    #print('Pred max: ' + str(torch.max(pred)))
    #print('Target min: ' + str(torch.min(target_batch)))
    #print('Target max: ' + str(torch.max(target_batch)))

    loss = criterion(pred, target_batch)
    loss.backward()
    
    experimental_opt.step()
    #maskInputs(experimental, 'gru.weight_ih_l0', input_bottleneck_size, hidden_size)
    #print("Loss: " + str(loss.item()))

    return loss.item()
    
def full_train(n_epochs, max_batch_per_epoch = 1000, suppress = False, verbose = False, unique_path_log = []):
    criterion = nn.CrossEntropyLoss(weight=torch.tensor(train_weights).to(device))
    experimental_opt = optim.Adam(experimental.parameters(), lr=0.001)

    error = 0

    for epoch in range(n_epochs):
        experimental.train()
        for i, data in enumerate(trainloader, 0):
            src_batch = data[0]
            seq_lengths = data[1]
            target_batch = data[2]
            if data[3] not in unique_path_log:
                unique_path_log.append(data[3])
            batch_verbose = False
            if verbose and i % 100 == 0:
                print('Training batch: ' + str(i))
                batch_verbose = True
            error = trainBatch(src_batch, seq_lengths, target_batch, experimental,
                              experimental_opt, criterion, batch_verbose)
            if i > max_batch_per_epoch:
                break
        if not suppress:
            print('epoch: ', epoch, ' loss: ', error)
            print('number of unique paths encountered: ' + str(len(unique_path_log)))
    return error

if not skip_training:
    full_train(0, max_batch_per_epoch = 1000, verbose = True)

In [12]:
def save_model(model, filename, skip_dialogue = False):
    try:
        if not skip_dialogue:
            do_save = input('Do you want to save the model (type yes to confirm)? ').lower()
        else:
            do_save = 'yes'
            
        if do_save == 'yes':
            torch.save(model.state_dict(), filename)
            print('Model saved to %s.' % (filename))
        else:
            print('Model not saved.')
    except:
        raise Exception('The notebook should be run or validated with skip_training=True.')


def load_model(model, filename, device):
    model.load_state_dict(torch.load(filename, map_location=lambda storage, loc: storage))
    print('Model loaded from %s.' % filename)
    model.to(device)
    model.eval()
    
# Save the model to disk
if not skip_training:
    save_model(experimental, 'gru_experimental.pth')
else:
    experimental = Experimental(encode_size, hidden_size, n_tgt_class).to(device)
    load_model(experimental, 'gru_experimental.pth', device)

Do you want to save the model (type yes to confirm)? yes
Model saved to gru_experimental.pth.


## Evaluation

In [13]:
def sensitivity_specificity(bin_pred, tgt_batch, num_label, n_tgt_class):
    #print('Binary prediction size: ' + str(bin_pred.size()))
    #print('Target batch size: ' + str(tgt_batch.size()))
    diff = tgt_batch[:, num_label] - bin_pred[:, num_label]
    false_neg_M = np.equal(diff, 1)
    n_false_neg = torch.sum(false_neg_M).item()
    false_pos_M = np.equal(diff, -1)
    n_false_pos = torch.sum(false_pos_M).item()
    n_true_pos = torch.sum(tgt_batch[:, num_label]).item() - n_false_neg
    batch_sz = tgt_batch.size(0)
    n_true_neg = batch_sz - (n_false_neg + n_false_pos + n_true_pos)
    sens = 0
    spec = 0
    if n_true_pos != 0:
        try:
            sens = n_true_pos / (n_true_pos + n_false_neg)
        except:
            print('Sensitivity division failed, n_true_pos: ' + str(n_true_pos))
    if n_true_neg != 0:
        try:
            spec = n_true_neg / (n_true_neg + n_false_pos)
        except:
            print('Specificity division failed, n_true_neg: ' + str(n_true_neg))
    return sens, spec
    
bin_pred = torch.tensor([[1, 0], [0, 1]])
tgt_batch = torch.tensor([[1, 0], [1, 0]])

sens1 = sensitivity_specificity(bin_pred, tgt_batch, 0, 2)
sens2 = sensitivity_specificity(bin_pred, tgt_batch, 1, 2)
print(sens1)
print(sens2)

bin_pred = torch.tensor([[1, 0], [0, 1]])
tgt_batch = torch.tensor([[1, 0], [1, 0]])

(0.5, 0)
(0, 0.5)


In [14]:
def evaluation(loader, n = 50):
    experimental.eval()
    eval_crit = nn.CrossEntropyLoss(weight=torch.tensor(train_weights).to(device))
    losses = []
    sensitivities = []
    specificities = []
    
    for idx, data in enumerate(loader, 0):
        src_batch = data[0]
        seq_lengths = data[1]
        tgt_batch = data[2]
        batch_size = src_batch.size(1)
        max_src_seq_length = src_batch.size(0)
        probs, hidden, _ = experimental.forward(src_batch, seq_lengths)
        pred = probs.squeeze(0)
        loss = eval_crit(pred, tgt_batch)
        losses.append(loss.item())
        bin_pred = np.greater(pred.cpu().detach(), 0.5)
        
        sens_arr = []
        spec_arr = []
        for i in range(n_tgt_class):
            sens, spec = sensitivity_specificity(bin_pred, tgt_batch.cpu().detach(), i, n_tgt_class)
            sens_arr.append(sens)
            spec_arr.append(spec)
        sensitivities.append(sens_arr)
        specificities.append(spec_arr)
        if idx >= n:
            break
    losses = np.mean(losses)
    print('Average loss: ' + str(losses))
    for i in range(n_tgt_class):
        sens_total = 0
        for arr in sensitivities:
            sens_total += arr[i]
        sens_avg = sens_total / len(sensitivities)
        spec_total = 0
        for arr in specificities:
            spec_total += arr[i]
        spec_avg = spec_total / len(specificities)
        print ('Class with label ' + str(i) + ' Sensitivity: ' + str(round(sens_avg, 3)) + '\tSpecificity: ' + str(round(spec_avg, 3)))
    return (losses, sensitivities, specificities)

testset = AudioDataset(dataset_direc + '/test_redo', csv_path, 'gammatone_', train = False)
testloader = DataLoader(dataset=testset, batch_size=16, shuffle=True, collate_fn=collate, pin_memory=False)
param_search = True
train_accs = []
test_accs = []

evaluate = True
if evaluate:
    print('Evaluating with training set', end='... ')
    train_acc = evaluation(trainloader)
    train_accs.append(train_acc)

    print('Evaluating with testing set', end='... ')
    test_acc = evaluation(testloader)
    test_accs.append(test_acc)

n_tgt_class: 2
class_dict: {'backward': 0, 'bed': 1, 'bird': 2, 'cat': 3, 'dog': 4, 'down': 5, 'eight': 6, 'five': 7, 'follow': 8, 'forward': 9, 'four': 10, 'go': 11, 'happy': 12, 'house': 13, 'learn': 14, 'left': 15, 'marvin': 16, 'nine': 17, 'no': 18, 'off': 19, 'on': 20, 'one': 21, 'right': 22, 'seven': 23, 'sheila': 24, 'six': 25, 'stop': 26, 'three': 27, 'tree': 28, 'two': 29, 'up': 30, 'visual': 32, 'wow': 33, 'yes': 34, 'zero': 35}
Evaluating with training set... Average loss: 0.6986339359128593
Class with label 0 Sensitivity: 0.029	Specificity: 0.961
Class with label 1 Sensitivity: 0.871	Specificity: 0.118
Evaluating with testing set... Average loss: 0.7002094694337444
Class with label 0 Sensitivity: 0.036	Specificity: 0.988
Class with label 1 Sensitivity: 0.888	Specificity: 0.188


In [15]:
param_search = True
save_best = True
if param_search:
    unique_path_log = []
    smallest_test_loss = math.inf
    for n in range(10):
        
        print('Training with ' + str(1 + n) + ' epochs')
        full_train(1, max_batch_per_epoch = 500, verbose = True, unique_path_log = unique_path_log)
        
        print('Evaluating with training set', end='... ')
        train_acc = evaluation(trainloader)
        train_accs.append(train_acc)
        
        print('Evaluating with testing set', end='... ')
        test_acc = evaluation(testloader)
        test_accs.append(test_acc)
        
        if test_acc[0] < smallest_test_loss:
            smallest_test_loss = test_acc[0]
            if save_best:
                save_model(experimental, 'gru_experimental_best.pth', skip_dialogue = True)
        """
        sens = test_acc[1]
        spec = test_acc[2]
        print("Sensitivity array: " + str(sens))
        print("Specificity array: " + str(spec))
        """
        

Training with 1 epochs
Training batch: 0
Training batch: 100
Training batch: 200
Training batch: 300
Training batch: 400
Training batch: 500
epoch:  0  loss:  0.6279609099737553
number of unique paths encountered: 502
Evaluating with training set... Average loss: 0.6482402680460558
Class with label 0 Sensitivity: 0.541	Specificity: 0.673
Class with label 1 Sensitivity: 0.667	Specificity: 0.511
Evaluating with testing set... Average loss: 0.6546552348778629
Class with label 0 Sensitivity: 0.531	Specificity: 0.666
Class with label 1 Sensitivity: 0.666	Specificity: 0.537
Model saved to gru_experimental_best.pth.
Training with 2 epochs
Training batch: 0
Training batch: 100
Training batch: 200


KeyboardInterrupt: 

In [None]:
printMaskedIntervals(experimental, 'gru.weight_ih_l0')

In [None]:
import matplotlib.pyplot as plt

W = experimental.state_dict()['gru.weight_hh_l0'].cpu()
plt.imshow(W, interpolation='nearest', aspect='auto')
plt.show()
print(W)
print(torch.max(W))

In [None]:
W_hr = W[0:hidden_size, :]
W_hz = W[hidden_size:2*hidden_size, :]
W_hn = W[2*hidden_size:3*hidden_size, :]

print(W_hr.size())
print(W_hz.size())
print(W_hn.size())

plt.imshow(W_hr, interpolation='nearest', aspect='auto')
plt.show()
plt.imshow(W_hz, interpolation='nearest', aspect='auto')
plt.show()
plt.imshow(W_hn, interpolation='nearest', aspect='auto')
plt.show()

In [None]:
classif_input_hr = W_hr[-n_tgt_class:,:]
classif_input_hz = W_hz[-n_tgt_class:,:]
classif_input_hn = W_hn[-n_tgt_class:,:]

plt.imshow(classif_input_hr, interpolation='nearest', aspect='auto')
plt.show()
print((torch.min(classif_input_hr), torch.max(classif_input_hr)))
plt.imshow(classif_input_hz, interpolation='nearest', aspect='auto')
plt.show()
print((torch.min(classif_input_hn), torch.max(classif_input_hn)))
plt.imshow(classif_input_hn, interpolation='nearest', aspect='auto')
plt.show()
print((torch.min(classif_input_hz), torch.max(classif_input_hz)))

In [None]:
import networkx as nx
hr_pos_threshold = 0.15
hr_neg_threshold = -0.25
hz_pos_threshold = 0.25
hz_neg_threshold = -0.25
hn_pos_threshold = 0.15
hn_neg_threshold = -0.15

g = nx.DiGraph()
node_colors = []
for i in range(hidden_size):
    g.add_node(i)
    node_color = 'grey'
    if i < input_bottleneck_size:
        node_color = 'green'
    if i >= hidden_size - n_tgt_class:
        node_color = 'red'
    node_colors.append(node_color)

hr_pos_edges = np.argwhere(W_hr > hr_pos_threshold)
hr_neg_edges = np.argwhere(W_hr < hr_neg_threshold)
hz_pos_edges = np.argwhere(W_hz > hz_pos_threshold)
hz_neg_edges = np.argwhere(W_hz < hz_neg_threshold)
hn_pos_edges = np.argwhere(W_hn > hn_pos_threshold)
hn_neg_edges = np.argwhere(W_hn < hn_neg_threshold)

def addEdges(g, edge_tensor):
    for i in range(edge_tensor.size(1)):
        node_from = edge_tensor[1,i].item()
        node_to = edge_tensor[0,i].item()
        #print((node_from, node_to))
        g.add_edge(node_from, node_to)
        
addEdges(g, hr_pos_edges)
addEdges(g, hr_neg_edges)
addEdges(g, hz_pos_edges)
addEdges(g, hz_neg_edges)
addEdges(g, hn_pos_edges)
addEdges(g, hn_neg_edges)
    
nx.draw_kamada_kawai(g, node_size = 70, node_color = node_colors)

In [None]:
def findStrongestInputs(W, idx, n):
    input_nodes = set()
    W_input = W[idx,:]
    strength_order = np.argsort(torch.abs(W_input))
    for i in range(n):
        input_nodes.add(strength_order[i].item())
    return input_nodes

def findStrongestOutputs(W, idx, n):
    output_nodes = set()
    W_output = W[:, idx]
    strength_order = np.argsort(torch.abs(W_output))
    for i in range(n):
        output_nodes.add(strength_order[i].item())
    return output_nodes

def findCombinedInputs(W_hr, W_hz, W_hn, idx, n):
    input_nodes = set()
    input_nodes.update(findStrongestInputs(W_hr, idx, n))
    input_nodes.update(findStrongestInputs(W_hz, idx, n))
    input_nodes.update(findStrongestInputs(W_hn, idx, n))
    return input_nodes

def findCombinedOutputs(W_hr, W_hz, W_hn, idx, n):
    output_nodes = set()
    output_nodes.update(findStrongestOutputs(W_hr, idx, n))
    output_nodes.update(findStrongestOutputs(W_hz, idx, n))
    output_nodes.update(findStrongestOutputs(W_hn, idx, n))
    return output_nodes

def addInputEdgeSet(g, idx, node_set):
    for node in node_set:
        g.add_edge(node, idx)

def addInputSection(g, tgt_node_idx, layer_dict, W_hr, W_hz, W_hn, n):
    input_node_set = findCombinedInputs(W_hr, W_hz, W_hn, tgt_node_idx, n)
    addInputEdgeSet(g, tgt_node_idx, input_node_set)
    for i in range(len(input_node_set)):
        node = list(input_node_set)[i]
        if node not in layer_dict:
            n_layer = max(layer_dict[tgt_node_idx][0] - 1, 0)
            n_section = layer_dict[tgt_node_idx][1] * 3 * n + i
            layer_dict[node] = (n_layer, n_section)
    return input_node_set

def addInputLayer(g, tgt_node_idxs, layer_dict, W_hr, W_hz, W_hn, n):
    input_layer = set()
    for idx in tgt_node_idxs:
        input_layer.update(addInputSection(g, idx, layer_dict, W_hr, W_hz, W_hn, n))
    return input_layer

def addRecursiveInputLayers(g, idx_list, W_hr, W_hz, W_hn, n, n_layers, layer_dict):
    new_layer = addInputLayer(g, idx_list, layer_dict, W_hr, W_hz, W_hn, n)
    if n_layers > 1:
        addRecursiveInputLayers(g, new_layer, W_hr, W_hz, W_hn, n, n_layers - 1, layer_dict)
    

#nx.draw_kamada_kawai(g, node_size = 70, node_color = node_colors)
#spring_pos = nx.spring_layout(g)
#print(spring_pos)

#print(findCombinedInputs(W_hr, W_hz, W_hn, 255, 5))

In [None]:
g = nx.DiGraph()
node_colors = []
for i in range(hidden_size):
    g.add_node(i)
    node_color = 'grey'
    if i < input_bottleneck_size:
        node_color = 'green'
    if i >= hidden_size - n_tgt_class:
        node_color = 'red'
    node_colors.append(node_color)

n_inter_layers = 2
layer_dict = {}
for i in range(input_bottleneck_size):
    layer_dict[i] = (0, 0)

for i in range(hidden_size-n_tgt_class, hidden_size):
    layer_dict[i] = (n_inter_layers + 1, i-hidden_size+n_tgt_class)

addRecursiveInputLayers(g, range(hidden_size-n_tgt_class, hidden_size), W_hr, W_hz, W_hn, 2, n_inter_layers,
                        layer_dict)

def drawNetwork(g, layer_dict):
    pos_x_max = 0.95
    pos_x_min = -0.95
    pos_y_max = 0.95
    pos_y_min = -0.95

    pos_dict = {}
    n_section_dict = {}
    for i in range(n_inter_layers + 2):
        sec_max = 0
        for j in range(hidden_size):
            if j in layer_dict:
                if layer_dict[j][0] == i and layer_dict[j][1] > sec_max:
                    sec_max = layer_dict[j][1]
        n_section_dict[i] = sec_max
        
    n_orphans = 0
    for i in range(hidden_size):
        if i in layer_dict:
            layer = layer_dict[i][0]
            section = layer_dict[i][1]
            x_pos = pos_x_min + layer * (pos_x_max - pos_x_min) / (n_inter_layers + 1)
            y_pos = pos_y_min + section * (pos_y_max - pos_y_min) / (n_section_dict[layer] + 1)
            pos_dict[i] = [x_pos, y_pos]
        else:
            n_orphans += 1
    for i in range(hidden_size):
        if i not in layer_dict:
            pos_dict[i] = [-0.9, 0]
            
    for i in range(input_bottleneck_size):
        output_set = findCombinedOutputs(W_hr, W_hz, W_hn, i, 10)
        for node in output_set:
            node_colors[node] = 'orange'
    
    nx.draw(g, pos = pos_dict, node_size = 50, node_color = node_colors)
    print("Orphans left over: " + str(n_orphans))

drawNetwork(g, layer_dict)

## Classification Visualizer

In [None]:
from scipy.io import wavfile
import gammatone.gtgram as gtg

def getAudioGammatone(filename):
    direc = '../cats_dogs/'

    print('Processing ' + filename)
    fs, data = wavfile.read(direc + filename)
    gram = gtg.gtgram(data, fs, 0.1, 0.01, encode_size, 20)
    np_M = np.array(gram)
    np_M = np.clip(np_M, 1.0, 1000000.0)
    log_M = np.log(np_M)
    max_M = max(log_M.flatten())
    M = log_M / max_M
    plt.imshow(M, interpolation='nearest', aspect='auto')
    plt.show()
    return (M, 0)
    
    print('File ' + filename + ' not found')
    return (np.zeros((1,1)), 0)

def plotClassification(filename):
    _, _, output = experimental.forward(src_batch, seq_lengths)
    print(output)

class DummyDataset(Dataset):
    def __init__(self, gt):
        self.len = 1
        self.gt = gt
                        
    def __len__(self):
        return self.len
        
    def __getitem__(self, idx):
        item = (torch.tensor(self.gt[0]).permute(1,0).float().to(device), gt[1])
        #print("fetching tensor of size " + str(item[0].size()) + " and class " + str(gt[1]))
        return item

filename = 'dog_barking_1.wav'
gt = getAudioGammatone(filename)
dummyset = DummyDataset(gt)
dummyloader = DataLoader(dataset=dummyset, batch_size=1, shuffle=False, collate_fn=collate, pin_memory=False)
gt, seq_len, _ = next(iter(dummyloader))
_, _, output = experimental.forward(gt, seq_len)
output = output.squeeze(1)
t_class = output[:, -n_tgt_class:]
t = range(len(t_class))
t_class = t_class.cpu().detach().numpy()
for i in range(n_tgt_class):
    label = 'animal'
    if i == 1:
        label = 'nature'
    if i == 2:
        label = 'people'
    if i == 3:
        label = 'percussive'
    if i == 4:
        label = 'machine'
    plt.plot(t, t_class[:,i], label = label)
plt.legend()
plt.show