In [1]:
import pandas as pd
import numpy as np
import os
os.environ['MKL_NUM_THREADS'] = '1'
from collections import Counter
import itertools
import json
import scipy.sparse as sparse
import pickle
import torch
import misc.util as util
from torch import nn

import importlib
from torch.utils.data import DataLoader
from misc.loader import AEDataset
from misc.util import *
from tqdm.auto import tqdm
from eval.rec_eval import *
import neuralsort.pl as pl
from models.loss import neuPrecLoss
from misc.loader import RecDataset
import models
from models.loss import *

In [2]:
import pickle
with open("data/parsed/ml-1m-new", 'rb') as f:
    (tr_users, val_users, te_users, train_data, val_tr, val_te, te_tr, te_te) = pickle.load(f)
empty = csr_matrix(train_data.shape)

In [3]:
n_users, n_items = train_data[tr_users].shape
ae_dataset = AEDataset(train_data[tr_users])
total_anneal_steps = 50000
anneal_cap = 0.2


In [4]:
!ls saved_models/ml-1m-new

dae  ttt  vae


In [5]:
best_dae = torch.load(os.path.join("saved_models", "ml-1m-new", "dae"))

In [6]:
best_dae

{'model': MultiDAE(
   (layers): ModuleList(
     (0): Linear(in_features=3075, out_features=200, bias=True)
     (1): Linear(in_features=200, out_features=3075, bias=True)
   )
   (drop): Dropout(p=0.2, inplace=False)
   (dd): Dropout(p=0.2, inplace=False)
 ),
 'epoch': 90,
 'bs': 500,
 'lr': 0.0005,
 'dim': [200],
 'lamb': 0,
 'best_map10': 0.13484332813681035,
 'annel_caps': 0,
 'dropout': 0.2}

In [7]:
vin = val_tr[val_users]
vo = val_te[val_users]

In [8]:
tin = te_tr[te_users]
to = te_te[te_users]

In [9]:
wrapper = models.ae.implicitWrapper(best_dae['model'].eval(), naive_sparse2tensor, vae=True)

In [10]:
ranking_metrics_at_k(wrapper, tin, to, K=10)

{'precision': 0.13159999999999955,
 'recall': 0.23530452263344312,
 'map': 0.13635227135298564,
 'ndcg': 0.22833346604982988}

In [11]:
ranking_metrics_at_k(wrapper, tin, to, K=5)

{'precision': 0.16819999999999974,
 'recall': 0.16054058426269446,
 'map': 0.1458738888888888,
 'ndcg': 0.2187494307757626}

In [12]:
del best_dae['model']
del wrapper

In [13]:
vin_dense = torch.FloatTensor(np.asarray(vin.todense())).cuda()
tin_dense = torch.FloatTensor(np.asarray(tin.todense())).cuda()

In [14]:
list(zip([1,2],[3,4]))

[(1, 3), (2, 4)]

In [15]:
batch_size = best_dae['bs']

In [17]:
class nfwrapper():
    def __init__(self, AEModel, tensor_wrapper, vae=False):
        self.model = AEModel
        self.u = torch.LongTensor([0]).cuda()
        self.vae = vae
        self.tensor_wrapper = tensor_wrapper

    def recommend(self, userid, user_items, N=10):
        liked = set()
        liked.update(user_items[userid].indices)
        self.u[0] = userid
        # calculate the top N items, removing the users own liked items from the results
        i = self.tensor_wrapper(user_items[userid]).cuda()
        if self.vae:
            _ = self.model(i)
            try:
                scores, _, _ = _
            except:
                scores = _
            scores = scores.cpu().detach().numpy()[0]
        else:
            scores = self.model(i, self.u, use_dropout=False).cpu().detach().numpy()[0]

        count = N + len(liked)
        if count < len(scores):
            ids = np.argpartition(scores, -count)[-count:]
            best = sorted(zip(ids, scores[ids]), key=lambda x: -x[1])
        else:
            best = sorted(enumerate(scores), key=lambda x: -x[1])
        return list(itertools.islice((rec for rec in best), N))



In [21]:
qed = []
import models.ae
importlib.reload(models.ae)
import neuralsort.neuralobjs
importlib.reload(models.ae)
use_vae=True
loader = DataLoader(ae_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
model = models.ae.MultiDAE(best_dae['dim'] + [n_items], dropout=best_dae['dropout'])
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=best_dae['lr'])
lm = -1
update_count = 0
tr_t = {
    'l': [],
    ('precision', 5):[],
    ('precision', 10):[],
    ('map', 5): [],
    ('map', 10): [],
    ('ndcg', 5): [],
    ('ndcg', 10): [] ,
}
val_t = {
    'l': [],
    ('precision', 5):[],
    ('precision', 10):[],
    ('map', 5): [],
    ('map', 10): [],
    ('ndcg', 5): [],
    ('ndcg', 10): [] ,
}
te_t = {
    'l': [],
    ('precision', 5):[],
    ('precision', 10):[],
    ('map', 5): [],
    ('map', 10): [],
    ('ndcg', 5): [],
    ('ndcg', 10): [] ,
}
for epoch in (range(1, 200 + 1)):
    model = model.train()
    model.training = True
    tr_losses = []
    # train for one epoch
    for uid, rowl in (loader):
        row = rowl.float().cuda()
        uid = uid.cuda() 
        loss = None 
        scores = model.forward(row)
        loss  = models.loss.MultinomialLoss(row, scores)    
        (loss.mean()).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
        tr_losses.append(loss.detach().unsqueeze(-1)) 
        optimizer.step()
    
    if (epoch % 5 == 0):
        model.eval()
        model.training=False
        tr_loss = torch.cat(tr_losses).mean().detach().cpu().numpy()
        vad_scores = model.forward(vin_dense)
        vad_loss = models.loss.MultinomialLoss(vin_dense, vad_scores) 
        vad_loss = vad_loss.mean().detach().cpu().numpy()
        te_scores = model.forward(tin_dense)
        te_loss = models.loss.MultinomialLoss(tin_dense, te_scores) 
        te_loss = te_loss.mean().detach().cpu().numpy()

        tr_t['l'].append(tr_loss)
        val_t['l'].append(vad_loss)
        te_t['l'].append(te_loss)
        trwrapper = nfwrapper(model, naive_sparse2tensor, vae=True)
        wrapper = models.ae.implicitWrapper(model, naive_sparse2tensor, vae=True)
        ret = (tr_loss, vad_loss)

        for store, q, a, ft in  zip([tr_t, val_t, te_t], [train_data, vin, tin], [train_data, vo, to], [0, 1,1]):
            for topk in [5, 10]:
                if ft == 1:
                    scs = ranking_metrics_at_k(wrapper, q, a, K=topk, num_threads=4)
                else:
                    trwrapper.model.train()
                    scs = ranking_metrics_at_k(trwrapper, q, a, K=topk, num_threads=4)
                    trwrapper.model.eval()
                for ty in ['map', 'precision', 'ndcg']:
                    store[(ty, topk)].append(scs[ty])
            print(store)
        qed.append(ret)
