In [1]:
from __future__ import absolute_import, division, print_function

import os
import sys
import argparse
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from utils import *
from data_utils import AmazonDataset, AmazonDataLoader
from transe_model import KnowledgeEmbedding


logger = None




In [2]:
def train(args):
    dataset = load_dataset(args.dataset)
    print(dataset)
    dataloader = AmazonDataLoader(dataset, args.batch_size)
    words_to_train = args.epochs * dataset.review.word_count + 1

    model = KnowledgeEmbedding(dataset, args).to(args.device)
    logger.info('Parameters:' + str([i[0] for i in model.named_parameters()]))
    
    optimizer = optim.SGD(model.parameters(), lr=args.lr)#随机梯度下降
    steps = 0
    smooth_loss = 0.0

    for epoch in range(1, args.epochs + 1):
        dataloader.reset()
        while dataloader.has_next():
            # Set learning rate.
            lr = args.lr * max(1e-4, 1.0 - dataloader.finished_word_num / float(words_to_train))
            for pg in optimizer.param_groups:
                pg['lr'] = lr

            # Get training batch.
            batch_idxs = dataloader.get_batch()
            batch_idxs = torch.LongTensor(batch_idxs)
            print(batch_idxs)
            
            if len(batch_idxs)<64:
                continue
            #batch_idxs = torch.from_numpy(batch_idxs).to(args.device)

            # Train model.
            optimizer.zero_grad()
            
            train_loss = model(batch_idxs)
            train_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            smooth_loss += train_loss.item() / args.steps_per_checkpoint

            steps += 1
            
            if steps % args.steps_per_checkpoint == 0:
                logger.info('Epoch: {:02d} | '.format(epoch) +
                            'Words: {:d}/{:d} | '.format(dataloader.finished_word_num, words_to_train) +
                            'Lr: {:.5f} | '.format(lr) +
                            'Smooth loss: {:.5f}'.format(smooth_loss))
                smooth_loss = 0.0

        torch.save(model.state_dict(), '{}/transe_model_sd_epoch_{}.ckpt'.format(args.log_dir, epoch))




In [3]:
def extract_embeddings(args):
    """Note that last entity embedding is of size [vocab_size+1, d]."""
    model_file = '{}/transe_model_sd_epoch_{}.ckpt'.format(args.log_dir, args.epochs)
    print('Load embeddings', model_file)
    state_dict = torch.load(model_file, map_location=lambda storage, loc: storage)
    embeds = {
        USER: state_dict['user.weight'].cpu().data.numpy()[:-1],  # Must remove last dummy 'user' with 0 embed.
        PRODUCT: state_dict['product.weight'].cpu().data.numpy()[:-1],
        WORD: state_dict['word.weight'].cpu().data.numpy()[:-1],
        BRAND: state_dict['brand.weight'].cpu().data.numpy()[:-1],
        CATEGORY: state_dict['category.weight'].cpu().data.numpy()[:-1],
        RPRODUCT: state_dict['related_product.weight'].cpu().data.numpy()[:-1],

        PURCHASE: (
            state_dict['purchase'].cpu().data.numpy()[0],
            state_dict['purchase_bias.weight'].cpu().data.numpy()
        ),
        MENTION: (
            state_dict['mentions'].cpu().data.numpy()[0],
            state_dict['mentions_bias.weight'].cpu().data.numpy()
        ),
        DESCRIBED_AS: (
            state_dict['describe_as'].cpu().data.numpy()[0],
            state_dict['describe_as_bias.weight'].cpu().data.numpy()
        ),
        PRODUCED_BY: (
            state_dict['produced_by'].cpu().data.numpy()[0],
            state_dict['produced_by_bias.weight'].cpu().data.numpy()
        ),
        BELONG_TO: (
            state_dict['belongs_to'].cpu().data.numpy()[0],
            state_dict['belongs_to_bias.weight'].cpu().data.numpy()
        ),
        ALSO_BOUGHT: (
            state_dict['also_bought'].cpu().data.numpy()[0],
            state_dict['also_bought_bias.weight'].cpu().data.numpy()
        ),
        ALSO_VIEWED: (
            state_dict['also_viewed'].cpu().data.numpy()[0],
            state_dict['also_viewed_bias.weight'].cpu().data.numpy()
        ),
        BOUGHT_TOGETHER: (
            state_dict['bought_together'].cpu().data.numpy()[0],
            state_dict['bought_together_bias.weight'].cpu().data.numpy()
        ),
    }
    print(embeds[USER])
    save_embed(args.dataset, embeds)
    return embeds




In [4]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default=BEAUTY, help='One of {beauty, cd, cell, clothing}.')
parser.add_argument('--name', type=str, default='train_transe_model', help='model name.')
parser.add_argument('--seed', type=int, default=123, help='random seed.')
parser.add_argument('--gpu', type=str, default='1', help='gpu device.')
parser.add_argument('--epochs', type=int, default=60, help='number of epochs to train.')
parser.add_argument('--batch_size', type=int, default=64, help='batch size.')
parser.add_argument('--lr', type=float, default=0.5, help='learning rate.')
parser.add_argument('--weight_decay', type=float, default=0, help='weight decay for adam.')
parser.add_argument('--l2_lambda', type=float, default=0, help='l2 lambda')
parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Clipping gradient.')
parser.add_argument('--embed_size', type=int, default=100, help='knowledge embedding size.')
parser.add_argument('--num_neg_samples', type=int, default=5, help='number of negative samples.')
parser.add_argument('--steps_per_checkpoint', type=int, default=200, help='Number of steps for checkpoint.')

_StoreAction(option_strings=['--steps_per_checkpoint'], dest='steps_per_checkpoint', nargs=None, const=None, default=200, type=<class 'int'>, choices=None, help='Number of steps for checkpoint.', metavar=None)

# CLOTH

In [6]:
args = parser.parse_args(['--dataset',CLOTH])
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
args.device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'

args.log_dir = '{}/{}'.format(TMP_DIR[args.dataset], args.name)
if not os.path.isdir(args.log_dir):
    os.makedirs(args.log_dir)

global logger
logger = get_logger(args.log_dir + '/train_log_cloth.txt')
logger.info(args)

set_random_seed(args.seed)

train(args)#

[INFO]  Namespace(batch_size=64, dataset='cloth', device='cpu', embed_size=100, epochs=60, gpu='1', l2_lambda=0, log_dir='./tmp/Amazon_Clothing/train_transe_model', lr=0.5, max_grad_norm=5.0, name='train_transe_model', num_neg_samples=5, seed=123, steps_per_checkpoint=200, weight_decay=0)
dataset_file: ./tmp/Amazon_Clothing/dataset.pkl
[INFO]  Parameters:['purchase', 'mentions', 'describe_as', 'produced_by', 'belongs_to', 'also_bought', 'also_viewed', 'bought_together', 'user.weight', 'product.weight', 'word.weight', 'related_product.weight', 'brand.weight', 'category.weight', 'purchase_bias.weight', 'mentions_bias.weight', 'describe_as_bias.weight', 'produced_by_bias.weight', 'belongs_to_bias.weight', 'also_bought_bias.weight', 'also_viewed_bias.weight', 'bought_together_bias.weight']
[INFO]  Epoch: 04 | Words: 620302/11666341 | Lr: 0.47354 | Smooth loss: 22.93748
[INFO]  Epoch: 07 | Words: 1241025/11666341 | Lr: 0.44696 | Smooth loss: 21.14189
[INFO]  Epoch: 10 | Words: 1867689/11666

