In [1]:
from data_manager import *
#from SFTRL import SFTRL

In [2]:
nbUsers = 943
nbMovies = 1682
nbFeatures = nbUsers + nbMovies
nbRatingsTrain = 90570
nbRatingsTest = 9430


In [3]:
data_dir = './Data/ml-100k/'
filename1, filename2 = 'ub.base', './ub.test'

# load dataset
x_train, y_train, rate_train, timestamp_train = \
load_dataset(data_dir + filename1, nbRatingsTrain, nbFeatures, nbUsers)

# sort dataset in time
x_train_s, rate_train_s, _ = sort_dataset(x_train, rate_train, timestamp_train)

# sparse to dense
inputs_matrix = torch.tensor(x_train_s.todense()).double()
outputs = torch.tensor(rate_train_s).double()


In [4]:
import torch
from torch.nn import Module
from torch.autograd import Variable

import numpy as np


tensor_type = torch.DoubleTensor


class SFTRL(Module):

    def __init__(self, inputs_matrix, outputs, option):

        super(SFTRL, self).__init__()

        #self.A = inputs_matrix
        self.At = inputs_matrix.t()
        self.b = outputs
        self._thres = 1e-12

        self.num_data = inputs_matrix.shape[0]
        self.num_feature = inputs_matrix.shape[1]

        self.task = option['task']
        self.eta = option['eta']
        self.m = option['m']

        self.row_count_p = 0
        self.row_count_n = 0

        self.BT_P = tensor_type(np.zeros([self.num_feature, 2*self.m]))
        self.BT_N = tensor_type(np.zeros([self.num_feature, 2*self.m]))



    def _loss(self, x):

        if self.task == 'reg' :
            return x**2
        elif self.task == 'cla':
            return 1 / (1 + torch.exp(x))
        else :
            return


    def _grad_loss(self, x):

        if self.task == 'reg' :
            return 2*x
        elif self.task == 'cla' :
            return -1 / (1 + torch.exp(x))
        else :
            return



    def _predict(self):

        return



    def online_learning(self):

        pred_list = []
        real_list = []
        
        for idx in range(self.num_data):
            alpha = self.At[:,idx]

            #print(alpha)

            BP_alpha = self.BT_P.t().matmul(alpha).unsqueeze(1)
            BN_alpha = self.BT_N.t().matmul(alpha).unsqueeze(1)

            scalar = BP_alpha.t().matmul(BP_alpha) - BN_alpha.t().matmul(BN_alpha)


            
            
            if idx % 10 ==0 :
               #print(' %d th : current'%(idx))
               #print(BP_alpha)
               print(' %d th : pred %f , real %f , loss %f ' %(idx,scalar,self.b[idx],self._loss(scalar - self.b[idx]) ) )

            if self.task == 'cls':
                sign_idx = self._grad_loss(scalar*self.b[idx])*self.b[idx]
            elif self.task == 'reg':
                sign_idx = self._grad_loss(scalar - self.b[idx])
            else :
                raise NotImplementedError


            self._GFD(sign_idx, alpha)

            
            pred_list.append(scalar)
            real_list.append(self.b[idx])

        return np.asarray(pred_list),real_list



    def _GFD(self, sign, alpha):

        # alpha : num_feature x 1
        # sign : scalar value


        if sign <= 0:
            self.row_count_p += 1


            self.BT_P[:,self.row_count_p] = np.sqrt(-self.eta*sign)*alpha.squeeze()


            if self.row_count_p == 2*self.m - 1  :

                U, Sigma, _ = self.BT_P.t().matmul( self.BT_P).svd()
                Sigma[Sigma.data <= self._thres] = 0.0
                nnz = Sigma.nonzero().numel()

                V = self.BT_P.matmul(U[:, :nnz]).matmul(  (1/Sigma[:nnz].sqrt()).diag()   )


                if nnz >= self.m:

                    self.BT_P = V[:, :self.m - 1].matmul((Sigma[:self.m - 1] - Sigma[self.m]).sqrt().diag())
                    self.BT_P = torch.cat([self.BT_P, torch.zeros([self.num_feature, self.m + 1]).double() ], 1)
                    self.row_count_p = self.m - 1


                else:
                    self.BT_P = V[:, :nnz].matmul((Sigma[:nnz]).sqrt().diag())
                    self.BT_P= torch.cat([self.BT_P, torch.zeros([self.num_feature, (2 * self.m) - nnz]).double() ], 1)
                    self.row_count_p = nnz

        else:

            if self.row_count_n == 2*self.m - 1:

                self.row_count_n += 1
                self.BT_N[:,self.row_count_n] = np.sqrt(self.eta*sign)*alpha.squeeze()

                U, Sigma, _ = self.BT_N.t().matmul(self.BT_N).svd()
                Sigma[Sigma.data <= self._thres] = 0.0
                nnz = Sigma.nonzero().numel()
                V = self.BT_N.matmul(U[:, :nnz]).matmul(  (1/Sigma[:nnz].sqrt()).diag()    )

                if nnz >= self.m:
                    self.BT_N = V[:, :self.m - 1].matmul((Sigma[:self.m - 1] - Sigma[self.m]).sqrt().diag())
                    self.BT_N = torch.cat([self.BT_N, torch.zeros([self.num_feature, self.m + 1 ]).double() ], 1)
                    self.row_count_n = self.m - 1

                else:
                    self.BT_N = V[:, :nnz].matmul((Sigma[:nnz]).sqrt().diag())
                    self.BT_N = torch.cat([self.BT_N, torch.zeros([self.num_feature, (2 * self.m) - nnz] ).double()] , 1)
                    self.row_count_n = nnz

        return