no_add = qed

{'l': [array(301.4743, dtype=float32)], ('precision', 5): [0.20361445783132678], ('precision', 10): [0.17751916757941166], ('map', 5): [0.13934008762322012], ('map', 10): [0.09983678607441544], ('ndcg', 5): [0.22492603304078299], ('ndcg', 10): [0.20296583287720776]}
{'l': [array(244.06711, dtype=float32)], ('precision', 5): [0.041000000000000016], ('precision', 10): [0.04070000000000016], ('map', 5): [0.031649444444444445], ('map', 10): [0.030036502739984903], ('ndcg', 5): [0.05401369999114445], ('ndcg', 10): [0.06271385370532362]}
{'l': [array(234.31343, dtype=float32)], ('precision', 5): [0.04320000000000008], ('precision', 10): [0.03960000000000019], ('map', 5): [0.035369444444444426], ('map', 10): [0.033404711199294584], ('ndcg', 5): [0.059533244660306246], ('ndcg', 10): [0.06625256822239406]}
{'l': [array(301.4743, dtype=float32), array(281.37524, dtype=float32)], ('precision', 5): [0.20361445783132678, 0.29731653888280735], ('precision', 10): [0.17751916757941166, 0.2743975903614

{'l': [array(301.4743, dtype=float32), array(281.37524, dtype=float32), array(273.93225, dtype=float32), array(266.13522, dtype=float32), array(258.90024, dtype=float32), array(252.95773, dtype=float32)], ('precision', 5): [0.20361445783132678, 0.29731653888280735, 0.3407995618839043, 0.3989594742606829, 0.466100766703177, 0.518565169769989], ('precision', 10): [0.17751916757941166, 0.27439759036144834, 0.29901423877327665, 0.34189485213581705, 0.40117743702081, 0.451369112814895], ('map', 5): [0.13934008762322012, 0.21506845564074575, 0.25252920774005194, 0.31608616283315044, 0.37905531215772076, 0.4320664476086152], ('map', 10): [0.09983678607441544, 0.1736308040517736, 0.19901947953636825, 0.24467555484426506, 0.29887243926920526, 0.3462618215768314], ('ndcg', 5): [0.22492603304078299, 0.31024126487659737, 0.35541574328189646, 0.42245354078928865, 0.4907493680109849, 0.5453562289629698], ('ndcg', 10): [0.20296583287720776, 0.2943334676833198, 0.3263120384715519, 0.3801445311164637, 

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32)], ('precision', 5): [0.04320000000000008, 0.08440000000000043, 0.09640000000000051, 0.10280000000000053, 0.12560000000000068, 0.13740000000000058, 0.14420000000000027, 0.1508000000000002], ('precision', 10): [0.03960000000000019, 0.07130000000000014, 0.07730000000000002, 0.08639999999999985, 0.09889999999999974, 0.10899999999999976, 0.1170999999999997, 0.12199999999999973], ('map', 5): [0.035369444444444426, 0.06416083333333346, 0.07412444444444452, 0.08688833333333333, 0.10521694444444449, 0.11344416666666665, 0.12212194444444438, 0.1279480555555554], ('map', 10): [0.033404711199294584, 0.06049645518392553, 0.06797517353237603, 0.08024988882590078, 0.09620599111866973, 0.10379722883597878, 0.11349876291257252, 0.11900552

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32)], ('precision', 5): [0.04320000000000008, 0.08440000000000043, 0.09640000000000051, 0.10280000000000053, 0.12560000000000068, 0.13740000000000058, 0.14420000000000027, 0.1508000000000002, 0.15699999999999997, 0.1596], ('precision', 10): [0.03960000000000019, 0.07130000000000014, 0.07730000000000002, 0.08639999999999985, 0.09889999999999974, 0.10899999999999976, 0.1170999999999997, 0.12199999999999973, 0.12529999999999963, 0.12719999999999973], ('map', 5): [0.035369444444444426, 0.06416083333333346, 0.07412444444444452, 0.08688833333333333, 0.10521694444444449, 0.11344416666666665, 0.12212194444444438, 0.1279480555555554, 0.1342113888888888, 0.13605249999999

{'l': [array(244.06711, dtype=float32), array(229.41023, dtype=float32), array(224.11641, dtype=float32), array(218.03897, dtype=float32), array(212.56789, dtype=float32), array(208.55965, dtype=float32), array(206.17012, dtype=float32), array(204.45474, dtype=float32), array(202.91722, dtype=float32), array(201.35861, dtype=float32), array(199.9807, dtype=float32), array(198.4782, dtype=float32)], ('precision', 5): [0.041000000000000016, 0.08180000000000034, 0.0926000000000004, 0.11520000000000058, 0.13540000000000055, 0.14760000000000023, 0.15300000000000025, 0.16140000000000002, 0.16259999999999994, 0.16599999999999984, 0.1671999999999997, 0.16839999999999988], ('precision', 10): [0.04070000000000016, 0.0718000000000001, 0.07650000000000008, 0.09069999999999981, 0.10719999999999978, 0.11759999999999954, 0.12419999999999949, 0.1292999999999995, 0.13149999999999937, 0.13409999999999933, 0.13649999999999934, 0.13559999999999928], ('map', 5): [0.031649444444444445, 0.060806944444444545,

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32), array(190.45683, dtype=float32), array(189.02472, dtype=float32), array(187.70117, dtype=float32)], ('precision', 5): [0.04320000000000008, 0.08440000000000043, 0.09640000000000051, 0.10280000000000053, 0.12560000000000068, 0.13740000000000058, 0.14420000000000027, 0.1508000000000002, 0.15699999999999997, 0.1596, 0.1634000000000001, 0.16180000000000003, 0.16239999999999996], ('precision', 10): [0.03960000000000019, 0.07130000000000014, 0.07730000000000002, 0.08639999999999985, 0.09889999999999974, 0.10899999999999976, 0.1170999999999997, 0.12199999999999973, 0.12529999999999963, 0.12719999999999973, 0.12889999999999957, 0.13279999999999956, 0.1313999999999

{'l': [array(301.4743, dtype=float32), array(281.37524, dtype=float32), array(273.93225, dtype=float32), array(266.13522, dtype=float32), array(258.90024, dtype=float32), array(252.95773, dtype=float32), array(248.93044, dtype=float32), array(245.63205, dtype=float32), array(242.7356, dtype=float32), array(239.81932, dtype=float32), array(237.23848, dtype=float32), array(234.46144, dtype=float32), array(231.66847, dtype=float32), array(229.1477, dtype=float32), array(226.53769, dtype=float32)], ('precision', 5): [0.20361445783132678, 0.29731653888280735, 0.3407995618839043, 0.3989594742606829, 0.466100766703177, 0.518565169769989, 0.5610076670317609, 0.6006571741511461, 0.6342825848849898, 0.6669222343921107, 0.7018619934282566, 0.7380065717415103, 0.7674698795180753, 0.8035049288061349, 0.831982475355977], ('precision', 10): [0.17751916757941166, 0.27439759036144834, 0.29901423877327665, 0.34189485213581705, 0.40117743702081, 0.451369112814895, 0.4843373493975895, 0.5156626506024105, 

{'l': [array(244.06711, dtype=float32), array(229.41023, dtype=float32), array(224.11641, dtype=float32), array(218.03897, dtype=float32), array(212.56789, dtype=float32), array(208.55965, dtype=float32), array(206.17012, dtype=float32), array(204.45474, dtype=float32), array(202.91722, dtype=float32), array(201.35861, dtype=float32), array(199.9807, dtype=float32), array(198.4782, dtype=float32), array(197.02736, dtype=float32), array(195.71655, dtype=float32), array(194.43697, dtype=float32), array(193.15305, dtype=float32)], ('precision', 5): [0.041000000000000016, 0.08180000000000034, 0.0926000000000004, 0.11520000000000058, 0.13540000000000055, 0.14760000000000023, 0.15300000000000025, 0.16140000000000002, 0.16259999999999994, 0.16599999999999984, 0.1671999999999997, 0.16839999999999988, 0.16899999999999976, 0.1677999999999999, 0.17159999999999984, 0.17239999999999975], ('precision', 10): [0.04070000000000016, 0.0718000000000001, 0.07650000000000008, 0.09069999999999981, 0.1071999

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32), array(190.45683, dtype=float32), array(189.02472, dtype=float32), array(187.70117, dtype=float32), array(186.41939, dtype=float32), array(185.2293, dtype=float32), array(183.98132, dtype=float32), array(182.82497, dtype=float32)], ('precision', 5): [0.04320000000000008, 0.08440000000000043, 0.09640000000000051, 0.10280000000000053, 0.12560000000000068, 0.13740000000000058, 0.14420000000000027, 0.1508000000000002, 0.15699999999999997, 0.1596, 0.1634000000000001, 0.16180000000000003, 0.16239999999999996, 0.16080000000000003, 0.15940000000000004, 0.16040000000000001, 0.16100000000000006], ('precision', 10): [0.03960000000000019, 0.07130000000000014, 0.0773000

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32), array(190.45683, dtype=float32), array(189.02472, dtype=float32), array(187.70117, dtype=float32), array(186.41939, dtype=float32), array(185.2293, dtype=float32), array(183.98132, dtype=float32), array(182.82497, dtype=float32), array(181.69336, dtype=float32)], ('precision', 5): [0.04320000000000008, 0.08440000000000043, 0.09640000000000051, 0.10280000000000053, 0.12560000000000068, 0.13740000000000058, 0.14420000000000027, 0.1508000000000002, 0.15699999999999997, 0.1596, 0.1634000000000001, 0.16180000000000003, 0.16239999999999996, 0.16080000000000003, 0.15940000000000004, 0.16040000000000001, 0.16100000000000006, 0.16179999999999997], ('precision', 10)

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32), array(190.45683, dtype=float32), array(189.02472, dtype=float32), array(187.70117, dtype=float32), array(186.41939, dtype=float32), array(185.2293, dtype=float32), array(183.98132, dtype=float32), array(182.82497, dtype=float32), array(181.69336, dtype=float32), array(180.51837, dtype=float32)], ('precision', 5): [0.04320000000000008, 0.08440000000000043, 0.09640000000000051, 0.10280000000000053, 0.12560000000000068, 0.13740000000000058, 0.14420000000000027, 0.1508000000000002, 0.15699999999999997, 0.1596, 0.1634000000000001, 0.16180000000000003, 0.16239999999999996, 0.16080000000000003, 0.15940000000000004, 0.16040000000000001, 0.16100000000000006, 0.1617

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32), array(190.45683, dtype=float32), array(189.02472, dtype=float32), array(187.70117, dtype=float32), array(186.41939, dtype=float32), array(185.2293, dtype=float32), array(183.98132, dtype=float32), array(182.82497, dtype=float32), array(181.69336, dtype=float32), array(180.51837, dtype=float32), array(179.52647, dtype=float32)], ('precision', 5): [0.04320000000000008, 0.08440000000000043, 0.09640000000000051, 0.10280000000000053, 0.12560000000000068, 0.13740000000000058, 0.14420000000000027, 0.1508000000000002, 0.15699999999999997, 0.1596, 0.1634000000000001, 0.16180000000000003, 0.16239999999999996, 0.16080000000000003, 0.15940000000000004, 0.1604000000000

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32), array(190.45683, dtype=float32), array(189.02472, dtype=float32), array(187.70117, dtype=float32), array(186.41939, dtype=float32), array(185.2293, dtype=float32), array(183.98132, dtype=float32), array(182.82497, dtype=float32), array(181.69336, dtype=float32), array(180.51837, dtype=float32), array(179.52647, dtype=float32), array(178.73567, dtype=float32)], ('precision', 5): [0.04320000000000008, 0.08440000000000043, 0.09640000000000051, 0.10280000000000053, 0.12560000000000068, 0.13740000000000058, 0.14420000000000027, 0.1508000000000002, 0.15699999999999997, 0.1596, 0.1634000000000001, 0.16180000000000003, 0.16239999999999996, 0.16080000000000003, 0.1

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32), array(190.45683, dtype=float32), array(189.02472, dtype=float32), array(187.70117, dtype=float32), array(186.41939, dtype=float32), array(185.2293, dtype=float32), array(183.98132, dtype=float32), array(182.82497, dtype=float32), array(181.69336, dtype=float32), array(180.51837, dtype=float32), array(179.52647, dtype=float32), array(178.73567, dtype=float32), array(178.1146, dtype=float32)], ('precision', 5): [0.04320000000000008, 0.08440000000000043, 0.09640000000000051, 0.10280000000000053, 0.12560000000000068, 0.13740000000000058, 0.14420000000000027, 0.1508000000000002, 0.15699999999999997, 0.1596, 0.1634000000000001, 0.16180000000000003, 0.16239999999

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32), array(190.45683, dtype=float32), array(189.02472, dtype=float32), array(187.70117, dtype=float32), array(186.41939, dtype=float32), array(185.2293, dtype=float32), array(183.98132, dtype=float32), array(182.82497, dtype=float32), array(181.69336, dtype=float32), array(180.51837, dtype=float32), array(179.52647, dtype=float32), array(178.73567, dtype=float32), array(178.1146, dtype=float32), array(177.78853, dtype=float32)], ('precision', 5): [0.04320000000000008, 0.08440000000000043, 0.09640000000000051, 0.10280000000000053, 0.12560000000000068, 0.13740000000000058, 0.14420000000000027, 0.1508000000000002, 0.15699999999999997, 0.1596, 0.1634000000000001, 0

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32), array(190.45683, dtype=float32), array(189.02472, dtype=float32), array(187.70117, dtype=float32), array(186.41939, dtype=float32), array(185.2293, dtype=float32), array(183.98132, dtype=float32), array(182.82497, dtype=float32), array(181.69336, dtype=float32), array(180.51837, dtype=float32), array(179.52647, dtype=float32), array(178.73567, dtype=float32), array(178.1146, dtype=float32), array(177.78853, dtype=float32), array(177.3018, dtype=float32)], ('precision', 5): [0.04320000000000008, 0.08440000000000043, 0.09640000000000051, 0.10280000000000053, 0.12560000000000068, 0.13740000000000058, 0.14420000000000027, 0.1508000000000002, 0.1569999999999999

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32), array(190.45683, dtype=float32), array(189.02472, dtype=float32), array(187.70117, dtype=float32), array(186.41939, dtype=float32), array(185.2293, dtype=float32), array(183.98132, dtype=float32), array(182.82497, dtype=float32), array(181.69336, dtype=float32), array(180.51837, dtype=float32), array(179.52647, dtype=float32), array(178.73567, dtype=float32), array(178.1146, dtype=float32), array(177.78853, dtype=float32), array(177.3018, dtype=float32), array(176.7004, dtype=float32)], ('precision', 5): [0.04320000000000008, 0.08440000000000043, 0.09640000000000051, 0.10280000000000053, 0.12560000000000068, 0.13740000000000058, 0.14420000000000027, 0.1508

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32), array(190.45683, dtype=float32), array(189.02472, dtype=float32), array(187.70117, dtype=float32), array(186.41939, dtype=float32), array(185.2293, dtype=float32), array(183.98132, dtype=float32), array(182.82497, dtype=float32), array(181.69336, dtype=float32), array(180.51837, dtype=float32), array(179.52647, dtype=float32), array(178.73567, dtype=float32), array(178.1146, dtype=float32), array(177.78853, dtype=float32), array(177.3018, dtype=float32), array(176.7004, dtype=float32), array(175.84787, dtype=float32)], ('precision', 5): [0.04320000000000008, 0.08440000000000043, 0.09640000000000051, 0.10280000000000053, 0.12560000000000068, 0.1374000000000

{'l': [array(244.06711, dtype=float32), array(229.41023, dtype=float32), array(224.11641, dtype=float32), array(218.03897, dtype=float32), array(212.56789, dtype=float32), array(208.55965, dtype=float32), array(206.17012, dtype=float32), array(204.45474, dtype=float32), array(202.91722, dtype=float32), array(201.35861, dtype=float32), array(199.9807, dtype=float32), array(198.4782, dtype=float32), array(197.02736, dtype=float32), array(195.71655, dtype=float32), array(194.43697, dtype=float32), array(193.15305, dtype=float32), array(191.95535, dtype=float32), array(190.7737, dtype=float32), array(189.6139, dtype=float32), array(188.59637, dtype=float32), array(187.8906, dtype=float32), array(187.30045, dtype=float32), array(187.0688, dtype=float32), array(186.57204, dtype=float32), array(185.97789, dtype=float32), array(185.05591, dtype=float32), array(183.89172, dtype=float32)], ('precision', 5): [0.041000000000000016, 0.08180000000000034, 0.0926000000000004, 0.11520000000000058, 0.13

{'l': [array(301.4743, dtype=float32), array(281.37524, dtype=float32), array(273.93225, dtype=float32), array(266.13522, dtype=float32), array(258.90024, dtype=float32), array(252.95773, dtype=float32), array(248.93044, dtype=float32), array(245.63205, dtype=float32), array(242.7356, dtype=float32), array(239.81932, dtype=float32), array(237.23848, dtype=float32), array(234.46144, dtype=float32), array(231.66847, dtype=float32), array(229.1477, dtype=float32), array(226.53769, dtype=float32), array(224.04153, dtype=float32), array(221.71695, dtype=float32), array(219.44897, dtype=float32), array(217.17242, dtype=float32), array(215.10481, dtype=float32), array(212.90387, dtype=float32), array(210.83636, dtype=float32), array(209.10875, dtype=float32), array(207.8062, dtype=float32), array(206.35236, dtype=float32), array(205.19987, dtype=float32), array(203.93437, dtype=float32), array(202.75362, dtype=float32)], ('precision', 5): [0.20361445783132678, 0.29731653888280735, 0.340799561

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32), array(190.45683, dtype=float32), array(189.02472, dtype=float32), array(187.70117, dtype=float32), array(186.41939, dtype=float32), array(185.2293, dtype=float32), array(183.98132, dtype=float32), array(182.82497, dtype=float32), array(181.69336, dtype=float32), array(180.51837, dtype=float32), array(179.52647, dtype=float32), array(178.73567, dtype=float32), array(178.1146, dtype=float32), array(177.78853, dtype=float32), array(177.3018, dtype=float32), array(176.7004, dtype=float32), array(175.84787, dtype=float32), array(174.74278, dtype=float32), array(173.87225, dtype=float32)], ('precision', 5): [0.04320000000000008, 0.08440000000000043, 0.0964000000

{'l': [array(244.06711, dtype=float32), array(229.41023, dtype=float32), array(224.11641, dtype=float32), array(218.03897, dtype=float32), array(212.56789, dtype=float32), array(208.55965, dtype=float32), array(206.17012, dtype=float32), array(204.45474, dtype=float32), array(202.91722, dtype=float32), array(201.35861, dtype=float32), array(199.9807, dtype=float32), array(198.4782, dtype=float32), array(197.02736, dtype=float32), array(195.71655, dtype=float32), array(194.43697, dtype=float32), array(193.15305, dtype=float32), array(191.95535, dtype=float32), array(190.7737, dtype=float32), array(189.6139, dtype=float32), array(188.59637, dtype=float32), array(187.8906, dtype=float32), array(187.30045, dtype=float32), array(187.0688, dtype=float32), array(186.57204, dtype=float32), array(185.97789, dtype=float32), array(185.05591, dtype=float32), array(183.89172, dtype=float32), array(182.95032, dtype=float32), array(181.86461, dtype=float32)], ('precision', 5): [0.041000000000000016, 

{'l': [array(301.4743, dtype=float32), array(281.37524, dtype=float32), array(273.93225, dtype=float32), array(266.13522, dtype=float32), array(258.90024, dtype=float32), array(252.95773, dtype=float32), array(248.93044, dtype=float32), array(245.63205, dtype=float32), array(242.7356, dtype=float32), array(239.81932, dtype=float32), array(237.23848, dtype=float32), array(234.46144, dtype=float32), array(231.66847, dtype=float32), array(229.1477, dtype=float32), array(226.53769, dtype=float32), array(224.04153, dtype=float32), array(221.71695, dtype=float32), array(219.44897, dtype=float32), array(217.17242, dtype=float32), array(215.10481, dtype=float32), array(212.90387, dtype=float32), array(210.83636, dtype=float32), array(209.10875, dtype=float32), array(207.8062, dtype=float32), array(206.35236, dtype=float32), array(205.19987, dtype=float32), array(203.93437, dtype=float32), array(202.75362, dtype=float32), array(201.91982, dtype=float32), array(201.1133, dtype=float32)], ('preci

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32), array(190.45683, dtype=float32), array(189.02472, dtype=float32), array(187.70117, dtype=float32), array(186.41939, dtype=float32), array(185.2293, dtype=float32), array(183.98132, dtype=float32), array(182.82497, dtype=float32), array(181.69336, dtype=float32), array(180.51837, dtype=float32), array(179.52647, dtype=float32), array(178.73567, dtype=float32), array(178.1146, dtype=float32), array(177.78853, dtype=float32), array(177.3018, dtype=float32), array(176.7004, dtype=float32), array(175.84787, dtype=float32), array(174.74278, dtype=float32), array(173.87225, dtype=float32), array(172.87292, dtype=float32), array(172.19868, dtype=float32)], ('preci

{'l': [array(244.06711, dtype=float32), array(229.41023, dtype=float32), array(224.11641, dtype=float32), array(218.03897, dtype=float32), array(212.56789, dtype=float32), array(208.55965, dtype=float32), array(206.17012, dtype=float32), array(204.45474, dtype=float32), array(202.91722, dtype=float32), array(201.35861, dtype=float32), array(199.9807, dtype=float32), array(198.4782, dtype=float32), array(197.02736, dtype=float32), array(195.71655, dtype=float32), array(194.43697, dtype=float32), array(193.15305, dtype=float32), array(191.95535, dtype=float32), array(190.7737, dtype=float32), array(189.6139, dtype=float32), array(188.59637, dtype=float32), array(187.8906, dtype=float32), array(187.30045, dtype=float32), array(187.0688, dtype=float32), array(186.57204, dtype=float32), array(185.97789, dtype=float32), array(185.05591, dtype=float32), array(183.89172, dtype=float32), array(182.95032, dtype=float32), array(181.86461, dtype=float32), array(181.13333, dtype=float32), array(180

{'l': [array(301.4743, dtype=float32), array(281.37524, dtype=float32), array(273.93225, dtype=float32), array(266.13522, dtype=float32), array(258.90024, dtype=float32), array(252.95773, dtype=float32), array(248.93044, dtype=float32), array(245.63205, dtype=float32), array(242.7356, dtype=float32), array(239.81932, dtype=float32), array(237.23848, dtype=float32), array(234.46144, dtype=float32), array(231.66847, dtype=float32), array(229.1477, dtype=float32), array(226.53769, dtype=float32), array(224.04153, dtype=float32), array(221.71695, dtype=float32), array(219.44897, dtype=float32), array(217.17242, dtype=float32), array(215.10481, dtype=float32), array(212.90387, dtype=float32), array(210.83636, dtype=float32), array(209.10875, dtype=float32), array(207.8062, dtype=float32), array(206.35236, dtype=float32), array(205.19987, dtype=float32), array(203.93437, dtype=float32), array(202.75362, dtype=float32), array(201.91982, dtype=float32), array(201.1133, dtype=float32), array(20

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32), array(190.45683, dtype=float32), array(189.02472, dtype=float32), array(187.70117, dtype=float32), array(186.41939, dtype=float32), array(185.2293, dtype=float32), array(183.98132, dtype=float32), array(182.82497, dtype=float32), array(181.69336, dtype=float32), array(180.51837, dtype=float32), array(179.52647, dtype=float32), array(178.73567, dtype=float32), array(178.1146, dtype=float32), array(177.78853, dtype=float32), array(177.3018, dtype=float32), array(176.7004, dtype=float32), array(175.84787, dtype=float32), array(174.74278, dtype=float32), array(173.87225, dtype=float32), array(172.87292, dtype=float32), array(172.19868, dtype=float32), array(17

{'l': [array(244.06711, dtype=float32), array(229.41023, dtype=float32), array(224.11641, dtype=float32), array(218.03897, dtype=float32), array(212.56789, dtype=float32), array(208.55965, dtype=float32), array(206.17012, dtype=float32), array(204.45474, dtype=float32), array(202.91722, dtype=float32), array(201.35861, dtype=float32), array(199.9807, dtype=float32), array(198.4782, dtype=float32), array(197.02736, dtype=float32), array(195.71655, dtype=float32), array(194.43697, dtype=float32), array(193.15305, dtype=float32), array(191.95535, dtype=float32), array(190.7737, dtype=float32), array(189.6139, dtype=float32), array(188.59637, dtype=float32), array(187.8906, dtype=float32), array(187.30045, dtype=float32), array(187.0688, dtype=float32), array(186.57204, dtype=float32), array(185.97789, dtype=float32), array(185.05591, dtype=float32), array(183.89172, dtype=float32), array(182.95032, dtype=float32), array(181.86461, dtype=float32), array(181.13333, dtype=float32), array(180

{'l': [array(301.4743, dtype=float32), array(281.37524, dtype=float32), array(273.93225, dtype=float32), array(266.13522, dtype=float32), array(258.90024, dtype=float32), array(252.95773, dtype=float32), array(248.93044, dtype=float32), array(245.63205, dtype=float32), array(242.7356, dtype=float32), array(239.81932, dtype=float32), array(237.23848, dtype=float32), array(234.46144, dtype=float32), array(231.66847, dtype=float32), array(229.1477, dtype=float32), array(226.53769, dtype=float32), array(224.04153, dtype=float32), array(221.71695, dtype=float32), array(219.44897, dtype=float32), array(217.17242, dtype=float32), array(215.10481, dtype=float32), array(212.90387, dtype=float32), array(210.83636, dtype=float32), array(209.10875, dtype=float32), array(207.8062, dtype=float32), array(206.35236, dtype=float32), array(205.19987, dtype=float32), array(203.93437, dtype=float32), array(202.75362, dtype=float32), array(201.91982, dtype=float32), array(201.1133, dtype=float32), array(20

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32), array(190.45683, dtype=float32), array(189.02472, dtype=float32), array(187.70117, dtype=float32), array(186.41939, dtype=float32), array(185.2293, dtype=float32), array(183.98132, dtype=float32), array(182.82497, dtype=float32), array(181.69336, dtype=float32), array(180.51837, dtype=float32), array(179.52647, dtype=float32), array(178.73567, dtype=float32), array(178.1146, dtype=float32), array(177.78853, dtype=float32), array(177.3018, dtype=float32), array(176.7004, dtype=float32), array(175.84787, dtype=float32), array(174.74278, dtype=float32), array(173.87225, dtype=float32), array(172.87292, dtype=float32), array(172.19868, dtype=float32), array(17

{'l': [array(244.06711, dtype=float32), array(229.41023, dtype=float32), array(224.11641, dtype=float32), array(218.03897, dtype=float32), array(212.56789, dtype=float32), array(208.55965, dtype=float32), array(206.17012, dtype=float32), array(204.45474, dtype=float32), array(202.91722, dtype=float32), array(201.35861, dtype=float32), array(199.9807, dtype=float32), array(198.4782, dtype=float32), array(197.02736, dtype=float32), array(195.71655, dtype=float32), array(194.43697, dtype=float32), array(193.15305, dtype=float32), array(191.95535, dtype=float32), array(190.7737, dtype=float32), array(189.6139, dtype=float32), array(188.59637, dtype=float32), array(187.8906, dtype=float32), array(187.30045, dtype=float32), array(187.0688, dtype=float32), array(186.57204, dtype=float32), array(185.97789, dtype=float32), array(185.05591, dtype=float32), array(183.89172, dtype=float32), array(182.95032, dtype=float32), array(181.86461, dtype=float32), array(181.13333, dtype=float32), array(180

{'l': [array(301.4743, dtype=float32), array(281.37524, dtype=float32), array(273.93225, dtype=float32), array(266.13522, dtype=float32), array(258.90024, dtype=float32), array(252.95773, dtype=float32), array(248.93044, dtype=float32), array(245.63205, dtype=float32), array(242.7356, dtype=float32), array(239.81932, dtype=float32), array(237.23848, dtype=float32), array(234.46144, dtype=float32), array(231.66847, dtype=float32), array(229.1477, dtype=float32), array(226.53769, dtype=float32), array(224.04153, dtype=float32), array(221.71695, dtype=float32), array(219.44897, dtype=float32), array(217.17242, dtype=float32), array(215.10481, dtype=float32), array(212.90387, dtype=float32), array(210.83636, dtype=float32), array(209.10875, dtype=float32), array(207.8062, dtype=float32), array(206.35236, dtype=float32), array(205.19987, dtype=float32), array(203.93437, dtype=float32), array(202.75362, dtype=float32), array(201.91982, dtype=float32), array(201.1133, dtype=float32), array(20

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32), array(190.45683, dtype=float32), array(189.02472, dtype=float32), array(187.70117, dtype=float32), array(186.41939, dtype=float32), array(185.2293, dtype=float32), array(183.98132, dtype=float32), array(182.82497, dtype=float32), array(181.69336, dtype=float32), array(180.51837, dtype=float32), array(179.52647, dtype=float32), array(178.73567, dtype=float32), array(178.1146, dtype=float32), array(177.78853, dtype=float32), array(177.3018, dtype=float32), array(176.7004, dtype=float32), array(175.84787, dtype=float32), array(174.74278, dtype=float32), array(173.87225, dtype=float32), array(172.87292, dtype=float32), array(172.19868, dtype=float32), array(17

{'l': [array(244.06711, dtype=float32), array(229.41023, dtype=float32), array(224.11641, dtype=float32), array(218.03897, dtype=float32), array(212.56789, dtype=float32), array(208.55965, dtype=float32), array(206.17012, dtype=float32), array(204.45474, dtype=float32), array(202.91722, dtype=float32), array(201.35861, dtype=float32), array(199.9807, dtype=float32), array(198.4782, dtype=float32), array(197.02736, dtype=float32), array(195.71655, dtype=float32), array(194.43697, dtype=float32), array(193.15305, dtype=float32), array(191.95535, dtype=float32), array(190.7737, dtype=float32), array(189.6139, dtype=float32), array(188.59637, dtype=float32), array(187.8906, dtype=float32), array(187.30045, dtype=float32), array(187.0688, dtype=float32), array(186.57204, dtype=float32), array(185.97789, dtype=float32), array(185.05591, dtype=float32), array(183.89172, dtype=float32), array(182.95032, dtype=float32), array(181.86461, dtype=float32), array(181.13333, dtype=float32), array(180

{'l': [array(301.4743, dtype=float32), array(281.37524, dtype=float32), array(273.93225, dtype=float32), array(266.13522, dtype=float32), array(258.90024, dtype=float32), array(252.95773, dtype=float32), array(248.93044, dtype=float32), array(245.63205, dtype=float32), array(242.7356, dtype=float32), array(239.81932, dtype=float32), array(237.23848, dtype=float32), array(234.46144, dtype=float32), array(231.66847, dtype=float32), array(229.1477, dtype=float32), array(226.53769, dtype=float32), array(224.04153, dtype=float32), array(221.71695, dtype=float32), array(219.44897, dtype=float32), array(217.17242, dtype=float32), array(215.10481, dtype=float32), array(212.90387, dtype=float32), array(210.83636, dtype=float32), array(209.10875, dtype=float32), array(207.8062, dtype=float32), array(206.35236, dtype=float32), array(205.19987, dtype=float32), array(203.93437, dtype=float32), array(202.75362, dtype=float32), array(201.91982, dtype=float32), array(201.1133, dtype=float32), array(20

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32), array(190.45683, dtype=float32), array(189.02472, dtype=float32), array(187.70117, dtype=float32), array(186.41939, dtype=float32), array(185.2293, dtype=float32), array(183.98132, dtype=float32), array(182.82497, dtype=float32), array(181.69336, dtype=float32), array(180.51837, dtype=float32), array(179.52647, dtype=float32), array(178.73567, dtype=float32), array(178.1146, dtype=float32), array(177.78853, dtype=float32), array(177.3018, dtype=float32), array(176.7004, dtype=float32), array(175.84787, dtype=float32), array(174.74278, dtype=float32), array(173.87225, dtype=float32), array(172.87292, dtype=float32), array(172.19868, dtype=float32), array(17

{'l': [array(244.06711, dtype=float32), array(229.41023, dtype=float32), array(224.11641, dtype=float32), array(218.03897, dtype=float32), array(212.56789, dtype=float32), array(208.55965, dtype=float32), array(206.17012, dtype=float32), array(204.45474, dtype=float32), array(202.91722, dtype=float32), array(201.35861, dtype=float32), array(199.9807, dtype=float32), array(198.4782, dtype=float32), array(197.02736, dtype=float32), array(195.71655, dtype=float32), array(194.43697, dtype=float32), array(193.15305, dtype=float32), array(191.95535, dtype=float32), array(190.7737, dtype=float32), array(189.6139, dtype=float32), array(188.59637, dtype=float32), array(187.8906, dtype=float32), array(187.30045, dtype=float32), array(187.0688, dtype=float32), array(186.57204, dtype=float32), array(185.97789, dtype=float32), array(185.05591, dtype=float32), array(183.89172, dtype=float32), array(182.95032, dtype=float32), array(181.86461, dtype=float32), array(181.13333, dtype=float32), array(180

{'l': [array(301.4743, dtype=float32), array(281.37524, dtype=float32), array(273.93225, dtype=float32), array(266.13522, dtype=float32), array(258.90024, dtype=float32), array(252.95773, dtype=float32), array(248.93044, dtype=float32), array(245.63205, dtype=float32), array(242.7356, dtype=float32), array(239.81932, dtype=float32), array(237.23848, dtype=float32), array(234.46144, dtype=float32), array(231.66847, dtype=float32), array(229.1477, dtype=float32), array(226.53769, dtype=float32), array(224.04153, dtype=float32), array(221.71695, dtype=float32), array(219.44897, dtype=float32), array(217.17242, dtype=float32), array(215.10481, dtype=float32), array(212.90387, dtype=float32), array(210.83636, dtype=float32), array(209.10875, dtype=float32), array(207.8062, dtype=float32), array(206.35236, dtype=float32), array(205.19987, dtype=float32), array(203.93437, dtype=float32), array(202.75362, dtype=float32), array(201.91982, dtype=float32), array(201.1133, dtype=float32), array(20

{'l': [array(234.31343, dtype=float32), array(219.68962, dtype=float32), array(214.16057, dtype=float32), array(208.2797, dtype=float32), array(203.09192, dtype=float32), array(199.03616, dtype=float32), array(196.58012, dtype=float32), array(194.75833, dtype=float32), array(193.25974, dtype=float32), array(191.76204, dtype=float32), array(190.45683, dtype=float32), array(189.02472, dtype=float32), array(187.70117, dtype=float32), array(186.41939, dtype=float32), array(185.2293, dtype=float32), array(183.98132, dtype=float32), array(182.82497, dtype=float32), array(181.69336, dtype=float32), array(180.51837, dtype=float32), array(179.52647, dtype=float32), array(178.73567, dtype=float32), array(178.1146, dtype=float32), array(177.78853, dtype=float32), array(177.3018, dtype=float32), array(176.7004, dtype=float32), array(175.84787, dtype=float32), array(174.74278, dtype=float32), array(173.87225, dtype=float32), array(172.87292, dtype=float32), array(172.19868, dtype=float32), array(17

In [13]:
qed = []
import models.ae
importlib.reload(models.ae)
import neuralsort.neuralobjs
importlib.reload(models.ae)
use_vae=True
loader = DataLoader(ae_dataset, batch_size=batch_size, shuffle=True, num_workers=6)
if use_vae:
    model = models.ae.MultiVAE([200] + [n_items], dropout=0.5)
else:
    model = models.ae.MultiDAE([200] + [n_items], dropout=0.5)
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=5 * 1e-4)
sc = neuralsort.neuralobjs.SC()
lm = -1 
tr_tt = {
    5: [],
    10: []
}
val_tt = {
    5: [],
    10: [],
}
lolis = []
update_count = 0
for epoch in (range(1, 1000 + 1)):
    model = model.train()
    model.training = True
    tr_losses = []
    loss2s =[] 
    # train for one epoch
    for uid, rowl in (loader):
        row = rowl.float().cuda()
        uid = uid.cuda() 
        loss = None 
        if use_vae:
            scores, mu, logvar = model.forward(row)
            if total_anneal_steps > 0:
                anneal_cap = min(anneal_cap, 1. * update_count / total_anneal_steps)
            else:
                anneal = anneal_cap
            loss  = models.ae.loss_function(scores, row, mu, logvar)
            update_count += 1

        else:
            scores = model.forward(row)
            loss  = models.loss.MultinomialLoss(row, scores)

        loss2 = neuMapLoss(sc, scores, row, topk=300, k=5, tau=15.0, use_top=True)
        loss2s.append(loss2.unsqueeze(-1))
        (loss.mean() + 10.0 * loss2.mean()).backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
        tr_losses.append(loss.detach().unsqueeze(-1)) 
        optimizer.step()
    
    if (epoch % 10 == 0):
        tr_loss = torch.cat(tr_losses).mean()
        lo2 = torch.cat(loss2s).mean()
        lolis.append(lo2)
        _in = torch.from_numpy(vin_dense).float().cuda()
        vad_scores, mu, logvar = model.forward(_in)
        vad_loss = models.ae.loss_function(vad_scores, _in, mu, logvar)
        vad_loss.mean().detach().cpu().numpy()
        model.eval()
        model.training=False                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
        wrapper1 = models.ae.implicitWrapper(model, train_data, naive_sparse2tensor, vae=True)
        wrapper2 = models.ae.implicitWrapper(model, vin, naive_sparse2tensor, vae=True)

        
        ret = (tr_loss.detach().cpu().numpy(), vad_loss.mean().detach().cpu().numpy())
        
        for topk in [5, 10]:
            tr_tt[topk].append(ranking_metrics_at_k(wrapper1, empty, train_data, K=topk, num_threads=4)['map'])
            val_tt[topk].append((ranking_metrics_at_k(wrapper2, vin, vo, K=topk, num_threads=4)['map']))
        print(["%0.4f" % x for x in ret], tr_tt[5][-1], tr_tt[10][-1],val_tt[5][-1], val_tt[10][-1])
        qed.append(ret)
add = qed

['294.7231', '225.8089'] 0.21826278174821165 0.17152883630953158 0.05616430555555556 0.05015200877110111
['274.4365', '219.9437'] 0.2750879604178133 0.21030542590868131 0.06868902777777773 0.06192627480158732
['280.2097', '216.5288'] 0.35857614073666827 0.2742169076872701 0.08582458333333337 0.07594075207860913
['275.0313', '214.3504'] 0.400465457210921 0.3105787457333225 0.09942861111111095 0.08769221812799184
['267.6081', '212.9479'] 0.43259849734286077 0.33781472371745475 0.10510222222222197 0.09244149911816571
['263.7416', '212.0777'] 0.4634359538207799 0.3622940053920485 0.1130180555555553 0.09898759511211869
['266.4737', '211.8793'] 0.4883085944658228 0.3832880855266988 0.11566236111111074 0.10196724600025156
['272.7971', '211.2510'] 0.509337548103352 0.401759022361108 0.11879930555555519 0.10581691649344904
['262.5125', '211.0949'] 0.5284835990470935 0.4176413090331327 0.12163013888888845 0.10819690641534364
['274.7474', '210.6121'] 0.5448414879970654 0.4326874071712478 0.122983

['259.1045', '214.6977'] 0.9268123511086737 0.8329643437506576 0.13897249999999972 0.12617133739606934
['265.1940', '214.8335'] 0.9272796408283005 0.8339701823529737 0.13918958333333298 0.12575000543272832
['263.5654', '214.5125'] 0.9297067986072995 0.8383749975414234 0.13796916666666656 0.1252059786785081
['258.7459', '214.9248'] 0.9311599780099016 0.8402338847446248 0.139772083333333 0.12628875015747018
['265.3243', '214.5748'] 0.931131574124983 0.8416007946965607 0.13900611111111094 0.12567849434681252
['257.1082', '214.8861'] 0.934390690855788 0.8463302887354356 0.13806486111111108 0.12572311799256713
['264.7551', '214.9235'] 0.9347544438336146 0.846072154675306 0.13944319444444422 0.1265741059618291
['264.0295', '215.1135'] 0.9377350192413472 0.8498813985615058 0.1396763888888886 0.12661996779730392
['264.2753', '214.9118'] 0.9366208539490622 0.84978738020823 0.1380765277777774 0.12572762117346908
['258.3012', '214.8421'] 0.9376378962800136 0.8499655549741576 0.1372198611111109 0.

In [16]:
lolis

[tensor(-0.4817, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-0.7677, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-1.0045, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-1.1744, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-1.2924, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-1.3773, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-1.4989, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-1.5715, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-1.5790, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-1.6430, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-1.7150, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-1.7087, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-1.7785, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-1.7870, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-1.8957, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-1.8911, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(-1.9152, device='cuda:0', grad_f