In [1]:
import os,sys,signal, copy
import math
from contextlib import contextmanager

import pickle
import numpy as np                                       # fast vectors and matrices
import matplotlib.pyplot as plt                          # plotting
import matplotlib.ticker as ticker
from matplotlib.animation import ArtistAnimation

from scipy.fftpack import fft
from scipy.signal.windows import hann

sys.path.insert(0, '../../')
import musicnetRaven as musicnet

from time import time

sys.path.insert(0,'../../lib/')
import config
import diagnosticsP3
# import base_model

from sklearn.metrics import average_precision_score

os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'   # see issue #152
os.environ['CUDA_VISIBLE_DEVICES']='0'

import torch
from torch.nn.functional import conv1d, mse_loss
import torch.nn.functional as F
import torch.nn as nn
from torchcontrib.optim import SWA

from tqdm import tqdm
import mir_eval

from torchsummary import summary

%matplotlib inline

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
# Network Parameters
epochs = 20
train_size = 100000
test_size = 50000
lr = 5e-4
pitch_shift = 0
jitter = 0.
num_workers = 15

# lvl1 convolutions are shared between regions
m = 128
k = 512              # lvl1 nodes
n_fft = 4096              # lvl1 receptive field
window = 16384 # total number of audio samples?
stride = 512
batch_size = 100

In [3]:
# function for returning scientific notation in a plot
def fmt(x, pos):
    a, b = '{:.0e}'.format(x).split('e')
    b = int(b)
    return fr'${a} \times 10^{{{b}}}$'

In [4]:
regions = 1 + (window - n_fft)//stride

def worker_init(args):
    signal.signal(signal.SIGINT, signal.SIG_IGN) # ignore signals so parent can handle them
    np.random.seed(os.getpid() ^ int(time())) # approximately random seed for workers
kwargs = {'num_workers': num_workers, 'pin_memory': True, 'worker_init_fn': worker_init}

In [5]:
start = time()
root = '../../../data/'
train_set = musicnet.MusicNet(root=root, epoch_size=train_size,sequence=10
                              , train=True, download=True, refresh_cache=False, 
                              window=window, mmap=False, pitch_shift=pitch_shift, jitter=jitter)
test_set = musicnet.MusicNet(root=root, train=False, download=True, refresh_cache=False, window=window, epoch_size=test_size, mmap=False)
print("Time used = ", time()-start)

Time used =  27.62343454360962


In [6]:
train_loader = torch.utils.data.DataLoader(dataset=train_set,batch_size=batch_size,**kwargs)
test_loader = torch.utils.data.DataLoader(dataset=test_set,batch_size=batch_size,**kwargs)

In [7]:
def create_filtersv2(n_fft, freq_bins=None, low=50,high=6000, sr=44100, freq_scale='linear', mode="fft"):
    if freq_bins==None:
        freq_bins = n_fft//2+1
    
    s = torch.arange(0, n_fft, 1.)
    wsin = torch.empty((freq_bins,1,n_fft))
    wcos = torch.empty((freq_bins,1,n_fft))
    start_freq = low
    end_freq = high
    

    # num_cycles = start_freq*d/44000.
    # scaling_ind = np.log(end_freq/start_freq)/k
    
    if mode=="fft":
        window_mask = 1
    elif mode=="stft":
        window_mask = 0.5-0.5*torch.cos(2*math.pi*s/(n_fft)) # same as hann(n_fft, sym=False)
    else:
        raise Exception("Unknown mode, please chooes either \"stft\" or \"fft\"")
        
    if freq_scale == 'linear':
        start_bin = start_freq*n_fft/sr
        scaling_ind = (end_freq/start_freq)/freq_bins
        for k in range(freq_bins): # Only half of the bins contain useful info
            wsin[k,0,:] = window_mask*torch.sin(2*math.pi*(k*scaling_ind*start_bin)*s/n_fft)
            wcos[k,0,:] = window_mask*torch.cos(2*math.pi*(k*scaling_ind*start_bin)*s/n_fft)
    elif freq_scale == 'log':
        start_bin = start_freq*n_fft/sr
        scaling_ind = np.log(end_freq/start_freq)/freq_bins
        for k in range(freq_bins): # Only half of the bins contain useful info
            wsin[k,0,:] = window_mask*torch.sin(2*math.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)
            wcos[k,0,:] = window_mask*torch.cos(2*math.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)    
    else:
        print("Please select the correct frequency scale, 'linear' or 'log'")
    
    return wsin,wcos


