In [1]:
import layers
import numpy as np
import torch

In [2]:
train_set_dir = "/media/data2/Data/wsdm2019/data/training_set/"
test_set_dir = "/media/data2/Data/wsdm2019/data/test_set/"
track_features_dir = "/media/data2/Data/wsdm2019/python/data/track_features/"

train_example_dir = "/media/data2/Data/wsdm2019/python/data/train_examples/"

In [3]:
# get all the maps and vectors
import pickle
with open(track_features_dir + "music_vector.pkl", "rb") as f:
    vector = pickle.load(f)
with open(track_features_dir + "track2idx.pkl", "rb") as f:
    track2idx = pickle.load(f)

In [4]:
skip2idx = {'PAD':0, 'UNK':1, False:2, True:3}

In [5]:
music_embedding = torch.Tensor(vector)

In [6]:
with open(train_example_dir+"sample_examples", "rb") as f:
    sample_examples = pickle.load(f)

In [7]:
# brute force pad everything to length 20 since all lengths are of range 10 to 20
max_length = 20

In [8]:
def pad_sequences(sequences):
    sequence_lengths = [len(sequence) for sequence in sequences]
#     max_length = max(sequence_lengths)
    padded_sequences = [sequences[i] + [0] * (max_length - sequence_lengths[i]) for i in range(len(sequences))]
    sequence_masks = [[0.0] * sequence_lengths[i] + [1.0] * (max_length - sequence_lengths[i]) for i in range(len(sequences))]
    return torch.cuda.LongTensor(padded_sequences), torch.cuda.ByteTensor(sequence_masks)

In [9]:
# take in examples and generate batches
import math
import copy
def batchGen(examples):
    inputs_track = []
    inputs_skip = []
    targets = []
    loss_masks = []
    sequence_masks = []
    
    tracks = [example['tracks'] for example in examples]
    skip2s = [example['skip2'] for example in examples]
    sequence_lengths = [len(track) for track in tracks]
    
    for track, skip2, sequence_length in zip(tracks, skip2s, sequence_lengths):
        
        pad_length = max_length - sequence_length
        cut = math.floor(sequence_length/2)
        
        # create song inputs
        input_track = copy.deepcopy(track) + [0] * pad_length
        inputs_track.append(input_track)        
        
        # create skip inputs and mask out the correct answer
        input_skip = copy.deepcopy(skip2) + [0] * pad_length
        input_skip[cut:sequence_length] = (sequence_length-cut) * [1]
        inputs_skip.append(input_skip)

        # create targets
        target = copy.deepcopy(skip2)
        target = [item-2 for item in target] + [0] * pad_length
        targets.append(target)

        # create loss masks
        mask = [0.0] * cut + [1.0] * (sequence_length-cut) + [0.0] * pad_length
        loss_masks.append(mask)
        
        # create sequence masks
        sequence_mask = [0.0] * sequence_length + [1.0] * pad_length
        sequence_masks.append(sequence_mask)
    
    return torch.cuda.LongTensor(inputs_track), torch.cuda.LongTensor(inputs_skip), torch.cuda.FloatTensor(targets), torch.cuda.FloatTensor(loss_masks), torch.cuda.ByteTensor(sequence_masks)

In [10]:
a,b,c,d,e = batchGen(sample_examples[:10])

In [11]:
c.size()

torch.Size([10, 20])

In [12]:
from rnn_encoder import RNNEncoder
encoder = RNNEncoder()

In [13]:
encoder.init_embeddings(music_embedding)
encoder.cuda()
encoder.train()

RNNEncoder(
  (song_embedding): Embedding(3706389, 8, padding_idx=0)
  (skip_embedding): Embedding(4, 2, padding_idx=0)
  (bidirectional_rnn): StackedBRNN(
    (rnns): ModuleList(
      (0): LSTM(10, 10, bidirectional=True)
      (1): LSTM(20, 10, bidirectional=True)
    )
  )
  (linear1): Linear(in_features=20, out_features=1, bias=True)
)

In [14]:
encoder_result = encoder.forward(a,b,e)



In [15]:
loss_fcn = torch.nn.BCELoss(reduction='none')

In [16]:
# loss = (loss_fcn(encoder_result, c) * d).mean()

In [17]:
# loss.backward()

In [18]:
res = encoder_result * d > 0.5

In [19]:
res = np.array(res)

In [20]:
gt = np.array(c)

In [21]:
np.sum((res == gt) * np.array(d)) / np.sum(np.array(d))

0.35869566

In [23]:
import torch.optim as optim
import random

encoder_optimizer = optim.Adamax(params=encoder.parameters(), lr = 0.01, weight_decay = 0.0)
step = 0

avg_acc = 0
avg_loss = 0

for epoch in range(10):
    dataset_size = len(sample_examples)
    batch_size = 32
    print("epoch: " + str(epoch))
    for start_index in range(0, dataset_size, batch_size):
        end_index = min(start_index + batch_size, dataset_size)
        batch = sample_examples[start_index: end_index]
        
        it, isk, t, lm, sm = batchGen(batch)
        
        encoder_result = encoder.forward(it, isk, sm)
        loss = (loss_fcn(encoder_result, t) * lm).mean()
        encoder_optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(encoder.parameters(), 5)
        
        encoder_optimizer.step()
        step += 1
        
        res = (encoder_result * lm > 0.5).detach().cpu().numpy()
        gt = t.detach().cpu().numpy()
        acc = np.sum((res == gt) * np.array(lm)) / np.sum(np.array(lm))
        
        avg_acc += acc
        avg_loss += loss.detach().cpu().numpy()
        
        if step % 200 == 0:
            print("step: " + str(step) + " accuracy: " + str(avg_acc/200) + " loss: " + str(avg_loss/200))
            avg_acc = 0
            avg_loss = 0
            
            
        



step: 200 accuracy: 0.6280155904591084 loss: 0.27692155353724957
step: 400 accuracy: 0.6544649860262871 loss: 0.268240697234869
step: 600 accuracy: 0.6529144057631493 loss: 0.26946472719311715
step: 800 accuracy: 0.6530436463654041 loss: 0.26894513636827466
step: 1000 accuracy: 0.6595906069874764 loss: 0.2656978008151054
step: 1200 accuracy: 0.6570003083348275 loss: 0.2661834440380335
step: 1400 accuracy: 0.6607881060242653 loss: 0.2670547955483198
step: 1600 accuracy: 0.6592116144299507 loss: 0.2659791777282953
step: 1800 accuracy: 0.6574752444028854 loss: 0.2667047675698996
step: 2000 accuracy: 0.663350087404251 loss: 0.2643253093957901
step: 2200 accuracy: 0.6630785179138183 loss: 0.26477712623775007
step: 2400 accuracy: 0.6588929070532322 loss: 0.26595423512160776
step: 2600 accuracy: 0.6658255207538605 loss: 0.2644596965610981
step: 2800 accuracy: 0.6661141327023506 loss: 0.2631688713282347
step: 3000 accuracy: 0.657518829703331 loss: 0.26649584248661995
step: 3200 accuracy: 0.666

KeyboardInterrupt: 