In [None]:
import torch
import torch.nn as nn
import torch.utils.data as data
import os
import numpy as np
import json

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
def expandData(weekly_songs, encoded):
    res = []
    for pos in range(len(encoded)):
        item = []
        for t in range(len(encoded[0])):
            code = encoded[pos][t]
            item.append(weekly_songs[code[0]][code[1]])
        res.append(item)

    available = [i for i in range(len(res))]
    res2 = []
    while len(available) > 0:
        idx = available[np.random.randint(len(available))]
        res2.append(res[idx])
        available.remove(idx)
    
    return res2

In [None]:
class JsonDataset(data.Dataset):
    """
    Each item is a tuple t, with:
    t[0].shape = num_top_songs X sequence_length X x_seq_size
    t[1].shape = x_size
    t[2].shape = 1
    """
    def __init__(self, data_path):
        f = open(data_path, 'r')
        d = json.loads(f.read())
        self.weekly_songs = d['weekly_songs']
        self.data = d['data']
        f.close()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return torch.FloatTensor(expandData(self.weekly_songs, self.data[index][0])), \
            torch.FloatTensor(self.data[index][1]), \
            torch.LongTensor([float(self.data[index][2])])
        #tmp = expandData(self.weekly_songs, self.data[index][0])
        #if tmp[-1][-1][-1] > 0.5:
        #    return torch.FloatTensor(tmp), torch.FloatTensor(self.data[index][1]), torch.LongTensor([1])
        #else:
        #    return torch.FloatTensor(tmp), torch.FloatTensor(self.data[index][1]), torch.LongTensor([0])

In [None]:
train_data = JsonDataset('drive/MyDrive/cpsc490/small-multi-rnn-train.json')
validation_data = JsonDataset('drive/MyDrive/cpsc490/small-multi-rnn-validation.json')
test_data = JsonDataset('drive/MyDrive/cpsc490/small-multi-rnn-test.json')

In [None]:
"""
item = train_data[0][0]

sum = 0
for i in range(0, 199):
    sum += np.linalg.norm(item[1][i] - item[1][i + 1])

print(sum)

point = item[np.random.randint(63)][i]
sum2 = 0
for i in range(0, 199):
    point2 = item[np.random.randint(63)][i + 1]
    sum2 += np.linalg.norm(point - point2)
    point = point2

print(sum2)
"""


27.238822096958756
341.9393405262381


In [None]:
num_top_songs = train_data[0][0].shape[0]

x_seq_size = train_data[0][0].shape[2]
rnn_hidden_size = 30
rnn_num_layers = 2

x_size = train_data[0][1].shape[0]
fc_hidden_size = 3000
fc_num_layers = 10

batch_size = 64

In [None]:
params = {'batch_size': 64, 'shuffle': True, 'num_workers': 1, 'pin_memory': True}
train_loader = data.DataLoader(train_data, **params)
validation_loader = data.DataLoader(validation_data, **params)
test_loader = data.DataLoader(test_data, **params)

In [None]:
class MultiRNN(nn.Module):
    def __init__(self, num_top_songs, x_seq_size, rnn_hidden_size, rnn_num_layers, x_size, fc_hidden_size, fc_num_layers):
        super(MultiRNN, self).__init__()

        self.num_top_songs = num_top_songs
        self.rnn_hidden_size = rnn_hidden_size
        self.rnn_num_layers = rnn_num_layers

        self.rnns = nn.ModuleList([nn.LSTM(x_seq_size, rnn_hidden_size, rnn_num_layers,
                                            batch_first = True) for _ in range(num_top_songs)])

        seq = []
        seq.append(nn.Linear(num_top_songs * rnn_hidden_size + x_size, fc_hidden_size))
        seq.append(nn.Tanh())
        seq.append(nn.Dropout(0.8))

        for _ in range(fc_num_layers - 1):
            seq.append(nn.Linear(fc_hidden_size, fc_hidden_size))
            seq.append(nn.Tanh())
            seq.append(nn.Dropout(0.8))
        
        seq.append(nn.Linear(fc_hidden_size, 2))

        self.fc = nn.Sequential(*seq)

    def forward(self, x_seqs, x):
        hs = None
        for i in range(self.num_top_songs):
            h0 = torch.zeros(self.rnn_num_layers, x_seqs.shape[0], self.rnn_hidden_size).to(device)
            c0 = torch.zeros(self.rnn_num_layers, x_seqs.shape[0], self.rnn_hidden_size).to(device)

            out, _ = self.rnns[i](x_seqs[:, i, :, :], (h0, c0))

            h = out[:, -1, :] # h.shape = batch_size x rnn_hidden_size
            if hs == None:
                hs = h
            else:
                hs = torch.cat((hs, h), 1)

        joined = torch.cat((hs, x), 1) # joined.shape = batch_size x (num_top_songs * rnn_hidden_size + x_size)
        out = self.fc(joined)
        return out

