In [1]:
import sys
from collections import defaultdict

from scipy.io import wavfile
import numpy as np
from numpy.lib.stride_tricks import as_strided
import tensorflow as tf
from tensorflow.keras.layers import Input, Reshape, Conv2D, BatchNormalization, Conv1D
from tensorflow.keras.layers import MaxPool2D, Dropout, Permute, Flatten, Dense, MaxPool1D
from tensorflow.keras.models import Model
from shutil import copyfile
import pyhocon
import os
import matplotlib.pyplot as plt
import pandas as pd
from pydub import AudioSegment
from resampy import resample
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import pickle
from tqdm import tqdm
import math
import h5py
import os
import faiss
from sklearn.model_selection import LeaveOneOut
from scipy import stats
from sklearn.neighbors import KNeighborsClassifier
import tensorflow as tf
from sklearn.model_selection import train_test_split
import pickle



In [2]:
import numpy as np
import faiss


class FaissKNeighbors:
    def __init__(self, full_spd_lm_file, mbid_map,wd, k=5):
        self.index = None
        self.y = None
        self.k = k
        self.full_spd_lm_file = full_spd_lm_file
        self.mbid_map = mbid_map
        self.wd = wd

    def fit(self, X_ind, y):
        X = self.get_X(X_ind)
        self.index = faiss.IndexFlatL2(X.shape[1])
        self.index.add(X.astype(np.float32))
        self.y = y

    def get_X(self, X_ind):
        x2 = len(self.full_spd_lm_file['{}_{}'.format(self.wd, self.mbid_map[0])])
        X = np.zeros([len(X_ind), x2])
        for idx, x_id in enumerate(X_ind):
            X[idx] = self.full_spd_lm_file['{}_{}'.format(self.wd, self.mbid_map[x_id])]
        return X
    
    def predict(self, X_ind):
        X = self.get_X(X_ind)
        distances, indices = self.index.search(X.astype(np.float32), k=self.k)
        votes = self.y[indices]
        print(distances, indices)
        predictions = np.array([np.argmax(np.bincount(x)) for x in votes])
        return predictions

class SPDKNN:
    def __init__(self, full_spd_lm_file, mbid_map,wd,k=5):
        self.y = None
        self.full_spd_lm_file = full_spd_lm_file
        self.mbid_map = mbid_map   
        self.knn = KNeighborsClassifier(n_neighbors=k, algorithm='ball_tree', metric=self.bhatta)
        self.wd = wd
    
    def get_X(self, X_ind):
        x2 = len(self.full_spd_lm_file['{}_{}'.format(self.wd, self.mbid_map[0])])
        X = np.zeros([len(X_ind), x2])
        for idx, x_id in enumerate(X_ind):
            X[idx] = self.full_spd_lm_file['{}_{}'.format(self.wd, self.mbid_map[x_id])]
        return X
    
    def bhatta(self, hist1,  hist2):
        # calculate mean of hist1
        h1_ = np.mean(hist1)
        h2_ = np.mean(hist2)
        # calculate mean of hist2
    
        
        # calculate score
        score = np.sum(np.sqrt(np.multiply(hist1, hist2)))
        # print h1_,h2_,score;
        score = math.sqrt( 1 - ( 1 / math.sqrt(h1_*h2_*len(hist1)*len(hist2)) ) * score );
        return score

    def fit(self, X_ind, y):
        X = self.get_X(X_ind)
        self.knn.fit(X,y)
    
    def predict(self, X_ind):
        X = self.get_X(X_ind)
        return self.knn.predict_proba(X)

In [2]:
def get_pitch_data(pitch_file_path):
    data = pd.read_csv(pitch_file_path, sep='\t')
    return np.array([p for p in data.values[:,1] if p!=0])

In [3]:
def fix_paths(path, add_mp3=False):
    fixed_path = path.replace('&', '_')
    fixed_path = fixed_path.replace(':', '_')
    fixed_path = fixed_path.replace('\'', '_')

    return fixed_path

def copy_pitch_file(old_path, mbid, lm_file):
    current_path = os.path.dirname(os.path.abspath("__file__"))
    
    old_path = old_path.replace('audio', 'features')
    new_path = 'data\\pitches\\' + mbid +'.pitch'
    old_path = os.path.join(current_path, old_path)
    new_path = os.path.join(current_path, new_path)
    old_path = '\\\\?\\' + old_path.replace('/', '\\')
    new_path = '\\\\?\\' + new_path.replace('/', '\\')
    lm_file[mbid] = get_pitch_data(old_path)
    
#     copyfile(old_path,new_path)
    return mbid
    
def fetch_tonic(data, data_path):
    paths = data['path']
    mbids = data['mbid']
    tonic_list = []
    tonic_fine_list = []
    wav_path_list = []
    audio_data_path = os.path.join(data_path, 'audio')
    audio_lens = []
    # mp3_file_path = []
    current_path = os.path.dirname(os.path.abspath("__file__"))
    for mbid, path in tqdm(zip(mbids, paths)):
        
        path = r'' + path
#         data_path = os.path.dirname(os.path.realpath(__file__))
        # feature_path = os.path.join(data_path, feature_path)
        path = fix_paths(path)
#         print(path)
        feature_path = path.replace('/audio/', '/features/')
        tonic_path = os.path.join(feature_path).replace('/', '\\') + '.tonic'
        tonic_path = os.path.join(current_path, tonic_path)
        tonic_path = '\\\\?\\' + tonic_path
#         tonic_path = os.path.join(feature_path).replace('/', '\\') + '.tonic'
        
        tonic_fine_path = os.path.join(feature_path).replace('/', '\\') + '.tonicFine'
        tonic_fine_path = os.path.join(current_path, tonic_fine_path)
        tonic_fine_path = '\\\\?\\' + tonic_fine_path
        
        mp3_file = os.path.join(feature_path).replace('/', '\\') + '.mp3'
        mp3_file = os.path.join(current_path, mp3_file)
        mp3_file = '\\\\?\\' + mp3_file
        
#         tonic_fine_path = '\\\\?\\' + os.path.join(feature_path).replace('/', '\\') + '.tonicFine'
#         tonic_fine_path = os.path.join(current_path, tonic_fine_path)
# #         tonic_fine_path =  os.path.join(feature_path).replace('/', '\\') + '.tonicFine'
        
        
#         mp3_file = '\\\\?\\' + os.path.join(path).replace('/', '\\') + '.mp3'
#         mp3_file = os.path.join(current_path, mp3_file)
#         mp3_file =  os.path.join(path).replace('/', '\\') + '.mp3'
        if os.path.exists(tonic_path):
            
            with open(tonic_path, 'r') as f:
                tonic = f.readline().strip()
                tonic_list.append(tonic)
        else:
            print(tonic_path)
            tonic_list.append(-1)

        if os.path.exists(tonic_fine_path):
            with open(tonic_fine_path, 'r') as f:
                tonic_fine = f.readline().strip()
                tonic_fine_list.append(tonic_fine)
        else:
            tonic_fine_list.append(tonic_list[-1])

#         wav_file = os.path.join(audio_data_path, mbid + '.wav')
        # mp3_file_moved_path = os.path.join(audio_data_path, mbid+'.mp3')
        # move_files(mp3_file, mp3_file_moved_path)
        # mp3_file_path.append(mp3_file_moved_path)
#         audio_len = mp3_to_wav(mp3_file, wav_file)
#         audio_lens.append(audio_len)
#         wav_path_list.append(wav_file)

    data['path'] = data['path'].map(lambda x: os.path.join(x+'.pitch'))
    data['tonic'] = tonic_list
    data['tonic_fine'] = tonic_fine_list
#     data['path'] = wav_path_list
#     data['len'] = audio_lens
    return data

def create_data_file():
    
    folder_path = 'data/'
    traditions = ['Hindustani', 'Carnatic']
    lm_file = 'pitch_data.h5'
    with h5py.File('data/RagaDataset/pitch_data.h5', "w") as lm_file:
        for trad in traditions:

            data_path = 'data/RagaDataset/{}/data.tsv'.format(trad)
            # if os.path.exists(train_path) and os.path.exists(train_path) and os.path.exists(train_path):
            #     continue
            path_mbid_ragaid = os.path.join(folder_path,'RagaDataset', trad, '_info_', 'path_mbid_ragaid.txt')
            df = pd.read_csv(path_mbid_ragaid, names=['path', 'mbid', 'rag_id'], sep='\t')
            df['path'] = df['path'].map(lambda x: os.path.join(folder_path, x))
            df = fetch_tonic(df, folder_path)
        #     df = df[df['len'] != -1]
            grouped = df.groupby(['rag_id'])

            ragaId_to_ragaName_mapping = os.path.join(folder_path,'RagaDataset', trad, '_info_',
                                                      'ragaId_to_ragaName_mapping.txt')
            ragaId_to_ragaName = pd.read_csv(ragaId_to_ragaName_mapping, sep='\t', names=['rag_id', 'rag_name'])

            ragaId_to_ragaName['labels'] = np.arange(ragaId_to_ragaName.shape[0])
            ragaId_to_ragaName = ragaId_to_ragaName.set_index(['rag_id'])

            data_list = []
            lbl = 0
            for k, v in grouped:
                v['rag_name'] = v['rag_id'].map(lambda x: ragaId_to_ragaName.loc[x]['rag_name'])
                v['labels'] = [lbl] * v.shape[0]
                v['path'] = v['path'].map(lambda x: fix_paths(x, False))
                v = v.reset_index(drop=True)
                data_list.append(v)
                lbl += 1

            data_list = pd.concat(data_list)
            data_list['new_path'] = data_list.apply(lambda x: copy_pitch_file(x['path'], x['mbid'], lm_file), axis=1)
            data_list.to_csv(data_path, sep='\t', index=False)

In [None]:
# create_data_file()

In [4]:
def freq_to_cents_np(freq, cents_mapping, std=25):
    frequency_reference = 10
    c_true = 1200 * np.log2((np.array(freq)+1e-5) / frequency_reference)
    c_true = np.expand_dims(c_true, 1)
    cents_mapping = np.tile(np.expand_dims(cents_mapping,0), [c_true.shape[0],1])
    target = np.exp(-(cents_mapping - c_true) ** 2 / (2 * std ** 2))
    pitch_cent = np.sum(target.reshape([c_true.shape[0], 6, 120]), 1)
    return pitch_cent
    

In [5]:
def freq_to_cents(freq, cents_mapping, std=25):
    frequency_reference = 10
    c_true = 1200 * math.log((freq+1e-5) / frequency_reference,2)
    target = np.exp(-(cents_mapping - c_true) ** 2 / (2 * std ** 2))
    pitch_cent = np.sum(target.reshape([6, 120]), 0)
    return pitch_cent
    

In [6]:
def get_pitchvalues(pitches_arr):
    cents_mapping = np.linspace(0, 7190, 720) + 2051.1487628680297
    return freq_to_cents_np(pitches_arr, cents_mapping)
