In [9]:
import pandas as pd
import numpy as np
from random import sample
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import pickle

In [10]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [11]:
df= pd.read_csv('/content/drive/MyDrive/Colab Notebooks/trainwithzerostopredict.csv')
topredict= pd.read_csv('/content/drive/MyDrive/Colab Notebooks/topredict.csv')
df_withoutzero= pd.read_csv('/content/drive/MyDrive/Colab Notebooks/trainwithoutzeros.csv')

***df has all the topredict values entered as 0 for the Book-Rating column.***</br>
***topredict has all the values that we want to predict.***</br>
***df_withoutzero has all the topredict all the values including the values in the test dataset.***



In [12]:
df['User-ID'].nunique()

77805

In [13]:
topredict['User-ID'].nunique()

19935

##**500 users that rate the most**

In [14]:
df_num_rating= df.groupby('User-ID').agg(Number_of_ratings=('Book-Rating','count')).reset_index()
# Num ratings count

In [15]:
df_num_rating.sort_values(by=['Number_of_ratings'], ascending=False, inplace=True)
df_num_rating

Unnamed: 0,User-ID,Number_of_ratings
3160,11676,8524
27626,98391,5802
43027,153662,1969
52924,189835,1906
6510,23902,1395
...,...,...
23601,84129,1
46266,165812,1
9606,34231,1
46268,165826,1


In [16]:
top_users= df_num_rating['User-ID'][:500]
# top 500 users

##**To Predict User List and Batch Designation**

In [17]:
topredict_users= topredict['User-ID'].unique()
topredict_users

array([    17,     56,    114, ..., 278844, 278851, 278854])

In [18]:
len(topredict_users)

19935

**We divide the test into 20 folds because of memory limitations. We will predict the test values by loading in batches**