In [5]:
# model setup
options = {}
options['m']  = 20
options['eta'] = 1e-2
options['task'] = 'reg'

#print(inputs_matrix)

recent_num = 10000




In [6]:
Model = SFTRL(inputs_matrix[:recent_num,:],outputs[:recent_num],options)
#alpha = torch.tensor(np.random.randn(10,1))
pred,real = Model.online_learning()


 0 th : pred 0.000000 , real 4.000000 , loss 16.000000 
 10 th : pred 0.692885 , real 3.000000 , loss 5.322780 
 20 th : pred 0.474746 , real 4.000000 , loss 12.427413 
 30 th : pred 0.547890 , real 5.000000 , loss 19.821287 
 40 th : pred 1.182152 , real 4.000000 , loss 7.940265 
 50 th : pred 1.688988 , real 4.000000 , loss 5.340775 
 60 th : pred 1.995433 , real 4.000000 , loss 4.018291 
 70 th : pred 2.272217 , real 3.000000 , loss 0.529668 
 80 th : pred 2.608865 , real 3.000000 , loss 0.152987 
 90 th : pred 2.813845 , real 1.000000 , loss 3.290032 
 100 th : pred 3.008264 , real 3.000000 , loss 0.000068 
 110 th : pred 0.625942 , real 5.000000 , loss 19.132379 
 120 th : pred 0.582793 , real 5.000000 , loss 19.511715 
 130 th : pred 1.254607 , real 4.000000 , loss 7.537184 
 140 th : pred 1.791392 , real 5.000000 , loss 10.295168 
 150 th : pred 2.189768 , real 4.000000 , loss 3.276940 
 160 th : pred 2.507503 , real 3.000000 , loss 0.242554 
 170 th : pred 0.080000 , real 4.000

 1560 th : pred 3.444881 , real 4.000000 , loss 0.308157 
 1570 th : pred 3.552944 , real 5.000000 , loss 2.093971 
 1580 th : pred 3.623454 , real 4.000000 , loss 0.141787 
 1590 th : pred 3.710907 , real 2.000000 , loss 2.927202 
 1600 th : pred 3.738676 , real 4.000000 , loss 0.068290 
 1610 th : pred 3.779030 , real 4.000000 , loss 0.048828 
 1620 th : pred 3.800716 , real 4.000000 , loss 0.039714 
 1630 th : pred 2.414750 , real 1.000000 , loss 2.001517 
 1640 th : pred 2.442620 , real 4.000000 , loss 2.425433 
 1650 th : pred 2.490071 , real 2.000000 , loss 0.240169 
 1660 th : pred 2.500269 , real 1.000000 , loss 2.250808 
 1670 th : pred 2.510224 , real 1.000000 , loss 2.280777 
 1680 th : pred 2.527952 , real 1.000000 , loss 2.334637 
 1690 th : pred 3.796739 , real 1.000000 , loss 7.821751 
 1700 th : pred 0.649081 , real 5.000000 , loss 18.930496 
 1710 th : pred 1.157747 , real 4.000000 , loss 8.078403 
 1720 th : pred 1.677058 , real 4.000000 , loss 5.396061 
 1730 th : pr

 3260 th : pred 2.429283 , real 4.000000 , loss 2.467153 
 3270 th : pred 2.574350 , real 3.000000 , loss 0.181178 
 3280 th : pred 2.705046 , real 1.000000 , loss 2.907181 
 3290 th : pred 2.762558 , real 5.000000 , loss 5.006146 
 3300 th : pred 0.464825 , real 5.000000 , loss 20.567814 
 3310 th : pred 0.840175 , real 4.000000 , loss 9.984497 
 3320 th : pred 0.434629 , real 2.000000 , loss 2.450385 
 3330 th : pred 0.731416 , real 4.000000 , loss 10.683643 
 3340 th : pred 1.234235 , real 4.000000 , loss 7.649456 
 3350 th : pred 1.600133 , real 4.000000 , loss 5.759363 
 3360 th : pred 1.915704 , real 3.000000 , loss 1.175698 
 3370 th : pred 2.190448 , real 4.000000 , loss 3.274477 
 3380 th : pred 2.483961 , real 2.000000 , loss 0.234218 
 3390 th : pred 2.762860 , real 4.000000 , loss 1.530515 
 3400 th : pred 2.863457 , real 4.000000 , loss 1.291729 
 3410 th : pred 2.934943 , real 1.000000 , loss 3.744004 
 3420 th : pred 1.727568 , real 3.000000 , loss 1.619083 
 3430 th : p

 4940 th : pred 2.678989 , real 3.000000 , loss 0.103048 
 4950 th : pred 2.755497 , real 3.000000 , loss 0.059782 
 4960 th : pred 2.856518 , real 4.000000 , loss 1.307550 
 4970 th : pred 1.080321 , real 4.000000 , loss 8.524528 
 4980 th : pred 1.545535 , real 5.000000 , loss 11.933325 
 4990 th : pred 1.962294 , real 5.000000 , loss 9.227658 
 5000 th : pred 2.239317 , real 2.000000 , loss 0.057272 
 5010 th : pred 2.515005 , real 4.000000 , loss 2.205211 
 5020 th : pred 2.784315 , real 5.000000 , loss 4.909258 
 5030 th : pred 0.584077 , real 3.000000 , loss 5.836683 
 5040 th : pred 1.131328 , real 5.000000 , loss 14.966623 
 5050 th : pred 0.814959 , real 4.000000 , loss 10.144484 
 5060 th : pred 1.470624 , real 5.000000 , loss 12.456496 
 5070 th : pred 2.078104 , real 5.000000 , loss 8.537475 
 5080 th : pred 2.530576 , real 2.000000 , loss 0.281511 
 5090 th : pred 2.769643 , real 5.000000 , loss 4.974494 
 5100 th : pred 3.133977 , real 5.000000 , loss 3.482040 
 5110 th :

 6600 th : pred 1.008682 , real 3.000000 , loss 3.965346 
 6610 th : pred 1.289674 , real 3.000000 , loss 2.925214 
 6620 th : pred 1.510497 , real 3.000000 , loss 2.218618 
 6630 th : pred 3.696118 , real 4.000000 , loss 0.092345 
 6640 th : pred 3.707940 , real 4.000000 , loss 0.085299 
 6650 th : pred 2.105209 , real 3.000000 , loss 0.800651 
 6660 th : pred 3.701249 , real 4.000000 , loss 0.089252 
 6670 th : pred 0.399635 , real 3.000000 , loss 6.761896 
 6680 th : pred 0.782147 , real 4.000000 , loss 10.354579 
 6690 th : pred 1.201583 , real 3.000000 , loss 3.234302 
 6700 th : pred 1.402034 , real 4.000000 , loss 6.749426 
 6710 th : pred 1.726671 , real 4.000000 , loss 5.168026 
 6720 th : pred 1.881321 , real 3.000000 , loss 1.251443 
 6730 th : pred 1.991745 , real 2.000000 , loss 0.000068 
 6740 th : pred 2.236491 , real 5.000000 , loss 7.636981 
 6750 th : pred 2.648403 , real 5.000000 , loss 5.530010 
 6760 th : pred 2.911017 , real 4.000000 , loss 1.185884 
 6770 th : pr

 8280 th : pred 2.759190 , real 4.000000 , loss 1.539610 
 8290 th : pred 2.865933 , real 3.000000 , loss 0.017974 
 8300 th : pred 2.978988 , real 4.000000 , loss 1.042466 
 8310 th : pred 3.115397 , real 4.000000 , loss 0.782522 
 8320 th : pred 3.234632 , real 2.000000 , loss 1.524316 
 8330 th : pred 3.279598 , real 3.000000 , loss 0.078175 
 8340 th : pred 3.354830 , real 4.000000 , loss 0.416244 
 8350 th : pred 3.430592 , real 4.000000 , loss 0.324225 
 8360 th : pred 3.503372 , real 5.000000 , loss 2.239895 
 8370 th : pred 3.560334 , real 3.000000 , loss 0.313974 
 8380 th : pred 3.632434 , real 4.000000 , loss 0.135105 
 8390 th : pred 3.694746 , real 1.000000 , loss 7.261658 
 8400 th : pred 3.763251 , real 5.000000 , loss 1.529548 
 8410 th : pred 3.804500 , real 3.000000 , loss 0.647220 
 8420 th : pred 3.839613 , real 3.000000 , loss 0.704951 
 8430 th : pred 3.865562 , real 3.000000 , loss 0.749198 
 8440 th : pred 2.630023 , real 4.000000 , loss 1.876838 
 8450 th : pre

 9990 th : pred 2.343454 , real 3.000000 , loss 0.431052 


In [None]:
#pred

In [8]:
import matplotlib.pyplot as plt

plt.figure()
plt.plot(pred)
plt.show()