# Sort Data

## 0. Preliminary

In [1]:
%matplotlib inline

import IPython.display as ipd
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import MultiLabelBinarizer
import os
import sys

from ast import literal_eval

""
sns.set_context("notebook", font_scale=1.5)
plt.rcParams['figure.figsize'] = (17, 5)

import platform
if 'Windows' in platform.platform():
    ROOT_PATH = "D:/PycharmProjects/HMAN"
else:
    ROOT_PATH = "/home/xkliu/PycharmProjects/HMAN"
RAW_DATA_PATH = ROOT_PATH  + "/raw_data"
DATA_PATH = ROOT_PATH + "/data"
os.chdir(ROOT_PATH)
sys.path.append("./")

from kddirkit.utils import utils


In [2]:
genre_id2id = {}
genre_str2id = {}
def init_genre():
    # reading relation ids...
#     global genre_id2id
#     print('reading genre id2id...')
#     f = open(RAW_DATA_PATH + "/fma_metadata/genre_id2id.csv","r")
#     total = (int)(f.readline().strip())
#     for i in range(total):
#         content = f.readline().strip().split(',')
#         genre_id2id[content[0]] = int(content[1])
#     f.close()
    
    global genre_str2id
    print('reading genre str2ids...')
    f = open(RAW_DATA_PATH + "/fma_metadata/genre_str2id.csv","r")
    total = (int)(f.readline().strip())
    for i in range(total):
        content = f.readline().strip().split(',')
        genre_str2id[content[0]] = int(content[1])
    f.close()
init_genre()

reading genre str2ids...


## 3. Init Files

In [3]:
def init_files(name):
    print('reading ' + name +' data...')
    f = open(RAW_DATA_PATH + '/'+ name + '.txt','r')
    total = (int)(f.readline().strip())
    print(total)
    sen_len = np.zeros((total), dtype = np.int32)
    sen_label = np.zeros((total), dtype = np.int32)
    sen_label_bottom = np.zeros((total), dtype = np.int32)
    instance_scope = []
    instance_triple = []
    for s in range(total):
        content = f.readline().strip().split('-----')
        album_id = content[1]
        artist_id = content[3]
#         print(content)
        genre_name = content[6]
        if  genre_name in genre_str2id:
            genre_id = genre_str2id[genre_name]
        else:
            genre_id = genre_str2id['NA']
        genre_id_bottom = literal_eval(content[8])
        if genre_id_bottom:
            genre_id_bottom =  genre_id_bottom[-1] 
        else:
            genre_id_bottom = 0
        sen_label[s] = genre_id
        sen_label_bottom[s] = genre_id_bottom
        tup = (album_id,artist_id,genre_id)
        if instance_triple == [] or instance_triple[len(instance_triple) - 1] != tup:
            instance_triple.append(tup)
            instance_scope.append([s,s])
        instance_scope[len(instance_triple) - 1][1] = s
#         if (s+1) % 100 == 0:
#             sys.stdout.write(str(s)+'\r')
#             sys.stdout.flush()
    return np.array(instance_triple), np.array(instance_scope), sen_label, sen_label_bottom

In [4]:
%time
# medium_instance_triple, medium_instance_scope, medium_label, medium_label_bottom = init_files("medium_data_sort_clean")

CPU times: user 2 µs, sys: 2 µs, total: 4 µs
Wall time: 5.25 µs


In [5]:
# np.save(DATA_PATH+'/' + 'medium_instance_triple', medium_instance_triple)
# np.save(DATA_PATH+'/' + 'medium_instance_scope', medium_instance_scope)
# np.save(DATA_PATH+'/' + 'medium_label', medium_label)
# np.save(DATA_PATH+'/' + 'medium_label_bottom', medium_label_bottom)

### 3.1 Init Training Data