In [None]:
def train(model, criterion, optimizer):
    model.train()
    train_loss = 0
    correct = 0
    total = 0

    for i, (x_seqs, x, targets) in enumerate(train_loader):
        x_seqs = x_seqs.to(device)
        x = x.to(device)
        targets = torch.flatten(targets).to(device)

        outputs = model(x_seqs, x)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total += targets.size(0)
        train_loss += loss.item() * targets.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(targets).sum().item()
        
    epoch_train_loss = train_loss / total
    epoch_train_acc = float(100 * correct / total)

    return epoch_train_loss, epoch_train_acc

In [None]:
def validation(model, criterion):
    model.eval()
    validation_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for i, (x_seqs, x, targets) in enumerate(validation_loader):
            x_seqs = x_seqs.to(device)
            x = x.to(device)
            targets = torch.flatten(targets).to(device)

            outputs = model(x_seqs, x)
            loss = criterion(outputs, targets)

        total += targets.size(0)
        validation_loss += loss.item() * targets.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(targets).sum().item()
        
    epoch_validation_loss = validation_loss / total
    epoch_validation_acc = float(100 * correct / total)

    return epoch_validation_loss, epoch_validation_acc

In [None]:
weight_zero = len([i for i in range(len(train_data)) if train_data[i][2] == 1]) / len(train_data)
print('weight_zero: {}'.format(weight_zero))

model = MultiRNN(num_top_songs, x_seq_size, rnn_hidden_size, rnn_num_layers, x_size, fc_hidden_size, fc_num_layers).to(device)
criterion = nn.CrossEntropyLoss(weight=torch.tensor([weight_zero, 1 - weight_zero]).to(device))
#criterion = nn.CrossEntropyLoss().to(device)
#criterion = nn.MSELoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)  

num_epochs = 100

weight_zero: 0.4828125


In [None]:
# Train
best_validation_loss = None

for epoch in range(0, num_epochs):
    epoch_train_loss, epoch_train_acc = train(model, criterion, optimizer)
    epoch_validation_loss, epoch_validation_acc = validation(
                                                    model, criterion)
    
    if best_validation_loss == None or epoch_validation_loss < best_validation_loss:
        torch.save(model.state_dict(), 'best_multi_rnn.pth')
        print('Saved.')
        best_validation_loss = epoch_validation_loss

    print('Epoch {}. Training loss: {} ({}% accuracy). Validation loss: {} ({}% accuracy)'
        .format(epoch + 1, 
                format(epoch_train_loss, '.4f'), format(epoch_train_acc, '.4f'),
                format(epoch_validation_loss, '.4f'), format(epoch_validation_acc, '.4f')))
    

In [None]:
# Test
model.load_state_dict(torch.load('best_multi_rnn.pth'))

with torch.no_grad():
    n_correct = 0
    n_samples = 0
    for i, (x_seq, x, targets) in enumerate(test_loader):
        x_seq = x_seq.to(device)
        x = x.to(device)
        targets = torch.flatten(targets).to(device)
        #targets = targets.reshape(-1, 1).to(device)

        outputs = model(x_seq, x)
        _, predicted = torch.max(outputs.data, 1)
        
        #if i == 0:
            #print(outputs)
            #print(predicted)
            #print(targets)

        n_samples += x_seq.shape[0]
        n_correct += (predicted == targets).sum().item()
    
    acc = float(100 * n_correct / n_samples)
    print('Test accuracy: {}%'.format(acc))

Test accuracy: 45.0%
