# **  This notebook handles data processing and loading**- for Last.FM dataset

-

-

In [None]:
from typing import Callable, Any

import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import os
import numpy as np
import scipy
from scipy.sparse import csr_matrix
from pathlib import Path
from torch.utils.data import DataLoader
export_dir = os.getcwd()
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.graph_objs as go
from plotly.offline import plot
import random
import math
import heapq
from scipy.special import expit  # Sigmoid function
import itertools
from IPython.display import Latex, display
import pickle
import warnings

# Ignore FutureWarnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

# pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
torch.set_printoptions(sci_mode=False)

test_flag = 1

In [None]:
pip install ipynb

In [None]:
from ipynb.fs.defs.utils import *
from ipynb.fs.defs.training import *
from ipynb.fs.defs.models import *

# Load Dataset

In [None]:
users = pd.read_csv(Path(export_dir,'dataset/movielens/users.dat'), sep='::', header=None, engine='python').drop(index=4169,axis=0)

movies = pd.read_csv(Path(export_dir,'dataset/movielens/movies.dat'), sep='::', encoding='ISO-8859-1',header=None, engine='python')
ratings = pd.read_csv(Path(export_dir,'dataset/movielens/ratings.dat'), sep='::', header=None, engine='python')


df_music = Path(export_dir,'dataset/lastFM/music_data.csv')
df_user = Path(export_dir,'dataset/lastFM/user_data.csv')

In [None]:
# Filter out redundant columns
columns_to_keep = ['track_id', 'name', 'artist', 'tags', 'year']
df_music_filtered = df_music[columns_to_keep]
df_joined_user_music = pd.merge(df_user, df_music_filtered, on='track_id', how='inner')
df_joined_user_music

Unnamed: 0,track_id,user_id,playcount,name,artist,tags,year
0,TRIRLYL128F42539D1,b80344d063b5ccb3212f76538f3d9e43d87dca9e,1,Nothing From Nothing,Billy Preston,"soul, funk, piano, 70s, oldies",2010
1,TRFUPBA128F934F7E1,b80344d063b5ccb3212f76538f3d9e43d87dca9e,1,Paper Gangsta,Lady Gaga,"electronic, pop, female_vocalists, dance, piano, electro",2012
2,TRLQPQJ128F42AA94F,b80344d063b5ccb3212f76538f3d9e43d87dca9e,1,Soy,Gipsy Kings,guitar,2003
3,TRTUCUY128F92E1D24,b80344d063b5ccb3212f76538f3d9e43d87dca9e,1,Stacked Actors,Foo Fighters,"rock, alternative, alternative_rock, hard_rock, 90s, grunge",1999
4,TRHDDQG12903CB53EE,b80344d063b5ccb3212f76538f3d9e43d87dca9e,1,Heaven's Gonna Burn Your Eyes,Thievery Corporation,"electronic, female_vocalists, ambient, chillout, trip_hop, downtempo, lounge, chill",2002
...,...,...,...,...,...,...,...
9711296,TRBKFKL128E078ED76,b7815dbb206eb2831ce0fe040d0aa537e2e800f7,1,Hailie's Song,Eminem,"pop, rap, hip_hop, love, 00s",2002
9711297,TRYFXPG128E078ECBD,b7815dbb206eb2831ce0fe040d0aa537e2e800f7,1,Forgot About Dre,Dr. Dre,"rap, hip_hop",1999
9711298,TROBUUZ128F4263002,b7815dbb206eb2831ce0fe040d0aa537e2e800f7,1,Paralyzer,Finger Eleven,"rock, alternative, alternative_rock, hard_rock",2007
9711299,TROEWXC128F148C83E,b7815dbb206eb2831ce0fe040d0aa537e2e800f7,1,What's The Difference,Dr. Dre,"rap, 90s, hip_hop",1999


In [None]:
'''artists who were listened by more than 50 users'''
user_counts_per_artist = df_joined_user_music.groupby('artist')['user_id'].nunique()

# keep artists with more than 50 listeners
artists_with_enough_users = user_counts_per_artist[user_counts_per_artist > 50].index

# filter data
filtered_df = df_joined_user_music[df_joined_user_music['artist'].isin(artists_with_enough_users)]
filtered_df

Unnamed: 0,track_id,user_id,playcount,name,artist,tags,year
0,TRIRLYL128F42539D1,b80344d063b5ccb3212f76538f3d9e43d87dca9e,1,Nothing From Nothing,Billy Preston,"soul, funk, piano, 70s, oldies",2010
1,TRFUPBA128F934F7E1,b80344d063b5ccb3212f76538f3d9e43d87dca9e,1,Paper Gangsta,Lady Gaga,"electronic, pop, female_vocalists, dance, piano, electro",2012
2,TRLQPQJ128F42AA94F,b80344d063b5ccb3212f76538f3d9e43d87dca9e,1,Soy,Gipsy Kings,guitar,2003
3,TRTUCUY128F92E1D24,b80344d063b5ccb3212f76538f3d9e43d87dca9e,1,Stacked Actors,Foo Fighters,"rock, alternative, alternative_rock, hard_rock, 90s, grunge",1999
4,TRHDDQG12903CB53EE,b80344d063b5ccb3212f76538f3d9e43d87dca9e,1,Heaven's Gonna Burn Your Eyes,Thievery Corporation,"electronic, female_vocalists, ambient, chillout, trip_hop, downtempo, lounge, chill",2002
...,...,...,...,...,...,...,...
9711296,TRBKFKL128E078ED76,b7815dbb206eb2831ce0fe040d0aa537e2e800f7,1,Hailie's Song,Eminem,"pop, rap, hip_hop, love, 00s",2002
9711297,TRYFXPG128E078ECBD,b7815dbb206eb2831ce0fe040d0aa537e2e800f7,1,Forgot About Dre,Dr. Dre,"rap, hip_hop",1999
9711298,TROBUUZ128F4263002,b7815dbb206eb2831ce0fe040d0aa537e2e800f7,1,Paralyzer,Finger Eleven,"rock, alternative, alternative_rock, hard_rock",2007
9711299,TROEWXC128F148C83E,b7815dbb206eb2831ce0fe040d0aa537e2e800f7,1,What's The Difference,Dr. Dre,"rap, 90s, hip_hop",1999


In [None]:
'''users that listened to more than 50 songs'''
# Get all values that occur more than 50×
counts = filtered_df['user_id'].value_counts()
common_vals = counts[counts > 50].index

# # Keep only rows whose value is in that list
filtered_df1 = filtered_df[filtered_df['user_id'].isin(common_vals)]