# BEAUTY 

In [5]:
args = parser.parse_args(['--dataset',BEAUTY])
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
args.device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'

args.log_dir = '{}/{}'.format(TMP_DIR[args.dataset], args.name)
if not os.path.isdir(args.log_dir):
    os.makedirs(args.log_dir)

global logger
logger = get_logger(args.log_dir + '/train_log_beauty.txt')
logger.info(args)

set_random_seed(args.seed)
train(args)

[INFO]  Namespace(batch_size=64, dataset='beauty', device='cpu', embed_size=100, epochs=60, gpu='1', l2_lambda=0, log_dir='./tmp/Amazon_Beauty/train_transe_model', lr=0.5, max_grad_norm=5.0, name='train_transe_model', num_neg_samples=5, seed=123, steps_per_checkpoint=200, weight_decay=0)
dataset_file: ./tmp/Amazon_Beauty/dataset.pkl
<data_utils.AmazonDataset object at 0x0000022190105748>
[INFO]  Parameters:['purchase', 'mentions', 'describe_as', 'produced_by', 'belongs_to', 'also_bought', 'also_viewed', 'bought_together', 'user.weight', 'product.weight', 'word.weight', 'related_product.weight', 'brand.weight', 'category.weight', 'purchase_bias.weight', 'mentions_bias.weight', 'describe_as_bias.weight', 'produced_by_bias.weight', 'belongs_to_bias.weight', 'also_bought_bias.weight', 'also_viewed_bias.weight', 'bought_together_bias.weight']
tensor([[ 15924,    874,      4,     -1,     21,  36303,     -1,     -1],
        [  5992,   5382,      5,   1603,      0,    209, 139121,   5886],
  

RuntimeError: index out of range at c:\n\pytorch_1559129895673\work\aten\src\th\generic/THTensorEvenMoreMath.cpp:191

In [7]:
#args = parser.parse_args(['--dataset',BEAUTY])
train(args)

dataset_file: ./tmp/Amazon_Beauty/dataset.pkl
[INFO]  Parameters:['purchase', 'mentions', 'describe_as', 'produced_by', 'belongs_to', 'also_bought', 'also_viewed', 'bought_together', 'user.weight', 'product.weight', 'word.weight', 'related_product.weight', 'brand.weight', 'category.weight', 'purchase_bias.weight', 'mentions_bias.weight', 'describe_as_bias.weight', 'produced_by_bias.weight', 'belongs_to_bias.weight', 'also_bought_bias.weight', 'also_viewed_bias.weight', 'bought_together_bias.weight']


RuntimeError: index out of range at c:\n\pytorch_1559129895673\work\aten\src\th\generic/THTensorEvenMoreMath.cpp:191

# CD

In [5]:
args = parser.parse_args(['--dataset',CD])
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
args.device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'

args.log_dir = '{}/{}'.format(TMP_DIR[args.dataset], args.name)
if not os.path.isdir(args.log_dir):
    os.makedirs(args.log_dir)

global logger
logger = get_logger(args.log_dir + '/train_log_beauty.txt')
logger.info(args)

set_random_seed(args.seed)

train(args)#CD

[INFO]  Namespace(batch_size=64, dataset='cd', device='cpu', embed_size=100, epochs=60, gpu='1', l2_lambda=0, log_dir='./tmp/Amazon_CDs/train_transe_model', lr=0.5, max_grad_norm=5.0, name='train_transe_model', num_neg_samples=5, seed=123, steps_per_checkpoint=200, weight_decay=0)
dataset_file: ./tmp/Amazon_CDs/dataset.pkl
[INFO]  Parameters:['purchase', 'mentions', 'describe_as', 'produced_by', 'belongs_to', 'also_bought', 'also_viewed', 'bought_together', 'user.weight', 'product.weight', 'word.weight', 'related_product.weight', 'brand.weight', 'category.weight', 'purchase_bias.weight', 'mentions_bias.weight', 'describe_as_bias.weight', 'produced_by_bias.weight', 'belongs_to_bias.weight', 'also_bought_bias.weight', 'also_viewed_bias.weight', 'bought_together_bias.weight']
[INFO]  Epoch: 04 | Words: 601408/11666341 | Lr: 0.47435 | Smooth loss: 22.00827
[INFO]  Epoch: 07 | Words: 1207115/11666341 | Lr: 0.44840 | Smooth loss: 20.70964
[INFO]  Epoch: 10 | Words: 1822605/11666341 | Lr: 0.4

# CELL

In [6]:
train(args) #CELL

dataset_file: ./tmp/Amazon_Cellphones/dataset.pkl
[INFO]  Parameters:['purchase', 'mentions', 'describe_as', 'produced_by', 'belongs_to', 'also_bought', 'also_viewed', 'bought_together', 'user.weight', 'product.weight', 'word.weight', 'related_product.weight', 'brand.weight', 'category.weight', 'purchase_bias.weight', 'mentions_bias.weight', 'describe_as_bias.weight', 'produced_by_bias.weight', 'belongs_to_bias.weight', 'also_bought_bias.weight', 'also_viewed_bias.weight', 'bought_together_bias.weight']
[INFO]  Epoch: 04 | Words: 610917/11666341 | Lr: 0.47395 | Smooth loss: 22.02442
[INFO]  Epoch: 07 | Words: 1228920/11666341 | Lr: 0.44747 | Smooth loss: 20.38819
[INFO]  Epoch: 10 | Words: 1845219/11666341 | Lr: 0.42105 | Smooth loss: 19.53596
[INFO]  Epoch: 13 | Words: 2454394/11666341 | Lr: 0.39492 | Smooth loss: 19.09178
[INFO]  Epoch: 16 | Words: 3057775/11666341 | Lr: 0.36906 | Smooth loss: 18.43481
[INFO]  Epoch: 19 | Words: 3670539/11666341 | Lr: 0.34284 | Smooth loss: 18.23842


# CLOTH

In [5]:
args = parser.parse_args(['--dataset',CLOTH])
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
args.device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'

args.log_dir = '{}/{}'.format(TMP_DIR[args.dataset], args.name)
if not os.path.isdir(args.log_dir):
    os.makedirs(args.log_dir)

global logger
logger = get_logger(args.log_dir + '/train_log_cloth.txt')
logger.info(args)

set_random_seed(args.seed)

train(args)#

[INFO]  Namespace(batch_size=64, dataset='cloth', device='cpu', embed_size=100, epochs=60, gpu='1', l2_lambda=0, log_dir='./tmp/Amazon_Clothing/train_transe_model', lr=0.5, max_grad_norm=5.0, name='train_transe_model', num_neg_samples=5, seed=123, steps_per_checkpoint=200, weight_decay=0)
dataset_file: ./tmp/Amazon_Clothing/dataset.pkl
[INFO]  Parameters:['purchase', 'mentions', 'describe_as', 'produced_by', 'belongs_to', 'also_bought', 'also_viewed', 'bought_together', 'user.weight', 'product.weight', 'word.weight', 'related_product.weight', 'brand.weight', 'category.weight', 'purchase_bias.weight', 'mentions_bias.weight', 'describe_as_bias.weight', 'produced_by_bias.weight', 'belongs_to_bias.weight', 'also_bought_bias.weight', 'also_viewed_bias.weight', 'bought_together_bias.weight']


ValueError: not enough values to unpack (expected 6, got 3)

In [5]:
args = parser.parse_args(['--dataset',CLOTH])
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
args.device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'

args.log_dir = '{}/{}'.format(TMP_DIR[args.dataset], args.name)
if not os.path.isdir(args.log_dir):
    os.makedirs(args.log_dir)

global logger
logger = get_logger(args.log_dir + '/train_log_cloth.txt')
logger.info(args)

set_random_seed(args.seed)

train(args)#

[INFO]  Namespace(batch_size=64, dataset='cloth', device='cpu', embed_size=100, epochs=60, gpu='1', l2_lambda=0, log_dir='./tmp/Amazon_Clothing/train_transe_model', lr=0.5, max_grad_norm=5.0, name='train_transe_model', num_neg_samples=5, seed=123, steps_per_checkpoint=200, weight_decay=0)
dataset_file: ./tmp/Amazon_Clothing/dataset.pkl
[INFO]  Parameters:['purchase', 'mentions', 'describe_as', 'produced_by', 'belongs_to', 'also_bought', 'also_viewed', 'bought_together', 'user.weight', 'product.weight', 'word.weight', 'related_product.weight', 'brand.weight', 'category.weight', 'purchase_bias.weight', 'mentions_bias.weight', 'describe_as_bias.weight', 'produced_by_bias.weight', 'belongs_to_bias.weight', 'also_bought_bias.weight', 'also_viewed_bias.weight', 'bought_together_bias.weight']
[INFO]  Epoch: 04 | Words: 611675/11666341 | Lr: 0.47389 | Smooth loss: 22.93162
[INFO]  Epoch: 07 | Words: 1221355/11666341 | Lr: 0.44780 | Smooth loss: 21.14529
[INFO]  Epoch: 10 | Words: 1839284/11666

epoch: 40
tensor([[ 26125,   1222,      5,     -1,      0, 265394,  80568,  22729],
        [  2878,   6354,      5,    511,     26,  64444,  55725, 167930],
        [  6897,   6612,      1,     -1,     25,  19190,  22912,     -1],
        [  8307,   4876,      1,     -1,    203,  24516, 164804,  91121],
        [ 16479,   3819,      5,     -1,     63,  82904,  70417,     -1],
        [ 14986,   6184,      4,    259,    443, 106767,  43299,     -1],
        [ 15874,   5401,      4,     -1,    521,  34025,  31813,     -1],
        [ 19187,   2435,      4,     -1,     44,  38153,  34589,     -1],
        [  6662,   3211,      5,     -1,     48,  39902,  39697,  14607],
        [ 11560,   1092,      5,     -1,    137,  48450, 100243,  48448],
        [  9481,  10128,      3,     -1,     25, 187505, 200542, 106163],
        [  7668,   3540,      5,     37,      0,  45923,  54867,     -1],
        [ 25465,   5411,      2,    415,     28,  20812, 124015,   4698],
        [  2058,   6901,    

epoch: 40
tensor([[ 20790,   2668,      4,     -1,     97, 257949, 118486, 118486],
        [  4515,   7680,      5,    993,      0, 208930, 146464, 146464],
        [ 23035,   9134,      2,     -1,      0,  34807,     -1,     -1],
        [ 10932,   1626,      3,     -1,    125, 199538,  70499,  70491],
        [ 22700,   6568,      5,     -1,     48, 303781, 261347,     -1],
        [ 25278,   2236,      5,     -1,    788, 202217, 206445,     -1],
        [ 15010,   5169,      5,     -1,      5,  39637,  18713,  18713],
        [ 26467,   1750,      1,    621,    109,  73433,  72782,     -1],
        [ 24715,   5262,      2,    614,     66, 204299, 108750,     -1],
        [ 19875,   6126,      2,     -1,      0, 216220, 194872, 194872],
        [ 19750,   3413,      1,     -1,    267,  38726,  47272,     -1],
        [  1725,   2172,      3,     -1,    507, 145977,     -1,     -1],
        [ 26867,   6854,      4,     -1,     45,  53582, 204769,     -1],
        [   383,   2197,    

epoch: 40
tensor([[  1724,   2878,      5,     -1,     42,  13772,  20426,     -1],
        [  9460,   8564,      5,     -1,    149, 114218, 114219,  41825],
        [ 19545,   6530,      4,     17,      0,  48165,   1398,     -1],
        [  8137,   5441,      4,     -1,     45, 108000, 107988,     -1],
        [  4899,   5495,      5,     -1,      5,  31897,  31892,     -1],
        [   346,   9494,      5,     -1,    441,  31884,  22854,     -1],
        [  2149,   2634,      5,    746,      2,  39353,   2292,   2322],
        [  7918,   9782,      5,     -1,     69,   6435,  16520,   6435],
        [ 25518,   9834,      1,     -1,     93,   1597,   3819,   3825],
        [ 15176,   3924,      3,     -1,     27,  31513,    923,   9468],
        [ 19973,   3762,      5,     -1,    106,  30140,   6057,  13241],
        [ 26347,   9343,      4,     -1,     97,     -1,     -1,     -1],
        [ 21890,   4421,      4,     -1,    149, 231245, 219055,     -1],
        [ 26068,   4228,    

epoch: 40
tensor([[  8998,    896,      5,     32,    124,   4423,  34455,   9961],
        [ 10685,   4053,      5,     -1,     19,  69241,  69241,  46795],
        [  9101,   8991,      5,     74,     38,  46502,  17022,   8821],
        [ 15011,   3819,      5,     -1,      0, 129820,   7277,     -1],
        [ 21987,   7649,      1,     -1,      2,  21798,  21767,  20984],
        [ 13014,   5983,      5,     -1,     15,  26364,     -1,     -1],
        [ 22538,   2553,      1,     -1,      1,   3701,  53721,     -1],
        [  7424,   5520,      5,     -1,     69, 120605, 153571,  31554],
        [ 18173,   2135,      5,     -1,      1, 102671,  60302,     -1],
        [ 13130,   3068,      5,     -1,     48,  16181,  16277,  16277],
        [ 24000,   4382,      1,     -1,      5,   6855,     -1,   6906],
        [ 14092,   3332,      3,     -1,     51, 145781,  15693,  15707],
        [  5723,   5434,      4,     -1,    346,  16493,     -1,     -1],
        [ 22045,   2864,    

epoch: 40
tensor([[  5869,   9117,      4,     -1,    407, 252630, 138902,  31584],
        [ 25677,   6920,      5,     -1,      2, 259876,  72885, 167753],
        [ 13897,   6129,      4,     -1,      5,  49720,     -1,     -1],
        [ 10214,   9465,      5,     -1,      2,  21782,  21768,  21755],
        [ 20883,    679,      3,     -1,      0,  28357,   3398,   3398],
        [ 24521,   3991,      4,     -1,     28,   9802,   5679,     -1],
        [ 16677,   4134,      1,     -1,     45,  79013,  26183,     -1],
        [ 10921,   3939,      1,     -1,    130,  78888,  11549,     -1],
        [  7008,   6537,      5,     -1,      0,   2204,   2204,   2204],
        [  7569,   3758,      4,      6,    325,  37252,  37186,     -1],
        [  4151,   7381,      3,     -1,     28,  20786,  51321,     -1],
        [  9369,   6668,      1,     -1,      0, 200181, 149801,  40121],
        [ 24667,   3868,      4,     -1,     25, 135549, 105374,     -1],
        [  7730,   6538,    

epoch: 40
tensor([[ 19024,   7681,      5,     -1,    711, 134117, 200170,     -1],
        [  1746,   2781,      5,     -1,    157, 121748, 152338, 103699],
        [  1445,   7501,      5,     -1,      5,  22092,  22132,     -1],
        [ 21780,    244,      2,     -1,    350,  33713, 145861, 103811],
        [ 12448,   2535,      1,     -1,     25, 171442, 142366,  75390],
        [ 15540,    563,      5,     -1,    372,  59401,     -1,     -1],
        [ 10394,   3383,      5,    531,     62, 123276,     -1,     -1],
        [ 26084,   4099,      5,     -1,     45,  27925,  34343,     -1],
        [  3443,   6495,      5,     -1,     50,  10318,  52043,  10765],
        [  4774,   9225,      5,     -1,    583, 273414, 186223, 233965],
        [ 11939,   5828,      4,     -1,      2,   5158,  45190,  22360],
        [ 20258,   9954,      5,     -1,    267,   2706,  64260,  64214],
        [  2970,   2506,      3,     -1,    125,   5354,   5406,     -1],
        [ 10066,  10258,    

epoch: 40
tensor([[  1296,   3340,      3,    256,     64,  80024, 134163,   1771],
        [ 21651,   7115,      5,     -1,    211,  31570, 112728,  17070],
        [ 16648,   9814,      4,     -1,    699,  97394,  97422,     -1],
        [ 21844,   2847,      4,     -1,     97,  96695,  71595,     -1],
        [  1528,   8910,      2,     -1,   1037, 245005,  37550,     -1],
        [  6161,  10135,      5,     -1,     69, 171871, 202217,     -1],
        [ 24328,   8025,      2,     -1,      0, 102962,  29862,  29877],
        [ 15328,   3377,      3,     -1,      0, 135585,  69251,     -1],
        [  9874,   7359,      5,     -1,     25,  17065,  40118, 157050],
        [ 12561,   2562,      5,     -1,    280,  28782,  72229,     -1],
        [  4250,   5899,      5,     37,    130,  14183,  15991,     -1],
        [ 15026,   6716,      1,     -1,      0, 135165, 187856,     -1],
        [ 12789,   6636,      4,     17,     26,  46443,    824,     -1],
        [ 12845,   7171,    

epoch: 40
tensor([[ 27483,   7628,      5,     -1,     25,  47066,  39129,     -1],
        [ 26776,   1143,      3,     -1,     25,  85164,  31990,     -1],
        [ 12105,   4444,      2,     -1,     25, 331900, 230457, 158184],
        [  1321,   9787,      1,     -1,      0, 115306,  69381,     -1],
        [  1156,     33,      3,     -1,     45,  34880,  34876,  34876],
        [  7835,   4293,      5,     -1,      0,  33102,   7803,   7803],
        [  3143,   7314,      3,     -1,      2,  90451, 291053,  16857],
        [    82,   6147,      4,     -1,      5,  95623,     -1,     -1],
        [ 19824,   6425,      4,     -1,    157,   8643,     -1,     -1],
        [ 23937,   4358,      2,     -1,      0,  47835,  18253,  18268],
        [ 16457,   2274,      3,     -1,     25,  22145,  35729,  35729],
        [ 24234,   2327,      1,     -1,     25,   3141,   3141,     -1],
        [  2540,   9180,      5,     -1,      0,  43543,  32059,  43538],
        [  1634,   7529,    

epoch: 40
tensor([[ 22240,    825,      5,     -1,      0,   1129,  13492,     -1],
        [  8551,   9302,      3,     -1,     13,  38193, 121441,  31584],
        [  1145,    585,      3,     -1,      0,  52763,     -1,  52805],
        [  5840,     82,      5,     -1,      0,  20650, 172430,  51166],
        [ 21543,   1005,      4,     -1,    106, 103400,  59032,     -1],
        [  2086,   5323,      5,     -1,     28, 146107,  91304,     -1],
        [  3092,   9877,      5,     -1,     25, 105150,  92036,     -1],
        [ 20687,   8018,      5,     -1,     48, 160218,  65997, 234523],
        [  7836,   6604,      2,     -1,      0,   4080,  58115,   4080],
        [  4293,   9986,      4,     -1,     48,  69188,  99057,     -1],
        [ 19473,   8193,      5,    360,      1,     94, 141104,     -1],
        [   871,    257,      5,     -1,    117,  91339,  91334,     -1],
        [  4360,    636,      5,     -1,      0,  88892,  30900,     -1],
        [  7183,   8116,    

epoch: 40
tensor([[ 14804,   8742,      2,     -1,      2, 251950,     -1,     -1],
        [ 24630,   4047,      3,     -1,      0, 179864, 118082,     -1],
        [ 26505,   3796,      4,     -1,     73, 245947,   4259,     -1],
        [ 10001,   1478,      5,     17,      0, 317125,   7077,     -1],
        [  7533,   4627,      5,     -1,     70,  10514,  10342,   3031],
        [ 24851,   8564,      4,     -1,      0,   6031,  59032,   5021],
        [ 17401,   4772,      3,     -1,      2,  65050,  39457,  27415],
        [ 21888,  10080,      5,     -1,     76,  31348,  10712,     -1],
        [  9546,    552,      1,     -1,      5,  44750,  24895,     -1],
        [ 26511,   1549,      5,     -1,     53, 216467,     -1,     -1],
        [ 21575,   4896,      5,     19,    184,   4151,  34888,   1759],
        [ 18543,    805,      3,     -1,     45,  20354, 135953,     -1],
        [ 22459,    634,      5,     64,     25,  30203,  31746,     -1],
        [  9650,   2518,    

epoch: 40
tensor([[ 23950,   5137,      1,     -1,      0,  12389, 126199, 126199],
        [ 13983,   4817,      4,     -1,      2,  28019,  42448,  42448],
        [ 14657,   7829,      5,     -1,      0, 133244,  51435,     -1],
        [  8501,  10370,      5,     -1,      2,  91336,  91334,     -1],
        [  6427,   9332,      2,     -1,      5, 271945, 271940,     -1],
        [ 12537,   8742,      4,     -1,     53,  11580,     -1,     -1],
        [ 26889,   9739,      5,     -1,    697,  16117, 122082,     -1],
        [ 11307,   2401,      3,     -1,     79,  51142,  49231,  17134],
        [ 15618,   4219,      5,     -1,     25,  17552, 172427,  27209],
        [ 23660,   9692,      2,     -1,    137,  80644,  39947,     -1],
        [  3312,   8515,      4,     -1,     50,   2575,  10359,     -1],
        [ 15298,   7949,      1,     -1,    302, 195770,  33964,     -1],
        [ 17130,   7412,      4,    973,      2, 134105,  33986,   6982],
        [  9332,   1844,    

epoch: 40
tensor([[ 15226,   5289,      5,     -1,     49,  60230,  46995,  22973],
        [ 12559,   4903,      5,    724,     56,  84453, 232392,     -1],
        [ 17122,   8304,      3,     -1,     44, 174787, 161489,  17355],
        [  7159,   7913,      1,     -1,      0, 139979,  29493,  13761],
        [  2441,   2455,      3,     -1,      5,   7165,  15540,     -1],
        [  6190,   6490,      3,     -1,     27,  25148,   7055,     -1],
        [  9728,   5885,      2,      6,    325,  32702,  88292,     -1],
        [ 20303,   5569,      4,     -1,     48,  26505,  14586,  14607],
        [ 15250,   6657,      4,     -1,     25,   7491,  38105,   2891],
        [  6382,   7810,      2,     -1,     25,  34604, 103417,  11474],
        [ 13366,   2288,      4,     -1,      5,   5387,   5442,     -1],
        [   835,   6436,      5,     -1,    137,  49420,  22133,  17087],
        [   482,   6145,      3,    543,      9, 122459,  16016,     -1],
        [   455,   2321,    

epoch: 40
tensor([[  9078,   4267,      3,     -1,      5,  98249,  14608,  14607],
        [ 19695,   4632,      3,     -1,     25, 304543, 162845,  40113],
        [   769,   3646,      5,     -1,    375, 199750,  81905,  24443],
        [ 12625,   2148,      4,    834,     62,  98032, 194330,     -1],
        [ 10709,   2642,      5,    469,      0,  66521,  68078,   8715],
        [  9537,   8415,      3,     -1,     69, 105815,  52790,     -1],
        [  7948,   6935,      3,     -1,      0, 151805, 151767, 151753],
        [   909,   3624,      4,   1024,      0, 303005, 152718,     -1],
        [ 15399,   6056,      3,     -1,     62, 335291,     -1,     -1],
        [ 24208,   6808,      4,     -1,    375, 246709, 157827,  10550],
        [ 20496,   2046,      5,     -1,     25, 264811,  17807, 110103],
        [ 14535,   3892,      4,     -1,     62,  48646, 127021,     -1],
        [ 10393,   6801,      4,     -1,    118,  27725,  37856,     -1],
        [ 11932,   8772,    

epoch: 40
tensor([[ 10430,   2181,      5,     -1,     39, 121511,  67659,  51321],
        [ 24848,   2912,      2,     -1,     19,  31861,  22288, 111349],
        [ 10219,   5199,      2,     -1,     51,   5384,  16907,   9967],
        [  3715,   4703,      5,     -1,     39,  20915,  36878,     -1],
        [ 27629,   2719,      1,    359,     39,  34721,  34455,     -1],
        [ 25130,    868,      5,     -1,      0,  24284,  46155,  46155],
        [ 23028,   1157,      3,     -1,    226,  50890,  47031,     -1],
        [  9895,   3587,      5,     -1,      1,  19826,  19822,  19822],
        [ 21279,   4007,      3,    149,      0, 226895,  41389,     -1],
        [ 11314,   2370,      1,     -1,    270,  38094, 184040, 142074],
        [ 27094,   4896,      4,     19,      0,  73668,  76197,   1759],
        [ 26385,   3208,      3,     -1,    249,  36951,   2001,     -1],
        [  3464,   5785,      2,     -1,      0,   6532,     -1,     -1],
        [ 12472,  10339,    

epoch: 40
tensor([[ 25592,   1143,      2,     -1,    119,  90109,   6916,     -1],
        [ 26506,    862,      2,     52,     26,  25066, 242501,     -1],
        [ 26705,   8742,      5,     -1,    118, 251947,     -1,     -1],
        [  5874,   1896,      3,    194,      5,  83363,  83313,  83291],
        [   269,   2352,      5,     -1,     25,   2624,  10765,   1536],
        [ 23402,    896,      4,     32,     28,   4410,  20802,   9961],
        [ 16513,   2236,      4,     -1,    119, 113946, 223774,     -1],
        [  4618,   2634,      5,    746,    119, 139482,   2322,  39554],
        [  3645,   8225,      5,     -1,      2, 136663,  85164,   2877],
        [ 21660,   1605,      2,     -1,    417,  39636,  28236,  28232],
        [  4199,   7242,      5,    113,     28,  54873,  54894,  54867],
        [ 11671,   4953,      5,     -1,     71,  99208,   3876,     -1],
        [  3812,   7341,      3,     -1,    136, 209668,  43522,     -1],
        [  8479,   5692,    

epoch: 40
tensor([[ 18075,   3032,      2,    273,    131, 107250, 107295,     -1],
        [  4614,  10085,      5,     -1,    149,  57729,  57746,     -1],
        [ 13977,   3239,      4,   1024,    178,  91244,  50635,     -1],
        [ 19773,    536,      4,     -1,      0, 207405,     -1,     -1],
        [  5852,   3999,      4,     -1,    978, 105261, 190320, 104291],
        [  8276,   5207,      4,     -1,     48, 152799,  79559,     -1],
        [ 17287,  10284,      1,     52,     31,  24905,  25025,     -1],
        [ 19025,   9853,      4,     -1,     45, 234408,  12202,  12169],
        [  1202,   7631,      3,    259,   1038,  12330,  82307,     -1],
        [ 25794,     33,      1,     -1,    149,  32167,  34876,  34876],
        [ 11651,   6219,      5,     -1,     25,  35714, 137673,  35033],
        [ 16540,   3674,      5,     -1,    129,     -1,     -1,     -1],
        [  9874,   6978,      5,     -1,    654,  31792,  70359,     -1],
        [  8104,   6967,    

epoch: 40
tensor([[ 26678,  10263,      5,     -1,     25, 279485,  47090,  17552],
        [ 18150,   9678,      1,     -1,    357, 181220, 142391, 137433],
        [    10,   6770,      5,    950,     25, 300679,  46491,     -1],
        [  8027,   8015,      1,     -1,     48, 307365, 296636,     -1],
        [  8838,   9319,      1,     -1,    123,  26780, 155395, 124445],
        [ 15582,   6673,      5,     -1,     26,  25110,  25045,     -1],
        [ 11419,   4021,      4,     -1,    297, 176589,  44289,     -1],
        [   842,   3187,      5,     -1,    112,  47815,  47791,     -1],
        [ 21712,   6628,      1,    200,      9,  31150,  31159,   7551],
        [  7137,   8238,      2,    182,      7, 126336,   8725,     -1],
        [ 17182,   7374,      5,     -1,      0, 222400,  93486,  22033],
        [ 21103,   8281,      5,     -1,    223,  62176,     -1,     -1],
        [ 19949,   8704,      3,    181,      9,  34555,   7662,  43281],
        [ 10528,  10050,    

epoch: 40
tensor([[  2643,   8056,      1,     -1,    215,   5465,     -1,     -1],
        [ 11674,   4843,      5,    419,     40,  26757, 144958,  23785],
        [ 18659,   5232,      3,    165,     62,  73994,  73979,  73932],
        [ 19696,   1613,      3,     -1,    226,  31393,  31441,     -1],
        [  5059,   9269,      3,     -1,      0,  14811,  82912,     -1],
        [ 23974,   2113,      5,     -1,    119, 158177, 148768,  60537],
        [  9581,   7891,      3,     -1,      0, 262259,  21636,     -1],
        [ 26712,   7730,      1,     -1,     73,   2345,   3941,     -1],
        [ 16200,  10255,      5,     -1,      0,  89932,  43599,  75320],
        [ 19165,   4505,      2,     -1,     62, 223913,  76392,     -1],
        [  8877,   4703,      5,     -1,     28, 136844,  32263,     -1],
        [ 24817,   3363,      4,     -1,     45,  71080,     -1,     -1],
        [  6699,    921,      5,     -1,    703, 275035, 153731,  85262],
        [ 12680,   7166,    

epoch: 40
tensor([[  4266,   8550,      3,     -1,     19,   4217,     -1,     -1],
        [ 15739,   5837,      4,     -1,      0,  67315, 164907,  67313],
        [  5783,   4125,      2,     -1,    119, 192913, 196635,  24638],
        [ 25827,   5679,      5,     -1,      0, 122674,  43940,  43940],
        [  3576,   9691,      3,     -1,    149,  38105,     -1,     -1],
        [ 25834,   3080,      5,     -1,      5, 124633,  31787,     -1],
        [ 25682,   1438,      3,    152,    229,  62617,  58023,     -1],
        [ 10861,   6621,      5,     -1,     36, 167283,     -1,     -1],
        [  9932,   2015,      5,     -1,     25,  32792,  32792,  24610],
        [  7100,   1233,      5,     -1,    638, 159178,  23176,     -1],
        [  2971,   2535,      5,     -1,    646, 141018, 137449,  75390],
        [ 15810,   8664,      1,     -1,     31,  30473,  76480,     -1],
        [  5894,    933,      1,    309,     17,    479,    464,    467],
        [  8848,   2699,    

epoch: 40
tensor([[ 10908,   4190,      3,     -1,    119,  33640, 106163, 106163],
        [ 23156,   6894,      5,      9,      0,  96268,  46457,     -1],
        [ 19411,   6538,      4,     -1,    540, 243803, 203646,  63580],
        [ 19608,   1488,      5,    344,    587,  80066,  80931,   8701],
        [ 12259,   2902,      4,    980,      0, 308637,  85122,     -1],
        [ 22365,   8056,      5,     -1,      2,   5357,     -1,     -1],
        [ 23006,   5387,      5,      9,     26, 114874,  31505,     -1],
        [ 27330,   9755,      5,     52,     28,  25209,  25182,     -1],
        [  5143,   8805,      3,     -1,    350, 153217, 102947, 102947],
        [  3056,   9465,      5,     -1,      2, 160469,  21754,  21755],
        [ 23293,  10250,      5,     -1,     25, 135580,     -1,     -1],
        [  3443,  10181,      5,     -1,      5,  84896,     -1,     -1],
        [ 22650,   2848,      3,     -1,     82,   2403,   2292,   2284],
        [ 25886,   1037,    

epoch: 40
tensor([[ 15497,   5009,      5,     -1,    135,  35744,  22094,  83738],
        [ 19655,   4876,      3,     -1,     59, 107722, 164804,  91121],
        [  8530,   2537,      5,    552,     39,  96471,  32700,     -1],
        [ 26581,   1998,      1,     -1,    117,  66981,  16791, 105453],
        [  1107,   2189,      4,     -1,     45,  52756,  99815,  43040],
        [   636,   7827,      4,     -1,     45,  52336,  68204,     -1],
        [  3728,  10252,      5,     -1,    117, 240809, 240777, 240777],
        [  9416,   9782,      3,     -1,    214,  31563,  37814,   6435],
        [ 17131,   1849,      5,     -1,     69, 279277, 136295,  51390],
        [ 16747,   1922,      5,    872,      0,  32792,  80042,  66666],
        [  1452,   8039,      3,     -1,    238,  23622,  94739,  43599],
        [ 13863,   7263,      5,     -1,     56, 145605,  21810,  40420],
        [ 25508,   4766,      5,    873,    671,  45402,    639,    513],
        [ 26446,   1461,    

epoch: 40
tensor([[  7833,   8652,      5,     -1,     53,  12257,  24675,     -1],
        [  6781,   6855,      4,     -1,    375,  66889, 199020,     -1],
        [ 15307,   2117,      5,     -1,    117, 186545,  67997,     -1],
        [ 25771,   6955,      4,     -1,     25,  69151,  17087,     -1],
        [ 10784,    223,      4,     -1,     76,  31551, 137045,  31583],
        [  2645,  10161,      2,     -1,     25, 170422, 157729,  17552],
        [ 10537,    548,      5,     -1,      2, 256738,  21636,     -1],
        [ 20413,   6285,      5,     -1,      2,  18136,     -1,     -1],
        [ 15135,   5190,      5,     -1,      5, 137439,  83261, 137439],
        [ 24260,    902,      5,     -1,     25,  14215, 220095,  26555],
        [ 16638,   1970,      3,    834,      5, 293311,  98307,     -1],
        [ 20503,   2174,      3,     -1,     25,  51410,  31787,     -1],
        [ 23163,  10095,      4,     -1,      2,  86194,  67008,  36721],
        [ 22255,   9055,    

epoch: 40
tensor([[  8540,   8249,      5,     -1,    166, 247871,     -1,     -1],
        [  3056,   3051,      5,     -1,    152, 102802,  46649,  23784],
        [ 19856,    310,      5,    625,    507,  21525,     -1,     -1],
        [   637,   6265,      2,     55,    211,  37165,  37186,  58882],
        [ 26515,   1408,      4,     38,     39,  99843,  62058,  31610],
        [ 21840,   4099,      2,     -1,     45,   2739,  39129,     -1],
        [  3588,   1560,      5,     -1,    123,   1565,     -1,     -1],
        [  7200,   3631,      5,     -1,     25, 140670,  26183,     -1],
        [ 27002,   7830,      2,     -1,      0,  90258,  21505,     -1],
        [  4251,   6154,      5,     -1,      2, 207377, 139480,     -1],
        [ 16581,   1813,      4,     -1,    551,  24629,  46774,  26600],
        [  3493,   5188,      3,     -1,     28,  19213,   4698,     -1],
        [ 19169,   5200,      4,     -1,      0,  18253,  59119,     -1],
        [ 23070,   1414,    

epoch: 40
tensor([[  1004,   7353,      5,     -1,     27,  64927,  48184,  23785],
        [  4599,   4868,      4,     -1,     81,  49901,  27209,     -1],
        [  8510,   2902,      2,    980,     39, 145854, 141559,     -1],
        [  8950,   1473,      4,     -1,     28,  32693,  75025,  45758],
        [ 16085,   7245,      1,     -1,     45,  59025,  17782,     -1],
        [ 24071,   3663,      4,     -1,     25,   1738,   7773,     -1],
        [ 19082,   5111,      3,     -1,    100,  65424, 216417, 164914],
        [  8221,   6511,      5,     -1,     25, 308718,     -1,     -1],
        [  1440,   4115,      5,    262,      5,  61140,  61135,  94055],
        [ 22683,   4546,      3,     -1,    119, 123172,  32571, 158180],
        [  3363,   6225,      1,     -1,      0,  98100,     -1,     -1],
        [ 12448,   6657,      1,     -1,    612,  17908,   2891,   2891],
        [  2902,  10030,      5,     -1,     47,   4592,  49282,     -1],
        [  4469,   1660,    

epoch: 40
tensor([[  2527,   7391,      4,     -1,    150,  55187,   6037,   4632],
        [ 13464,   6555,      3,      9,     28,  12863,   1418,     -1],
        [  8743,   8515,      2,     -1,      2,  15662,   2973,     -1],
        [ 25825,   6263,      1,      9,     26,  23784,  23771,     -1],
        [  4629,   8039,      3,     -1,     48, 138558, 155928,  43599],
        [ 11668,   6555,      3,      9,     28,   1370,   1418,     -1],
        [ 19734,   1123,      5,     -1,    119, 238807,   2334, 174938],
        [ 21880,   2847,      2,     -1,     25,  10420,  71595,     -1],
        [ 14619,   8924,      1,     -1,      2,   2001,  27043,   3165],
        [ 22564,   3961,      5,     -1,     62, 152840, 103862,     -1],
        [  7068,   3363,      3,     -1,     25, 151329,     -1,     -1],
        [ 21505,    136,      1,     -1,      2, 144022,  30878,     -1],
        [ 20013,   2382,      5,     -1,     45, 316892,  79193,     -1],
        [ 10500,   5744,    

epoch: 40
tensor([[ 12198,    980,      4,     -1,     69,  51166,     -1,  27209],
        [ 20963,  10246,      5,    349,     39,  81943,     -1,     -1],
        [  1450,  10001,      5,     -1,    166,   8402,   8446, 186116],
        [ 22353,   5346,      5,     -1,      2,   4079, 153575, 137045],
        [  6936,   4013,      2,     17,     31,   8779,   8787,     -1],
        [  7646,    682,      1,    872,     64, 195397, 104196,  66666],
        [   753,   6794,      5,     -1,    106, 146921,     -1,     -1],
        [ 19225,   9551,      2,     -1,     48, 186524,  79598,  67997],
        [ 18023,   5198,      5,   1036,      0, 169797, 203215,     -1],
        [ 24204,   8335,      2,     -1,    220,  87885,  11828,     -1],
        [ 22690,  10110,      5,     -1,     25,  59108, 197501,  33986],
        [  5152,   2217,      5,     -1,    884, 229948, 135703,  91406],
        [  3917,   1364,      5,    127,      0,  67964,  67950,  67923],
        [ 13472,   7177,    

epoch: 40
tensor([[ 18939,   3100,      2,     -1,      0, 107288,   4847,   4847],
        [ 24245,   8564,      5,     -1,     19,   7044,  55188,   5021],
        [   153,   3634,      5,     -1,     86,  37729,  61455,     -1],
        [  4208,   1927,      2,     -1,      0,  49095, 233954,     -1],
        [ 16685,   2701,      1,     -1,      0, 198295,  94926,     -1],
        [ 21481,   9013,      1,     -1,     42,  31960,   1720,   1720],
        [ 27332,   9542,      1,     -1,     25,   2442,   2504,     -1],
        [ 10054,   5174,      1,     -1,     19, 228926, 181012,     -1],
        [ 15054,   7119,      2,    162,     28, 121634,  22533,  22533],
        [ 12369,   2560,      5,     -1,    101,   6996,   7002,     -1],
        [ 25551,   1974,      5,     -1,      0, 261067, 128549,     -1],
        [  7452,    592,      5,     -1,     28,  42785,  11096,   5231],
        [ 27083,      0,      2,     -1,      0, 172589, 148899,  40113],
        [ 22814,   8418,    

epoch: 40
tensor([[ 17754,   9862,      3,     -1,     28,   5377,  43768,     -1],
        [ 23858,   6837,      3,     -1,    184,  20580,  20389,     -1],
        [ 19477,   5434,      1,     -1,      0, 135173,     -1,     -1],
        [ 10970,   8811,      5,     -1,     45, 176542, 172702,     -1],
        [  7011,   7313,      5,     -1,     62, 174065,  80464,  83139],
        [  1721,   5637,      4,     -1,    281, 134137,  19933,  19957],
        [ 17105,   5168,      4,     -1,     25,  14741,  43475,     -1],
        [ 16648,   8400,      4,     -1,     19,  93515,  93520,     -1],
        [  8574,   2719,      5,    359,      0,  18097,  27730,     -1],
        [  6668,   7057,      5,     -1,      2,  35714,   2350, 226635],
        [ 18753,   2401,      5,     -1,     25, 172418,  17157,  17134],
        [ 12698,   2554,      2,     -1,      2,  97289,     -1,     -1],
        [ 26989,   9023,      5,     -1,    179,   9551,   9539,     -1],
        [ 19073,   4900,    

epoch: 40
tensor([[  1702,   6396,      3,     -1,    375,  18433, 157827,  46795],
        [  4273,   4801,      5,      9,     39,   1362,  46612,     -1],
        [  2767,   9254,      5,     -1,    267,  18535,  30056,     -1],
        [  8424,   3758,      4,      6,      5,  37181, 122270,     -1],
        [  6087,   9762,      4,     -1,      2, 254428,     -1,     -1],
        [  4593,   3159,      4,     -1,      0,   3163,   3171,   2725],
        [ 10392,   6768,      5,     -1,     25, 102343,  74476,  30482],
        [  7191,     35,      1,     -1,    593, 160270, 103813, 103813],
        [  4559,   2318,      1,     17,    155,   8779,   8863,     -1],
        [  1152,   4530,      4,     -1,     48,  53181,     -1,     -1],
        [ 11539,  10207,      4,     -1,      2,   1543, 127109,     -1],
        [ 16523,   2983,      5,     -1,      2,  55737,   8315,     -1],
        [ 27598,   8192,      5,     -1,     91,  10383,   3620,     -1],
        [  7783,   9858,    

epoch: 40
tensor([[  4997,   1081,      5,    218,     25, 141136,  53121,     -1],
        [ 25688,   6416,      4,     -1,     48, 236325,     -1,     -1],
        [  2844,   9226,      5,     -1,    257,  31301,   9936,     -1],
        [ 19123,   5455,      1,     -1,    101,   2584,   7002,     -1],
        [  3669,   4621,      3,     -1,     25,  45844,  89009,  59250],
        [ 19234,   8535,      1,     -1,    164,  11099,  18505,     -1],
        [ 16113,   2583,      4,     -1,      0, 118464, 292936, 118453],
        [ 17765,   2397,      3,     -1,    413,  17930,   3341,  40098],
        [  1346,   4200,      1,     -1,      5,  22729,  44350,   6468],
        [ 17744,   2421,      5,     -1,      2,  27930,   3825,     -1],
        [  9775,   3635,      3,    329,      0,  72566,  91034,    636],
        [  6803,   1497,      5,     -1,      3,   9449,  13336,  13336],
        [ 26315,   8772,      5,     -1,     46,  38561,  43357,     -1],
        [ 14493,   8572,    

UnboundLocalError: local variable 'data' referenced before assignment

In [9]:
emb = extract_embeddings(args)

Load embeddings ./tmp/Amazon_Cellphones/train_transe_model/transe_model_sd_epoch_60.ckpt
[[-0.00211423  0.00353241 -0.00094879 ...  0.00394515  0.00175596
   0.00451101]
 [ 0.003749    0.00302171  0.00130764 ... -0.00449195  0.00298442
   0.0036598 ]
 [ 0.00138646 -0.0020297   0.00209535 ... -0.00067849 -0.0029328
  -0.00129671]
 ...
 [-0.00076333  0.00042423  0.00172309 ... -0.00453087  0.00054564
   0.00016628]
 [-0.00158372  0.00021306  0.00426845 ... -0.00221793  0.00039889
  -0.0011367 ]
 [-0.0010701  -0.00076093 -0.00179834 ... -0.00145278 -0.0025443
  -0.00503418]]


In [12]:
'''
USER: state_dict['user.weight'].cpu().data.numpy()[:-1],  # Must remove last dummy 'user' with 0 embed.
        PRODUCT: state_dict['product.weight'].cpu().data.numpy()[:-1],
        WORD: state_dict['word.weight'].cpu().data.numpy()[:-1],
        BRAND: state_dict['brand.weight'].cpu().data.numpy()[:-1],
        CATEGORY: state_dict['category.weight'].cpu().data.numpy()[:-1],
        RPRODUCT: state_dict['related_product.weight'].cpu().data.numpy()[:-1],
        
'''

print('user:',emb[USER].shape)
print('prod:',emb[PRODUCT].shape)
print('word:',emb[WORD].shape)
print('brand:',emb[BRAND].shape)
print('related_prod:',emb[RPRODUCT].shape)
print('self_loop:',emb[SELF_LOOP].shape)

user: (27879, 100)
prod: (10429, 100)
word: (22493, 100)
brand: (955, 100)
related_prod: (101287, 100)


KeyError: 'self_loop'

In [19]:
emb[PURCHASE][0].shape

(100,)

In [20]:
emb[PURCHASE][1].shape

(10430, 1)

In [17]:
def load_embed(dataset):
    embed_file = '{}/transe_embed.pkl'.format(TMP_DIR[dataset])
    print('Load embedding:', embed_file)
    embed = pickle.load(open(embed_file, 'rb'))
    return embed
embeding = load_embed(CELL)

In [26]:
len(embeding['user'])
embeding['user'].shape

(27879, 100)

In [13]:
user
product
word
brand
category
realated_product
purchase
mentions
described_as
produced_by
belongs_to
also_bought
also_viewed
bought_together


NoneType

In [None]:
'''
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default=BEAUTY, help='One of {beauty, cd, cell, clothing}.')
    parser.add_argument('--name', type=str, default='train_transe_model', help='model name.')
    parser.add_argument('--seed', type=int, default=123, help='random seed.')
    parser.add_argument('--gpu', type=str, default='1', help='gpu device.')
    parser.add_argument('--epochs', type=int, default=30, help='number of epochs to train.')
    parser.add_argument('--batch_size', type=int, default=64, help='batch size.')
    parser.add_argument('--lr', type=float, default=0.5, help='learning rate.')
    parser.add_argument('--weight_decay', type=float, default=0, help='weight decay for adam.')
    parser.add_argument('--l2_lambda', type=float, default=0, help='l2 lambda')
    parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Clipping gradient.')
    parser.add_argument('--embed_size', type=int, default=100, help='knowledge embedding size.')
    parser.add_argument('--num_neg_samples', type=int, default=5, help='number of negative samples.')
    parser.add_argument('--steps_per_checkpoint', type=int, default=200, help='Number of steps for checkpoint.')
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    args.device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'

    args.log_dir = '{}/{}'.format(TMP_DIR[args.dataset], args.name)
    if not os.path.isdir(args.log_dir):
        os.makedirs(args.log_dir)

    global logger
    logger = get_logger(args.log_dir + '/train_log.txt')
    logger.info(args)

    set_random_seed(args.seed)
    train(args)
    extract_embeddings(args)


if __name__ == '__main__':
    main()

'''

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default=BEAUTY, help='One of {beauty, cd, cell, clothing}.')
parser.add_argument('--name', type=str, default='train_transe_model', help='model name.')
parser.add_argument('--seed', type=int, default=123, help='random seed.')
parser.add_argument('--gpu', type=str, default='1', help='gpu device.')
parser.add_argument('--epochs', type=int, default=30, help='number of epochs to train.')
parser.add_argument('--batch_size', type=int, default=64, help='batch size.')
parser.add_argument('--lr', type=float, default=0.5, help='learning rate.')
parser.add_argument('--weight_decay', type=float, default=0, help='weight decay for adam.')
parser.add_argument('--l2_lambda', type=float, default=0, help='l2 lambda')
parser.add_argument('--max_grad_norm', type=float, default=5.0, help='Clipping gradient.')
parser.add_argument('--embed_size', type=int, default=100, help='knowledge embedding size.')
parser.add_argument('--num_neg_samples', type=int, default=5, help='number of negative samples.')
parser.add_argument('--steps_per_checkpoint', type=int, default=200, help='Number of steps for checkpoint.')
args = parser.parse_args()

#os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
#args.device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'

args.log_dir = '{}/{}'.format(TMP_DIR[args.dataset], args.name)
if not os.path.isdir(args.log_dir):
    os.makedirs(args.log_dir)

global logger
logger = get_logger(args.log_dir + '/train_log.txt')
logger.info(args)

set_random_seed(args.seed)
train(args)
extract_embeddings(args)

In [None]:
'''
bash
python train_transe_model.py --dataset <dataset_name>
'''

In [6]:
a = [[1,2,3],[]]

In [7]:
a[-1]

[]

In [8]:
len(a[-1])

0