In [6]:
# def init_train_files(name):
#     print('reading ' + name +' data...')
#     f = open(RAW_DATA_PATH + '/'+ name + '.txt','r')
#     total = (int)(f.readline().strip())
#     print(total)
#     sen_len = np.zeros((total), dtype = np.int32)
#     sen_label = np.zeros((total), dtype = np.int32)
#     instance_scope = []
#     instance_triple = []
#     for s in range(total):
#         content = f.readline().strip().split('-----')
#         album_id = content[1]
#         artist_id = content[3]
# #         print(content)'
#         genre_name = content[6]
#         if  genre_name in genre_str2id:
#             genre_id = genre_str2id[genre_name]
#         else:
# #             print(genre_name)
#             genre_id = genre_str2id['NA']
#         genre_id = literal_eval(content[7])
#         if genre_id:
#             genre_id =  genre_id[0] 
#         else:
#             genre_id = 0
#         sen_label[s] = genre_id
#         tup = (album_id,artist_id,genre_id)
#         if instance_triple == [] or instance_triple[len(instance_triple) - 1] != tup:
#             instance_triple.append(tup)
#             instance_scope.append([s,s])
#         instance_scope[len(instance_triple) - 1][1] = s
# #         if (s+1) % 100 == 0:
# #             sys.stdout.write(str(s)+'\r')
# #             sys.stdout.flush()
#     return np.array(instance_triple), np.array(instance_scope), sen_label

In [7]:
%time
medium_instance_triple_train, medium_instance_scope_train, medium_label_train, medium_label_bottom_train = init_files("medium_data_train_sort_clean")

CPU times: user 1 µs, sys: 1e+03 ns, total: 2 µs
Wall time: 3.81 µs
reading medium_data_train_sort_clean data...
13510


In [8]:
# large_instance_triple_train, large_instance_scope_train, large_label_train = init_train_files("large_data_train_sort")

In [9]:
np.save(DATA_PATH+'/' + 'medium_instance_triple_train_clean', medium_instance_triple_train)
np.save(DATA_PATH+'/' + 'medium_instance_scope_train_clean', medium_instance_scope_train)
np.save(DATA_PATH+'/' + 'medium_label_train_clean', medium_label_train)
np.save(DATA_PATH+'/' + 'medium_label_bottom_train_clean', medium_label_bottom_train)

### 3.2 Init Validation Data

In [10]:
# def init_val_files(name):
#     print('reading ' + name +' data...')
#     f = open(RAW_DATA_PATH + '/'+ name + '.txt','r')
#     total = (int)(f.readline().strip())
#     print(total)
#     sen_len = np.zeros((total), dtype = np.int32)
#     sen_label = np.zeros((total), dtype = np.int32)
#     instance_scope = []
#     instance_triple = []
#     for s in range(total):
#         content = f.readline().strip().split('-----')
#         album_id = content[1]
#         artist_id = content[3]
# #         print(content)
#         genre_name = content[6]
#         if  genre_name in genre_str2id:
#             genre_id = genre_str2id[genre_name]
#         else:
#             genre_id = genre_str2id['NA']
# #         genre_id = literal_eval(content[8])
# #         if genre_id:
# #             genre_id =  genre_id[-1] 
# #         else:
# #             genre_id = 0
#         sen_label[s] = genre_id
#         tup = (album_id,artist_id,genre_id)
#         if instance_triple == [] or instance_triple[len(instance_triple) - 1] != tup:
#             instance_triple.append(tup)
#             instance_scope.append([s,s])
#         instance_scope[len(instance_triple) - 1][1] = s
# #         if (s+1) % 100 == 0:
# #             sys.stdout.write(str(s)+'\r')
# #             sys.stdout.flush()
#     return np.array(instance_triple), np.array(instance_scope), sen_label

In [11]:
%time
medium_instance_triple_val, medium_instance_scope_val, medium_label_val, medium_label_bottom_val = init_files("medium_data_val_sort_clean")


CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 6.2 µs
reading medium_data_val_sort_clean data...
1704


In [12]:
np.save(DATA_PATH+'/' + 'medium_instance_triple_val', medium_instance_triple_val)
np.save(DATA_PATH+'/' + 'medium_instance_scope_val', medium_instance_scope_val)
np.save(DATA_PATH+'/' + 'medium_label_val', medium_label_val)
np.save(DATA_PATH+'/' + 'medium_label_bottom_val', medium_label_bottom_val)

### 3.3 Init Testing Data

In [13]:
def init_test_files(name):
    print('reading ' + name +' data...')
    f = open(RAW_DATA_PATH + '/'+ name + '.txt','r')
    total = (int)(f.readline().strip())
    print(total)
    sen_label = np.zeros((total), dtype = np.int32)
    sen_label_bottom = np.zeros((total), dtype = np.int32)
    entity_pair = []
    entity_scope = []
    for s in range(total):
        content = f.readline().strip().split('-----')
        album_id = content[1]
        artist_id = content[3]
#         print(content)
        genre_name = content[6]
        if  genre_name in genre_str2id:
            genre_id = genre_str2id[genre_name]
        else:
#             print(genre_name)
            genre_id = genre_str2id['NA']
        genre_id_bottom = literal_eval(content[8])
        if genre_id_bottom:
            genre_id_bottom =  genre_id_bottom[-1] 
        else:
            genre_id_bottom = 0
        sen_label[s] = genre_id
        sen_label_bottom[s] = genre_id_bottom
        pair = (album_id,artist_id)
        if entity_pair == [] or entity_pair[-1] != pair:
            entity_pair.append(pair)
            entity_scope.append([s,s])
        entity_scope[-1][1] = s
#         if (s+1) % 100 == 0:
#             sys.stdout.write(str(s)+'\r')
#             sys.stdout.flush()
    return np.array(entity_pair), np.array(entity_scope),  sen_label, sen_label_bottom

In [14]:
%time
medium_instance_triple_test, medium_instance_scope_test, medium_label_test, medium_label_bottom_test = init_test_files("medium_data_test_sort_clean")

CPU times: user 1 µs, sys: 2 µs, total: 3 µs
Wall time: 4.05 µs
reading medium_data_test_sort_clean data...
1772


In [15]:
np.save(DATA_PATH+'/' + 'medium_entity_pair_test_clean', medium_instance_triple_test)
np.save(DATA_PATH+'/' + 'medium_entity_scope_test_clean', medium_instance_scope_test)
np.save(DATA_PATH+'/' + 'medium_label_test_clean', medium_label_test)
np.save(DATA_PATH+'/' + 'medium_label_bottom_test_clean', medium_label_bottom_test)

## 4. Transform Data

### 4.1 initialize bag label for test


In [16]:
from collections import Counter

#### Small

In [17]:
# small_label = np.load(DATA_PATH + '/' + 'small_test_label.npy')
# small_scope = np.load(DATA_PATH + '/' + 'small_test_entity_scope.npy')
# Counter(small_label)

In [18]:
# small_label = np.load(DATA_PATH + '/' + 'small_test_label.npy')
# small_scope = np.load(DATA_PATH + '/' + 'small_test_entity_scope.npy')
# small_label[small_label == 12] = 0
# small_label[small_label == 2] = 1
# small_label[small_label == 17] = 2
# small_label[small_label == 38] = 3
# small_label[small_label == 1235] = 4
# small_label[small_label == 10] = 5
# small_label[small_label == 21] = 6
# small_label[small_label == 15] = 7
# Counter(small_label)

In [19]:
# small_all_true_label = np.zeros((small_scope.shape[0], np.max(small_label)+1))
# for pid in range(small_scope.shape[0]):
#     small_all_true_label[pid][small_label[small_scope[pid][0]:small_scope[pid][1]+1]] = 1
# small_all_true_label = np.reshape(small_all_true_label[:, 1:], -1)
# np.save(DATA_PATH + '/'  + 'small_all_true_label.npy', small_all_true_label)

In [20]:
# small_all_true_label

#### Medium

In [24]:
medium_label = np.load(DATA_PATH + '/' + 'medium_label_test_clean.npy')
medium_scope = np.load(DATA_PATH + '/' + 'medium_entity_scope_test_clean.npy')
sorted(Counter(medium_label).keys())

[2, 3, 4, 5, 8, 9, 10, 12, 13, 14, 15, 17, 20, 21, 38, 1235]

In [25]:
Counter(medium_label)

Counter({12: 610,
         14: 42,
         38: 125,
         17: 52,
         3: 8,
         15: 532,
         8: 51,
         1235: 74,
         4: 39,
         21: 120,
         9: 18,
         5: 62,
         10: 19,
         20: 12,
         2: 2,
         13: 6})

In [26]:
# medium_label = np.load(DATA_PATH + '/' + 'medium_test_label.npy')
# medium_scope = np.load(DATA_PATH + '/' + 'medium_test_entity_scope.npy')
# for i in range(len(medium_label)):
#     if medium_label[i] ==12:
#         medium_label[i] = 0
#     elif medium_label[i] == 2:
#         medium_label[i] =1
#     elif medium_label[i] == 17:
#         medium_label[i] =2
#     elif medium_label[i] == 38:
#         medium_label[i] =3 
#     elif medium_label[i] == 1235:
#         medium_label[i] =4
#     elif medium_label[i] == 10:
#         medium_label[i] =5
#     elif medium_label[i] == 21:
#         medium_label[i] =6
#     elif medium_label[i] == 15:
#         medium_label[i] =7
#     elif medium_label[i] == 14:
#         medium_label[i] =8
#     elif medium_label[i] == 3rooddd:
#         medium_label[i] =9
#     elif medium_label[i] == 8:
#         medium_label[i] =10
#     elif medium_label[i] == 4:
#         medium_label[i] =11
#     elif medium_label[i] == 9:
#         medium_label[i] =12
#     elif medium_label[i] == 5:
#         medium_label[i] =13
#     elif medium_label[i] == 20:
#         medium_label[i] =14
#     elif medium_label[i] == 13:
#         medium_label[i] =15
# Counter(medium_label)