In [None]:
'''artists who were listened by more than 40 users'''
user_counts_per_artist_1 = filtered_df1.groupby('artist')['user_id'].nunique()

# keep artists with more than 50 listeners
artists_with_enough_users = user_counts_per_artist_1[user_counts_per_artist_1 > 40].index

# filter data
filtered_df2 = filtered_df1[filtered_df1['artist'].isin(artists_with_enough_users)]
filtered_df2

Unnamed: 0,track_id,user_id,playcount,name,artist,tags,year
121,TRLATHU128F92FC275,5a905f000fc1ff3df7ca807d57edb608863db05d,11,Freedom Blade,This Will Destroy You,"instrumental, post_rock",2009
122,TRMKFPN128F42858C3,5a905f000fc1ff3df7ca807d57edb608863db05d,2,Caterpillar House,Black Moth Super Rainbow,"psychedelic, psychedelic_rock",2006
123,TRTSSUT128F1472A51,5a905f000fc1ff3df7ca807d57edb608863db05d,1,Tchaparian,Hot Chip,"electronic, dance",2006
124,TRNJLKP128F427CE28,5a905f000fc1ff3df7ca807d57edb608863db05d,1,Aerodynamic,Daft Punk,"electronic, dance, house, techno, electro, french",2001
125,TRGAOLV128E0789D40,5a905f000fc1ff3df7ca807d57edb608863db05d,2,Swallowed in the Sea,Coldplay,"rock, alternative, indie, pop, alternative_rock, indie_rock, british, love, britpop, mellow",2005
...,...,...,...,...,...,...,...
9711269,TRGCHLH12903CB7352,8305c896f42308824da7d4386f4b9ee584281412,5,Party In The U.S.A.,The Barden Bellas,"soundtrack, cover",2012
9711270,TRVSJOM12903CD2DC1,8305c896f42308824da7d4386f4b9ee584281412,1,One Less Lonely Girl,Justin Bieber,"pop, rnb, love",2010
9711271,TRAALAH128E078234A,8305c896f42308824da7d4386f4b9ee584281412,2,Bitter Sweet Symphony,The Verve,"rock, alternative, indie, pop, alternative_rock, indie_rock, british, chillout, soundtrack, 90s, beautiful, britpop",1999
9711272,TRTKLFX12903CD2DC2,8305c896f42308824da7d4386f4b9ee584281412,2,First Dance,Justin Bieber,"black_metal, industrial, thrash_metal, melodic_death_metal, power_metal, doom_metal, gothic_metal, grunge, symphonic_metal, grindcore, nu_metal",2010


In [None]:
distinct_user_count = filtered_df2['user_id'].nunique()
distinct_track_count = filtered_df2['artist'].nunique()
print(f"Number of distinct user IDs: {distinct_user_count}, Number of distinct track IDs: {distinct_track_count}")

Number of distinct user IDs: 22546, Number of distinct track IDs: 2277


In [None]:
df_joined_user_music_new=filtered_df2.copy()

## Create 1 hot interactions matrix:

In [None]:
user_artist_matrix = df_joined_user_music_new.groupby(['user_id', 'artist']).size().unstack(fill_value=0)
users_index_artists = user_artist_matrix.index

# Convert the matrix to binary (1 if listened, 0 otherwise)
user_artist_matrix = (user_artist_matrix > 0).astype(int)
user_artist_matrix = user_artist_matrix.reset_index(drop=True)
user_artist_matrix

artist,"""Weird Al"" Yankovic",...And You Will Know Us by the Trail of Dead,.38 Special,10 Years,10cc,12 Stones,1200 Micrograms,16 Horsepower,2 Live Crew,2 Unlimited,...,dEUS,deadmau5,dredg,jj,mclusky,mewithoutYou,múm,of Montreal,Émilie Simon,Ólafur Arnalds
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
22541,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
22542,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
22543,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
22544,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [None]:
USERS_lastFM = len(user_artist_matrix)
ITEMS_lastFM = len(user_artist_matrix.columns)
print(USERS_lastFM,ITEMS_lastFM)

22546 2277


# Data Processing and analysis:

In [None]:
excluded_tags = {'beautiful', 'indie', 'alternative','singer_songwriter', 'oldies', 'alternative_rock', 'indie_pop', 'pop_rock', 'singer_songwriter',
    'ambient', 'experimental', 'electro', 'grunge', 'synthpop', 'rnb',
    'seen_live', 'favorites','00s', '90s', '80s', '70s', '60s'}

df_joined_user_music_new1 = df_joined_user_music_new.copy()
df_joined_user_music_new1['tags'] = df_joined_user_music_new1['tags'].str.lower().str.split(',')

df_joined_user_music_new1['tags'] = df_joined_user_music_new1['tags'].apply(lambda x: x if isinstance(x, list) else [])

df_exploded = df_joined_user_music_new1.explode('tags')
df_exploded['tags'] = df_exploded['tags'].str.strip()

df_exploded = df_exploded[~df_exploded['tags'].isin(excluded_tags)]

tag_counts = df_exploded.groupby(['artist', 'tags']).size().reset_index(name='count')

top_tags = (
    tag_counts.sort_values(['artist', 'count'], ascending=[True, False])
    .groupby('artist')
    .head(4)
)

top_tags_per_artist = top_tags.groupby('artist')['tags'].apply(list).reset_index()

all_artists = user_artist_matrix.columns
top_tags_per_artist = pd.DataFrame({'artist': all_artists}).merge(top_tags_per_artist, on='artist', how='left')

top_tags_per_artist.index = top_tags_per_artist['artist']

Unnamed: 0_level_0,artist,tags
artist,Unnamed: 1_level_1,Unnamed: 2_level_1
"""Weird Al"" Yankovic","""Weird Al"" Yankovic",[cover]
...And You Will Know Us by the Trail of Dead,...And You Will Know Us by the Trail of Dead,"[progressive_rock, rock, indie_rock, post_rock]"
.38 Special,.38 Special,"[classic_rock, rock]"
10 Years,10 Years,"[rock, hard_rock, american, hardcore]"
10cc,10cc,"[classic_rock, rock, pop, male_vocalists]"
...,...,...
mewithoutYou,mewithoutYou,"[post_hardcore, rock, screamo, emo]"
múm,múm,"[post_rock, electronic, chillout, instrumental]"
of Montreal,of Montreal,"[psychedelic, dance, emo, avant_garde]"
Émilie Simon,Émilie Simon,"[electronic, french, trip_hop, pop]"