In [19]:
dict_batch= {}
for i in range(0,len(topredict_users),1000):
  dict_batch['fold_{}'.format(i//1000)]= topredict_users[i:i+1000]

In [20]:
dict_batch #fold_0 to fold_19#

{'fold_0': array([   17,    56,   114,   160,   183,   242,   243,   254,   272,
          289,   300,   362,   383,   388,   392,   408,   424,   440,
          441,   444,   446,   472,   476,   487,   503,   505,   507,
          566,   619,   625,   626,   638,   640,   643,   651,   657,
          695,   709,   726,   741,   744,   746,   753,   776,   786,
          805,   819,   850,   853,   882,   885,   896,   899,   900,
          901,   914,   929,  1009,  1021,  1025,  1031,  1063,  1075,
         1083,  1096,  1113,  1116,  1129,  1131,  1155,  1161,  1167,
         1178,  1184,  1211,  1235,  1248,  1249,  1254,  1261,  1297,
         1343,  1376,  1409,  1412,  1421,  1424,  1435,  1436,  1467,
         1485,  1517,  1548,  1558,  1570,  1585,  1596,  1597,  1619,
         1652,  1660,  1667,  1674,  1688,  1696,  1733,  1790,  1791,
         1797,  1805,  1830,  1848,  1903,  1990,  2009,  2010,  2024,
         2030,  2033,  2041,  2046,  2084,  2090,  2103,  2106,  21

In [21]:
# this is the matrix factorization function from the MF notebook. 
def MF(M,k,max_it,lambd,mu):
    n=M.size()[0]
    m= M.size()[1]
    nonzero=len(M.nonzero())
    index= M.nonzero().split(1, dim=1)
    #param=torch.rand(n*k+k*m,dtype=float,requires_grad=True)
    param1=torch.rand((n,k),dtype=torch.float,requires_grad=True)
    param2=torch.rand((k,m),dtype=torch.float,requires_grad=True)

    opt1= torch.optim.Adam([param1],lr=0.1)
    opt2= torch.optim.Adam([param2],lr=0.1)
    
    #scheduler1= ReduceLROnPlateau(opt1, 'min') 
    #scheduler2 = ReduceLROnPlateau(opt2, 'min')
    

    #def get_loss(params,params_hat):
        #return torch.sum(torch.square(params- params_hat))

    def run_iterations(max_it):
        loss_record=[]
        converged=False
        for it in tqdm(range(max_it)):
            if it%2==0:
                opt1.zero_grad(set_to_none=True)
                #torch.matmul(param[:n*k].reshape(n,k), pam[n*k:].reshape(k,m))
                loss=torch.sum(torch.square(torch.matmul(param1, param2)[index]- M[index])) + lambd*torch.sum(torch.square(param1))+mu*torch.sum(torch.square(param2))
                loss_record.append(loss.item())
                loss.backward()
                opt1.step()
                #scheduler1.step(loss)
            else:
                opt2.zero_grad(set_to_none=True)
                #torch.matmul(param[:n*k].reshape(n,k), pam[n*k:].reshape(k,m))
                loss=torch.sum(torch.square(torch.matmul(param1, param2)[index]- M[index])) + lambd*torch.sum(torch.square(param1))+mu*torch.sum(torch.square(param2))
                loss_record.append(loss.item())
                loss.backward()
                opt2.step()
                #scheduler2.step(loss)
        display(loss_record)
        return torch.matmul(param1,param2)
    return run_iterations(max_it)

**The following loop loads the ith batch of the train dataset and constructs a matrix using the ith batch and the subset of the train dataset containing the 500 top users. Then it computes the matrix factorization prediction on the ith batch. It calculates the MSE for each batch in each iteration.**

In [22]:
mse_batch=[] # this will collect the batch MSE
indices_batch=[] # this will collect the indices of the test dataset in each batch
for i in range(len(dict_batch)): #length of dict_batch is 20 because we have 20 folds
  df_batch= topredict[topredict['User-ID'].isin(dict_batch['fold_{}'.format(i)])].reset_index(drop=True)
  df_matrix= df[df['User-ID'].isin(list(dict_batch['fold_{}'.format(i)])+list(top_users))].reset_index(drop=True) #combines data from 500 top users and the test batch
  mat= df_matrix.pivot(index='User-ID',columns='ISBN',values='Book-Rating') .fillna(0) # matrix using pivot
  matrix= torch.tensor(mat.values)/10 #converts to tensor and scales the values by 1/10
  dict_user= dict(zip(sorted(set(df_matrix['User-ID'])),range(len(sorted(set(df_matrix['User-ID']))))))
  dict_book= dict(zip(sorted(set(df_matrix['ISBN'])),range(len(sorted(set(df_matrix['ISBN']))))))
  index1=[] #this will collect the row index
  index2=[] #this will collect the column index
  for j in range(len(df_batch)):
    index1.append([dict_user[df_batch['User-ID'][j]]])
    index2.append([dict_book[df_batch['ISBN'][j]]])
  indices_topred = (torch.tensor(index1),torch.tensor(index1)) #this contains the indices from the matrix whose values we want to predict
  # this is a bit peculiar but torch tensors work this way
  df_actual= df_withoutzero[df_withoutzero['User-ID'].isin(list(dict_batch['fold_{}'.format(i)])+list(top_users))]
  # df_actual is the matrix with the ratings of the test dataset included
  # we need this to calculate MSE
  actual_matrix= df_actual.pivot(index='User-ID',columns='ISBN',values='Book-Rating') .fillna(0)
  actual_matrix= torch.tensor(actual_matrix.values)/10 # convert to tensor and scale like above

  mse_batch.append(100*torch.sum(torch.square(MF(matrix,10,1500,0.1,0.1)[indices_topred]- actual_matrix[indices_topred])))
  # we multiply by 10^2 because we scaled by 1/10 and squared
  # the values 25,0.1,0.1 come from the MF validation notebook where we performed validation for parameter tuning 
  indices_batch.append(len(indices_topred[0]))
  display(mse_batch[-1])






100%|██████████| 1500/1500 [09:48<00:00,  2.55it/s]


[417782.43285339617,
 250582.62210716488,
 141137.4343929216,
 82179.77189549133,
 45896.38928847458,
 35388.18536310081,
 27320.587131621127,
 32378.061310013443,
 31360.170629022905,
 38733.55011540025,
 37562.398396559234,
 43346.66332634709,
 40615.79756844395,
 44327.12745439111,
 40469.18576982042,
 42563.03797111553,
 38332.0353095903,
 39358.50428402485,
 35345.605278681374,
 35730.10088893239,
 32266.866859994203,
 32293.445845016373,
 29489.485416968077,
 29335.811821277006,
 27154.775648644652,
 26929.330505606995,
 25262.32212964457,
 25030.26318832387,
 23746.67138239398,
 23544.227793499118,
 22519.93562041162,
 22361.832938580716,
 21492.87742148172,
 21377.053042455642,
 20585.16487042119,
 20497.80664966951,
 19731.01870112797,
 19653.552461552834,
 18884.223225768907,
 18800.71516142553,
 18023.799237498562,
 17925.39078930138,
 17153.344734120878,
 17038.664109515757,
 16292.505433280403,
 16165.253478857694,
 15465.909394807191,
 15331.587884881548,
 14694.010379431

tensor(116041.1186, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [10:18<00:00,  2.43it/s]


[368914.71290312085,
 215229.67933080636,
 120960.36725807172,
 69927.68514811122,
 40257.25153110045,
 33243.153380506454,
 26834.617473941944,
 32897.008436964315,
 31530.854366909516,
 38852.10481815859,
 36873.059479412725,
 42295.792721644066,
 38920.53042803603,
 42275.76365136951,
 38097.351689822244,
 39941.06190616318,
 35671.606383419465,
 36559.82466876387,
 32685.308143857706,
 33016.40055897054,
 29773.833432411324,
 29803.92307512089,
 27244.85760171875,
 27132.879562944465,
 25191.206951182252,
 25031.04461818519,
 23577.414334275454,
 23415.781500100322,
 22300.442062992162,
 22150.2463664787,
 21235.890633637522,
 21090.990579206442,
 20273.304136817576,
 20123.07637964121,
 19338.772562234615,
 19177.655401895972,
 18401.58536628548,
 18231.54890422956,
 17465.537959089717,
 17293.55601254625,
 16552.531729084527,
 16386.454067610917,
 15686.915011175524,
 15532.241056954426,
 14884.959033211462,
 14743.377174123842,
 14150.863892569152,
 14020.47674577009,
 13478.921

tensor(24621.2108, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [11:08<00:00,  2.24it/s]


[394667.50505753455,
 232386.9103615814,
 130039.8855450664,
 74031.82934634516,
 41579.08860455161,
 32479.515451855033,
 26006.67932421455,
 31619.880621487227,
 30982.49728795017,
 38542.19155155625,
 37226.46777188065,
 43085.1827830679,
 40045.923879630594,
 43822.31569462981,
 39648.28585140087,
 41819.59321253939,
 37324.88371426193,
 38433.18240775664,
 34237.342329090316,
 34699.07821988044,
 31155.144886900693,
 31255.053911324947,
 28485.079675307083,
 28407.61652802491,
 26362.988898702275,
 26219.030438670765,
 24745.917655722413,
 24591.561616919047,
 23497.3342553072,
 23351.341579403757,
 22459.049878550377,
 22320.28441317393,
 21500.6705995495,
 21361.961372036792,
 20543.08477333953,
 20399.288062087988,
 19558.97035779223,
 19409.92602820481,
 18560.044217318824,
 18410.00199818248,
 17578.331920957455,
 17433.10703846204,
 16648.389317946614,
 16512.03888331519,
 15795.078452796068,
 15667.983868767817,
 15028.250202652122,
 14907.200370086948,
 14343.064779177328,

tensor(28230.1569, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [08:51<00:00,  2.82it/s]


[407875.9972733115,
 242284.5876257174,
 136246.59271004458,
 78308.93247356259,
 44020.932287726886,
 34142.80892595051,
 26910.435152672635,
 32372.70114713505,
 31432.739816755085,
 39061.318077620155,
 37582.99952725412,
 43535.562761601715,
 40341.0292905849,
 44170.5637386962,
 39811.28321112917,
 41984.468426896376,
 37287.149987490935,
 38355.99593960987,
 33975.01114258639,
 34373.89100588233,
 30697.70775922586,
 30730.473395016663,
 27898.702393901243,
 27762.301900793158,
 25719.16862814082,
 25526.163041952837,
 24093.34095196607,
 23893.95083916726,
 22850.949061071944,
 22657.698325935417,
 21809.456094844078,
 21620.456520997068,
 20835.369749275054,
 20649.747571196127,
 19865.0299372693,
 19687.709269320767,
 18891.601346661722,
 18729.989381121828,
 17936.42375844517,
 17795.038405332187,
 17023.951508014135,
 16902.009535204917,
 16170.099002244684,
 16062.417097457697,
 15380.39373089699,
 15279.588171698219,
 14651.782527940239,
 14550.596307357137,
 13976.6746428

tensor(23045.0287, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [11:51<00:00,  2.11it/s]


[399666.63837871025,
 236592.11727726425,
 133153.19864898187,
 76845.05128628836,
 43418.41613306553,
 33987.5490943106,
 26720.242907206477,
 31952.577191167013,
 30881.05321266933,
 38117.809911323595,
 36698.67588353668,
 42320.65699446911,
 39385.588459767445,
 42989.48638593483,
 39017.345612272504,
 41056.22978483778,
 36786.72206041682,
 37790.55835040746,
 33798.88826656745,
 34182.06152907666,
 30798.078197547093,
 30847.544416735538,
 28183.7290805745,
 28084.21109884053,
 26093.59354298501,
 25951.090353998487,
 24489.587907853285,
 24350.863166397103,
 23237.421004092466,
 23110.501833873146,
 22179.28546712885,
 22055.14012216894,
 21189.73180116842,
 21059.406267433376,
 20202.074061932846,
 20064.302906467907,
 19204.434839066635,
 19064.459034142197,
 18218.00422661818,
 18082.272715763174,
 17273.17542491785,
 17144.747482841463,
 16392.795980713534,
 16269.99834961025,
 15585.105178159594,
 15463.187971345287,
 14845.210780672656,
 14719.180693745453,
 14161.68882623

tensor(27229.4440, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [12:27<00:00,  2.01it/s]


[431680.0411130956,
 260854.66906578757,
 146816.18124115368,
 85774.45836989992,
 47414.269743503915,
 35806.306377088345,
 27202.07826008598,
 31660.942121096985,
 30840.549500236855,
 37886.818342239196,
 37194.20947418143,
 42822.852821223096,
 40602.94880106897,
 44262.33289257382,
 40815.09044471822,
 42907.47446700314,
 38933.879162437195,
 39974.82245304636,
 36071.26136991542,
 36471.93301165037,
 33010.845466534985,
 33054.83873688323,
 30194.328425260705,
 30060.133681240637,
 27795.83046409399,
 27585.202328517782,
 25818.147643283286,
 25583.998311686075,
 24185.75832227012,
 23956.484170317344,
 22813.753793508222,
 22609.33195660225,
 21640.374884789504,
 21477.268916158442,
 20625.63601521712,
 20512.5142622661,
 19733.811145409854,
 19665.48183742284,
 18921.101079767857,
 18878.658411048855,
 18139.1293312594,
 18098.226297713234,
 17349.927110183977,
 17291.6758880495,
 16539.80142271956,
 16457.0387140342,
 15721.174005371959,
 15617.364000784331,
 14921.75059068204

tensor(27205.2009, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [09:27<00:00,  2.64it/s]


[364516.21082105185,
 210849.2898249206,
 117971.43586868385,
 67235.11428167732,
 38772.20999978793,
 32305.754160280077,
 26546.784535037455,
 33137.29484245193,
 31889.505749814343,
 39579.07082327779,
 37417.17660961568,
 43046.92576386685,
 39374.97253342077,
 42810.05532843455,
 38340.95369068776,
 40176.35896558284,
 35696.700319797055,
 36530.56810231685,
 32558.23748649614,
 32826.8881309556,
 29585.594870694626,
 29565.937977461443,
 27060.758359352563,
 26911.879919930987,
 25023.223372552624,
 24829.401866978304,
 23392.72518574504,
 23195.553086705622,
 22056.093740709104,
 21873.94443714284,
 20913.701690479265,
 20751.370135530393,
 19893.4395435658,
 19747.193402028577,
 18947.387486134452,
 18809.64039335002,
 18045.984645742647,
 17909.757526855847,
 17174.300948743872,
 17035.58671068034,
 16328.256232760163,
 16186.14808501894,
 15511.031658419537,
 15366.829171107485,
 14729.988424757245,
 14585.94798521816,
 13993.358157645223,
 13851.53924403601,
 13307.240704485

tensor(24274.7484, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [12:27<00:00,  2.01it/s]


[371302.6581075446,
 216216.90920616686,
 121387.33425105269,
 69898.73313769474,
 40364.189260644736,
 33261.08596566581,
 27160.40584816206,
 33261.1879117255,
 32187.269357291076,
 39580.317142841675,
 37813.140085058636,
 43282.3198625713,
 39983.64871691076,
 43325.233704359984,
 39070.429050664345,
 40836.638646763946,
 36382.479752871826,
 37157.65588297371,
 33073.547245921945,
 33299.20026524307,
 29893.685875975487,
 29855.133054836802,
 27205.52504071077,
 27062.397200600335,
 25087.091853822363,
 24916.97016752259,
 23458.274605907092,
 23295.233277584583,
 22185.958118009657,
 22041.913715639857,
 21141.3627767418,
 21013.781684398964,
 20216.850067233903,
 20093.963677286232,
 19329.285342169496,
 19198.108296820308,
 18426.32143870345,
 18280.471087736034,
 17494.028849110684,
 17336.565390256284,
 16554.368705402077,
 16394.56578479351,
 15647.693624769801,
 15494.895177748123,
 14809.816715538569,
 14668.632997720295,
 14056.843747991887,
 13925.867777203654,
 13383.39

tensor(33830.0272, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [11:21<00:00,  2.20it/s]


[379902.3715760374,
 222434.60081962997,
 124618.9555676896,
 71426.07994872217,
 40668.145130646895,
 32698.53252641986,
 26284.277440477774,
 32067.716046509686,
 30964.86051070451,
 38329.27206169876,
 36591.36344244086,
 42186.757851485054,
 38942.016943723254,
 42489.30863472064,
 38326.44425720375,
 40322.90377830399,
 36001.63384506136,
 36987.76876649664,
 33061.15074237452,
 33449.179278385665,
 30199.24971641171,
 30264.184606468392,
 27742.345311263787,
 27649.078367924405,
 25756.40515104093,
 25594.152175265568,
 24168.323239290043,
 23983.60457548915,
 22862.563151113587,
 22680.60413901063,
 21733.029844970555,
 21565.543555862532,
 20698.685537537356,
 20546.452928327366,
 19705.84048647341,
 19563.542260265283,
 18729.084176616558,
 18591.57187562257,
 17768.69049056206,
 17634.26074850283,
 16842.153958220122,
 16711.328447242973,
 15970.908645095013,
 15843.345407726873,
 15167.978038442836,
 15040.967208709082,
 14433.473807425917,
 14303.012779838182,
 13758.046128

tensor(20949.0820, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [12:21<00:00,  2.02it/s]


[373002.8279703225,
 217611.46821855687,
 121988.23591931562,
 70110.10834231255,
 40271.826311608325,
 32977.46109240086,
 26815.298617761324,
 32897.9413950041,
 31767.284018376886,
 39216.30770507983,
 37346.64320201821,
 42897.632654049536,
 39489.2038226594,
 42918.46004811493,
 38601.33169778595,
 40451.941374565504,
 36013.53345455276,
 36862.61091292169,
 32867.509910514025,
 33151.38698047149,
 29877.983996035055,
 29884.35602090263,
 27370.894938746016,
 27267.733548124597,
 25395.532160810122,
 25263.74930424509,
 23846.11707044211,
 23711.11455973902,
 22566.962634364107,
 22427.0414362922,
 21423.820819356122,
 21271.865436855296,
 20336.558414069532,
 20171.444142000455,
 19279.74502131049,
 19107.713291783853,
 18263.603705537535,
 18093.65837237458,
 17308.89025419564,
 17147.487961289975,
 16427.75641767267,
 16276.476858622844,
 15617.928351810751,
 15474.41397238088,
 14868.063645024666,
 14728.44457662345,
 14165.885625429026,
 14026.98366660374,
 13502.72769404381,

tensor(21146.3154, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [11:28<00:00,  2.18it/s]


[417766.2679221271,
 249751.9752329155,
 140593.39251671275,
 81322.98742761955,
 45336.8919951366,
 34603.59896761804,
 26751.286824600407,
 31632.904206260748,
 30830.607206340213,
 38123.25671261996,
 37167.67469071298,
 42965.54739303741,
 40401.38681888454,
 44189.69979923207,
 40402.66670053841,
 42602.881173294736,
 38324.591831427875,
 39457.6040470677,
 35297.02707151165,
 35768.634657689654,
 32111.67669997272,
 32201.19116616151,
 29225.22047098843,
 29119.10736608657,
 26839.31196331426,
 26658.562795018086,
 24980.956491354962,
 24795.300120057633,
 23566.551404392063,
 23404.736158493302,
 22457.75764888254,
 22321.32829506347,
 21512.27066891965,
 21390.729647707183,
 20620.18215330415,
 20501.770839416888,
 19719.494477496333,
 19596.52734941076,
 18793.132503929017,
 18663.85937189125,
 17855.252034393106,
 17722.87430799417,
 16934.353065645515,
 16803.855615332723,
 16058.541835020966,
 15933.693851811087,
 15246.800371884123,
 15128.659750225297,
 14506.30479781733,

tensor(21172.7772, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [11:25<00:00,  2.19it/s]


[402463.85233791394,
 241652.1087511672,
 136321.58513780122,
 80819.40129107566,
 45375.48865209301,
 35742.18917596609,
 27381.84072962407,
 32092.43831797049,
 30915.082921813635,
 37689.488604035476,
 36729.030171228325,
 42046.5575795997,
 39749.813650900054,
 43202.823281792866,
 39800.140682453515,
 41805.160803831226,
 37897.12772882468,
 38935.70041331864,
 35048.05797551696,
 35488.32824155957,
 31975.932585861297,
 32071.28918870667,
 29122.109376068136,
 29039.620505464154,
 26701.86040996076,
 26543.73272265079,
 24760.221238493657,
 24579.12335166429,
 23227.88149287388,
 23042.673854307446,
 21980.092106991673,
 21793.071394795414,
 20892.25236975217,
 20703.960096034432,
 19878.21035770403,
 19694.379095369626,
 18901.00016602487,
 18730.662379823814,
 17959.069774360632,
 17808.485525363612,
 17063.77271359764,
 16932.17117296262,
 16221.720016744923,
 16101.984876626491,
 15430.14147581988,
 15313.32521868929,
 14682.533241144632,
 14562.478344618048,
 13975.341022394

tensor(21786.5487, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [12:40<00:00,  1.97it/s]


[410800.79741205927,
 245246.61956726055,
 138060.23133779695,
 80221.93608041386,
 44982.925392328645,
 34837.136390028725,
 27111.04045802733,
 32113.219463108224,
 31198.19401382967,
 38416.11298346965,
 37298.45788178936,
 42976.276183318114,
 40290.219025542174,
 43978.93548318669,
 40148.31992537097,
 42289.557534802414,
 38050.803975705465,
 39165.81155623459,
 35100.51125102663,
 35587.622120934306,
 32030.605382119647,
 32155.02513245824,
 29229.165872518293,
 29156.493675601705,
 26844.420437351535,
 26676.57997174061,
 24889.55434251194,
 24691.830220541247,
 23316.109784595355,
 23130.457686670576,
 22051.383229178213,
 21899.192402685054,
 21013.906522045065,
 20897.107461535554,
 20122.400065871814,
 20028.594484181078,
 19304.346659295374,
 19215.463905797827,
 18504.474397404432,
 18405.183656889603,
 17690.95405811639,
 17573.75348797501,
 16857.73649849642,
 16723.9002138242,
 16021.439061892695,
 15878.35570739228,
 15210.948648643089,
 15067.520571115141,
 14453.187

tensor(41768.7834, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [10:57<00:00,  2.28it/s]


[383370.546802126,
 225525.73454285596,
 126797.84677958659,
 73378.59365084986,
 41773.26363961315,
 33698.75098455869,
 26739.993722005638,
 32457.5697289717,
 31127.40595239828,
 38430.12756750142,
 36660.26404837667,
 42182.352449525286,
 39007.59390239639,
 42474.5598155506,
 38431.12900948555,
 40355.31541149848,
 36155.16934521036,
 37085.574243665294,
 33228.700481960695,
 33572.2284641652,
 30307.198140210618,
 30328.91898198436,
 27717.567025448396,
 27581.43447991481,
 25571.701821164723,
 25376.765616995355,
 23857.173453172945,
 23659.87713307581,
 22497.791242823827,
 22324.555565293238,
 21393.39309249066,
 21249.073838650023,
 20443.46140253873,
 20318.70502919657,
 19561.34501635221,
 19440.911045903253,
 18686.75557704206,
 18557.627083488238,
 17793.088180532624,
 17649.678766505804,
 16884.56537528405,
 16729.459913787618,
 15985.951935373763,
 15826.697373118106,
 15128.628325259064,
 14973.490998363157,
 14337.791885167204,
 14192.673672255572,
 13625.982684452516

tensor(30185.7303, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [10:17<00:00,  2.43it/s]


[396163.63016942283,
 233303.24095930925,
 131062.5809147772,
 75104.9364447876,
 42523.17275133748,
 33569.87456806807,
 26716.30961913189,
 32404.73228865845,
 31378.567359791905,
 38945.27375279396,
 37327.86478980261,
 43135.77264610645,
 39899.238559647085,
 43579.41503897657,
 39324.43977622724,
 41379.09549423769,
 36923.095002755814,
 37925.79465329813,
 33862.510736683675,
 34253.84825090655,
 30878.03227329854,
 30948.300631201142,
 28309.08836172059,
 28229.79359621281,
 26224.767855457845,
 26088.119385421902,
 24554.93395989929,
 24406.060882971415,
 23185.911978628017,
 23042.63010379224,
 22011.183396692777,
 21875.518457042086,
 20949.73643554574,
 20816.5547295959,
 19949.040851664406,
 19812.65020261037,
 18982.291078837872,
 18839.81390552787,
 18042.521929336646,
 17894.338455738012,
 17135.130016596326,
 16983.713664509934,
 16270.373235427012,
 16118.748198903544,
 15457.597000246202,
 15308.338248115615,
 14702.288408183218,
 14557.180223691908,
 14005.2839810422

tensor(18774.6691, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [09:50<00:00,  2.54it/s]


[388792.46723356517,
 228116.7351381968,
 128194.59091416784,
 73467.44512686899,
 41822.88639810957,
 33431.98893073846,
 26763.220599134977,
 32542.346379836315,
 31451.2138225233,
 38897.29731103192,
 37253.78034961534,
 42904.37191033951,
 39737.43439835305,
 43289.448113322636,
 39145.15393242159,
 41104.231655859905,
 36735.82238884627,
 37659.16206418552,
 33636.03496045961,
 33953.594036230315,
 30590.319865206366,
 30596.43175667047,
 27983.179510792244,
 27861.555706772226,
 25927.127541176313,
 25777.033431222568,
 24359.31276505496,
 24220.837053077976,
 23131.597432580995,
 23009.869370723067,
 22085.634455884876,
 21971.599835316818,
 21102.978864625285,
 20987.60888200959,
 20123.818768211087,
 20003.509924389095,
 19138.324123845054,
 19013.64774449819,
 18164.838242898022,
 18037.456975894667,
 17229.268560920864,
 17099.877514970198,
 16351.527434349013,
 16219.479809996788,
 15540.501497700772,
 15404.51194333803,
 14795.477195725092,
 14654.405273964825,
 14109.5808

tensor(17867.4360, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [12:19<00:00,  2.03it/s]


[389093.95586540375,
 230025.22157988008,
 129324.58854314304,
 75130.53412826551,
 42592.56838657106,
 34091.801335831144,
 26932.7128719431,
 32429.24663438093,
 31244.37138333643,
 38425.00400601425,
 36929.62084066412,
 42384.33413255714,
 39505.17053760303,
 42941.41421786226,
 39157.90703028171,
 41078.38749793569,
 37073.397013360074,
 38020.0628599551,
 34283.59377083043,
 34656.35431986784,
 31437.816676156122,
 31493.498255963957,
 28858.91463048874,
 28751.75009226083,
 26655.298903776886,
 26473.418537217287,
 24813.39161775269,
 24602.03748980697,
 23259.541127392993,
 23036.83360050788,
 21904.625647132474,
 21675.992743097064,
 20676.101196009273,
 20446.661074459716,
 19534.791838194586,
 19314.976356747928,
 18473.81510240307,
 18277.162080188784,
 17502.705183952727,
 17339.363458354997,
 16628.073202256975,
 16499.07687148028,
 15841.22991586044,
 15737.46418837931,
 15118.30391270385,
 15025.003158807049,
 14431.593633340526,
 14335.374821084708,
 13762.797796660128

tensor(29011.3798, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [12:22<00:00,  2.02it/s]


[408951.9504333593,
 242189.0477336532,
 135895.95472568998,
 77619.85085363357,
 43449.08267772795,
 33440.37241749535,
 26479.994665543585,
 31881.91853172742,
 31315.004081589577,
 38976.235155223054,
 37880.51369692257,
 43927.43025022384,
 41048.139331015205,
 44985.95687157869,
 40840.45346459201,
 43107.5033404608,
 38517.38448591877,
 39658.302126663184,
 35288.667104415144,
 35743.34321730245,
 32001.226718328708,
 32084.315311625618,
 29131.37730747259,
 29049.83567395258,
 26852.254529984657,
 26723.020143096684,
 25122.292533929583,
 24992.63883466081,
 23785.620629563753,
 23661.26002280568,
 22667.98615552241,
 22541.922318128185,
 21641.768641120092,
 21511.440313043873,
 20645.519231774553,
 20514.779734262618,
 19667.884157379256,
 19541.50041420749,
 18720.01029405954,
 18599.314190787845,
 17815.668660305528,
 17698.20309866601,
 16963.204923894646,
 16844.721285929638,
 16164.65432279164,
 16041.156857454846,
 15417.763870413695,
 15286.741861701503,
 14718.52140062

tensor(33767.0035, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [12:38<00:00,  1.98it/s]


[401761.1314476526,
 237749.71927934932,
 133695.37392332556,
 76590.30284956995,
 43160.831615718314,
 33689.07215703108,
 26767.934624305766,
 32412.79151753307,
 31545.70752532589,
 39238.64765606704,
 37686.44299062689,
 43630.94130098277,
 40283.679572509245,
 44071.225123548465,
 39559.003947326244,
 41695.954729105084,
 36901.08386095674,
 37968.464100494595,
 33559.09941095246,
 33999.68944466178,
 30343.846797497794,
 30450.66404696236,
 27656.188047128475,
 27608.00590881005,
 25597.755169762117,
 25494.146005471117,
 24085.53405761451,
 23972.87647750406,
 22949.81019877648,
 22840.685340916643,
 22010.22123162952,
 21899.52188438607,
 21125.83356216343,
 21003.868725213528,
 20216.692020399198,
 20077.485214110897,
 19260.94268552921,
 19105.309466688086,
 18277.545923963607,
 18112.244491429017,
 17305.214283400204,
 17139.563951393986,
 16383.562387982258,
 16225.483199401158,
 15540.020578298147,
 15393.633781027198,
 14784.729889793703,
 14649.937795936716,
 14112.33892

tensor(24445.1088, dtype=torch.float64, grad_fn=<MulBackward0>)

100%|██████████| 1500/1500 [10:59<00:00,  2.28it/s]


[378734.99061484006,
 221388.52692496445,
 124093.94210618219,
 70815.86498528402,
 40378.8624747237,
 32681.80815242092,
 26509.20309608042,
 32767.118402832297,
 31707.343744890004,
 39465.03434359249,
 37541.1864334576,
 43341.15082263503,
 39769.40307272819,
 43365.28346126112,
 38844.56741267051,
 40807.676487081044,
 36174.93352945733,
 37106.78687794361,
 32949.54586528428,
 33294.989044265356,
 29899.02583224651,
 29942.47653323268,
 27348.301426629594,
 27252.227339181565,
 25344.79631896983,
 25192.02177776872,
 23789.978682259614,
 23618.115981545547,
 22540.35042673876,
 22363.577217308448,
 21465.273943353444,
 21286.216784028464,
 20470.10679847123,
 20287.454175717477,
 19500.24676481405,
 19313.937649635336,
 18536.03060226354,
 18349.505366858983,
 17583.14197695926,
 17402.472563177987,
 16660.77947414564,
 16491.952183421592,
 15790.170379045196,
 15636.971183419446,
 14986.907374002429,
 14850.253917811538,
 14257.426480344293,
 14135.463298941413,
 13598.1400788083

tensor(14735.8750, dtype=torch.float64, grad_fn=<MulBackward0>)

In [23]:
with open("/content/drive/MyDrive/Colab Notebooks/mse_batch", "wb") as fp:   #Pickling
  pickle.dump(mse_batch, fp)

with open("/content/drive/MyDrive/Colab Notebooks/indices_num_batch", "wb") as fp:   #Pickling
  pickle.dump(indices_batch, fp)

In [34]:
mse= sum(mse_batch)/sum(indices_batch)

In [35]:
mse

tensor(7.8160, dtype=torch.float64, grad_fn=<DivBackward0>)

In [36]:
rmse= np.sqrt(mse.detach())
rmse

tensor(2.7957, dtype=torch.float64)