In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math

if torch.cuda.is_available():
    device = "cuda:0"
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

In [2]:
# Network Parameters
epochs = 20
# train_size = 100000
# test_size = 50000
train_size = 1000
test_size = 500
lr = 5e-4
pitch_shift = 0
jitter = 0.
num_workers = 10
sequence = 3

# lvl1 convolutions are shared between regions
m = 128
k = 512              # lvl1 nodes
n_fft = 4096              # lvl1 receptive field
window = 16384 # total number of audio samples?
stride = 512
batch_size = 100

In [3]:
def create_filtersv2(n_fft, freq_bins=None, low=50,high=6000, sr=44100, freq_scale='linear', mode="fft"):
    if freq_bins==None:
        freq_bins = n_fft//2+1
    
    s = torch.arange(0, n_fft, 1.)
    wsin = torch.empty((freq_bins,1,n_fft))
    wcos = torch.empty((freq_bins,1,n_fft))
    start_freq = low
    end_freq = high
    

    # num_cycles = start_freq*d/44000.
    # scaling_ind = np.log(end_freq/start_freq)/k
    
    if mode=="fft":
        window_mask = 1
    elif mode=="stft":
        window_mask = 0.5-0.5*torch.cos(2*math.pi*s/(n_fft)) # same as hann(n_fft, sym=False)
    else:
        raise Exception("Unknown mode, please chooes either \"stft\" or \"fft\"")
        
    if freq_scale == 'linear':
        start_bin = start_freq*n_fft/sr
        scaling_ind = (end_freq/start_freq)/freq_bins
        for k in range(freq_bins): # Only half of the bins contain useful info
            wsin[k,0,:] = window_mask*torch.sin(2*math.pi*(k*scaling_ind*start_bin)*s/n_fft)
            wcos[k,0,:] = window_mask*torch.cos(2*math.pi*(k*scaling_ind*start_bin)*s/n_fft)
    elif freq_scale == 'log':
        start_bin = start_freq*n_fft/sr
        scaling_ind = np.log(end_freq/start_freq)/freq_bins
        for k in range(freq_bins): # Only half of the bins contain useful info
            wsin[k,0,:] = window_mask*torch.sin(2*math.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)
            wcos[k,0,:] = window_mask*torch.cos(2*math.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)    
    else:
        print("Please select the correct frequency scale, 'linear' or 'log'")
    
    return wsin,wcos


In [6]:
class Model(nn.Module):
    def __init__(self, avg=.9998):
        super(Model, self).__init__()
        # Create filter windows
        wsin, wcos = create_filtersv2(n_fft,k, low=50, high=6000,
                                      mode="stft", freq_scale='log')

        self.wsin = torch.Tensor(wsin)
        self.wcos = torch.Tensor(wcos)
            
        # Creating Layers
#         self.linear = torch.nn.Linear(regions*k, k,bias=False)
#         self.linear_output = torch.nn.Linear(k,m, bias=False)
#         wscale = 10e-5
#         torch.nn.init.normal_(self.linear.weight, std=1e-4) # initialize
#         torch.nn.init.normal_(self.linear_output.weight, std=1e-4)
#         torch.nn.init.zeros_(self.linear.weight)
#         torch.nn.init.zeros_(self.linear_output.weight)
        k_out = 128
        k2_out = 256
        stride1 = (2,1)
        self.CNN_freq = nn.Conv2d(1,k_out,
                                kernel_size=(128,1),stride=(2,1))
        self.CNN_time = nn.Conv2d(k_out,k2_out,
                                kernel_size=(1,25),stride=(1,1))
        self.Linear = nn.Linear(k2_out*1*193, m)
        self.activation = nn.ReLU()
        self.lstm = nn.LSTM(input_size=m, hidden_size=m)
        
        torch.nn.init.normal_(self.CNN_freq.weight, std=1e-4)
        torch.nn.init.normal_(self.CNN_time.weight, std=1e-4)
        torch.nn.init.normal_(self.Linear.weight, std=1e-4)
        
    def _get_conv_output(self, shape):
        bs = 1
        x = torch.tensor(torch.rand(bs, *shape))
        output_feat = self._forward_features(x)
        n_size = output_feat.data.view(bs, -1).size(1)
        return n_size
        
    def forward(self,x):
        zx = F.conv1d(x[:,None,:], self.wsin, stride=stride).pow(2) \
           + F.conv1d(x[:,None,:], self.wcos, stride=stride).pow(2)
        zx = torch.log(zx + 10e-8) # Log Magnitude Spectrogram
        z2 = self.CNN_freq(zx.unsqueeze(1)) # Make channel as 1 (N,C,H,W)
        z2 = F.relu(z2)
        z3 = self.CNN_time(z2)
        z3 = F.relu(z3)
        y = self.Linear(self.activation(torch.flatten(z3,1)))
        print(y.view(batch_size, sequence, m).transpose(0,1).shape)
        y, _ = self.lstm(y.view(batch_size, sequence, m).transpose(0,1))
        return y
    

In [7]:
model = Model()
model.to(device)

Model(
  (CNN_freq): Conv2d(1, 128, kernel_size=(128, 1), stride=(2, 1))
  (CNN_time): Conv2d(128, 256, kernel_size=(1, 25), stride=(1, 1))
  (Linear): Linear(in_features=49408, out_features=128, bias=True)
  (activation): ReLU()
  (lstm): LSTM(128, 128)
)

In [15]:
from torchsummary import summary

In [17]:
summary(model, (window,), batch_size=batch_size)

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

In [6]:
# x = torch.rand([3, 100, 128])
batch_size=10
sequence=2
x = torch.rand(batch_size*sequence,window)

In [8]:
model(x)

NameError: name 'x' is not defined

In [10]:
test = torch.randn([2, 10, 128])

In [11]:
output, hidden = model.lstm(test)

In [12]:
output

tensor([[[ 9.2887e-02,  1.2167e-02,  2.4735e-01,  ...,  9.7279e-03,
          -8.9668e-02, -1.3971e-01],
         [ 7.5475e-02,  1.7518e-02,  9.7688e-02,  ..., -2.0750e-02,
           7.7046e-03, -1.7113e-01],
         [-6.0126e-02,  1.4890e-01, -4.2717e-02,  ..., -2.1198e-02,
          -7.6942e-02, -2.2452e-01],
         ...,
         [ 4.0418e-03,  6.9984e-02, -5.2068e-02,  ..., -3.2988e-02,
           3.3356e-02, -2.2957e-02],
         [-4.0318e-02,  6.6363e-02, -2.2246e-02,  ..., -6.1911e-02,
          -2.2948e-01, -3.1753e-01],
         [ 7.0087e-02,  8.1119e-03, -3.8013e-02,  ...,  2.0789e-01,
           1.0536e-01,  3.6563e-02]],

        [[ 2.3396e-01,  1.5800e-01,  1.5966e-01,  ..., -9.7576e-03,
           2.4485e-02, -8.2572e-02],
         [-2.8950e-02,  1.2393e-01,  5.8415e-02,  ..., -1.8619e-01,
           1.4774e-01,  3.2130e-02],
         [ 7.1430e-02,  2.0507e-01,  6.8264e-02,  ...,  1.5714e-02,
           1.7735e-04, -1.0748e-01],
         ...,
         [-3.2890e-02,  3