In [8]:
Loss = nn.MSELoss()
def L(yhatvar,y):
    return Loss(yhatvar,y) * 128/2

In [9]:
class Model(nn.Module):
    def __init__(self, avg=.9998):
        super(Model, self).__init__()
        # Create filter windows
        wsin, wcos = create_filtersv2(n_fft,k, low=50, high=6000,
                                      mode="stft", freq_scale='log')
        with torch.cuda.device(0):
            self.wsin = torch.Tensor(wsin).cuda()
            self.wcos = torch.Tensor(wcos).cuda()
            
        # Creating Layers
#         self.linear = torch.nn.Linear(regions*k, k,bias=False)
#         self.linear_output = torch.nn.Linear(k,m, bias=False)
#         wscale = 10e-5
#         torch.nn.init.normal_(self.linear.weight, std=1e-4) # initialize
#         torch.nn.init.normal_(self.linear_output.weight, std=1e-4)
#         torch.nn.init.zeros_(self.linear.weight)
#         torch.nn.init.zeros_(self.linear_output.weight)
        k_out = 128
        k2_out = 256
        stride1 = (2,1)
        self.CNN_freq = nn.Conv2d(1,k_out,
                                kernel_size=(128,1),stride=(2,1))
        self.CNN_time = nn.Conv2d(k_out,k2_out,
                                kernel_size=(1,25),stride=(1,1))
        self.Linear = nn.Linear(k2_out*1*193, m)
        self.activation = nn.ReLU()
        
        torch.nn.init.normal_(self.CNN_freq.weight, std=1e-4)
        torch.nn.init.normal_(self.CNN_time.weight, std=1e-4)
        torch.nn.init.normal_(self.Linear.weight, std=1e-4)
        
        
        self.avg = avg
        #Create a container for weight average
        self.averages = copy.deepcopy(list(parm.cuda().data for parm in self.parameters())) 
    def _get_conv_output(self, shape):
        bs = 1
        x = torch.tensor(torch.rand(bs, *shape))
        output_feat = self._forward_features(x)
        n_size = output_feat.data.view(bs, -1).size(1)
        return n_size
        
    def forward(self,x):
        zx = conv1d(x[:,None,:], self.wsin, stride=stride).pow(2) \
           + conv1d(x[:,None,:], self.wcos, stride=stride).pow(2)
        zx = torch.log(zx + 10e-8) # Log Magnitude Spectrogram
        z2 = self.CNN_freq(zx.unsqueeze(1)) # Make channel as 1 (N,C,H,W)
        z2 = F.relu(z2)
        z3 = self.CNN_time(z2)
        z3 = F.relu(z3)
        y = self.Linear(self.activation(torch.flatten(z3,1)))
        return y
    
    def average_iterates(self):
        for parm, pavg in zip(self.parameters(), self.averages):
            pavg.mul_(self.avg).add_(1.-self.avg, parm.data) # 0.9W_avg + 0.1W_this_ite
    
    
@contextmanager
def averages(model):
    orig_parms = copy.deepcopy(list(parm.data for parm in model.parameters()))
    for parm, pavg in zip(model.parameters(), model.averages):
        parm.data.copy_(pavg)
    yield
    for parm, orig in zip(model.parameters(), orig_parms):
        parm.data.copy_(orig)

In [10]:
model = Model()
model.cuda()

result_dict = {'loss_history_train': [],
               'avgp_history_train': [],
               'loss_history_test': [],
               'avgp_history_test': [],
               'parameters': {}}

result_dict['parameters']['train_size'] = train_size
result_dict['parameters']['test_size'] = test_size
result_dict['parameters']['lr'] = lr
result_dict['parameters']['pitch_shift'] = pitch_shift
result_dict['parameters']['jitter'] = jitter

In [11]:
avg = .9998
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)
# optimizer = SWA(base_opt, swa_start=0, swa_freq=1, swa_lr=0.000001)