#     pitchvalue_prob = []
    
#     for p in pitches_arr:
#         if p==0:
#             continue
#         pitchvalue_prob.append(freq_to_cents(p, cents_mapping))
#     return pitchvalue_prob

In [7]:
def reorder_tonic(pitchvalue_prob, tonic_freq):
    cents_mapping = np.linspace(0, 7190, 720) + 2051.1487628680297
    tonic_pv_arr = freq_to_cents(tonic_freq, cents_mapping)
    tonic_pv = np.argmax(tonic_pv_arr) 
    return np.roll(pitchvalue_prob, -tonic_pv, axis=1)

In [8]:
def normalize(z):
    z_min = np.min(z)
    return (z - z_min)/(np.max(z)-z_min+1e-6)

In [9]:
def compare(a,b,x,asc):
    if not asc:
        a, b = b, a
    if a <= modulo_add(a,x) <= modulo_add(a,b):
        return True
    return False

In [10]:
def modulo(x):
    return x%120

In [11]:
def modulo_add(x,y):
    mx = modulo(x)
    my = modulo(y)
    if mx>my:
        return my+120
    return my

In [12]:
def relax_fun(p,add,r=4):
    if add:
        return modulo(p+4)
    return modulo(p-4)

In [13]:
def get_pitch_distribution(pitchvalue_prob):
    s_mean = normalize(np.mean(pitchvalue_prob, axis=0))
    return s_mean

In [14]:
def get_pitch_histograms(pitchvalue_prob):
    s_mean = get_pitch_distribution(pitchvalue_prob)
    s_std = normalize(np.std(pitchvalue_prob, axis=0))

    s = np.stack([s_mean, s_std], axis=-1)
    return s


In [15]:
def get_pd_between_pspe(pitchvalue_prob, ps, pe, asc, relax=2):
    if asc:
        relax_sign = relax
    else:
        relax_sign = -relax
    pd = get_pitch_distribution(pitchvalue_prob)
    w = np.zeros(60)
    for i in range(60):
        w[i] = pd[i]*compare(ps-relax_sign,pe+relax_sign,i, asc)
    return normalize(w)

In [16]:
# def get_dist_btw_idx(pitchvalue_prob, pitch_st_mapping, idx):
#     dist = 0
#     start_idx = pitch_st_mapping[idx][0]
#     end_idx = pitch_st_mapping[idx][1]
#     for i in range(start_idx, end_idx+1):
#         dist+=pitchvalue_prob[i]
#     return dist

def get_dist_btw_idx(pitchvalue_prob, start_idx, end_idx):
    dist = 0
    for i in range(start_idx, end_idx+1):
        dist+=pitchvalue_prob[i]
    return dist

def get_dist_btw_shortlisted_idxs(pitchvalue_prob,shortlisted_idxs, off_start=0, off_end=None):
    if off_end is None:
        off_end = len(pitchvalue_prob)-1
    dist = 0
    for sidx in shortlisted_idxs:
        i1, i2, i3, i4 = sidx[0], sidx[1], sidx[2], sidx[3]
        if i2<off_start:
            continue
        if i3>off_end:
            continue
        if i1<=off_start<=i2:
            i1 = off_start
        if i3<=off_end<=i4:
            i4 = off_end
        
        dist += get_dist_btw_idx(pitchvalue_prob, i1, i4)
    return normalize(dist)


In [17]:
def get_width(ps, pe, asc, relax=4):
    if asc:
        relax_sign = relax
    else:
        relax_sign = -relax
    c=0
    for i in range(120):
        if compare(ps-relax_sign,pe+relax_sign,i, asc):
            c+=1
    return c

In [18]:
# def add_lm_file(lm_file, pitch_st_mapping, start_index, end_index, prev_dist, base_key):
#     psm_ss = pitch_st_mapping[start_index][0] 
#     psm_se = pitch_st_mapping[start_index][1] 
#     psm_es = pitch_st_mapping[end_index][0]
#     psm_ee = pitch_st_mapping[end_index][1] 
            
#     lm_file[base_key.format(psm_ss, psm_se, psm_es, psm_ee)] = prev_dist
def update_shortlisted_index(shortlisted_index, pitch_st_mapping, start_index, end_index):
    psm_ss = pitch_st_mapping[start_index][0] 
    psm_se = pitch_st_mapping[start_index][1] 
    psm_es = pitch_st_mapping[end_index][0]
    psm_ee = pitch_st_mapping[end_index][1] 
    
    shortlisted_index.append((psm_ss, psm_se, psm_es, psm_ee))
    

In [19]:
def compute_spd_ps_pe(pitchvalue_prob, start_idx, pitches_arg, pitch_st_mapping, ps, pe, asc, mbid, relax=4):
#     width = None
#     prev_dist = 0
    n = len(pitches_arg)
    k=0
    start = True
    end = False
    b = 0
    start_id = 0
    end_id = 0
    si=0
    prev_s = None
    start_index = -1
    end_index = -1
    idx = start_idx[si]
    dist_pres = False
    dist_added = False
#     for idx in range(start_idx[si], n):
#     base_key = '{}_{}_{}_{}'
    shortlisted_index = []
    while idx<n:
        
#         if si>=len(start_idx):
#             break
#         if si>=len(start_idx):
#             break
#         if idx>=start_idx[si]:
#             si+=1
        p = pitches_arg[idx]

        if start and end and p!=pe:
            update_shortlisted_index(shortlisted_index, pitch_st_mapping, start_index, end_index)
            # This verfies Equation 13; Page 4
#             add_lm_file(lm_file, pitch_st_mapping, start_index, end_index, prev_dist, base_key)
#             lm_file[base_key.format()]
#             shortlisted_index.append()
#             prev_dist = 0
            start = False
            end = False
            start_index = -1
            end_index = -1
            dist_pres = True

        if start and p==pe:
            end = True  # This verifies Equation 14; Page 4
            end_index = idx
#             prev_dist += get_dist_btw_idx(pitchvalue_prob, pitch_sst_mapping, idx)
            
            
#         if start and (compare(ps, pe, p, asc)) and (not end): # This verifies Equation 15; Page 4
#             prev_dist += get_dist_btw_idx(pitchvalue_prob, pitch_st_mapping, idx) 
#           prev_dist_start.append(idx)
    

        if p==ps:
            start = True  # This verifies Equation 12; Page 4
            start_index = idx
#             prev_dist += get_dist_btw_idx(pitchvalue_prob, pitch_st_mapping, idx)
            
        if not (compare(ps, pe, p, asc)):
#             prev_dist = 0
            start = False
            end = False
            start_index = -1
            end_index = -1
            
        if p==ps:
            si+=1
        if not start:
            if si>=len(start_idx):
                break
            else:
                idx = start_idx[si]
                idx-=1
        idx+=1
    
    # This handles an edge case where prev_dist is not empty but not yet been added to cum_pitch_dist
    if start and end:
        update_shortlisted_index(shortlisted_index, pitch_st_mapping, start_index, end_index)
#         add_lm_file(lm_file, pitch_st_mapping, start_index, end_index, prev_dist, base_key)
    
    if not dist_pres:
#         prev_dist = get_pd_between_pspe(pitchvalue_prob, ps, pe, asc)  # Return simple pitch distributin incase SPD is empty lines 254, 255
#         add_lm_file(lm_file, pitch_st_mapping, start_index, end_index, prev_dist, base_key)
        update_shortlisted_index(shortlisted_index, pitch_st_mapping, 0, n-1)
    return shortlisted_index

In [20]:
def get_all_smooth_pitch_values(std=25):
    c_note = freq_to_cents(32.7 * 2, std)
    all_notes = np.zeros([120, 120])
    for p in range(120):
        all_notes[p] = get_smooth_pitch_value(c_note, p)

    return all_notes, c_note

def get_smooth_pitch_value(c_note, note):
    return np.roll(c_note, note, axis=-1)

def gauss_smooth(raga_feat):
    all_notes, c_note = get_all_smooth_pitch_values(std=25)
    smooth = np.zeros([12,12,120,2])
    for i in range(12):
        for j in range(12):
            if i==j:
                continue
            for k in range(0,2):
                smooth[i,j,:,k] = gauss_smooth_util(raga_feat[i,j,:,k], all_notes)
    return smooth
                
def gauss_smooth_util(arr1, all_notes):
    smooth = 0
    for i in range(120):
        smooth = smooth + all_notes[i]*arr1[i]
    
#     smooth = np.power(normalize(smooth), 0.8)
    smooth = normalize(smooth)
    return smooth

In [21]:
def full_spd(pitchvalue_prob, pitches_arg, mbid, lm_file):
    pitch_dict, std_pitches, pitch_st_mapping = get_std_idx(pitches_arg)
    for asc in [True, False]:
        for s in range(0, 12, 1):
            start_idx = pitch_dict[s]
            for e in range(0, 12, 1):
                if s==e:
                    continue
#                 lm_file_group = lm_file.create_group("{}_{}_{}_{}".format(mbid, s, e, asc))
                shortlisted_index = compute_spd_ps_pe(pitchvalue_prob, start_idx, std_pitches, pitch_st_mapping, s, e, asc, mbid)  
                lm_file["{}_{}_{}_{}".format(mbid, s, e, asc)] = shortlisted_index

