In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import torchaudio
import torch
import numpy as np
import pandas as pd
import os
import pickle
import re
import torchaudio.transforms as T
import math
import librosa
import librosa.display
import matplotlib.patches as patches
from glob import glob

torch.manual_seed(1)
def process_filename(filename):
    chars_to_remove = ['_','a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z','.']
    my_filename = filename.lower()
    for char in chars_to_remove:
        my_filename = my_filename.replace(char,"")
    return my_filename

def get_info_from_fname(filename):
    year = filename[0:4]
    month = filename[4:6]
    date = filename[6:8]
    sub_file = year + '_' + month + '_' + date
    file_loc = sub_file + '/' + filename
    return file_loc
    
def get_whitelisted_filepaths(new_filename_list): #NOT IN USE
    zone_paths = '/project/graziul/data/Zone'
    file_list = []
    for file in new_filename_list:
        for i in range(15):
            zone_path = zone_paths + str(i) + '/'
            #print(zone_path + file)
            my_file = zone_path + get_info_from_fname(file)
            if(os.path.exists(my_file)):
                file_list.append(my_file)
    return file_list

def is_in_list(my_list,elt):
    for my_elt in my_list:
        #print(elt, my_elt)
        if elt == my_elt:
            return 1
    return 0

def divide_audio(datafile, div_size = 30): #Divide the audio clip into bits of 1 minute each
#resizes input arrays from (1,feature_length, time) to (div_size,feature_length,time/div_length)
    return np.reshape(datafile,[div_size,datafile.shape[1],datafile.shape[2]//div_size])

class audio_file():
    def __init__(self, name,new_flag = 1):
        self.name = name
        self.vad_slices = None
        self.frames = None
        self.frames_labels = None
        self.mfcc = None
        self.n_clips = 30
        self.flag = new_flag
    
    def get_slices(self, vad_dict):
        if self.flag == 1:
            self.vad_slices = vad_dict[self.name]['nonsilent_slices']
        else:
            self.vad_slices = vad_dict[self.name]['pydub'][-24]['nonsilent_slices']
        return self.vad_slices
    
    def get_frames(self):
        ms_2_sample = self.sample_rate/1000
        frames_array = np.zeros(self.mfcc.shape[2])

        for v in self.vad_slices:
            start = math.floor(v[0]*ms_2_sample)
            end = math.ceil(v[1]*ms_2_sample)
            #print(v)
            for i in range(start,end):
                n = math.floor(i/220)
                j = i%220
                if j <= 110:
                    frames_array[n-2] += 1
                    frames_array[n-1] += 1
                    frames_array[n] += 1
                elif j>=111 and j<=220:
                    frames_array[n-1] += 1
                    frames_array[n] += 1
                elif j>=221 and j<=330:
                    frames_array[n-1] += 1
                    frames_array[n] += 1
                    frames_array[n+1] += 1
                elif j>=331 and j<=440:
                    frames_array[n+1] += 1
                    frames_array[n] += 1
                elif j>=441:
                    frames_array[n+2] += 1
                    frames_array[n+1] += 1
                    frames_array[n] += 1
            
        self.frames = frames_array
        return self.frames
        
    def get_split_frames(self):
        '''ms_2_sample = self.sample_rate/1000
        frame_arr_list = []
        for j in range(self.n_clips):
            frames_array = np.zeros(self.mfcc.shape[2])
            #frames_array = np.zeros(180409)
            self.clip_size = self.mfcc.shape[2]
            start_idx = j*self.clip_size
            end_idx = j*self.clip_size
            print(start_idx, end_idx)
            for v in self.vad_slices:
                start = math.floor(v[0]*ms_2_sample)
                end = math.ceil(v[1]*ms_2_sample)
                if(start >= start_idx and end <= end_idx):
                    for i in range(start,end):
                        n = math.floor(i/220)
                        j = i%220
                        if j <= 110:
                            frames_array[n-2] += 1
                            frames_array[n-1] += 1
                            frames_array[n] += 1
                        elif j>=111 and j<=220:
                            frames_array[n-1] += 1
                            frames_array[n] += 1
                        elif j>=221 and j<=330:
                            frames_array[n-1] += 1
                            frames_array[n] += 1
                            frames_array[n+1] += 1
                        elif j>=331 and j<=440:
                            frames_array[n+1] += 1
                            frames_array[n] += 1
                        elif j>=441:
                            frames_array[n+2] += 1
                            frames_array[n+1] += 1
                            frames_array[n] += 1
            frame_arr_list.append(np.expand_dims(frames_array,axis = 0))        
        self.frames = np.concatenate(frame_arr_list,axis = 0)
        return self.frames'''
        ms_2_sample = self.sample_rate/1000
        frames_array = np.zeros(self.mfcc.shape[2]*self.n_clips)
        print(frames_array.shape)

        for v in self.vad_slices:
            start = math.floor(v[0]*ms_2_sample)
            end = math.ceil(v[1]*ms_2_sample)
            for i in range(start,end):
                n = min(math.floor(i/220),len(frames_array)-1)
                j = i%220
                if j <= 110:
                    frames_array[n-2] += 1
                    frames_array[n-1] += 1
                    frames_array[n] += 1
                elif j>=111 and j<=220:
                    frames_array[n-1] += 1
                    frames_array[n] += 1
                elif j>=221 and j<=330:
                    frames_array[n-1] += 1
                    frames_array[n] += 1
                    frames_array[n+1] += 1
                elif j>=331 and j<=440:
                    frames_array[n+1] += 1
                    frames_array[n] += 1
                elif j>=441:
                    frames_array[n+2] += 1
                    frames_array[n+1] += 1
                    frames_array[n] += 1
        
        self.clip_size = self.mfcc.shape[2]
        frame_arr_list = []
        for j in range(self.n_clips):
            frame_arr_list.append(np.expand_dims(frames_array[j*self.clip_size:(j+1)*self.clip_size],axis=0))
        self.frames = np.concatenate(frame_arr_list,axis=0)
        return self.frames
    
        
    def get_labels(self): 
        self.frames_labels = np.zeros(len(self.frames))
        self.frames_labels[np.where(self.frames>0)] = 1
        return self.frames_labels
    
    def get_split_labels(self):
        self.frames_labels = np.zeros_like(self.frames)
        self.frames_labels[np.where(self.frames>0)] = 1
        return self.frames_labels
        
    def get_mfcc(self): 
        if self.flag == 0:
            file_name = '/project/graziul/data/Zone1/2018_08_04/' + self.name
        else:
            file_name = self.name
        self.waveform, self.sample_rate = torchaudio.load(file_name)
        self.waveform = self.waveform[:,:1800*self.sample_rate] #Clip the file at 1800s
        n_fft = 2048
        win_length = 551
        hop_length = 220
        n_mels = 40
        n_mfcc = 40

        mfcc_transform = T.MFCC(
            sample_rate=self.sample_rate,
            n_mfcc=n_mfcc,
            melkwargs={
              'n_fft': n_fft,
              'n_mels': n_mels,
              'hop_length': hop_length,
              'mel_scale': 'htk',
            }
        )

        self.mfcc = mfcc_transform(self.waveform)
        return self.mfcc
    
    def get_split_mfcc(self):
        if self.flag == 1:
            file_name = self.name
        else:
            file_name = '/project/graziul/data/Zone1/2018_08_04/' + self.name
        self.waveform, self.sample_rate = torchaudio.load(file_name)
        self.waveform = self.waveform[:,:1800*self.sample_rate] #Clip the file at 1800s
        clip_size = math.floor(self.waveform.shape[1]/self.n_clips)
        n_clips = self.n_clips
        mfcc_list = []
        n_fft = 2048
        win_length = 551
        hop_length = 220
        n_mels = 40
        n_mfcc = 40
        mfcc_transform = T.MFCC(
                sample_rate=self.sample_rate,
                n_mfcc=n_mfcc,
                melkwargs={
                  'n_fft': n_fft,
                  'n_mels': n_mels,
                  'hop_length': hop_length,
                  'mel_scale': 'htk',
                }
            )
        for i in range(n_clips):
            mfcc_list.append(mfcc_transform(self.waveform[:,i*clip_size:(i+1)*clip_size]))
        self.mfcc = torch.cat(mfcc_list)
        return self.mfcc
    
    def plot_waveform_with_labels(self,i,clip_size):
        plt.figure(figsize=(14,5))
        fig,(ax1,ax2) = plt.subplots(2,1)
        librosa.display.waveshow(self.waveform.squeeze().numpy()[i*clip_size:(i+1)*clip_size],self.sample_rate,ax = ax1)
        ax2.plot(self.frames_labels[i])
        plt.show()
        return    
    
    def get_plots(self): 
        clip_size = math.floor(1800*self.sample_rate/self.n_clips)
        for i in range(self.n_clips):
            print(i)
            self.plot_waveform_with_labels(i,clip_size)
        return
    



'''datapath = '/project/graziul/data/whitelisted_vad_files.csv'
dataframe = pd.read_csv(datapath, header=None)

transcripts_path = '/project/graziul/transcripts/transcripts2021_10_27.csv'
transcripts_df = pd.read_csv(transcripts_path)
df_groups = transcripts_df.groupby(['zone','day','month','year','file'])
#clean_transcripts_df_files = [process_filename(transcripts_file) for transcripts_file in list(transcripts_df['file'])]
#print(clean_transcripts_df_files)

new_filename_list = []
chars_to_remove = ['_','a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z','.']
for idx,elt in enumerate(list(dataframe[0])):
    my_elt = process_filename(elt)
    new_filename_list.append(my_elt[:-1] + '.mp3')
    #new_filename_list.append(my_elt[:-1])
new_filename_list = list(set(new_filename_list))

to_ms = 1000
fname_list = []
transcripts_list = []
for state,frame in df_groups: 
    #print(frame)
    info_list = list(state)
    
    zone = info_list[0]
    day = info_list[1]
    month = info_list[2]
    year = info_list[3]
    if(len(str(day)) > 1):
        str_day = str(day)
    else:
        str_day = '0' + str(day)
    date = str(year) + '_0' + str(month) + '_' + str_day
    filename = process_filename(info_list[4]) + '.mp3'
    fpath = '/project/graziul/data/' + zone + '/' + date + '/' + filename
    print(list(state))
    print(fpath)
    #print(filename)
    if(is_in_list(new_filename_list, filename) == 1):
        #print(fpath)
        if(os.path.exists(fpath)):
            #print(frame)
            if(is_in_list(fname_list,fpath) == 0):
                print(list(set(list(frame['transcriber']))))
                #print(frame.head())
                start_times = list(frame['start'])
                end_times = list(frame['end'])
                start_samples = [(int)(to_ms*start_time) for start_time in start_times]  #Convert to milliseconds
                end_samples = [(int)(to_ms*end_time) for end_time in end_times]  #Convert to milliseconds
                #transcripts = list(zip(start_times,end_times))
                transcripts = list(zip(start_samples,end_samples))
                transcripts_list.append(transcripts)
                fname_list.append(fpath)
            else:
                pass

new_vad_dict = {}
for i,fname in enumerate(fname_list):
    new_vad_dict[fname] = {'nonsilent_slices': transcripts_list[i], 'units':'milliseconds'}

pkl_path = '/project/graziul/ra/ajays/whitelisted_vad_dict.pkl' 
file = open(pkl_path,'wb')
pickle.dump(new_vad_dict,file)
file.close()'''

pkl_path = '/project/graziul/ra/ajays/whitelisted_vad_dict.pkl' 
#pkl_path = '/project/graziul/data/Zone1/2018_08_04/2018_08_04vad_dict.pkl'
file = open(pkl_path,'rb')
vad_dict = pickle.load(file)
file.close()

input_list = []
labels_list = []

for idx,key in enumerate(vad_dict):
    print(idx)
    a = audio_file(key)
    a.get_slices(vad_dict)
    input_list.append(a.get_split_mfcc()) 
    a.get_split_frames()
    labels_list.append(a.get_split_labels()) 
    #a.get_plots()
input_list = torch.cat(input_list)
input_list = torch.transpose(input_list,1,2)
labels_list = torch.from_numpy(np.concatenate(labels_list,axis = 0)).float()
print(input_list.size())
print(labels_list.size())

In [None]:
class StackedLSTM(nn.Module):
    def __init__(self):
        super(StackedLSTM, self).__init__()
        self.input_dim1 = 40
        self.input_dim2 = 64 
        self.hidden_dim = 64
        self.n_layers = 3
        self.batch_size = 2
        #(input is of format batch_size, sequence_length, num_features)
        #hidden states should be (num_layers, batch_size, hidden_length)
        self.hidden_state1 = torch.randn(self.n_layers, self.batch_size, self.hidden_dim)
        self.cell_state1 = torch.randn(self.n_layers, self.batch_size, self.hidden_dim)
        self.hidden_state2 = torch.randn(self.n_layers, self.batch_size, self.hidden_dim)
        self.cell_state2 = torch.randn(self.n_layers, self.batch_size, self.hidden_dim)
        self.lstm1 = nn.LSTM(input_size = self.input_dim1, hidden_size = self.hidden_dim, num_layers = self.n_layers, batch_first=True) #should be True
        self.lstm2 = nn.LSTM(input_size = self.input_dim2, hidden_size = self.hidden_dim, num_layers = self.n_layers, batch_first=True) #should be True
        self.lstm2_out = None 
        self.hidden = None
        #self.flatten = nn.Flatten()
        self.convolve1d = nn.Sequential(
            nn.Conv1d(3,3, kernel_size=11, padding=5),
            nn.BatchNorm1d(64, affine=False, track_running_stats=False),
            nn.ReLU(),
            nn.Conv1d(3,5, kernel_size=11, padding=5),
            nn.BatchNorm1d(64, affine=False, track_running_stats=False),
            nn.ReLU(),
            nn.Conv1d(5,5, kernel_size=11, padding=5),
            nn.BatchNorm1d(64, affine=False, track_running_stats=False),
            nn.ReLU(),
            nn.Conv1d(5,1, kernel_size=11, padding=5)
        )
        self.output_stack = nn.Sequential(
            nn.Linear(64, 64),
            nn.Linear(64, 1)
        )
        self.sigmoid = nn.Sigmoid()

#     def create_rand_hidden1(self):
#         self.hidden_state1 = torch.randn(self.n_layers, self.batch_size, self.hidden_dim)
#         self.cell_state1 = torch.randn(self.n_layers, self.batch_size, self.hidden_dim)
#         return (self.hidden_state1, self.cell_state1)

    def temp_attention(self, data):
        #hn = self.hidden_state1
        #cn = self.cell_state1
        #for idx in range(data.size()[1]):
        #    output, (hn,cn) = self.lstm1(torch.unsqueeze(data[:,idx,:],1), (hn,cn))
        H, hidden = self.lstm1(data, (self.hidden_state1, self.cell_state1)) 
        #H = output
        #hidden = (hn,cn)
        H_maxtemp = torch.unsqueeze(torch.max(H, -1).values,2)
        H_avgtemp = torch.unsqueeze(torch.mean(H, -1),2)
        H_stdtemp = torch.unsqueeze(torch.std(H, -1),2)
        H_concattemp = torch.cat([H_maxtemp, H_avgtemp,H_stdtemp], dim=2)
        H_concattemp = torch.transpose(H_concattemp, 1,2)
        return H_concattemp,H 
    
    def convolve1(self, data):
        H_concattemp,H = self.temp_attention(data)
        H_temp = self.convolve1d(H_concattemp)
        # "Expand/copy" output of last layer (H_temp) to same dims as H
        H_temp = H_temp.expand(-1,64,-1)
        # Sigmoid activation     
        sigmoid = nn.Sigmoid()
        my_input = H_temp
        H_temp = sigmoid(my_input)
        H_temp = torch.transpose(H_temp, 1, 2)
        # Merge H_temp and H by element wise summation
        H_prime = torch.stack((H,H_temp))
        H_prime = torch.sum(H_prime,0)
        return H_prime
        
#     def create_rand_hidden2(self):
#         self.hidden_state2 = torch.randn(self.n_layers, self.batch_size, self.hidden_dim)
#         self.cell_state2 = torch.randn(self.n_layers, self.batch_size, self.hidden_dim)
#         return (self.hidden_state2, self.cell_state2)  
    
#     def freq_attention(hidden_feature_map):
#         H_maxfreq = torch.max(hidden_feature_map, 0).values
#         H_avgfreq = torch.mean(hidden_feature_map, 0)
#         H_stdfreq = torch.std(hidden_feature_map, 0)
#         H_concatfreq = torch.cat([H_maxfreq[None, :], H_avgfreq[None, :], H_stdfreq[None,:]], dim=0)
#         return H_concatfreq 

    def forward(self, data):
        input1 = self.convolve1(data)
        #print(input1.size())
        #hn = self.hidden_state2
        #cn = self.cell_state2
        #for idx in range(input1.size()[1]):
        #    output, (hn,cn) = self.lstm2(torch.unsqueeze(input1[:,idx,:],1), (hn,cn))
        #lstm2_out = output
        #hidden = (hn,cn)
        lstm2_out, hidden = self.lstm2(input1, (self.hidden_state2, self.cell_state2))
        self.output = self.output_stack(lstm2_out)
        print(self.output)
        self.output = torch.squeeze(self.output)
        return self.output
        

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.input_dim1 = 40
        self.input_dim2 = 64 
        self.hidden_dim = 64
        self.n_layers = 3
        self.batch_size = 2
        #(input is of format batch_size, sequence_length, num_features)
        #hidden states should be (num_layers, batch_size, hidden_length)
        self.hidden_state1 = torch.randn(self.n_layers, self.batch_size, self.hidden_dim)
        self.cell_state1 = torch.randn(self.n_layers, self.batch_size, self.hidden_dim)
        self.hidden_state2 = torch.randn(self.n_layers, self.batch_size, self.hidden_dim)
        self.cell_state2 = torch.randn(self.n_layers, self.batch_size, self.hidden_dim)
        self.lstm1 = nn.LSTM(input_size = self.input_dim1, hidden_size = self.hidden_dim, num_layers = self.n_layers, batch_first=True) #should be True
        self.lstm2 = nn.LSTM(input_size = self.input_dim2, hidden_size = self.hidden_dim, num_layers = self.n_layers, batch_first=True) #should be True
        self.lstm2_out = None 
        self.hidden = None
        #self.flatten = nn.Flatten()
        self.convolve1d = nn.Sequential(
            nn.Conv1d(3,3, kernel_size=11, padding=5),
            nn.BatchNorm1d(64, affine=False, track_running_stats=False),
            nn.ReLU(),
            nn.Conv1d(3,5, kernel_size=11, padding=5),
            nn.BatchNorm1d(64, affine=False, track_running_stats=False),
            nn.ReLU(),
            nn.Conv1d(5,5, kernel_size=11, padding=5),
            nn.BatchNorm1d(64, affine=False, track_running_stats=False),
            nn.ReLU(),
            nn.Conv1d(5,1, kernel_size=11, padding=5)
        )
        self.output_stack = nn.Sequential(
            nn.Linear(64, 128),
            nn.Linear(128, 1)
        )
        self.sigmoid = nn.Sigmoid()


    def forward(self, data):
        out1,_ = self.lstm1(data,(self.hidden_state1,self.cell_state1))
        out2 = self.sigmoid(self.output_stack(out1))
        return torch.squeeze(out2)

class WeightedFocalLoss(nn.Module):
    "Non weighted version of Focal Loss"
    def __init__(self, alpha=.25, gamma=1):
        super(WeightedFocalLoss, self).__init__()
        self.alpha = torch.tensor([alpha, 1-alpha])
        self.gamma = gamma

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        targets = targets.type(torch.long)
        at = self.alpha.gather(0, targets.data.view(-1))
        pt = torch.exp(-BCE_loss)
        F_loss = at*(1-pt)**self.gamma * BCE_loss
        return F_loss.mean()

class FocalLoss(nn.modules.loss._WeightedLoss):
    def __init__(self, weight=None, gamma=1,reduction='mean'):
        super(FocalLoss, self).__init__(weight,reduction=reduction)
        self.gamma = gamma
        self.weight = weight #weight parameter will act as the alpha parameter to balance class weights

    def forward(self, inputs, targets):
        print(inputs)
        print(targets)
        BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='mean')
        return torch.mean(BCE_loss)
        #ce_loss = F.binary_cross_entropy(inputs, targets,reduction=self.reduction,weight=self.weight)
        #pt = torch.exp(-ce_loss)
        #focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
        #return focal_loss

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
#model = StackedLSTM().to(device)
model = ToyModel()
loss_fn = FocalLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

training_steps = 30
batch_size = model.batch_size
num_samples = input_list.size()[0]//batch_size
idx = 0
flag = 0
for step in range(training_steps):
    input_batch = input_list[idx*batch_size:(idx+1)*batch_size]
    labels_batch = labels_list[idx*batch_size:(idx+1)*batch_size]
    idx = (idx+1)%num_samples
    print(step)
    optimizer.zero_grad()
    output_hat = model(input_batch)
    #print(output_hat)
    loss = loss_fn(output_hat, labels_batch)
    loss.backward()
    #for param in model.parameters():
    #    print(param.grad)
    print(loss)
    optimizer.step()
    
output_list = []
idx = 0
num_samples = labels_list.size()[0]//batch_size
with torch.no_grad():
    while(idx < num_samples):
        print(idx)
        input_batch = input_list[idx*batch_size:(idx+1)*batch_size]
        labels_batch = labels_list[idx*batch_size:(idx+1)*batch_size]
        idx = idx+1
        output_hat = model(input_batch)
        #print(output_hat)
        #for param in model.parameters():
        #    print(param.grad)
        output_list.append(output_hat)
    output_list = torch.cat(output_list, dim = 0)
    
def get_frame_error_rate(output_hat, labels):
    num_samples = labels.size()[0]
    fer_arr = []
    for i in range(num_samples):
        curr_output = output_hat[i]
        curr_label = labels[i]
        fer_arr.append(torch.mean(torch.add(curr_output,curr_label)%2).data*100)
    return fer_arr

def test_frame_error_rate(output_hat, labels):
    num_samples = labels.size()[0]
    s_length = labels.size()[1]
    fer_arr = []
    sum = 0
    for i in range(num_samples):
        curr_output = output_hat[i]
        curr_label = labels[i]
        for j in range(s_length):
            if curr_output[j] == curr_label[j]:
                pass
            else:
                sum = sum+1
        fer_arr.append(torch.mean(torch.add(curr_output,curr_label)%2)*100)
    return sum

with torch.no_grad():
    print("Frame error Rate :" + str(get_frame_error_rate(torch.round(output_list),labels_list)))
    print(test_frame_error_rate(output_list, labels_list))