In [27]:
# medium_all_true_label = np.zeros((medium_scope.shape[0], np.max(medium_label)+1))
# for pid in range(small_scope.shape[0]):
#     medium_all_true_label[pid][medium_all_true_label[medium_scope[pid][0]:medium_scope[pid][1]+1]] = 1
# medium_all_true_label = np.reshape(medium_all_true_label[:, 0:], -1)
# np.save(DATA_PATH + '/'  + 'medium_all_true_label.npy', medium_all_true_label)

#### Large

In [28]:
# large_label = np.load(DATA_PATH + '/' + 'large_test_label.npy')
# large_scope = np.load(DATA_PATH + '/' + 'large_test_entity_scope.npy')
# sorted(Counter(large_label).keys())

In [29]:
# Counter(large_label)

In [30]:
# large_label = np.load(DATA_PATH + '/' + 'large_test_label.npy')
# large_scope = np.load(DATA_PATH + '/' + 'large_test_entity_scope.npy')
# large_label[large_label == 12] = 0
# large_label[large_label == 2] = 1
# large_label[large_label == 17] = 2
# large_label[large_label == 38] = 3
# large_label[large_label == 1235] = 4
# large_label[large_label == 10] = 5
# large_label[large_label == 21] = 6
# large_label[large_label == 15] = 7
# Counter(large_label)

In [31]:
# large_all_true_label = np.zeros((small_scope.shape[0], np.max(small_label)+1))
# for pid in range(small_scope.shape[0]):
#     large_all_true_label[pid][large_label[large_scope[pid][0]:large_scope[pid][1]+1]] = 1
# large_all_true_label = np.reshape(large_all_true_label[:, 1:], -1)
# np.save(DATA_PATH + '/'  + 'large_all_true_label.npy', large_all_true_label)

### 4.2 Transform Label

In [32]:
def transform_label(label_pd):
    for i in range(len(label_pd)):
        if label_pd[i] ==12:
            label_pd[i] = 0
        elif label_pd[i] == 2:
            label_pd[i] =1
        elif label_pd[i] == 17:
            label_pd[i] =2
        elif label_pd[i] == 38:
            label_pd[i] =3 
        elif label_pd[i] == 1235:
            label_pd[i] =4
        elif label_pd[i] == 10:
            label_pd[i] =5
        elif label_pd[i] == 21:
            label_pd[i] =6
        elif label_pd[i] == 15:
            label_pd[i] =7
        elif label_pd[i] == 14:
            label_pd[i] =8
        elif label_pd[i] == 3:
            label_pd[i] =9
        elif label_pd[i] == 8:
            label_pd[i] =10
        elif label_pd[i] == 4:
            label_pd[i] =11
        elif label_pd[i] == 9:
            label_pd[i] =12
        elif label_pd[i] == 5:
            label_pd[i] =13
        elif label_pd[i] == 20:
            label_pd[i] =14
        elif label_pd[i] == 13:
            label_pd[i] =15
    return label_pd

In [33]:
# medium_label_transform = transform_label(np.load(DATA_PATH+'/' + 'medium_label_clean.npy'))
medium_train_label_transform = transform_label( np.load(DATA_PATH+'/' + 'medium_label_train_clean.npy'))
# medium_val_label_transform = transform_label(np.load(DATA_PATH+'/' + 'medium_val_label_clean.npy'))
medium_test_label_transform = transform_label(np.load(DATA_PATH+'/' + 'medium_label_test_clean.npy'))

In [34]:
# np.save(DATA_PATH+'/' + "medium_label_transform", medium_label_transform)
np.save(DATA_PATH+'/' + "medium_label_train_transform_clean", medium_train_label_transform)
# np.save(DATA_PATH+'/' + "medium_val_label_transform", medium_val_label_transform)
np.save(DATA_PATH+'/' + "medium_label_test_transform_clean", medium_test_label_transform)
len(medium_train_label_transform), len(medium_test_label_transform)