In [22]:
def get_std_idx(pitches_arg, relax=4):
    pitch_dict = defaultdict(list)
    std_pitches = []
    pitch_st_mapping = []
    prev_p = None
    k=0
    for i, p in enumerate(pitches_arg):
        if prev_p is None:
            std_pitches.append(p//10)
            pitch_dict[p//10].append(k)
        elif prev_p//10 != p//10:
            k+=1
            std_pitches.append(p//10)
            pitch_dict[p//10].append(k)
        if k>=len(pitch_st_mapping):
            pitch_st_mapping.append([i,i])
        else:
            if pitch_st_mapping[-1][1]+1==i:
                pitch_st_mapping[-1][1] = i
            else:
                pitch_st_mapping.append([i,i])
        prev_p = p
        
    return pitch_dict, std_pitches, pitch_st_mapping

In [23]:
def get_unique_seq(pitches_arg):
    ps = pitches_arg[0]
    count = 1
    pitches_unique = []
    pitches_count = []
    for p in pitches_arg[1:]:
        if p == ps:
            count += 1
        else:
            pitches_unique.append(ps)
            pitches_count.append(count)
            count = 1
        ps = p

    pitches_unique.append(ps)
    pitches_count.append(count)

    return np.array(pitches_unique), np.array(pitches_count)

In [24]:
def get_nearest_end_idx(key, data, start=0):
    low = start;
    high = len(data)-1;

    while( low <= high):
        mid = (low+high)//2
        
        if data[mid] < key:
            low = mid + 1
        elif data[mid] > key:
            high  = mid - 1
        else:
            return mid
    if high<0:
        return 0     
    elif low > len(data)-1:
        return -1
    else:
        if low < high:
            return low
        else:
            return high+1


In [25]:
def get_nearest_end_idx_lin(key,data, start=0):
    n = len(data)
    idx= start
    ep = data[idx]
    while idx<n:
        ep = data[idx]
        if ep>=key:
            return idx
        idx+=1
    
    if idx==n:
        return -1
    return idx

In [26]:
def generate_spd_idx_all_files(tradition):
    pitch_dict = None
    data_file_path = 'data/RagaDataset/{}/data.tsv'.format(tradition)
    df = pd.read_csv(data_file_path, sep='\t')
    
    with h5py.File('data/'+tradition+'_spd_cache', "w") as spd_idx_lm_file:
        with h5py.File('data/RagaDataset/pitch_data.h5', "r") as pitch_lm_file: 
            for row in df.iterrows():
                file_name = row[1]['new_path']
                tonic = row[1]['tonic_fine']
                mbid = row[1]['mbid']
                print(file_name)
                pitches = pitch_lm_file[mbid]
                pitchvalue_prob = get_pitchvalues(pitches)
                pitchvalue_prob = reorder_tonic(pitchvalue_prob, tonic)

                pitches_arg = np.argmax(pitchvalue_prob, axis=1)
#                 pitch_dict = get_full_spd_st(pitches_arg, mbid, spd_idx_lm_file)
                full_spd(pitchvalue_prob, pitches_arg, mbid, spd_idx_lm_file)
#                 break

In [1929]:
# pitch_dict = generate_spd_idx_all_files('Carnatic')
# pitch_dict = generate_spd_idx_all_files('Hindustani')

0ded0c52-7f15-4140-b45a-ca829106d053
ca5c5c46-47d9-4fa9-a5d8-ab6120940a6d
874a5b30-2eed-4760-8fba-431a4f290dcc
4b327e7a-6146-4f44-a5ff-b4c099a8bfb5
96ea7753-a2cc-4cf5-be0e-042acd56d29c
af6a1ff7-98b3-4e99-8133-5a15f66d8904
cb4a75f3-0005-4311-bd4c-2fa97da53bf3
33d896d2-1a7b-4e5f-9508-c3fb5d228c94
16a3263f-31dc-40da-839b-f5955b77c0b6
e1b0148b-1542-4e44-83b1-c92c1f0ca56e
7e43c413-89d0-491c-add2-2b1520d7cb33
22601164-1342-42ae-9a27-02a9a657bb55
16fbd94c-b09c-44d9-8d0f-841ada1b62c7
f769247c-e0c7-4078-8c5a-6741d12d06dd
3f0ea455-7e47-48c6-942a-9dd60d36e04b
a91ca01b-3338-4968-815a-6ae61644ccb5
0096a390-e1dd-4e29-a4c9-2e92352535c4
92c61c93-4a02-457e-8416-30af03d207dd
d6b02b8d-ae96-4a43-a8a5-2cc830761cce
31ca1f76-ce7a-42f4-b429-ca43ef6b2ba1
2c666f9b-17ae-4b27-8d29-a2bb49c2c2f5
3d13fd51-456e-46d6-a6e7-c6e50991fdac
5853dea4-9f86-4f62-8f47-abc520f3493e
f1f04b03-3f91-4d31-bf7b-6c21d26eeb7e
fafd1a88-325d-4daa-999e-9dfdba4853ba
cf7acd46-da61-47c7-99aa-1de0c4ebc94c
6fbf2d80-ada4-41aa-8baa-fb288619883e
7

9f56f2b9-9cb3-41d1-bf03-65486374f4a5
66cfbbdb-444b-4155-9565-3dee209775de
75a6e70e-916d-46fd-ad12-0d8a7161384b
417747bb-9210-456d-b6e5-6b4502b924bc
5a6ed43b-e79f-4dfc-ae67-f0c97770a6e2
17b24021-3716-4de5-907c-8105e0219647
b45961aa-9f9a-4f07-ae72-207135d3a679
3f3e341e-38d7-4479-8780-359f8e554195
57b2324e-f4c3-4b4e-8c03-ff985e29dd7b
68f26dd3-5974-460f-b43c-31a90cfe9b54
6f5c24f4-e570-436f-a8a9-b2a3d74a34d2
1e3e10ad-93ed-439e-85b6-de67cd362d05
7f3e0c2c-ff3e-4a9a-853f-159269f88941
4434f3a0-f11a-4027-95c8-cdefb92237cf
829df365-78bc-4157-9346-5a3b39bf12a5
1587bdd5-60b6-481e-9a30-60525607850f
225e1609-23af-44af-a360-5e1f5518caae
2611e0fd-7c4c-4fe2-8c35-d4dcde7a7962
1e2f83ca-b0ce-4f12-a571-037cdc676738
9449e420-b340-4475-ad25-676460010369
9cfa8a63-5ffe-4f48-ae53-ffbe74b03f55
aee1e256-38fc-4721-9cc9-403542ac0b94
5226b897-06fc-4032-9766-66e5b19e74fb
902b21c0-985b-4b6b-a30e-c3c505b69fb1
41b3231c-6b46-4eb8-bc20-67ff3f065fd3
6ba73b55-24e8-4efd-befc-1c8c163f2e17
91195a3f-b411-47e1-b16c-a507e44d891d
e

77f42a8e-6e25-4145-a919-0e6f62ae0fb8
afd1f1c8-ecf2-46aa-ac68-09e92d10f071
68bcad2f-b0b3-485f-ad5f-7edf9d783551
a9bf3ac1-9110-458c-8a3a-b0b6d9bf3d03
68f2f8bf-c53f-4180-95b1-4ee08a94b7bd
c44b016a-aa9f-48f9-8007-175b7b38bd54
de53bd66-536b-4929-88b1-9fe07eaa0121
c73d9043-f739-4a6e-887f-be32ddde73a7
1c0836c6-d7d8-4ceb-96c5-9c968cd41442
09bd5b5f-1b57-4269-aa76-e3bb87471743
46c4267b-f2c8-4fd7-8423-d7cc648395bb
7e8015eb-ca83-464d-aec6-ee0f50ab1186
323ff36b-db5c-4d91-a33a-239a944810e2
fea29da7-6bea-4a86-862c-7d91c366ea12
958dbe4f-d910-425b-91de-3125f08a0a6c
1520c037-248b-46c8-8eab-65e31d3f3b08
28bbe763-865b-4dd8-9e47-3b74919e945c
f6c659e6-abbf-4eb8-aacb-105056d7b47f
738433ad-9009-44df-80c0-6febc57e4ce8
0ba5d4d1-037b-4dcc-ae0c-c832a7681b89
2bade8d8-1cfa-4076-9329-98f7cacc65a0
855fa024-a4c5-4cca-bb26-0b2c806db722
728b5ad7-5348-4d24-bcf9-a6e6223ff7db
6f0a54e1-4dc0-4b3e-9482-94670862d8bf
9d8a5fd9-0a2d-4411-aff4-ca67698a27ec
4bdce886-d7f1-4076-b2c8-ee94096be9e5
58c2c5cf-b258-4f32-a15a-6a9f2b5acd88
2

In [27]:
def get_cliped_dist(s,e,asc,dist,clip=15):
    s10 = s*10
    e10 = e*10
    if asc:
        relax = clip
    else:
        relax = -clip
    i = modulo(s10-relax)
    j = modulo(e10+relax)
    m = 0
    while i!=j:
        if asc:
            i = modulo(i+1)
        else:
            i = modulo(i-1)
        m+=1
    dist_sliced = np.zeros(m)
    i = modulo(s10-relax)
    j = modulo(e10+relax)
    if (m<=abs(relax)):
        dist_sliced = dist
    else:
        m=0
        while i!=j:
            if asc:
                i = modulo(i+1)
            else:
                i = modulo(i-1)
            dist_sliced[m] = dist[i]
            m+=1
    return dist_sliced

def get_spd_from_idx(mbid, tonic, tradition, just_spd_lm_file, just_hist_lm_file, off_start=0, off_end=None):
    with h5py.File('data/RagaDataset/pitch_data.h5', "r") as pitch_lm_file:
        pitches = pitch_lm_file[mbid]
        pitchvalue_prob = get_pitchvalues(pitches)
        pitchvalue_prob = reorder_tonic(pitchvalue_prob, tonic)
        dist_hist = get_dist_btw_idx(pitchvalue_prob, 0, len(pitchvalue_prob)-1)
        dist_hist = normalize(dist_hist)
        with h5py.File('data/'+tradition+'_spd_cache', "r") as spd_idx_lm_file:
            full_spd_dist = np.zeros([12,12,120,2])
            for asc in [True, False]:
                asc_int = 1-int(asc)
                for s in range(0,12,1):
                    for e in range(0, 12, 1):
                        if s==e:
                            full_spd_dist[s,e,:,asc_int] = dist_hist
                            continue
                        shortlisted_idxs = spd_idx_lm_file['{}_{}_{}_{}'.format(mbid, s, e, asc)]
                        dist = get_dist_btw_shortlisted_idxs(pitchvalue_prob,shortlisted_idxs, off_start, off_end)
                        full_spd_dist[s,e,:,asc_int] = dist
            just_spd_lm_file[mbid] = full_spd_dist
            just_hist_lm_file[mbid] = dist_hist
            

In [28]:
def generate_full_spd_cache(tradition):
    with h5py.File('data/'+tradition+'_just_spd_cache', "w") as just_spd_lm_file:
        with h5py.File('data/'+tradition+'_just_hist_cache', "w") as just_hist_lm_file:
            data_file_path = 'data/RagaDataset/{}/data.tsv'.format(tradition)
            df = pd.read_csv(data_file_path, sep='\t')
            for row in df.iterrows():
                mbid = row[1]['mbid']
                tonic = row[1]['tonic']
                print(mbid)
                get_spd_from_idx(mbid, tonic, tradition, just_spd_lm_file, just_hist_lm_file)


In [1930]:
# generate_full_spd_cache('Carnatic')
# generate_full_spd_cache('Hindustani')

0ded0c52-7f15-4140-b45a-ca829106d053
ca5c5c46-47d9-4fa9-a5d8-ab6120940a6d
874a5b30-2eed-4760-8fba-431a4f290dcc
4b327e7a-6146-4f44-a5ff-b4c099a8bfb5
96ea7753-a2cc-4cf5-be0e-042acd56d29c
af6a1ff7-98b3-4e99-8133-5a15f66d8904
cb4a75f3-0005-4311-bd4c-2fa97da53bf3
33d896d2-1a7b-4e5f-9508-c3fb5d228c94
16a3263f-31dc-40da-839b-f5955b77c0b6
e1b0148b-1542-4e44-83b1-c92c1f0ca56e
7e43c413-89d0-491c-add2-2b1520d7cb33
22601164-1342-42ae-9a27-02a9a657bb55
16fbd94c-b09c-44d9-8d0f-841ada1b62c7
f769247c-e0c7-4078-8c5a-6741d12d06dd
3f0ea455-7e47-48c6-942a-9dd60d36e04b
a91ca01b-3338-4968-815a-6ae61644ccb5
0096a390-e1dd-4e29-a4c9-2e92352535c4
92c61c93-4a02-457e-8416-30af03d207dd
d6b02b8d-ae96-4a43-a8a5-2cc830761cce
31ca1f76-ce7a-42f4-b429-ca43ef6b2ba1
2c666f9b-17ae-4b27-8d29-a2bb49c2c2f5
3d13fd51-456e-46d6-a6e7-c6e50991fdac
5853dea4-9f86-4f62-8f47-abc520f3493e
f1f04b03-3f91-4d31-bf7b-6c21d26eeb7e
fafd1a88-325d-4daa-999e-9dfdba4853ba
cf7acd46-da61-47c7-99aa-1de0c4ebc94c
6fbf2d80-ada4-41aa-8baa-fb288619883e
7

9f56f2b9-9cb3-41d1-bf03-65486374f4a5
66cfbbdb-444b-4155-9565-3dee209775de
75a6e70e-916d-46fd-ad12-0d8a7161384b
417747bb-9210-456d-b6e5-6b4502b924bc
5a6ed43b-e79f-4dfc-ae67-f0c97770a6e2
17b24021-3716-4de5-907c-8105e0219647
b45961aa-9f9a-4f07-ae72-207135d3a679
3f3e341e-38d7-4479-8780-359f8e554195
57b2324e-f4c3-4b4e-8c03-ff985e29dd7b
68f26dd3-5974-460f-b43c-31a90cfe9b54
6f5c24f4-e570-436f-a8a9-b2a3d74a34d2
1e3e10ad-93ed-439e-85b6-de67cd362d05
7f3e0c2c-ff3e-4a9a-853f-159269f88941
4434f3a0-f11a-4027-95c8-cdefb92237cf
829df365-78bc-4157-9346-5a3b39bf12a5
1587bdd5-60b6-481e-9a30-60525607850f
225e1609-23af-44af-a360-5e1f5518caae
2611e0fd-7c4c-4fe2-8c35-d4dcde7a7962
1e2f83ca-b0ce-4f12-a571-037cdc676738
9449e420-b340-4475-ad25-676460010369
9cfa8a63-5ffe-4f48-ae53-ffbe74b03f55
aee1e256-38fc-4721-9cc9-403542ac0b94
5226b897-06fc-4032-9766-66e5b19e74fb
902b21c0-985b-4b6b-a30e-c3c505b69fb1
41b3231c-6b46-4eb8-bc20-67ff3f065fd3
6ba73b55-24e8-4efd-befc-1c8c163f2e17
91195a3f-b411-47e1-b16c-a507e44d891d
e

77f42a8e-6e25-4145-a919-0e6f62ae0fb8
afd1f1c8-ecf2-46aa-ac68-09e92d10f071
68bcad2f-b0b3-485f-ad5f-7edf9d783551
a9bf3ac1-9110-458c-8a3a-b0b6d9bf3d03
68f2f8bf-c53f-4180-95b1-4ee08a94b7bd
c44b016a-aa9f-48f9-8007-175b7b38bd54
de53bd66-536b-4929-88b1-9fe07eaa0121
c73d9043-f739-4a6e-887f-be32ddde73a7
1c0836c6-d7d8-4ceb-96c5-9c968cd41442
09bd5b5f-1b57-4269-aa76-e3bb87471743
46c4267b-f2c8-4fd7-8423-d7cc648395bb
7e8015eb-ca83-464d-aec6-ee0f50ab1186
323ff36b-db5c-4d91-a33a-239a944810e2
fea29da7-6bea-4a86-862c-7d91c366ea12
958dbe4f-d910-425b-91de-3125f08a0a6c
1520c037-248b-46c8-8eab-65e31d3f3b08
28bbe763-865b-4dd8-9e47-3b74919e945c
f6c659e6-abbf-4eb8-aacb-105056d7b47f
738433ad-9009-44df-80c0-6febc57e4ce8
0ba5d4d1-037b-4dcc-ae0c-c832a7681b89
2bade8d8-1cfa-4076-9329-98f7cacc65a0
855fa024-a4c5-4cca-bb26-0b2c806db722
728b5ad7-5348-4d24-bcf9-a6e6223ff7db
6f0a54e1-4dc0-4b3e-9482-94670862d8bf
9d8a5fd9-0a2d-4411-aff4-ca67698a27ec
4bdce886-d7f1-4076-b2c8-ee94096be9e5
58c2c5cf-b258-4f32-a15a-6a9f2b5acd88
2

In [29]:
class SPDKNN:
    def __init__(self,k=5):
        self.y = None
        self.knn = KNeighborsClassifier(n_neighbors=k, algorithm='ball_tree', metric=self.bhatta)
        self.wd = wd
    
    def get_X(self, X_ind):
        x2 = len(self.full_spd_lm_file['{}_{}'.format(self.wd, self.mbid_map[0])])
        X = np.zeros([len(X_ind), x2])
        for idx, x_id in enumerate(X_ind):
            X[idx] = self.full_spd_lm_file['{}_{}'.format(self.wd, self.mbid_map[x_id])]
        return X
    
    def bhatta(self, hist1,  hist2):
        # calculate mean of hist1
        h1_ = np.mean(hist1)
        h2_ = np.mean(hist2)
        # calculate mean of hist2
    
        
        # calculate score
        score = np.sum(np.sqrt(np.multiply(hist1, hist2)))
        # print h1_,h2_,score;
        score = math.sqrt( 1 - ( 1 / math.sqrt(h1_*h2_*len(hist1)*len(hist2)) ) * score );
        return score

    def fit(self, X, y):
        self.knn.fit(X,y)
    
    def predict(self, X):
        return self.knn.predict_proba(X)

In [36]:
for wd in range(120,240,10):
    
    for e in range(0,120,10):
        s = wd-120
        if s==e:
            continue
        print(s,e)
        
    print('ads')

0 10
0 20
0 30
0 40
0 50
0 60
0 70
0 80
0 90
0 100
0 110
ads
10 0
10 20
10 30
10 40
10 50
10 60
10 70
10 80
10 90
10 100
10 110
ads
20 0
20 10
20 30
20 40
20 50
20 60
20 70
20 80
20 90
20 100
20 110
ads
30 0
30 10
30 20
30 40
30 50
30 60
30 70
30 80
30 90
30 100
30 110
ads
40 0
40 10
40 20
40 30
40 50
40 60
40 70
40 80
40 90
40 100
40 110
ads
50 0
50 10
50 20
50 30
50 40
50 60
50 70
50 80
50 90
50 100
50 110
ads
60 0
60 10
60 20
60 30
60 40
60 50
60 70
60 80
60 90
60 100
60 110
ads
70 0
70 10
70 20
70 30
70 40
70 50
70 60
70 80
70 90
70 100
70 110
ads
80 0
80 10
80 20
80 30
80 40
80 50
80 60
80 70
80 90
80 100
80 110
ads
90 0
90 10
90 20
90 30
90 40
90 50
90 60
90 70
90 80
90 100
90 110
ads
100 0
100 10
100 20
100 30
100 40
100 50
100 60
100 70
100 80
100 90
100 110
ads
110 0
110 10
110 20
110 30
110 40
110 50
110 60
110 70
110 80
110 90
110 100
ads


In [121]:
def train_model(tradition, only_save=False):
    data_file_path = 'data/RagaDataset/{}/data.tsv'.format(tradition)
    if tradition=='Hindustani':
        n_rows = 300
        n_labels = 30
    else:
        n_rows = 480
        n_labels = 40
    df = pd.read_csv(data_file_path, sep='\t')
    with h5py.File('data/'+tradition+'_just_spd_cache', "r") as just_spd_lm_file:
        with h5py.File('data/'+tradition+'_just_hist_cache', "r") as just_hist_lm_file:
            with h5py.File('data/'+tradition+'_output_cache_clipped', "r") as output_lm_file:
                for wd in range(0,250,10):
                    spd_knn = SPDKNN(k=5)
                    y_labels = df['labels'].values
                    if wd == 0:
                        feat = np.zeros([n_rows, 120])
                        for row in df.iterrows():
                            feat[row[0]] = just_hist_lm_file[row[1]['mbid']]
                    elif 0<wd<120:
#                         feat = np.zeros([300, 120*12*2])
                        feat = np.zeros([n_rows, 120*12*2])
                        feat = []
                        for row in df.iterrows():
                            mbid = row[1]['mbid']
                            feat_curr = None
                            feat_curr = []
                            for s in range(0,120,10):
                                e = modulo(s+wd)
                                if s==e:
                                    continue
                                s10 = s//10
                                e10 = e//10
                                hist_1 = just_spd_lm_file[mbid][s10,e10,:,0]
                                hist_2 = just_spd_lm_file[mbid][e10,s10,:,1]
                                hist_1 = get_cliped_dist(s,e,True,hist_1,clip=15)
                                hist_2 = get_cliped_dist(e,s,False,hist_2,clip=15)
                                feat_curr.append(hist_1)
                                feat_curr.append(hist_2)
#                                 if feat_curr is None:
#                                     feat_curr = np.concatenate([hist_1, hist_2], axis=-1)
#                                 else:
#                                     feat_curr = np.concatenate([feat_curr, hist_1], axis=-1)
#                                     feat_curr = np.concatenate([feat_curr, hist_2], axis=-1)
                            feat_curr = np.concatenate(feat_curr, axis=-1)
#                             feat[row[0]] = feat_curr
                            feat.append(feat_curr)
                    elif 120<=wd<240:
                        feat = []
                        s = wd-120
                        for row in df.iterrows():
                            mbid = row[1]['mbid']
                            feat_curr = None
                            feat_curr = []
                            for e in range(0,120,10):
                                if s==e:
                                    continue
                                s10 = s//10
                                e10 = e//10
                                hist_1 = just_spd_lm_file[mbid][s10,e10,:,0]
                                hist_2 = just_spd_lm_file[mbid][e10,s10,:,1]
                                hist_1 = get_cliped_dist(s,e,True,hist_1,clip=15)
                                hist_2 = get_cliped_dist(e,s,False,hist_2,clip=15)
                                feat_curr.append(hist_1)
                                feat_curr.append(hist_2)
                            feat_curr = np.concatenate(feat_curr, axis=-1)
                            feat.append(feat_curr)
                    else:
                        feat = []
                        for row in df.iterrows():
                            mbid = row[1]['mbid']
                            hist_1 = np.array(just_spd_lm_file[mbid])
                            feat.append(np.reshape(hist_1, [-1]))
                    feat = np.array(feat)
                    loo = LeaveOneOut()
                    y_pred = np.zeros([n_rows,n_labels])
                    print('wd',wd)
                    if only_save:
                        spd_knn.fit(feat, y_labels)
                        with open('data/RagaDataset/{}/model/spd_knn_{}.pkl'.format(tradition, wd), 'wb') as pkl_f:
                            pickle.dump(spd_knn, pkl_f)
                    else:
                        for train_index, test_index in loo.split(feat):  
                            spd_knn.fit(feat[train_index], y_labels[train_index])
                            y_pred_proba = spd_knn.predict(feat[test_index])
                            y_pred[test_index[0]] = y_pred_proba[0]
                        output_lm_file[str(wd)] = y_pred


In [42]:
def interp1d(array: np.ndarray) -> np.ndarray:
    la = len(array)
    new_len = la//2
    return np.interp(np.linspace(0, la - 1, num=new_len), np.arange(la), array)

In [122]:
train_model('Carnatic', True)
train_model('Hindustani', True)

wd 0
wd 10
wd 20
wd 30
wd 40
wd 50
wd 60
wd 70
wd 80
wd 90
wd 100
wd 110
wd 120
wd 130
wd 140
wd 150
wd 160
wd 170
wd 180
wd 190
wd 200
wd 210
wd 220
wd 230
wd 240
wd 0
wd 10
wd 20
wd 30
wd 40
wd 50
wd 60
wd 70
wd 80
wd 90
wd 100
wd 110
wd 120
wd 130
wd 140
wd 150
wd 160
wd 170
wd 180
wd 190
wd 200
wd 210
wd 220
wd 230
wd 240


In [56]:
train_model('Carnatic')
# train_model('Hindustani')

wd 0
wd 10
wd 20
wd 30
wd 40
wd 50
wd 60
wd 70
wd 80
wd 90
wd 100
wd 110
wd 120
wd 130
wd 140
wd 150
wd 160
wd 170
wd 180
wd 190
wd 200
wd 210
wd 220
wd 230
wd 240


In [27]:
class SPDEnsembleModel:
    def __init__(self, tradition):
        self.tradition = tradition
        if tradition == 'Carnatic':
            self.n_rows = 480
            self.n_labels = 40
            self.labels = None
            self.models_weights = np.array([[0.13067141], [0.01998654], [0.00263896], [0.01067577], [0.03622559], [0.03518761], [0.04490862], 
                                   [0.05965471], [0.04594504], [0.02294095], [0.02302309], [0.03178377]])
        else:
            self.models_weights = np.array([[ 0.4077039 ], [-0.12407973], [ 0.36005375], [ 0.11901897], [ 0.06872483], [ 0.37223506],
                                         [ 0.4211175 ], [-0.2689638 ], [-0.3007931 ], [ 0.4191207 ], [-0.27818865], [ 0.49861392]])
            self.n_rows = 300
            self.n_labels = 30
            self.labels = None
        
    def train(self, data):
        with h5py.File('data/'+tradition+'_just_spd_cache', "r") as just_spd_lm_file:
            with h5py.File('data/'+tradition+'_just_hist_cache', "r") as just_hist_lm_file:
                for wd in range(0,120,10):
                    spd_knn = SPDKNN(k=5)
                    y_labels = df['labels'].values
                    if wd == 0:
                        feat = np.zeros([n_rows, 120])
                        for row in df.iterrows():
                            feat[row[0]] = just_hist_lm_file[row[1]['mbid']]
                    else:
                        feat = np.zeros([n_rows, 120*12*2])
                        feat = []
                        for row in df.iterrows():
                            mbid = row[1]['mbid']
                            feat_curr = None
                            feat_curr = []
                            for s in range(0,120,10):
                                e = modulo(s+wd)
                                if s==e:
                                    continue
                                s10 = s//10
                                e10 = e//10
                                hist_1 = just_spd_lm_file[mbid][s10,e10,:,0]
                                hist_2 = just_spd_lm_file[mbid][e10,s10,:,1]
                                hist_1 = get_cliped_dist(s,e,asc,hist_1,clip=15)
                                hist_2 = get_cliped_dist(s,e,asc,hist_2,clip=15)
                                feat_curr.append(hist_1)
                                feat_curr.append(hist_2)
#                                 if feat_curr is None:
#                                     feat_curr = np.concatenate([hist_1, hist_2], axis=-1)
#                                 else:
#                                     feat_curr = np.concatenate([feat_curr, hist_1], axis=-1)
#                                     feat_curr = np.concatenate([feat_curr, hist_2], axis=-1)
                            feat_curr = np.concatenate(feat_curr, axis=-1)
#                             feat[row[0]] = feat_curr
                            feat.append(feat_curr)
                    feat = np.array(feat)  
                    spd_knn.fit(feat, y_labels)
                    with open('data/{}_spd_knn_{}'.format(self.tradition, wd)) as f:
                        pickle.dump(spd_knn, f)
    def test(self, hist, spd):
        for wd in range(0,120,10):
            spd_knn = SPDKNN(k=5)
            if wd == 0:
                feat = np.zeros([n_rows, 120])
                for row in df.iterrows():
                    feat[row[0]] = just_hist_lm_file[row[1]['mbid']]
            else:
                feat = np.zeros([300, 120*12*2])
                feat = np.zeros([n_rows, 120*12*2])
                feat = []
                for row in df.iterrows():
                    mbid = row[1]['mbid']
                    feat_curr = None
                    feat_curr = []
                    for s in range(0,120,10):
                        e = modulo(s+wd)
                        if s==e:
                            continue
                        s10 = s//10
                        e10 = e//10
                        hist_1 = just_spd_lm_file[mbid][s10,e10,:,0]
                        hist_2 = just_spd_lm_file[mbid][e10,s10,:,1]
                        hist_1 = get_cliped_dist(s,e,asc,hist_1,clip=15)
                        hist_2 = get_cliped_dist(s,e,asc,hist_2,clip=15)
                        feat_curr.append(hist_1)
                        feat_curr.append(hist_2)
                          if feat_curr is None:
                              feat_curr = np.concatenate([hist_1, hist_2], axis=-1)
                          else:
                              feat_curr = np.concatenate([feat_curr, hist_1], axis=-1)
                              feat_curr = np.concatenate([feat_curr, hist_2], axis=-1)
                    feat_curr = np.concatenate(feat_curr, axis=-1)
                      feat[row[0]] = feat_curr
                    feat.append(feat_curr)
            feat = np.array(feat)  
            spd_knn.fit(feat, y_labels)
            with open('data/{}_spd_knn_{}'.format(self.tradition, wd)) as f:
                pickle.dump(spd_knn, f)

In [None]:
def save_trained_model(tradition):
    data_file_path = 'data/RagaDataset/{}/data.tsv'.format(tradition)
    if tradition=='Hindustani':
        n_rows = 300
        n_labels = 30
        models_weights = np.array([[0.13067141], [0.01998654], [0.00263896],  [0.01067577]
 [0.03622559]
 [0.03518761]
 [0.04490862]
 [0.05965471]
 [0.04594504]
 [0.02294095]
 [0.02302309]
 [0.03178377]])
    else:
        n_rows = 480
        n_labels = 40
    df = pd.read_csv(data_file_path, sep='\t')
    with h5py.File('data/'+tradition+'_just_spd_cache', "r") as just_spd_lm_file:
        with h5py.File('data/'+tradition+'_just_hist_cache', "r") as just_hist_lm_file:
            with h5py.File('data/'+tradition+'_output_cache_clipped', "w") as output_lm_file:
                for wd in range(0,120,10):
                    spd_knn = SPDKNN(k=5)
                    y_labels = df['labels'].values
                    if wd == 0:
                        feat = np.zeros([n_rows, 120])
                        for row in df.iterrows():
                            feat[row[0]] = just_hist_lm_file[row[1]['mbid']]
                    else:
#                         feat = np.zeros([300, 120*12*2])
                        feat = np.zeros([n_rows, 120*12*2])
                        feat = []
                        for row in df.iterrows():
                            mbid = row[1]['mbid']
                            feat_curr = None
                            feat_curr = []
                            for s in range(0,120,10):
                                e = modulo(s+wd)
                                if s==e:
                                    continue
                                s10 = s//10
                                e10 = e//10
                                hist_1 = just_spd_lm_file[mbid][s10,e10,:,0]
                                hist_2 = just_spd_lm_file[mbid][e10,s10,:,1]
                                hist_1 = get_cliped_dist(s,e,asc,hist_1,clip=15)
                                hist_2 = get_cliped_dist(s,e,asc,hist_2,clip=15)
                                feat_curr.append(hist_1)
                                feat_curr.append(hist_2)
#                                 if feat_curr is None:
#                                     feat_curr = np.concatenate([hist_1, hist_2], axis=-1)
#                                 else:
#                                     feat_curr = np.concatenate([feat_curr, hist_1], axis=-1)
#                                     feat_curr = np.concatenate([feat_curr, hist_2], axis=-1)
                            feat_curr = np.concatenate(feat_curr, axis=-1)
#                             feat[row[0]] = feat_curr
                            feat.append(feat_curr)
                    feat = np.array(feat)
                    loo = LeaveOneOut()
                    y_pred = np.zeros([n_rows,n_labels])
                    print('wd',wd)
                    for train_index, test_index in loo.split(feat):  
                        spd_knn.fit(feat[train_index], y_labels[train_index])
                        y_pred_proba = spd_knn.predict(feat[test_index])
                        y_pred[test_index[0]] = y_pred_proba[0]
                    output_lm_file[str(wd)] = y_pred


In [None]:
def get_data():
    X = np.zeros(300, 12, 30)
    with h5py.File('data/'+tradition+'_output_cache', "r") as output_lm_file:
        outp = 0
        for i in range(0,120,10):
            X[:,i//10,:] = output_lm_file[str(i)]
            curr_out = np.array()
            wt = np.sum(np.argmax(curr_out, axis=1) == df['labels'].values)/300
            outp += wt*curr_out

        print(np.sum(np.argmax(outp, axis=1) == df['labels'].values))

In [44]:
def get_data(tradition):
    data_file_path = 'data/RagaDataset/{}/data.tsv'.format(tradition)
    if tradition=='Hindustani':
        n_rows = 300
        n_labels = 30
    else:
        n_rows = 480
        n_labels = 40
        
    df = pd.read_csv(data_file_path, sep='\t')
    X = np.zeros([n_rows, n_labels, 25])
    y = df['labels']
    with h5py.File('data/'+tradition+'_output_cache_clipped', "r") as output_lm_file:
        outp = 0
        for i in range(0,250,10):
            X[:,:, i//10] = output_lm_file[str(i)]
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)
    return X_train, y_train, X_test, y_test

In [123]:
class BestWeights:
    def __init__(self, tradition):
        if tradition=='Hindustani':
            self.n_labels = 30
        else:
            self.n_labels = 40
        self.tradition = tradition
        self.model = self.build_model()
    
    def build_model(self):
        model = tf.keras.Sequential([tf.keras.Input(shape=[self.n_labels,25]), tf.keras.layers.Dense(1), tf.keras.layers.Lambda(lambda x: tf.squeeze(x,2))])
        model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(), optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                     metrics='accuracy')
        return model
    
    def train(self, X, y):
#         early_stop = tf.keras.callbacks.EarlyStopping(monitor="accuracy", mode='max', patience=300 )
        mcp_save = tf.keras.callbacks.ModelCheckpoint('data/{}_mdl_wts.hdf5'.format(self.tradition), save_best_only=True, monitor='accuracy', mode='max')
        self.model.fit(X,y, epochs=1300, batch_size=10, callbacks=[mcp_save])
    
    def test(self, X):
        model = tf.keras.models.load_model('data/{}_Best_mdl_wts.hdf5'.format(self.tradition))
#         model = tf.keras.models.load_model('data/{}_mdl_wts.hdf5'.format(self.tradition))
        print(list([a[0] for a in model.layers[0].weights[0].numpy()]))
        return model.predict(X)

In [113]:
X_train, y_train, X_test, y_test = get_data('Carnatic')
bw = BestWeights(tradition='Carnatic')
# X_train, y_train, X_test, y_test = get_data('Hindustani')
# bw = BestWeights(tradition='Hindustani')

# bw.train(np.concatenate([X_train, X_test], axis=0),np.concatenate([y_train.values, y_test.values], axis=0))
bw.train(X_train, y_train.values)

Epoch 1/1300
Epoch 2/1300
Epoch 3/1300
Epoch 4/1300
Epoch 5/1300
Epoch 6/1300
Epoch 7/1300
Epoch 8/1300
Epoch 9/1300
Epoch 10/1300
Epoch 11/1300
Epoch 12/1300
Epoch 13/1300
Epoch 14/1300
Epoch 15/1300
Epoch 16/1300
Epoch 17/1300
Epoch 18/1300
Epoch 19/1300
Epoch 20/1300
Epoch 21/1300
Epoch 22/1300
Epoch 23/1300
Epoch 24/1300
Epoch 25/1300
Epoch 26/1300
Epoch 27/1300
Epoch 28/1300
Epoch 29/1300
Epoch 30/1300
Epoch 31/1300
Epoch 32/1300
Epoch 33/1300
Epoch 34/1300
Epoch 35/1300
Epoch 36/1300
Epoch 37/1300
Epoch 38/1300
Epoch 39/1300
Epoch 40/1300
Epoch 41/1300
Epoch 42/1300
Epoch 43/1300
Epoch 44/1300
Epoch 45/1300
Epoch 46/1300
Epoch 47/1300
Epoch 48/1300
Epoch 49/1300
Epoch 50/1300
Epoch 51/1300
Epoch 52/1300
Epoch 53/1300
Epoch 54/1300
Epoch 55/1300
Epoch 56/1300
Epoch 57/1300
Epoch 58/1300
Epoch 59/1300
Epoch 60/1300
Epoch 61/1300
Epoch 62/1300
Epoch 63/1300
Epoch 64/1300
Epoch 65/1300
Epoch 66/1300
Epoch 67/1300
Epoch 68/1300
Epoch 69/1300
Epoch 70/1300
Epoch 71/1300
Epoch 72/1300
E

Epoch 81/1300
Epoch 82/1300
Epoch 83/1300
Epoch 84/1300
Epoch 85/1300
Epoch 86/1300
Epoch 87/1300
Epoch 88/1300
Epoch 89/1300
Epoch 90/1300
Epoch 91/1300
Epoch 92/1300
Epoch 93/1300
Epoch 94/1300
Epoch 95/1300
Epoch 96/1300
Epoch 97/1300
Epoch 98/1300
Epoch 99/1300
Epoch 100/1300
Epoch 101/1300
Epoch 102/1300
Epoch 103/1300
Epoch 104/1300
Epoch 105/1300
Epoch 106/1300
Epoch 107/1300
Epoch 108/1300
Epoch 109/1300
Epoch 110/1300
Epoch 111/1300
Epoch 112/1300
Epoch 113/1300
Epoch 114/1300
Epoch 115/1300
Epoch 116/1300
Epoch 117/1300
Epoch 118/1300
Epoch 119/1300
Epoch 120/1300
Epoch 121/1300
Epoch 122/1300
Epoch 123/1300
Epoch 124/1300
Epoch 125/1300
Epoch 126/1300
Epoch 127/1300
Epoch 128/1300
Epoch 129/1300
Epoch 130/1300
Epoch 131/1300
Epoch 132/1300
Epoch 133/1300
Epoch 134/1300
Epoch 135/1300
Epoch 136/1300
Epoch 137/1300
Epoch 138/1300
Epoch 139/1300
Epoch 140/1300
Epoch 141/1300
Epoch 142/1300
Epoch 143/1300
Epoch 144/1300
Epoch 145/1300
Epoch 146/1300
Epoch 147/1300
Epoch 148/1300

Epoch 160/1300
Epoch 161/1300
Epoch 162/1300
Epoch 163/1300
Epoch 164/1300
Epoch 165/1300
Epoch 166/1300
Epoch 167/1300
Epoch 168/1300
Epoch 169/1300
Epoch 170/1300
Epoch 171/1300
Epoch 172/1300
Epoch 173/1300
Epoch 174/1300
Epoch 175/1300
Epoch 176/1300
Epoch 177/1300
Epoch 178/1300
Epoch 179/1300
Epoch 180/1300
Epoch 181/1300
Epoch 182/1300
Epoch 183/1300
Epoch 184/1300
Epoch 185/1300
Epoch 186/1300
Epoch 187/1300
Epoch 188/1300
Epoch 189/1300
Epoch 190/1300
Epoch 191/1300
Epoch 192/1300
Epoch 193/1300
Epoch 194/1300
Epoch 195/1300
Epoch 196/1300
Epoch 197/1300
Epoch 198/1300
Epoch 199/1300
Epoch 200/1300
Epoch 201/1300
Epoch 202/1300
Epoch 203/1300
Epoch 204/1300
Epoch 205/1300
Epoch 206/1300
Epoch 207/1300
Epoch 208/1300
Epoch 209/1300
Epoch 210/1300
Epoch 211/1300
Epoch 212/1300
Epoch 213/1300
Epoch 214/1300
Epoch 215/1300
Epoch 216/1300
Epoch 217/1300
Epoch 218/1300
Epoch 219/1300
Epoch 220/1300
Epoch 221/1300
Epoch 222/1300
Epoch 223/1300
Epoch 224/1300
Epoch 225/1300
Epoch 226/

Epoch 239/1300
Epoch 240/1300
Epoch 241/1300
Epoch 242/1300
Epoch 243/1300
Epoch 244/1300
Epoch 245/1300
Epoch 246/1300
Epoch 247/1300
Epoch 248/1300
Epoch 249/1300
Epoch 250/1300
Epoch 251/1300
Epoch 252/1300
Epoch 253/1300
Epoch 254/1300
Epoch 255/1300
Epoch 256/1300
Epoch 257/1300
Epoch 258/1300
Epoch 259/1300
Epoch 260/1300
Epoch 261/1300
Epoch 262/1300
Epoch 263/1300
Epoch 264/1300
Epoch 265/1300
Epoch 266/1300
Epoch 267/1300
Epoch 268/1300
Epoch 269/1300
Epoch 270/1300
Epoch 271/1300
Epoch 272/1300
Epoch 273/1300
Epoch 274/1300
Epoch 275/1300
Epoch 276/1300
Epoch 277/1300
Epoch 278/1300
Epoch 279/1300
Epoch 280/1300
Epoch 281/1300
Epoch 282/1300
Epoch 283/1300
Epoch 284/1300
Epoch 285/1300
Epoch 286/1300
Epoch 287/1300
Epoch 288/1300
Epoch 289/1300
Epoch 290/1300
Epoch 291/1300
Epoch 292/1300
Epoch 293/1300
Epoch 294/1300
Epoch 295/1300
Epoch 296/1300
Epoch 297/1300
Epoch 298/1300
Epoch 299/1300
Epoch 300/1300
Epoch 301/1300
Epoch 302/1300
Epoch 303/1300
Epoch 304/1300
Epoch 305/

Epoch 318/1300
Epoch 319/1300
Epoch 320/1300
Epoch 321/1300
Epoch 322/1300
Epoch 323/1300
Epoch 324/1300
Epoch 325/1300
Epoch 326/1300
Epoch 327/1300
Epoch 328/1300
Epoch 329/1300
Epoch 330/1300
Epoch 331/1300
Epoch 332/1300
Epoch 333/1300
Epoch 334/1300
Epoch 335/1300
Epoch 336/1300
Epoch 337/1300
Epoch 338/1300
Epoch 339/1300
Epoch 340/1300
Epoch 341/1300
Epoch 342/1300
Epoch 343/1300
Epoch 344/1300
Epoch 345/1300
Epoch 346/1300
Epoch 347/1300
Epoch 348/1300
Epoch 349/1300
Epoch 350/1300
Epoch 351/1300
Epoch 352/1300
Epoch 353/1300
Epoch 354/1300
Epoch 355/1300
Epoch 356/1300
Epoch 357/1300
Epoch 358/1300
Epoch 359/1300
Epoch 360/1300
Epoch 361/1300
Epoch 362/1300
Epoch 363/1300
Epoch 364/1300
Epoch 365/1300
Epoch 366/1300
Epoch 367/1300
Epoch 368/1300
Epoch 369/1300
Epoch 370/1300
Epoch 371/1300
Epoch 372/1300
Epoch 373/1300
Epoch 374/1300
Epoch 375/1300
Epoch 376/1300
Epoch 377/1300
Epoch 378/1300
Epoch 379/1300
Epoch 380/1300
Epoch 381/1300
Epoch 382/1300
Epoch 383/1300
Epoch 384/

Epoch 396/1300
Epoch 397/1300
Epoch 398/1300
Epoch 399/1300
Epoch 400/1300
Epoch 401/1300
Epoch 402/1300
Epoch 403/1300
Epoch 404/1300
Epoch 405/1300
Epoch 406/1300
Epoch 407/1300
Epoch 408/1300
Epoch 409/1300
Epoch 410/1300
Epoch 411/1300
Epoch 412/1300
Epoch 413/1300
Epoch 414/1300
Epoch 415/1300
Epoch 416/1300
Epoch 417/1300
Epoch 418/1300
Epoch 419/1300
Epoch 420/1300
Epoch 421/1300
Epoch 422/1300
Epoch 423/1300
Epoch 424/1300
Epoch 425/1300
Epoch 426/1300
Epoch 427/1300
Epoch 428/1300
Epoch 429/1300
Epoch 430/1300
Epoch 431/1300
Epoch 432/1300
Epoch 433/1300
Epoch 434/1300
Epoch 435/1300
Epoch 436/1300
Epoch 437/1300
Epoch 438/1300
Epoch 439/1300
Epoch 440/1300
Epoch 441/1300
Epoch 442/1300
Epoch 443/1300
Epoch 444/1300
Epoch 445/1300
Epoch 446/1300
Epoch 447/1300
Epoch 448/1300
Epoch 449/1300
Epoch 450/1300
Epoch 451/1300
Epoch 452/1300
Epoch 453/1300
Epoch 454/1300
Epoch 455/1300
Epoch 456/1300
Epoch 457/1300
Epoch 458/1300
Epoch 459/1300
Epoch 460/1300
Epoch 461/1300
Epoch 462/

Epoch 475/1300
Epoch 476/1300
Epoch 477/1300
Epoch 478/1300
Epoch 479/1300
Epoch 480/1300
Epoch 481/1300
Epoch 482/1300
Epoch 483/1300
Epoch 484/1300
Epoch 485/1300
Epoch 486/1300
Epoch 487/1300
Epoch 488/1300
Epoch 489/1300
Epoch 490/1300
Epoch 491/1300
Epoch 492/1300
Epoch 493/1300
Epoch 494/1300
Epoch 495/1300
Epoch 496/1300
Epoch 497/1300
Epoch 498/1300
Epoch 499/1300
Epoch 500/1300
Epoch 501/1300
Epoch 502/1300
Epoch 503/1300
Epoch 504/1300
Epoch 505/1300
Epoch 506/1300
Epoch 507/1300
Epoch 508/1300
Epoch 509/1300
Epoch 510/1300
Epoch 511/1300
Epoch 512/1300
Epoch 513/1300
Epoch 514/1300
Epoch 515/1300
Epoch 516/1300
Epoch 517/1300
Epoch 518/1300
Epoch 519/1300
Epoch 520/1300
Epoch 521/1300
Epoch 522/1300
Epoch 523/1300
Epoch 524/1300
Epoch 525/1300
Epoch 526/1300
Epoch 527/1300
Epoch 528/1300
Epoch 529/1300
Epoch 530/1300
Epoch 531/1300
Epoch 532/1300
Epoch 533/1300
Epoch 534/1300
Epoch 535/1300
Epoch 536/1300
Epoch 537/1300
Epoch 538/1300
Epoch 539/1300
Epoch 540/1300
Epoch 541/

Epoch 554/1300
Epoch 555/1300
Epoch 556/1300
Epoch 557/1300
Epoch 558/1300
Epoch 559/1300
Epoch 560/1300
Epoch 561/1300
Epoch 562/1300
Epoch 563/1300
Epoch 564/1300
Epoch 565/1300
Epoch 566/1300
Epoch 567/1300
Epoch 568/1300
Epoch 569/1300
Epoch 570/1300
Epoch 571/1300
Epoch 572/1300
Epoch 573/1300
Epoch 574/1300
Epoch 575/1300
Epoch 576/1300
Epoch 577/1300
Epoch 578/1300
Epoch 579/1300
Epoch 580/1300
Epoch 581/1300
Epoch 582/1300
Epoch 583/1300
Epoch 584/1300
Epoch 585/1300
Epoch 586/1300
Epoch 587/1300
Epoch 588/1300
Epoch 589/1300
Epoch 590/1300
Epoch 591/1300
Epoch 592/1300
Epoch 593/1300
Epoch 594/1300
Epoch 595/1300
Epoch 596/1300
Epoch 597/1300
Epoch 598/1300
Epoch 599/1300
Epoch 600/1300
Epoch 601/1300
Epoch 602/1300
Epoch 603/1300
Epoch 604/1300
Epoch 605/1300
Epoch 606/1300
Epoch 607/1300
Epoch 608/1300
Epoch 609/1300
Epoch 610/1300
Epoch 611/1300
Epoch 612/1300
Epoch 613/1300
Epoch 614/1300
Epoch 615/1300
Epoch 616/1300
Epoch 617/1300
Epoch 618/1300
Epoch 619/1300
Epoch 620/

Epoch 632/1300
Epoch 633/1300
Epoch 634/1300
Epoch 635/1300
Epoch 636/1300
Epoch 637/1300
Epoch 638/1300
Epoch 639/1300
Epoch 640/1300
Epoch 641/1300
Epoch 642/1300
Epoch 643/1300
Epoch 644/1300
Epoch 645/1300
Epoch 646/1300
Epoch 647/1300
Epoch 648/1300
Epoch 649/1300
Epoch 650/1300
Epoch 651/1300
Epoch 652/1300
Epoch 653/1300
Epoch 654/1300
Epoch 655/1300
Epoch 656/1300
Epoch 657/1300
Epoch 658/1300
Epoch 659/1300
Epoch 660/1300
Epoch 661/1300
Epoch 662/1300
Epoch 663/1300
Epoch 664/1300
Epoch 665/1300
Epoch 666/1300
Epoch 667/1300
Epoch 668/1300
Epoch 669/1300
Epoch 670/1300
Epoch 671/1300
Epoch 672/1300
Epoch 673/1300
Epoch 674/1300
Epoch 675/1300
Epoch 676/1300
Epoch 677/1300
Epoch 678/1300
Epoch 679/1300
Epoch 680/1300
Epoch 681/1300
Epoch 682/1300
Epoch 683/1300
Epoch 684/1300
Epoch 685/1300
Epoch 686/1300
Epoch 687/1300
Epoch 688/1300
Epoch 689/1300
Epoch 690/1300
Epoch 691/1300
Epoch 692/1300
Epoch 693/1300
Epoch 694/1300
Epoch 695/1300
Epoch 696/1300
Epoch 697/1300
Epoch 698/

Epoch 711/1300
Epoch 712/1300
Epoch 713/1300
Epoch 714/1300
Epoch 715/1300
Epoch 716/1300
Epoch 717/1300
Epoch 718/1300
Epoch 719/1300
Epoch 720/1300
Epoch 721/1300
Epoch 722/1300
Epoch 723/1300
Epoch 724/1300
Epoch 725/1300
Epoch 726/1300
Epoch 727/1300
Epoch 728/1300
Epoch 729/1300
Epoch 730/1300
Epoch 731/1300
Epoch 732/1300
Epoch 733/1300
Epoch 734/1300
Epoch 735/1300
Epoch 736/1300
Epoch 737/1300
Epoch 738/1300
Epoch 739/1300
Epoch 740/1300
Epoch 741/1300
Epoch 742/1300
Epoch 743/1300
Epoch 744/1300
Epoch 745/1300
Epoch 746/1300
Epoch 747/1300
Epoch 748/1300
Epoch 749/1300
Epoch 750/1300
Epoch 751/1300
Epoch 752/1300
Epoch 753/1300
Epoch 754/1300
Epoch 755/1300
Epoch 756/1300
Epoch 757/1300
Epoch 758/1300
Epoch 759/1300
Epoch 760/1300
Epoch 761/1300
Epoch 762/1300
Epoch 763/1300
Epoch 764/1300
Epoch 765/1300
Epoch 766/1300
Epoch 767/1300
Epoch 768/1300
Epoch 769/1300
Epoch 770/1300
Epoch 771/1300
Epoch 772/1300
Epoch 773/1300
Epoch 774/1300
Epoch 775/1300
Epoch 776/1300
Epoch 777/

Epoch 789/1300
Epoch 790/1300
Epoch 791/1300
Epoch 792/1300
Epoch 793/1300
Epoch 794/1300
Epoch 795/1300
Epoch 796/1300
Epoch 797/1300
Epoch 798/1300
Epoch 799/1300
Epoch 800/1300
Epoch 801/1300
Epoch 802/1300
Epoch 803/1300
Epoch 804/1300
Epoch 805/1300
Epoch 806/1300
Epoch 807/1300
Epoch 808/1300
Epoch 809/1300
Epoch 810/1300
Epoch 811/1300
Epoch 812/1300
Epoch 813/1300
Epoch 814/1300
Epoch 815/1300
Epoch 816/1300
Epoch 817/1300
Epoch 818/1300
Epoch 819/1300
Epoch 820/1300
Epoch 821/1300
Epoch 822/1300
Epoch 823/1300
Epoch 824/1300
Epoch 825/1300
Epoch 826/1300
Epoch 827/1300
Epoch 828/1300
Epoch 829/1300
Epoch 830/1300
Epoch 831/1300
Epoch 832/1300
Epoch 833/1300
Epoch 834/1300
Epoch 835/1300
Epoch 836/1300
Epoch 837/1300
Epoch 838/1300
Epoch 839/1300
Epoch 840/1300
Epoch 841/1300
Epoch 842/1300
Epoch 843/1300
Epoch 844/1300
Epoch 845/1300
Epoch 846/1300
Epoch 847/1300
Epoch 848/1300
Epoch 849/1300
Epoch 850/1300
Epoch 851/1300
Epoch 852/1300
Epoch 853/1300
Epoch 854/1300
Epoch 855/

Epoch 868/1300
Epoch 869/1300
Epoch 870/1300
Epoch 871/1300
Epoch 872/1300
Epoch 873/1300
Epoch 874/1300
Epoch 875/1300
Epoch 876/1300
Epoch 877/1300
Epoch 878/1300
Epoch 879/1300
Epoch 880/1300
Epoch 881/1300
Epoch 882/1300
Epoch 883/1300
Epoch 884/1300
Epoch 885/1300
Epoch 886/1300
Epoch 887/1300
Epoch 888/1300
Epoch 889/1300
Epoch 890/1300
Epoch 891/1300
Epoch 892/1300
Epoch 893/1300
Epoch 894/1300
Epoch 895/1300
Epoch 896/1300
Epoch 897/1300
Epoch 898/1300
Epoch 899/1300
Epoch 900/1300
Epoch 901/1300
Epoch 902/1300
Epoch 903/1300
Epoch 904/1300
Epoch 905/1300
Epoch 906/1300
Epoch 907/1300
Epoch 908/1300
Epoch 909/1300
Epoch 910/1300
Epoch 911/1300
Epoch 912/1300
Epoch 913/1300
Epoch 914/1300
Epoch 915/1300
Epoch 916/1300
Epoch 917/1300
Epoch 918/1300
Epoch 919/1300
Epoch 920/1300
Epoch 921/1300
Epoch 922/1300
Epoch 923/1300
Epoch 924/1300
Epoch 925/1300
Epoch 926/1300
Epoch 927/1300
Epoch 928/1300
Epoch 929/1300
Epoch 930/1300
Epoch 931/1300
Epoch 932/1300
Epoch 933/1300
Epoch 934/

Epoch 947/1300
Epoch 948/1300
Epoch 949/1300
Epoch 950/1300
Epoch 951/1300
Epoch 952/1300
Epoch 953/1300
Epoch 954/1300
Epoch 955/1300
Epoch 956/1300
Epoch 957/1300
Epoch 958/1300
Epoch 959/1300
Epoch 960/1300
Epoch 961/1300
Epoch 962/1300
Epoch 963/1300
Epoch 964/1300
Epoch 965/1300
Epoch 966/1300
Epoch 967/1300
Epoch 968/1300
Epoch 969/1300
Epoch 970/1300
Epoch 971/1300
Epoch 972/1300
Epoch 973/1300
Epoch 974/1300
Epoch 975/1300
Epoch 976/1300
Epoch 977/1300
Epoch 978/1300
Epoch 979/1300
Epoch 980/1300
Epoch 981/1300
Epoch 982/1300
Epoch 983/1300
Epoch 984/1300
Epoch 985/1300
Epoch 986/1300
Epoch 987/1300
Epoch 988/1300
Epoch 989/1300
Epoch 990/1300
Epoch 991/1300
Epoch 992/1300
Epoch 993/1300
Epoch 994/1300
Epoch 995/1300
Epoch 996/1300
Epoch 997/1300
Epoch 998/1300
Epoch 999/1300
Epoch 1000/1300
Epoch 1001/1300
Epoch 1002/1300
Epoch 1003/1300
Epoch 1004/1300
Epoch 1005/1300
Epoch 1006/1300
Epoch 1007/1300
Epoch 1008/1300
Epoch 1009/1300
Epoch 1010/1300
Epoch 1011/1300
Epoch 1012/13

Epoch 1025/1300
Epoch 1026/1300
Epoch 1027/1300
Epoch 1028/1300
Epoch 1029/1300
Epoch 1030/1300
Epoch 1031/1300
Epoch 1032/1300
Epoch 1033/1300
Epoch 1034/1300
Epoch 1035/1300
Epoch 1036/1300
Epoch 1037/1300
Epoch 1038/1300
Epoch 1039/1300
Epoch 1040/1300
Epoch 1041/1300
Epoch 1042/1300
Epoch 1043/1300
Epoch 1044/1300
Epoch 1045/1300
Epoch 1046/1300
Epoch 1047/1300
Epoch 1048/1300
Epoch 1049/1300
Epoch 1050/1300
Epoch 1051/1300
Epoch 1052/1300
Epoch 1053/1300
Epoch 1054/1300
Epoch 1055/1300
Epoch 1056/1300
Epoch 1057/1300
Epoch 1058/1300
Epoch 1059/1300
Epoch 1060/1300
Epoch 1061/1300
Epoch 1062/1300
Epoch 1063/1300
Epoch 1064/1300
Epoch 1065/1300
Epoch 1066/1300
Epoch 1067/1300
Epoch 1068/1300
Epoch 1069/1300
Epoch 1070/1300
Epoch 1071/1300
Epoch 1072/1300
Epoch 1073/1300
Epoch 1074/1300
Epoch 1075/1300
Epoch 1076/1300
Epoch 1077/1300
Epoch 1078/1300
Epoch 1079/1300
Epoch 1080/1300
Epoch 1081/1300
Epoch 1082/1300
Epoch 1083/1300
Epoch 1084/1300
Epoch 1085/1300
Epoch 1086/1300
Epoch 10

Epoch 1103/1300
Epoch 1104/1300
Epoch 1105/1300
Epoch 1106/1300
Epoch 1107/1300
Epoch 1108/1300
Epoch 1109/1300
Epoch 1110/1300
Epoch 1111/1300
Epoch 1112/1300
Epoch 1113/1300
Epoch 1114/1300
Epoch 1115/1300
Epoch 1116/1300
Epoch 1117/1300
Epoch 1118/1300
Epoch 1119/1300
Epoch 1120/1300
Epoch 1121/1300
Epoch 1122/1300
Epoch 1123/1300
Epoch 1124/1300
Epoch 1125/1300
Epoch 1126/1300
Epoch 1127/1300
Epoch 1128/1300
Epoch 1129/1300
Epoch 1130/1300
Epoch 1131/1300
Epoch 1132/1300
Epoch 1133/1300
Epoch 1134/1300
Epoch 1135/1300
Epoch 1136/1300
Epoch 1137/1300
Epoch 1138/1300
Epoch 1139/1300
Epoch 1140/1300
Epoch 1141/1300
Epoch 1142/1300
Epoch 1143/1300
Epoch 1144/1300
Epoch 1145/1300
Epoch 1146/1300
Epoch 1147/1300
Epoch 1148/1300
Epoch 1149/1300
Epoch 1150/1300
Epoch 1151/1300
Epoch 1152/1300
Epoch 1153/1300
Epoch 1154/1300
Epoch 1155/1300
Epoch 1156/1300
Epoch 1157/1300
Epoch 1158/1300
Epoch 1159/1300
Epoch 1160/1300
Epoch 1161/1300
Epoch 1162/1300
Epoch 1163/1300
Epoch 1164/1300
Epoch 11

Epoch 1181/1300
Epoch 1182/1300
Epoch 1183/1300
Epoch 1184/1300
Epoch 1185/1300
Epoch 1186/1300
Epoch 1187/1300
Epoch 1188/1300
Epoch 1189/1300
Epoch 1190/1300
Epoch 1191/1300
Epoch 1192/1300
Epoch 1193/1300
Epoch 1194/1300
Epoch 1195/1300
Epoch 1196/1300
Epoch 1197/1300
Epoch 1198/1300
Epoch 1199/1300
Epoch 1200/1300
Epoch 1201/1300
Epoch 1202/1300
Epoch 1203/1300
Epoch 1204/1300
Epoch 1205/1300
Epoch 1206/1300
Epoch 1207/1300
Epoch 1208/1300
Epoch 1209/1300
Epoch 1210/1300
Epoch 1211/1300
Epoch 1212/1300
Epoch 1213/1300
Epoch 1214/1300
Epoch 1215/1300
Epoch 1216/1300
Epoch 1217/1300
Epoch 1218/1300
Epoch 1219/1300
Epoch 1220/1300
Epoch 1221/1300
Epoch 1222/1300
Epoch 1223/1300
Epoch 1224/1300
Epoch 1225/1300
Epoch 1226/1300
Epoch 1227/1300
Epoch 1228/1300
Epoch 1229/1300
Epoch 1230/1300
Epoch 1231/1300
Epoch 1232/1300
Epoch 1233/1300
Epoch 1234/1300
Epoch 1235/1300
Epoch 1236/1300
Epoch 1237/1300
Epoch 1238/1300
Epoch 1239/1300
Epoch 1240/1300
Epoch 1241/1300
Epoch 1242/1300
Epoch 12

Epoch 1259/1300
Epoch 1260/1300
Epoch 1261/1300
Epoch 1262/1300
Epoch 1263/1300
Epoch 1264/1300
Epoch 1265/1300
Epoch 1266/1300
Epoch 1267/1300
Epoch 1268/1300
Epoch 1269/1300
Epoch 1270/1300
Epoch 1271/1300
Epoch 1272/1300
Epoch 1273/1300
Epoch 1274/1300
Epoch 1275/1300
Epoch 1276/1300
Epoch 1277/1300
Epoch 1278/1300
Epoch 1279/1300
Epoch 1280/1300
Epoch 1281/1300
Epoch 1282/1300
Epoch 1283/1300
Epoch 1284/1300
Epoch 1285/1300
Epoch 1286/1300
Epoch 1287/1300
Epoch 1288/1300
Epoch 1289/1300
Epoch 1290/1300
Epoch 1291/1300
Epoch 1292/1300
Epoch 1293/1300
Epoch 1294/1300
Epoch 1295/1300
Epoch 1296/1300
Epoch 1297/1300
Epoch 1298/1300
Epoch 1299/1300
Epoch 1300/1300


In [124]:
bw = BestWeights('Hindustani')
X_train, y_train, X_test, y_test = get_data('Hindustani')
y_pred = bw.test(np.concatenate([X_train, X_test], axis=0))
np.sum(np.argmax(y_pred, 1)==np.concatenate([y_train.values, y_test.values], axis=0))

[0.14530538, 0.16622296, 0.023467897, -0.14283751, 0.28589576, -0.18367307, 0.06799893, 0.12999865, -0.50317067, 0.39374444, -0.011883757, -0.027416285, -0.10365037, 0.026902597, -0.0032952016, 0.1016536, 0.09131258, 0.29101947, 0.050576165, -0.16490516, -0.09560222, -0.116639435, 0.13686526, 0.14130622, 0.2870755]


297

In [125]:
bw = BestWeights('Carnatic')
X_train, y_train, X_test, y_test = get_data('Carnatic')
bw = BestWeights('Carnatic')
y_pred = bw.test(np.concatenate([X_train, X_test], axis=0))
np.sum(np.argmax(y_pred, 1)==np.concatenate([y_train.values, y_test.values], axis=0))

[0.21073548, 0.09824791, -0.11856023, 0.10903256, 0.071041666, 0.029498437, 0.058856107, 0.005804086, 0.23304237, -0.0036090536, -0.16575447, 0.1428653, -0.114484884, 0.067209795, -0.15399341, 0.06962044, 0.04790523, -0.1683678, 0.067225285, -0.06486485, 0.12629355, -0.06692766, 0.035410862, 0.14386775, 0.13790666]


422

In [1453]:
def train_model(tradition):
    with h5py.File('data/'+tradition+'_spd_cache_flat', "r") as full_spd_lm_file:

        y_all_pred = []
        for wd in range(1,12,1):
            mbid_list = []
            y_label = []
            y_pred_curr = []
            for row in df.iterrows():
                mbid = row[1]['mbid']
                mbid_list.append(mbid)
                y_label.append(row[1]['labels'])
            y_label = np.array(y_label)  
            loo = LeaveOneOut()
            X_ind = np.arange(len(mbid_list)).astype(np.int64)
#             fknn_1 = FaissKNeighbors(full_spd_lm_file, mbid_list, 3, wd)
            fknn_1 = SPDKNN(full_spd_lm_file, mbid_list, wd, k=3)
    #         fknn_3 = FaissKNeighbors(full_spd_lm_file, mbid_list, 3)
    #         fknn_5 = FaissKNeighbors(full_spd_lm_file, mbid_list, 5)
            acc = 0
            for train_index, test_index in loo.split(X_ind):
                y_train, y_test = y_label[train_index], y_label[test_index]
                fknn_1.fit(train_index, y_train)
    #             fknn_3.fit(train_index, y_train)
    #             fknn_5.fit(train_index, y_train)
                
                y_pred_1 = fknn_1.predict(test_index)
    #             y_pred_3 = fknn_3.predict(test_index)
    #             y_pred_5 = fknn_5.predict(test_index)
                y_pred = y_pred_1
    #             y_pred = stats.mode([y_pred_1[0], y_pred_3[0], y_pred_5[0]])[0]
#                 print(y_pred[0], y_test[0])
                y_pred_curr.append(y_pred)
#                 if y_pred[0]==y_test[0]:
#                     acc+=1
            y_all_pred.append(y_pred_curr)
        
        y_all_pred = np.array(y_all_pred)
        return y_all_pred
        print(y_all_pred)
        y_all_pred = stats.mode(y_all_pred, axis=1)[0][:,0]
            
        print('Accuracy: {}/{} = {}'.format(np.sum(y_all_pred==y_label),300,np.sum(y_all_pred==y_label)*100/300))
#             print('Accuracy: {}/{} = {}'.format(acc,300,acc*100/300))
    

In [1454]:
y_all_pred = train_model('Hindustani')

In [1461]:
np.squeeze(y_all_pred,2)[0][13]

array([0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])