In [1]:
from collections import namedtuple

import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as td

import pytorch_lightning as pl

import tqdm
import json
import sklearn.metrics as sm

import tensorboardX as tb
# import tensorflow as tf
import datetime, os

import matplotlib.pyplot as plt
import seaborn as sns

np.random.seed(31337)

In [2]:
DATA_DIR = 'data'

In [3]:
tracks = pd.read_json(os.path.join(DATA_DIR, 'tracks.json'), lines=True)

In [4]:
artist_dict = {}
artist_dict['artist'] = tracks.artist.unique().tolist()
df_artist = pd.DataFrame.from_dict(artist_dict)
df_artist['artist_id'] = list(range(len(df_artist)))

In [5]:
df_artist.to_csv(os.path.join(DATA_DIR, 'artists.csv'))

In [6]:
df_artist.set_index('artist', inplace=True)

In [7]:
tracks = tracks.join(df_artist, on='artist')

In [8]:
DATA_DIR = 'data'
train_data = pd.read_csv(os.path.join(DATA_DIR, 'train_data.csv'))
val_data = pd.read_csv(os.path.join(DATA_DIR, 'val_data.csv'))
test_data = pd.read_csv(os.path.join(DATA_DIR, 'test_data.csv'))

In [9]:
def add_artist(df, df_with_artists):
    df_new = df.join(df_with_artists.set_index('track')['artist_id'], on='start')
    df_new = df_new.join(df_with_artists.set_index('track')['artist_id'], on='track', rsuffix='_track')
    df_new = df_new.rename(columns={'artist_id': 'artist_context', 'artist_id_track': 'artist_track'})
    assert len(df_new) == len(df)
    return df_new

In [10]:
train_data_artist = add_artist(train_data, tracks)
val_data_artist = add_artist(val_data, tracks)
test_data_artist = add_artist(test_data, tracks)

In [11]:
train_data_artist.to_csv(os.path.join(DATA_DIR, 'train_data_artist.csv'), index=None)
val_data_artist.to_csv(os.path.join(DATA_DIR, 'val_data_artist.csv'), index=None)
test_data_artist.to_csv(os.path.join(DATA_DIR, 'test_data_artist.csv'), index=None)

In [12]:
train_data_artist

Unnamed: 0,user,start,track,time,artist_context,artist_track
0,0,2999,5089,0.00,1355,443
1,0,2999,7960,0.00,1355,969
2,0,2999,1725,0.00,1355,304
3,0,2999,2606,0.00,1355,1197
4,0,31348,2224,0.00,241,1009
...,...,...,...,...,...,...
572582,9999,2688,12,0.00,1213,11
572583,9999,44241,2580,0.02,3235,374
572584,9999,44241,7960,0.25,3235,969
572585,9999,44241,5372,0.00,3235,5