In [None]:
df_tags = top_tags_per_artist.copy()
df_tags = df_tags.drop(columns=[df_tags.columns[0]])
df_tags

Unnamed: 0_level_0,tags
artist,Unnamed: 1_level_1
"""Weird Al"" Yankovic",[cover]
...And You Will Know Us by the Trail of Dead,"[progressive_rock, rock, indie_rock, post_rock]"
.38 Special,"[classic_rock, rock]"
10 Years,"[rock, hard_rock, american, hardcore]"
10cc,"[classic_rock, rock, pop, male_vocalists]"
...,...
mewithoutYou,"[post_hardcore, rock, screamo, emo]"
múm,"[post_rock, electronic, chillout, instrumental]"
of Montreal,"[psychedelic, dance, emo, avant_garde]"
Émilie Simon,"[electronic, french, trip_hop, pop]"


In [None]:
df_tags['tags'] = df_tags['tags'].apply(lambda x: x if isinstance(x, list) else [])

all_artists = df_tags.index

df_exploded = df_tags.explode('tags')

df_exploded = df_exploded[df_exploded['tags'] != '']

df_exploded['value'] = 1

df_artists_tags = df_exploded.pivot_table(
    index='artist',
    columns='tags',
    values='value',
    fill_value=0
)

df_artists_tags = df_artists_tags.reindex(all_artists, fill_value=0)
df_artists_tags = df_artists_tags.astype(int)



In [None]:
# who are the popular movies check:
num_users_per_artist = user_artist_matrix.sum(axis=0)
num_users_per_artist_sort=num_users_per_artist.sort_values(ascending=False, ignore_index=False)

#ppopular movies:
most_popular = num_users_per_artist_sort.copy()
most_popular = most_popular.to_frame(name='num_users')
most_popular['rank'] = range(0,most_popular.shape[0])
most_popular['index1'] = [list(user_artist_matrix.columns).index(artist1) for artist1 in most_popular.index]


# num of movies for each genre:
num_per_genre = df_artists_tags.sum(axis=0)

plt.figure(figsize=(10,6))
num_per_genre.plot(kind='bar', color='lightblue')
plt.title('distribution of 200 most popular movies per genre')
plt.xlabel('Genre')
plt.ylabel('Number of Movies')
plt.xticks(rotation=45)
plt.grid(True)
plt.show()

In [None]:
from sklearn.preprocessing import normalize

# Load the dataframe directly
music_info_df = df_music_filtered.dropna(subset=['tags'])

# Clean and split tags
music_info_df['tags'] = music_info_df['tags'].str.lower().str.split(',')
music_info_df['tags'] = music_info_df['tags'].apply(lambda x: [tag.strip().replace(' ', '_') for tag in x])

# Identify all tags
all_tags_flat = {tag for tags in music_info_df['tags'] for tag in tags}

# Tags to remove fully
tags_to_remove = set([
    'indie', 'alternative', 'beautiful', 'oldies',
    'alternative_rock', 'indie_rock', 'indie_pop', 'pop_rock', 'singer_songwriter',
    'ambient', 'experimental', 'electro', 'grunge', 'synthpop', 'rnb',
    '70s', '90s', '00s', '80s', '60s'
])

# Tags to downweight
tags_to_downweight = {'rock': 0.3, 'electronic': 0.4}

# Whitelist tags for filtering
whitelist_tags = list(all_tags_flat - tags_to_remove)

# Explode to tag per row and filter
exploded_df = music_info_df.explode('tags')
exploded_df = exploded_df[exploded_df['tags'].isin(whitelist_tags)]

# Count tag frequency per artist
artist_tag_counts = exploded_df.groupby(['artist', 'tags']).size().reset_index(name='count')
artist_total_songs = music_info_df.groupby('artist').size().reset_index(name='total_songs')
artist_tag_counts = artist_tag_counts.merge(artist_total_songs, on='artist')
artist_tag_counts['proportion'] = artist_tag_counts['count'] / artist_tag_counts['total_songs']

artist_tag_counts_filtered = artist_tag_counts[artist_tag_counts['proportion'] >= 0.2].copy()

# Restrict to tags actually present after filtering (prevents adding unwanted tags)
valid_tags_after_filter = artist_tag_counts_filtered['tags'].unique().tolist()

# Pivot table: artist x tag with TF counts
artist_tag_matrix = artist_tag_counts_filtered.pivot_table(
    index='artist', columns='tags', values='count', fill_value=0
)

# Keep only columns of valid tags
artist_tag_matrix = artist_tag_matrix[valid_tags_after_filter]

for tag, weight in tags_to_downweight.items():
    if tag in artist_tag_matrix.columns:
        artist_tag_matrix[tag] = artist_tag_matrix[tag] * weight

# Normalize rows
artist_tag_matrix_norm = pd.DataFrame(
    normalize(artist_tag_matrix, norm='l2'),
    index=artist_tag_matrix.index,
    columns=artist_tag_matrix.columns
)

# KMeans clustering
kmeans = KMeans(n_clusters=30, random_state=42, n_init=10)
clusters = kmeans.fit_predict(artist_tag_matrix_norm)
artist_tag_matrix_norm['cluster_id'] = clusters

feature_names = artist_tag_matrix_norm.columns[:-1]
cluster_top_tags = {}
cluster_names = {}

for cluster_num in range(30):
    cluster_subset = artist_tag_matrix_norm[artist_tag_matrix_norm['cluster_id'] == cluster_num]
    mean_vector = cluster_subset[feature_names].mean(axis=0)
    top_tags = mean_vector.sort_values(ascending=False).head(5).index.tolist()
    cluster_top_tags[cluster_num] = top_tags
    cluster_names[cluster_num] = ' / '.join(top_tags)

artist_tag_matrix_norm['cluster_top_tags'] = artist_tag_matrix_norm['cluster_id'].map(cluster_top_tags)
artist_tag_matrix_norm['cluster_top_tags'] = artist_tag_matrix_norm['cluster_top_tags'].apply(lambda x: ', '.join(x))
artist_tag_matrix_norm['cluster_name'] = artist_tag_matrix_norm['cluster_id'].map(cluster_names)
artist_tag_matrix_norm['artist'] = artist_tag_matrix_norm.index

tagging_w_cluster = artist_tag_matrix_norm[['artist', 'cluster_id', 'cluster_name', 'cluster_top_tags']].reset_index(drop=True)


In [None]:
artist_clusters_df = tagging_w_cluster[['artist', 'cluster_id', 'cluster_top_tags']]
artist_clusters_df = artist_clusters_df.set_index('artist')
merged_df__all = df_tags.join(artist_clusters_df, how='left')
merged_df__all

In [None]:
all_tags = []
for row in df_tags.index:

  if (df_tags.loc[row, 'tags'])!=[]:
    genre_list = df_tags.loc[row, 'tags']
    all_tags.extend(genre_list)


# all_tags#.unque()
unique_tags = list(set(all_tags))
unique_tags

In [None]:
num_per_genre = df_artists_tags.sum(axis=0)/df_artists_tags.shape[0]

plt.figure(figsize=(30,6))
num_per_genre.plot(kind='bar', color='lightblue')
plt.title('Genre-wise activation profiles')
plt.xlabel('Genre')
plt.ylabel('Number of Movies')
plt.xticks(rotation=90,fontsize=14)
plt.grid(True)
plt.show()

num_per_genre

## Filter data into genres

In [None]:
genres = sorted(unique_tags)

In [None]:
genre_acoustic_id = np.array(np.where(df_artists_tags[genres[0]] == 1)[0])
genre_acoustic_id_name = list(df_artists_tags.index[genre_acoustic_id])

genre_american_id = np.array(np.where(df_artists_tags[genres[1]] == 1)[0])
genre_american_id_name = list(df_artists_tags.index[genre_american_id])

genre_avant_garde_id = np.array(np.where(df_artists_tags[genres[2]] == 1)[0])
genre_avant_garde_id_name = list(df_artists_tags.index[genre_avant_garde_id])

genre_black_metal_id = np.array(np.where(df_artists_tags[genres[3]] == 1)[0])
genre_black_metal_id_name = list(df_artists_tags.index[genre_black_metal_id])

genre_blues_id = np.array(np.where(df_artists_tags[genres[4]] == 1)[0])
genre_blues_id_name = list(df_artists_tags.index[genre_blues_id])

genre_blues_rock_id = np.array(np.where(df_artists_tags[genres[5]] == 1)[0])
genre_blues_rock_id_name = list(df_artists_tags.index[genre_blues_rock_id])

genre_british_id = np.array(np.where(df_artists_tags[genres[6]] == 1)[0])
genre_british_id_name = list(df_artists_tags.index[genre_british_id])

genre_britpop_id = np.array(np.where(df_artists_tags[genres[7]] == 1)[0])
genre_britpop_id_name = list(df_artists_tags.index[genre_britpop_id])

genre_chill_id = np.array(np.where(df_artists_tags[genres[8]] == 1)[0])
genre_chill_id_name = list(df_artists_tags.index[genre_chill_id])

genre_chillout_id = np.array(np.where(df_artists_tags[genres[9]] == 1)[0])
genre_chillout_id_name = list(df_artists_tags.index[genre_chillout_id])

genre_classic_rock_id = np.array(np.where(df_artists_tags[genres[10]] == 1)[0])
genre_classic_rock_id_name = list(df_artists_tags.index[genre_classic_rock_id])

genre_classical_id = np.array(np.where(df_artists_tags[genres[11]] == 1)[0])
genre_classical_id_name = list(df_artists_tags.index[genre_classical_id])

genre_country_id = np.array(np.where(df_artists_tags[genres[12]] == 1)[0])
genre_country_id_name = list(df_artists_tags.index[genre_country_id])

genre_cover_id = np.array(np.where(df_artists_tags[genres[13]] == 1)[0])
genre_cover_id_name = list(df_artists_tags.index[genre_cover_id])

genre_dance_id = np.array(np.where(df_artists_tags[genres[14]] == 1)[0])
genre_dance_id_name = list(df_artists_tags.index[genre_dance_id])

genre_dark_ambient_id = np.array(np.where(df_artists_tags[genres[15]] == 1)[0])
genre_dark_ambient_id_name = list(df_artists_tags.index[genre_dark_ambient_id])

genre_death_metal_id = np.array(np.where(df_artists_tags[genres[16]] == 1)[0])
genre_death_metal_id_name = list(df_artists_tags.index[genre_death_metal_id])

genre_doom_metal_id = np.array(np.where(df_artists_tags[genres[17]] == 1)[0])
genre_doom_metal_id_name = list(df_artists_tags.index[genre_doom_metal_id])

genre_downtempo_id = np.array(np.where(df_artists_tags[genres[18]] == 1)[0])
genre_downtempo_id_name = list(df_artists_tags.index[genre_downtempo_id])

genre_drum_and_bass_id = np.array(np.where(df_artists_tags[genres[19]] == 1)[0])
genre_drum_and_bass_id_name = list(df_artists_tags.index[genre_drum_and_bass_id])

genre_electronic_id = np.array(np.where(df_artists_tags[genres[20]] == 1)[0])
genre_electronic_id_name = list(df_artists_tags.index[genre_electronic_id])

genre_emo_id = np.array(np.where(df_artists_tags[genres[21]] == 1)[0])
genre_emo_id_name = list(df_artists_tags.index[genre_emo_id])

genre_female_vocalists_id = np.array(np.where(df_artists_tags[genres[22]] == 1)[0])
genre_female_vocalists_id_name = list(df_artists_tags.index[genre_female_vocalists_id])

genre_folk_id = np.array(np.where(df_artists_tags[genres[23]] == 1)[0])
genre_folk_id_name = list(df_artists_tags.index[genre_folk_id])

genre_french_id = np.array(np.where(df_artists_tags[genres[24]] == 1)[0])
genre_french_id_name = list(df_artists_tags.index[genre_french_id])

genre_funk_id = np.array(np.where(df_artists_tags[genres[25]] == 1)[0])
genre_funk_id_name = list(df_artists_tags.index[genre_funk_id])

genre_german_id = np.array(np.where(df_artists_tags[genres[26]] == 1)[0])
genre_german_id_name = list(df_artists_tags.index[genre_german_id])

genre_gothic_id = np.array(np.where(df_artists_tags[genres[27]] == 1)[0])
genre_gothic_id_name = list(df_artists_tags.index[genre_gothic_id])

genre_gothic_metal_id = np.array(np.where(df_artists_tags[genres[28]] == 1)[0])
genre_gothic_metal_id_name = list(df_artists_tags.index[genre_gothic_metal_id])

genre_grindcore_id = np.array(np.where(df_artists_tags[genres[29]] == 1)[0])
genre_grindcore_id_name = list(df_artists_tags.index[genre_grindcore_id])


genre_guitar_id = np.array(np.where(df_artists_tags[genres[30]] == 1)[0])
genre_guitar_id_name = list(df_artists_tags.index[genre_guitar_id])

genre_hard_rock_id = np.array(np.where(df_artists_tags[genres[31]] == 1)[0])
genre_hard_rock_id_name = list(df_artists_tags.index[genre_hard_rock_id])

genre_hardcore_id = np.array(np.where(df_artists_tags[genres[32]] == 1)[0])
genre_hardcore_id_name = list(df_artists_tags.index[genre_hardcore_id])


genre_heavy_metal_id = np.array(np.where(df_artists_tags[genres[33]] == 1)[0])
genre_heavy_metal_id_name = list(df_artists_tags.index[genre_heavy_metal_id])

genre_hip_hop_id = np.array(np.where(df_artists_tags[genres[34]] == 1)[0])
genre_hip_hop_id_name = list(df_artists_tags.index[genre_hip_hop_id])

genre_house_id = np.array(np.where(df_artists_tags[genres[35]] == 1)[0])
genre_house_id_name = list(df_artists_tags.index[genre_house_id])

genre_idm_id = np.array(np.where(df_artists_tags[genres[36]] == 1)[0])
genre_idm_id_name = list(df_artists_tags.index[genre_idm_id])


genre_indie_rock_id = np.array(np.where(df_artists_tags[genres[37]] == 1)[0])
genre_indie_rock_id_name = list(df_artists_tags.index[genre_indie_rock_id])

genre_industrial_id = np.array(np.where(df_artists_tags[genres[38]] == 1)[0])
genre_industrial_id_name = list(df_artists_tags.index[genre_industrial_id])

genre_instrumental_id = np.array(np.where(df_artists_tags[genres[39]] == 1)[0])
genre_instrumental_id_name = list(df_artists_tags.index[genre_instrumental_id])

genre_j_pop_id = np.array(np.where(df_artists_tags[genres[40]] == 1)[0])
genre_j_pop_id_name = list(df_artists_tags.index[genre_j_pop_id])

genre_japanese_id = np.array(np.where(df_artists_tags[genres[41]] == 1)[0])
genre_japanese_id_name = list(df_artists_tags.index[genre_japanese_id])

genre_jazz_id = np.array(np.where(df_artists_tags[genres[42]] == 1)[0])
genre_jazz_id_name = list(df_artists_tags.index[genre_jazz_id])

genre_lounge_id = np.array(np.where(df_artists_tags[genres[43]] == 1)[0])
genre_lounge_id_name = list(df_artists_tags.index[genre_lounge_id])

genre_love_id = np.array(np.where(df_artists_tags[genres[44]] == 1)[0])
genre_love_id_name = list(df_artists_tags.index[genre_love_id])

genre_male_vocalists_id = np.array(np.where(df_artists_tags[genres[45]] == 1)[0])
genre_male_vocalists_id_name = list(df_artists_tags.index[genre_male_vocalists_id])

genre_mellow_id = np.array(np.where(df_artists_tags[genres[46]] == 1)[0])
genre_mellow_id_name = list(df_artists_tags.index[genre_mellow_id])


genre_melodic_death_metal_id = np.array(np.where(df_artists_tags[genres[47]] == 1)[0])
genre_melodic_death_metal_id_name = list(df_artists_tags.index[genre_melodic_death_metal_id])

genre_metal_id = np.array(np.where(df_artists_tags[genres[48]] == 1)[0])
genre_metal_id_name = list(df_artists_tags.index[genre_metal_id])

genre_metalcore_id = np.array(np.where(df_artists_tags[genres[49]] == 1)[0])
genre_metalcore_id_name = list(df_artists_tags.index[genre_metalcore_id])

genre_new_age_id = np.array(np.where(df_artists_tags[genres[50]] == 1)[0])
genre_new_age_id_name = list(df_artists_tags.index[genre_new_age_id])

genre_new_wave_id = np.array(np.where(df_artists_tags[genres[51]] == 1)[0])
genre_new_wave_id_name = list(df_artists_tags.index[genre_new_wave_id])

genre_noise_id = np.array(np.where(df_artists_tags[genres[52]] == 1)[0])
genre_noise_id_name = list(df_artists_tags.index[genre_noise_id])

genre_nu_metal_id = np.array(np.where(df_artists_tags[genres[53]] == 1)[0])
genre_nu_metal_id_name = list(df_artists_tags.index[genre_nu_metal_id])


genre_piano_id = np.array(np.where(df_artists_tags[genres[54]] == 1)[0])
genre_piano_id_name = list(df_artists_tags.index[genre_piano_id])

genre_polish_id = np.array(np.where(df_artists_tags[genres[55]] == 1)[0])
genre_polish_id_name = list(df_artists_tags.index[genre_polish_id])

genre_pop_id = np.array(np.where(df_artists_tags[genres[56]] == 1)[0])
genre_pop_id_name = list(df_artists_tags.index[genre_pop_id])


genre_post_hardcore_id = np.array(np.where(df_artists_tags[genres[57]] == 1)[0])
genre_post_hardcore_id_name = list(df_artists_tags.index[genre_post_hardcore_id])

genre_post_punk_id = np.array(np.where(df_artists_tags[genres[58]] == 1)[0])
genre_post_punk_id_name = list(df_artists_tags.index[genre_post_punk_id])

genre_post_rock_id = np.array(np.where(df_artists_tags[genres[59]] == 1)[0])
genre_post_rock_id_name = list(df_artists_tags.index[genre_post_rock_id])

genre_power_metal_id = np.array(np.where(df_artists_tags[genres[60]] == 1)[0])
genre_power_metal_id_name = list(df_artists_tags.index[genre_power_metal_id])

genre_progressive_metal_id = np.array(np.where(df_artists_tags[genres[61]] == 1)[0])
genre_progressive_metal_id_name = list(df_artists_tags.index[genre_progressive_metal_id])

genre_progressive_rock_id = np.array(np.where(df_artists_tags[genres[62]] == 1)[0])
genre_progressive_rock_id_name = list(df_artists_tags.index[genre_progressive_rock_id])

genre_psychedelic_id = np.array(np.where(df_artists_tags[genres[63]] == 1)[0])
genre_psychedelic_id_name = list(df_artists_tags.index[genre_psychedelic_id])

genre_psychedelic_rock_id = np.array(np.where(df_artists_tags[genres[64]] == 1)[0])
genre_psychedelic_rock_id_name = list(df_artists_tags.index[genre_psychedelic_rock_id])

genre_punk_id = np.array(np.where(df_artists_tags[genres[65]] == 1)[0])
genre_punk_id_name = list(df_artists_tags.index[genre_punk_id])

genre_punk_rock_id = np.array(np.where(df_artists_tags[genres[66]] == 1)[0])
genre_punk_rock_id_name = list(df_artists_tags.index[genre_punk_rock_id])

genre_rap_id = np.array(np.where(df_artists_tags[genres[67]] == 1)[0])
genre_rap_id_name = list(df_artists_tags.index[genre_rap_id])

genre_reggae_id = np.array(np.where(df_artists_tags[genres[68]] == 1)[0])
genre_reggae_id_name = list(df_artists_tags.index[genre_reggae_id])

genre_rock_id = np.array(np.where(df_artists_tags[genres[69]] == 1)[0])
genre_rock_id_name = list(df_artists_tags.index[genre_rock_id])

genre_russian_id = np.array(np.where(df_artists_tags[genres[70]] == 1)[0])
genre_russian_id_name = list(df_artists_tags.index[genre_russian_id])

genre_screamo_id = np.array(np.where(df_artists_tags[genres[71]] == 1)[0])
genre_screamo_id_name = list(df_artists_tags.index[genre_screamo_id])

genre_ska_id = np.array(np.where(df_artists_tags[genres[72]] == 1)[0])
genre_ska_id_name = list(df_artists_tags.index[genre_ska_id])

genre_soul_id = np.array(np.where(df_artists_tags[genres[73]] == 1)[0])
genre_soul_id_name = list(df_artists_tags.index[genre_soul_id])

genre_soundtrack_id = np.array(np.where(df_artists_tags[genres[74]] == 1)[0])
genre_soundtrack_id_name = list(df_artists_tags.index[genre_soundtrack_id])

genre_swedish_id = np.array(np.where(df_artists_tags[genres[75]] == 1)[0])
genre_swedish_id_name = list(df_artists_tags.index[genre_swedish_id])

genre_symphonic_metal_id = np.array(np.where(df_artists_tags[genres[76]] == 1)[0])
genre_symphonic_metal_id_name = list(df_artists_tags.index[genre_symphonic_metal_id])

genre_techno_id = np.array(np.where(df_artists_tags[genres[77]] == 1)[0])
genre_techno_id_name = list(df_artists_tags.index[genre_techno_id])

genre_thrash_metal_id = np.array(np.where(df_artists_tags[genres[78]] == 1)[0])
genre_thrash_metal_id_name = list(df_artists_tags.index[genre_thrash_metal_id])

genre_trance_id = np.array(np.where(df_artists_tags[genres[79]] == 1)[0])
genre_trance_id_name = list(df_artists_tags.index[genre_trance_id])

genre_trip_hop_id = np.array(np.where(df_artists_tags[genres[80]] == 1)[0])
genre_trip_hop_id_name = list(df_artists_tags.index[genre_trip_hop_id])

In [None]:
genre_id = {}

genre_id['genre_rock_id'] = genre_rock_id
genre_id['genre_british_id'] = genre_british_id
genre_id['genre_progressive_metal_id'] = genre_progressive_metal_id
genre_id['genre_trip_hop_id'] = genre_trip_hop_id
genre_id['genre_female_vocalists_id'] = genre_female_vocalists_id
genre_id['genre_japanese_id'] = genre_japanese_id
genre_id['genre_indie_rock_id'] = genre_indie_rock_id
genre_id['genre_downtempo_id'] = genre_downtempo_id
genre_id['genre_emo_id'] = genre_emo_id
genre_id['genre_power_metal_id'] = genre_power_metal_id
genre_id['genre_punk_id'] = genre_punk_id
genre_id['genre_house_id'] = genre_house_id
genre_id['genre_heavy_metal_id'] = genre_heavy_metal_id
genre_id['genre_metalcore_id'] = genre_metalcore_id
genre_id['genre_thrash_metal_id'] = genre_thrash_metal_id
genre_id['genre_russian_id'] = genre_russian_id
genre_id['genre_psychedelic_id'] = genre_psychedelic_id
genre_id['genre_screamo_id'] = genre_screamo_id
genre_id['genre_death_metal_id'] = genre_death_metal_id
genre_id['genre_blues_id'] = genre_blues_id
genre_id['genre_german_id'] = genre_german_id
genre_id['genre_gothic_id'] = genre_gothic_id
genre_id['genre_folk_id'] = genre_folk_id
genre_id['genre_melodic_death_metal_id'] = genre_melodic_death_metal_id
genre_id['genre_instrumental_id'] = genre_instrumental_id
genre_id['genre_avant_garde_id'] = genre_avant_garde_id
genre_id['genre_techno_id'] = genre_techno_id
genre_id['genre_nu_metal_id'] = genre_nu_metal_id
genre_id['genre_black_metal_id'] = genre_black_metal_id
genre_id['genre_new_age_id'] = genre_new_age_id
genre_id['genre_drum_and_bass_id'] = genre_drum_and_bass_id
genre_id['genre_idm_id'] = genre_idm_id
genre_id['genre_chillout_id'] = genre_chillout_id
genre_id['genre_punk_rock_id'] = genre_punk_rock_id
genre_id['genre_country_id'] = genre_country_id
genre_id['genre_reggae_id'] = genre_reggae_id
genre_id['genre_rap_id'] = genre_rap_id
genre_id['genre_ska_id'] = genre_ska_id
genre_id['genre_american_id'] = genre_american_id
genre_id['genre_dark_ambient_id'] = genre_dark_ambient_id
genre_id['genre_symphonic_metal_id'] = genre_symphonic_metal_id
genre_id['genre_post_punk_id'] = genre_post_punk_id
genre_id['genre_guitar_id'] = genre_guitar_id
genre_id['genre_swedish_id'] = genre_swedish_id
genre_id['genre_polish_id'] = genre_polish_id
genre_id['genre_post_hardcore_id'] = genre_post_hardcore_id
genre_id['genre_britpop_id'] = genre_britpop_id
genre_id['genre_classic_rock_id'] = genre_classic_rock_id
genre_id['genre_hardcore_id'] = genre_hardcore_id
genre_id['genre_blues_rock_id'] = genre_blues_rock_id
genre_id['genre_post_rock_id'] = genre_post_rock_id
genre_id['genre_cover_id'] = genre_cover_id
genre_id['genre_funk_id'] = genre_funk_id
genre_id['genre_trance_id'] = genre_trance_id
genre_id['genre_male_vocalists_id'] = genre_male_vocalists_id
genre_id['genre_gothic_metal_id'] = genre_gothic_metal_id
genre_id['genre_hard_rock_id'] = genre_hard_rock_id
genre_id['genre_doom_metal_id'] = genre_doom_metal_id
genre_id['genre_piano_id'] = genre_piano_id
genre_id['genre_hip_hop_id'] = genre_hip_hop_id
genre_id['genre_mellow_id'] = genre_mellow_id
genre_id['genre_jazz_id'] = genre_jazz_id
genre_id['genre_j_pop_id'] = genre_j_pop_id
genre_id['genre_chill_id'] = genre_chill_id
genre_id['genre_lounge_id'] = genre_lounge_id
genre_id['genre_metal_id'] = genre_metal_id
genre_id['genre_soul_id'] = genre_soul_id
genre_id['genre_acoustic_id'] = genre_acoustic_id
genre_id['genre_love_id'] = genre_love_id
genre_id['genre_new_wave_id'] = genre_new_wave_id
genre_id['genre_french_id'] = genre_french_id
genre_id['genre_pop_id'] = genre_pop_id
genre_id['genre_chill_id'] = genre_chill_id
genre_id['genre_electronic_id'] = genre_electronic_id
genre_id['genre_progressive_rock_id'] = genre_progressive_rock_id
genre_id['genre_soundtrack_id'] = genre_soundtrack_id
genre_id['genre_dance_id'] = genre_dance_id
genre_id['genre_psychedelic_rock_id'] = genre_psychedelic_rock_id
genre_id['genre_grindcore_id'] = genre_grindcore_id
genre_id['genre_industrial_id'] = genre_industrial_id
genre_id['genre_classical_id'] = genre_classical_id
genre_id['genre_noise_id'] = genre_noise_id


# refault item- user recommender embeddings for models initialization

In [None]:
df_item_mf = pd.read_csv(Path(export_dir,'res_csv/lastFM/items_embeddings_mf_model.csv'))
df_user_mf = pd.read_csv(Path(export_dir,'res_csv/lastFM/users_embeddings_mf_model.csv'))

# CONVERT TO TENSORS
dataset_items_init = torch.tensor(df_item_mf.values, dtype=torch.float32)
dataset_users_init = torch.tensor(df_user_mf.values, dtype=torch.float32)

In [None]:
df_item_emb = pd.read_csv(Path(export_dir,'res_csv/lastFM/items_embeddings.csv'))
df_user_emb = pd.read_csv(Path(export_dir,'res_csv/lastFM/users_embeddings.csv'))

dataset_item_emb = torch.tensor(df_item_emb.values, dtype=torch.float32)
dataset_user_emb = torch.tensor(df_user_emb.values, dtype=torch.float32)

In [None]:
# df_user_embeddings = pd.read_csv(Path(export_dir,'NCF_user_embeddings.csv'))
# user_embeddings = torch.tensor(df_user_embeddings.values, dtype=torch.float32)

# df_item_embeddings = pd.read_csv(Path(export_dir,'NCF_item_embeddings.csv'))
# item_embeddings = torch.tensor(df_item_embeddings.values, dtype=torch.float32)

# df_item_emb1 = df_item_embeddings.copy()
# df_item_emb1.index = ratings_matrix.columns

# df_movies = pd.read_csv(Path(export_dir,'csv/movies_data.csv'))  # based on the whole dataset
# df_movies=df_movies.iloc[:,0:-1]
# df_users = pd.read_csv(Path(export_dir,'csv/users_data.csv')) #  based on the whole dataset

# # CONVERT TO TENSORS
# dataset_items = torch.tensor(df_item_emb.values, dtype=torch.float32)
# dataset_users = torch.tensor(df_user_emb.values, dtype=torch.float32)
# movies_data = torch.tensor(df_movies.values, dtype=torch.float32)
# users_data = torch.tensor(df_users.values, dtype=torch.float32)

# Prepare the data for recommenders

In [None]:
user_artist_matrix_tensor=torch.tensor(user_artist_matrix.values, dtype = torch.float64)

pos_ex_ = {
    row: (user_artist_matrix_tensor[row] == 1).nonzero(as_tuple=True)[0].tolist()
    for row in range(user_artist_matrix_tensor.shape[0])
}
pos_ex_

neg_ex_ = {
    row: (user_artist_matrix_tensor[row] == 0).nonzero(as_tuple=True)[0].tolist()
    for row in range(user_artist_matrix_tensor.shape[0])
}

pos_ex_num_ = {(row): len(pos_ex_[row]) for row in range(user_artist_matrix_tensor.shape[0])}


## for popularity

In [None]:
popularity_= user_artist_matrix_tensor.sum(axis=0)
dist_pop_ = [(occur)/(np.array(popularity_)).sum() for occur in popularity_] # same dist for all users


In [None]:
# dist_pop_neg_per_user_ = {}
# for row in range(user_artist_matrix_tensor.shape[0]):
#   curr_dist_neg = []
#   for i in range(user_artist_matrix_tensor.shape[1]):
#     if i in neg_ex_[row]:
#       curr_dist_neg.append(dist_pop_[i])
#   dist_pop_neg_per_user_[row]=curr_dist_neg

# with open(Path(OUTPUTS_DIR,'dors/170725/dist_pop_neg_per_user_170725.pkl'), 'wb') as file:
#     pickle.dump(dist_pop_neg_per_user_, file)
# # with open(Path(export_dir,'dataset/lastFM/dist_pop_neg_per_user_test.pkl'), 'wb') as file:
# #     pickle.dump(dist_pop_neg_per_user_, file)

In [None]:
# dist_pop_pos_per_user_ = {}
# for row in range(user_artist_matrix_tensor.shape[0]):
#   curr_dist_pos = []
#   for i in range(user_artist_matrix_tensor.shape[1]):
#     if user_artist_matrix.columns[i] in pos_ex_[row]:
#       curr_dist_pos.append(dist_pop_[i])
#   dist_pop_pos_per_user_[row]=curr_dist_pos

# with open(Path(export_dir,'dataset/lastFM/dist_pop_pos_per_user_test.pkl'), 'wb') as file:
#     pickle.dump(dist_pop_pos_per_user_, file)


In [None]:
with open(Path(export_dir,'dataset/lastFM/dist_pop_neg_per_user.pkl'), 'rb') as file:
    dist_pop_neg_per_user_ = pickle.load(file)

In [None]:
with open(Path(export_dir,'datasets/lastFM/dist_pop_pos_per_user.pkl'), 'rb') as file:
    dist_pop_pos_per_user_ = pickle.load(file)

## infrastructure for sampling of **negative** examples

In [None]:
alpha1 = 2
# # sample uniformly
# sum_per_user_neg = {key: (np.array(value)).sum() for key, value in dist_pop_neg_per_user.items()}
# norm_prob_neg = {(row): [(dist_pop_neg_per_user[row][i])/sum_per_user_neg[row] for i in range(len(dist_pop_neg_per_user[row]))] for row in ratings_matrix.index}

# sample w.r.t popularity
sum_per_user_neg_exp_ = {key: np.sum(np.exp(alpha1*np.array(value))) for key, value in dist_pop_neg_per_user_.items()}
norm_prob_neg_exp_ = {(row): [(np.exp(alpha1*np.array(dist_pop_neg_per_user_[row][i])))
                            /sum_per_user_neg_exp_[row] for i in
                             range(len(dist_pop_neg_per_user_[row]))] for
                            row in range(user_artist_matrix_tensor.shape[0])}

## infrastructure for sampling of **positive** examples wrt popularity

In [None]:
sum_per_user_pos_exp_ = {key: np.sum(np.exp(alpha1*np.array(value))) for key, value in
                           dist_pop_pos_per_user_.items()}
norm_prob_pos_exp_ = {(row): [(np.exp(alpha1*np.array(dist_pop_pos_per_user_[row][i])))
                            /sum_per_user_pos_exp_[row] for i in
                             range(len(dist_pop_pos_per_user_[row]))] for
                            row in range(user_artist_matrix_tensor.shape[0])}

## Temporal

In [None]:
# # 5 positive test samples uniformly:
# pos_idx_ex_hidden = {(row): np.sort(np.random.choice(pos_ex[row],replace=False, size=5)) for row in ratings_matrix.index}

# 5 positive test samples using popularity:
pos_idx_ex_hidden_ = {(row): np.sort(np.random.choice(pos_ex_[row], size=5,replace=True, p=norm_prob_pos_exp_[row])) for row in range(user_artist_matrix_tensor.shape[0])}

pos_idx_ex_use_ = {(row): torch.tensor([item for item in pos_ex_[row] if item not in pos_idx_ex_hidden_[row]]) for row in range(user_artist_matrix_tensor.shape[0])}

# if there are no neg examples in test set
neg_ex_hidden_ = []

## model dependant

creates a test set consists of positive and negative examples.

In [None]:
# sample negative examples to test set:
test_pop_neg_ = {(row): list(np.random.choice(neg_ex_[row],
                       size=5,replace=False, p=norm_prob_neg_exp_[row]))
                       for row in range(user_artist_matrix_tensor.shape[0])}
# test_unif_neg_ = {(row): random.sample(neg_ex_[row],5) for row in
#                         user_artist_matrix.index}

# take test set examples out of all neg examles bank- neg_ex
# wrt popularity
neg_ex_use_pop_ = {(row): list(filter(lambda x: x not in test_pop_neg_[row],
                                neg_ex_[row])) for row in range(user_artist_matrix_tensor.shape[0])}
# # uniformly
# neg_ex_use_unif_ = {(row): list(filter(lambda x: x not in test_unif_neg_[row],
#                                 neg_ex_[row])) for row in user_artist_matrix.index}

neg_ex_hidden_ = test_pop_neg_
neg_ex_use_ = neg_ex_use_pop_

# new train set for popularity: no 10 examples
train_set_pop_ = {(row): (torch.cat([torch.tensor(neg_ex_use_[row]), pos_idx_ex_use_[row]], dim=0)) for row in range(user_artist_matrix_tensor.shape[0])}


In [None]:
dist_pop_ = torch.tensor(dist_pop_)

In [None]:
# dist_pop_neg_use_per_user_ = {}
# for row in range(user_artist_matrix_tensor.shape[0]):
#   dist_pop_neg_use_per_user_[row] = torch.tensor([dist_pop_[i] for i in range(user_artist_matrix.shape[1]) if i in neg_ex_use_[row]])

# with open(Path(export_dir,'dataset/lastFM/dist_pop_neg_use_per_user_170725.pkl'), 'wb') as file:
#     pickle.dump(dist_pop_neg_use_per_user_, file)

In [None]:
with open(Path(export_dir,'dataset/lastFM/dist_pop_neg_use_per_user.pkl'), 'rb') as file:
    dist_pop_neg_use_per_user_ = pickle.load(file)

In [None]:
sum_per_user_neg_use_exp_ = {key: np.sum(np.exp(alpha1*np.array(value))) for key, value
                        in dist_pop_neg_use_per_user_.items()}
norm_prob_neg_use_exp_ = {(row): [(np.exp(alpha1*np.array(dist_pop_neg_use_per_user_[row][i])))
                            /sum_per_user_neg_use_exp_[row] for i in
                             range(len(dist_pop_neg_use_per_user_[row]))] for row in
                             user_artist_matrix.index}

# "1 hot" matrix

In [None]:
# for neg examples in test set:
# neg_ex_unif={}
# for row in user_artist_matrix.index:
#   if (pos_ex_num[row]-5) < user_artist_matrix.shape[1]/2:
#     if (pos_ex_num[row]-5)>0:
#       neg_ex_unif[row]=random.sample(neg_ex_use_unif[row],(pos_ex_num[row]-5))
#     else:
#       neg_ex_unif[row]=random.sample(neg_ex_use_unif[row],1)

# for neg examples in test set:
# neg_ex_unif = {(row): random.sample(neg_ex_use_unif[row],(pos_ex_num[row]-5)) for row in user_artist_matrix.index if (pos_ex_num[row]-5) < user_artist_matrix.shape[1]/2 else random.sample(neg_ex_use_unif[row],1)}
neg_ex_popularity_ = {(row): torch.tensor(np.random.choice(neg_ex_use_pop_[row],
                      size=len(pos_ex_[row]),replace=False, p=norm_prob_neg_use_exp_[row]))
                      for row in user_artist_matrix.index}

# change wrt unif/pop:
neg_idx_ex_use_ = neg_ex_popularity_


ts_1hot = torch.full((user_artist_matrix_tensor.shape[0], user_artist_matrix_tensor.shape[1]), -1, dtype=torch.int8)
for row in range(len(neg_idx_ex_use_)):
  ts_1hot[row,torch.tensor(neg_idx_ex_use_[row], dtype=torch.long)]=0
  ts_1hot[row,torch.tensor(pos_idx_ex_use_[row], dtype=torch.long)]=1

In [None]:
USERS = len(user_artist_matrix)
ITEMS = len(user_artist_matrix.columns)
print(USERS,ITEMS)