(13510, 1772)

In [35]:
len(medium_label_train), len(medium_label_test)

(13510, 1772)

In [36]:
medium_label_test_tramsform = np.load(DATA_PATH + '/' + 'medium_label_test_transform_clean.npy')
medium_scope = np.load(DATA_PATH + '/' + 'medium_entity_scope_test_clean.npy')
sorted(Counter(medium_label_test_tramsform).keys())

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

In [37]:
medium_all_true_label_transform = np.zeros((medium_scope.shape[0], np.max(medium_label_test_tramsform)+1))
for pid in range(medium_scope.shape[0]):
    medium_all_true_label_transform[pid][medium_label_test_tramsform[medium_scope[pid][0]:medium_scope[pid][1]+1]] = 1
medium_all_true_label_transform = np.reshape(medium_all_true_label_transform[:, :], -1)
np.save(DATA_PATH + '/'  + 'medium_flat_true_label_transform_clean.npy', medium_all_true_label_transform)

### 4.2 Get Bottom ID

In [38]:
def transform_label_via_dict(label_pd, transform_dict):
    for i in range(len(label_pd)):
        label_pd[i] = transform_dict[label_pd[i]]
    return label_pd

### Medium

In [39]:
medium_label_bottom = np.load(DATA_PATH + '/' + 'medium_label_bottom_test_clean.npy')
medium_scope = np.load(DATA_PATH + '/' + 'medium_entity_scope_test_clean.npy')
# sorted(Counter(medium_label_bottom).keys())

In [42]:
transform_csv = pd.read_csv(RAW_DATA_PATH + '/' + 'genre2id.csv')

In [43]:
orig_id = transform_csv.orig_id.to_list()
transform_id = transform_csv.transform_id.to_list()

In [44]:
transform_dict = {}
for i in zip(orig_id, transform_id):
    transform_dict[i[0]] = i[1]
# transform_dict

In [45]:
# medium_label_bottom_transform = transform_label_via_dict(np.load(DATA_PATH+'/' + 'medium_label_bottom.npy'), transform_dict)
medium_train_label_bottom_transform_clean = transform_label_via_dict( np.load(DATA_PATH+'/' + 'medium_label_bottom_train_clean.npy'), transform_dict)
# medium_val_label_bottom_transform = transform_label_via_dict(np.load(DATA_PATH+'/' + 'medium_label_bottom_val.npy'), transform_dict)
medium_test_label_bottom_transform_clean = transform_label_via_dict(np.load(DATA_PATH+'/' + 'medium_label_bottom_test_clean.npy'), transform_dict)

In [46]:
# np.save(DATA_PATH+'/' + "medium_label_bottom_transform", medium_label_bottom_transform)
np.save(DATA_PATH+'/' + "medium_label_bottom_train_transform_clean", medium_train_label_bottom_transform_clean)
# np.save(DATA_PATH+'/' + "medium_val_label_bottom_transform", medium_val_label_bottom_transform)
np.save(DATA_PATH+'/' + "medium_label_bottom_test_transform_clean", medium_test_label_bottom_transform_clean)
len(medium_train_label_bottom_transform_clean),  len(medium_test_label_bottom_transform_clean)

(13510, 1772)

In [47]:
medium_all_true_label_bottom_transform_clean = np.zeros((medium_scope.shape[0], np.max(medium_train_label_bottom_transform_clean)+1))
for pid in range(medium_scope.shape[0]):
    medium_all_true_label_bottom_transform_clean[pid][medium_test_label_bottom_transform_clean[medium_scope[pid][0]:medium_scope[pid][1]+1]] = 1
medium_all_true_label_bottom_transform_clean = np.reshape(medium_all_true_label_bottom_transform_clean[:, :], -1)
np.save(DATA_PATH + '/'  + 'medium_flat_true_label_bottom_transform_clean.npy', medium_all_true_label_bottom_transform_clean)

In [48]:
medium_all_true_label_bottom_transform_clean

array([1., 0., 0., ..., 0., 0., 0.])

In [49]:
medium_all_true_label_bottom_transform_clean.reshape(-1, 77).shape

(565, 77)

In [50]:
medium_all_true_label_bottom_transform_clean.reshape(-1, 77)

array([[1., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

In [51]:
import torch.nn as nn
import torch

In [52]:
genre_matrixs = []
genre_matrixs.append(nn.Embedding(16, 230, _weight=nn.init.xavier_uniform_(
                torch.Tensor(16, 230))))