try:
    with train_set, test_set:
        total_i = len(train_loader)
        print("epoch\ttrain loss\ttest loss\ttrain avg\ttest avg\ttime\tutime")
        for e in range(epochs):
            yground = torch.FloatTensor(batch_size*len(train_loader), m) # what not do this together with loss
            yhat = torch.FloatTensor(batch_size*len(train_loader), m)
            avgp, loss_e = 0.,0
            t = time()
            for i, (x,y) in enumerate(train_loader):
                print(f"Training {i}/{total_i} batches", end = '\r')
                optimizer.zero_grad()
#                 print(model.Layer2.weight[0][0][0])
                # making x and y into pytorch dealable format
                x = x.cuda(non_blocking=True)
                y = y.cuda(non_blocking=True)
                break
                yhatvar = model(x)
                loss = L(yhatvar,y)
                loss.backward()
                loss_e += loss.item() #getting the number
                
                yground[i*batch_size:(i+1)*batch_size] = y.data
                yhat[i*batch_size:(i+1)*batch_size] = yhatvar.data
                
                optimizer.step()
                model.average_iterates() # Averaging the weights for validation
                
            avgp = average_precision_score(yground.flatten(),yhat.flatten())    
            result_dict['loss_history_train'].append(loss_e/len(train_loader))
            result_dict['avgp_history_train'].append(avgp)   
            t1 = time()
            avgp, loss_e = 0.,0.           
#             optimizer.swap_swa_sgd() # change to average weight
            
            # For testing
            yground = torch.FloatTensor(batch_size*len(test_loader), m) # what not do this together with loss
            yhat = torch.FloatTensor(batch_size*len(test_loader), m)
            
            for i, (x_test,y_test) in enumerate(test_loader):
                print(f"Testing {i}/{len(test_loader)} batches", end = '\r')
                x_test = x_test.cuda()
                y_test = y_test.cuda()
                yhatvar = model(x_test)
                loss_e += L(yhatvar, y_test).item() #getting the number

                yground[i*batch_size:(i+1)*batch_size] = y_test.data
                yhat[i*batch_size:(i+1)*batch_size] = yhatvar.data
            avgp = average_precision_score(yground.cpu().flatten(),yhat.cpu().flatten())
            result_dict['loss_history_test'].append(loss_e/len(test_loader))
            result_dict['avgp_history_test'].append(avgp)
            print('{}\t{:2f}\t{:2f}\t{:2f}\t{:2f}\t{:2.1f}\t{:2.1f}'.\
                  format(e,
                         result_dict['loss_history_train'][-1],result_dict['loss_history_test'][-1],
                         result_dict['avgp_history_train'][-1],result_dict['avgp_history_test'][-1],
                         time()-t, time()-t1))


        
except KeyboardInterrupt:
    print('Graceful Exit')
else:
    print("Finsihed")
    

epoch	train loss	test loss	train avg	test avg	time	utime
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq

seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10


seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10


seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10
seq =  10


  recall = tps / tps[-1]


seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
s

seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
s

seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
s

seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
s

seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
s

seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
seq =  1
s

In [14]:
x.shape

In [13]:
train_set.rec_ids[1]

In [None]:
x, y = train_set.accessv2(2560, 100,10)

In [None]:
x.shape

In [None]:
with open('../../../data/train_data/2560.bin', 'rb') as f:
    x = np.fromfile(f, dtype=np.float32, count=int(1*window*10))

In [None]:
s = 100
sz_float = 4
sequence = 2
scale = 1
fid = train_set.records[2560][0]
#             start = time()
with open(fid, 'rb') as f:
    f.seek(s*sz_float, os.SEEK_SET)
    print('seq = ', sequence)
    x = np.fromfile(f, dtype=np.float32, count=int(scale*s*window))
#             x = torch.load(fid[:-4])
#             x = x[s:s+swindow]
#             print(time()-start)
print('x shape = ', x.shape)

xp = np.arange(window,dtype=np.float32)
x = np.interp(scale*xp,np.arange(len(x),dtype=np.float32),x).astype(np.float32)



In [31]:
x.shape

(16384,)