# [0] 미션 설명

이번 시간에는 모델들을 구현하고 결과를 평가 돌리고 성능을 비교해 보는 과정을 진행합니다


> **학습 목표**  
> 직접 코드를 돌려보고 성능을 확인하면서 강의에서 다룬 이론적 내용과 장단점을 직접 경험한다.

RecSys 기초 대회 강의에서는 Book Crossing 데이터를 사용하여, 모든 실습 및 미션, 대회를 진행합니다. [Kaggle Book-Crossing](https://www.kaggle.com/datasets/ruchi798/bookcrossing-dataset) 을 출처로 하며, 데이터는 재구성되어 제공되었습니다. 해당 데이터는 CC0: Public Domain 라이센스임을 밝힙니다.

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torch.autograd import Variable
import math
import tqdm
import pdb
from scipy.sparse import csr_matrix, linalg
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import warnings

warnings.filterwarnings(action='ignore')

# [1] 데이터 불러오기

지정된 경로에 미리 샘플링해 만들어놓은 데이터를 불러옵니다

In [None]:
!wget --load-cookies ~/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies ~/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=17Zbpx_Yn0ggLFVeijUalcnOUt5mCHeEH' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=17Zbpx_Yn0ggLFVeijUalcnOUt5mCHeEH" -O users.csv && rm -rf ~/cookies.txt
!wget --load-cookies ~/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies ~/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1JIxWPgC8JJkuZaWjk5FJBBKuqSX87HGE' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1JIxWPgC8JJkuZaWjk5FJBBKuqSX87HGE" -O books.csv && rm -rf ~/cookies.txt
!wget --load-cookies ~/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies ~/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1B-8lw3a1KPJdhFAMXP58a7taJjpDvfw6' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1B-8lw3a1KPJdhFAMXP58a7taJjpDvfw6" -O ratings.csv && rm -rf ~/cookies.txt

--2022-10-19 16:44:13--  https://docs.google.com/uc?export=download&confirm=&id=17Zbpx_Yn0ggLFVeijUalcnOUt5mCHeEH
Resolving docs.google.com (docs.google.com)... 108.177.125.101, 108.177.125.139, 108.177.125.138, ...
Connecting to docs.google.com (docs.google.com)|108.177.125.101|:443... connected.
HTTP request sent, awaiting response... 303 See Other
Location: https://doc-08-58-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/t2p9nk9l0mlo0qru5ifi33j61n0qtl6q/1666197825000/01233170983161563057/*/17Zbpx_Yn0ggLFVeijUalcnOUt5mCHeEH?e=download&uuid=dd8bd239-abf3-4924-aa5c-49c878b778ab [following]
--2022-10-19 16:44:15--  https://doc-08-58-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/t2p9nk9l0mlo0qru5ifi33j61n0qtl6q/1666197825000/01233170983161563057/*/17Zbpx_Yn0ggLFVeijUalcnOUt5mCHeEH?e=download&uuid=dd8bd239-abf3-4924-aa5c-49c878b778ab
Resolving doc-08-58-docs.googleusercontent.com (doc-08-58-docs.googleusercontent.com)... 142.250.157.1

In [None]:
books = pd.read_csv('./books.csv')
users = pd.read_csv('./users.csv')
ratings = pd.read_csv('./ratings.csv')

# [2] 데이터 가공하기

원활한 미션 진행을 위해 미션용 데이터 중에서도 일부 데이터만 샘플링하여 사용하겠습니다

In [None]:
seed=42

books 데이터 일부가 isbn이 잘못된 것이 있어 url에 있는 isbn으로 대체합니다

In [None]:
books['isbn'] = books['img_url'].apply(lambda x: x.split('P/')[1][:10])

고유 번호인 user_id와 isbn 대신 사용하기 편하게 인덱싱 처리를 해줍니다

In [None]:
user2idx = {v:k for k,v in enumerate(ratings['user_id'].unique())}
book2idx = {v:k for k,v in enumerate(ratings['isbn'].unique())}

ratings['isbn'] = ratings['isbn'].map(book2idx)
ratings['user_id'] = ratings['user_id'].map(user2idx)
ratings

Unnamed: 0,user_id,isbn,rating
0,0,0,1
1,0,1,5
2,1,2,1
3,1,3,7
4,1,4,5
...,...,...,...
164724,14109,8793,1
164725,14109,16652,7
164726,14109,2912,1
164727,14109,3184,8


Train, Test를 스플릿 해주겠습니다

- 전체 데이터셋에서는 A유저가 a, b, c 책을 평가했어도 Train 데이터에 a, b만 있고 c는 Test 데이터에 있다면 c는 평가하지 않은 것으로 생각합니다!
- ratings 데이터에서 6:2:2 로 train:valid:test split을 하고 나눠진 데이터를 행렬 형태와 데이터프레임 형태 두 가지 저장해두겠습니다

In [None]:
X_train, X_test_tmp, y_train, y_test_tmp = train_test_split(ratings[['user_id', 'isbn']], ratings['rating'], test_size=0.4, random_state=seed)
X_valid, X_test, y_valid, y_test = train_test_split(X_test_tmp, y_test_tmp, test_size=0.5, random_state=seed)

In [None]:
train = np.zeros((len(user2idx), len(book2idx)))
valid = np.zeros((len(user2idx), len(book2idx)))
test = np.zeros((len(user2idx), len(book2idx)))

In [None]:
train.shape

(14110, 20871)

train, valid, test 데이터로 나눈 것을 행렬 형태로 변형합니다  
pivot_table을 이용해 변형하지 않는 이유는 각 데이터에 포함되지 않는 user_id 혹은 isbn이 있을 수 있기 때문입니다

In [None]:
def df_to_arr(X: pd.DataFrame, y: pd.Series, arr: np.ndarray) -> np.ndarray:
  for i, value in enumerate(X.values):
    arr[value[0], value[1]] += y.values[i]
  
  return arr

In [None]:
train = df_to_arr(X_train, y_train, train)
valid = df_to_arr(X_valid, y_valid, valid)
test = df_to_arr(X_test, y_test, test)

FM/FFM용 데이터를 가공해줍니다  
- FM/FFM용 데이터는 index 형태로 한 컬럼에 표현합니다

In [None]:
def age_map(x: int) -> int:
    x = int(x)
    if x < 20:
        return 1
    elif x >= 20 and x < 30:
        return 2
    elif x >= 30 and x < 40:
        return 3
    elif x >= 40 and x < 50:
        return 4
    elif x >= 50 and x < 60:
        return 5
    else:
        return 6

In [None]:
users['location_city'] = users['location'].apply(lambda x: x.split(',')[0])
users['location_state'] = users['location'].apply(lambda x: x.split(',')[1])
users['location_country'] = users['location'].apply(lambda x: x.split(',')[2])
users = users.drop(['location'], axis=1)

# ratings는 user2idx로 인덱싱 처리해줬기 때문에 users에도 진행
users_context = users.copy()
users_context['user_id'] = users_context['user_id'].map(user2idx)

# ratings는 book2idx로 인덱싱 처리해줬기 때문에 books에도 진행
books_context = books.copy()
books_context['isbn'] = books_context['isbn'].map(book2idx)

# 인덱싱 처리된 데이터 조인
context_df = ratings.merge(users_context, on='user_id', how='left').merge(books_context[['isbn', 'language']], on='isbn', how='left')

# 인덱싱 처리합니다
loc_city2idx = {v:k for k,v in enumerate(context_df['location_city'].unique())}
loc_state2idx = {v:k for k,v in enumerate(context_df['location_state'].unique())}
loc_country2idx = {v:k for k,v in enumerate(context_df['location_country'].unique())}
context_df['location_city'] = context_df['location_city'].map(loc_city2idx)
context_df['location_state'] = context_df['location_state'].map(loc_state2idx)
context_df['location_country'] = context_df['location_country'].map(loc_country2idx)

context_df['age'] = context_df['age'].fillna(int(context_df['age'].mean()))
context_df['age'] = context_df['age'].apply(age_map)

language2idx = {v:k for k,v in enumerate(context_df['language'].unique())}
context_df['language'] = context_df['language'].map(language2idx)

context_df

Unnamed: 0,user_id,isbn,rating,age,location_city,location_state,location_country,language
0,0,0,1,3,0,0,0,0
1,0,1,5,3,0,0,0,0
2,1,2,1,3,1,1,1,0
3,1,3,7,3,1,1,1,0
4,1,4,5,3,1,1,1,1
...,...,...,...,...,...,...,...,...
164724,14109,8793,1,3,14,11,1,1
164725,14109,16652,7,3,14,11,1,1
164726,14109,2912,1,3,14,11,1,0
164727,14109,3184,8,3,14,11,1,0


In [None]:
X_train_context, X_test_tmp_context, y_train_context, y_test_tmp_context = train_test_split(context_df.drop(['rating'], axis=1), context_df['rating'], test_size=0.4, random_state=seed)
X_valid_context, X_test_context, y_valid_context, y_test_context = train_test_split(X_test_tmp_context, y_test_tmp_context, test_size=0.5, random_state=seed)

원활한 RAM 사용을 위해 사용하지 않는 변수를 제거합니다

In [None]:
del ratings, users, books

RMSE(Root Mean Square Error)를 구할 수 있는 함수를 만들고,   
0점보다 낮게 예측값이 나온 경우 0으로 10점보다 높게 예측값이 나온 경우 10점으로 하겠습니다

In [None]:
def rmse(real: list, predict: list) -> float:
  pred = np.array(predict)
  pred[np.where(pred < 0)] = 0
  pred[np.where(pred > 10)] = 10
  
  return np.sqrt(np.mean((real-pred) ** 2))

In [None]:
def modify_range(rating: float) -> float:
  if rating < 0:
    return 0
  elif rating > 10:
    return 10
  else:
    return rating

In [None]:
def matrix_rmse(real_mat: np.ndarray, predict_mat: np.ndarray) -> float:
  indicator = np.stack(real_mat.nonzero(), axis=1)
  
  cost = 0
  for i, ind in enumerate(indicator):
    pred = predict_mat[ind[0], ind[1]]
    real = real_mat[ind[0], ind[1]]
    cost += pow(real - modify_range(pred), 2)

  return np.sqrt(cost/ len(indicator)) 

# [3] 미션

## SVD

SVD를 이용해 평가해보도록 하겠습니다
- SVD 중 Full SVD는 기존 행렬을 복원하기에 추천과는 무의미하므로 Truncated SVD를 사용해보겠습니다

In [None]:
def truncated_svd(mat: np.ndarray, k: int=100):
  u, s, vh = linalg.svds(csr_matrix(train), k)
  return u @ np.diag(s) @ vh

In [None]:
svd_output2 = truncated_svd(train, k=2)
svd_output5 = truncated_svd(train, k=5)
svd_output10 = truncated_svd(train, k=10)

In [None]:
print("k=2 : ", matrix_rmse(test, svd_output2))
print("k=5 : ", matrix_rmse(test, svd_output5))
print("k=10 : ", matrix_rmse(test, svd_output10))

k=2 :  6.982625105161013
k=5 :  6.964931540028439
k=10 :  6.955595642810603


SVD는 y 값에 맞춰 훈련한 Latent Factor가 아니고

행렬 분해를 통해 차원 축소하는 기법에 가깝기 때문에 성능이 좋지 않은 것을 볼 수 있습니다

## MF

In [None]:
class MatrixFactorization:
    def __init__(self, R: np.ndarray, k: int, lr: float, regularization: float, epochs: int, verbose: bool =False) -> None:
        """
        :param R: rating matrix
        :param k: latent parameter
        :param lr: learning rate
        :param regularization: regularization term for update
        :param epochs: training epochs
        :param verbose: print status
        """

        self._R = csr_matrix(R)
        self._ind, self._col = self._R.nonzero()
        self._n_users, self._n_items = R.shape
        self._k = k
        self._lr = lr
        self._regularization = regularization
        self._epochs = epochs
        self._verbose = verbose


    def fit(self) -> None:

        # latent features
        self._P = np.random.normal(size=(self._n_users, self._k))
        self._Q = np.random.normal(size=(self._n_items, self._k))

        # biases
        self._bu = np.zeros(self._n_users)
        self._bi = np.zeros(self._n_items)
        self._b = np.mean(self._R[self._R.nonzero()])

        # train while epochs
        self._training_process = []
        for epoch in range(self._epochs):

            for i in range(len(self._ind)):
              self.gradient_descent(self._ind[i], self._col[i], self._R[self._ind[i], self._col[i]])
            cost = self.cost()
            self._training_process.append((epoch, cost))

            # print status
            if self._verbose == True and ((epoch + 1) % 1 == 0):
                print("Iteration: %d ; cost = %.4f" % (epoch + 1, cost))


    def cost(self) -> None:
        """
        compute root mean square error
        :return: rmse cost
        """

        # xi, yi: R[xi, yi]는 nonzero인 value를 의미한다.
        xi, yi = self._R.nonzero()
        # predicted = self.complete_matrix()
        cost = 0
        for x, y in zip(xi, yi):
            cost += pow(self._R[x, y] - self.predict(x, y), 2)
        return np.sqrt(cost / len(xi))


    def gradient(self, error: float, i: int, j: int) -> tuple:
        """
        gradient of latent feature for GD

        :param error: rating - prediction error
        :param i: user index
        :param j: item index
        :return: gradient of latent feature tuple
        """

        dp = (error * self._Q[j, :]) - (self._regularization * self._P[i, :])
        dq = (error * self._P[i, :]) - (self._regularization * self._Q[j, :])
        return dp, dq


    def gradient_descent(self, i: int, j: int, rating: int) -> None:
        """
        graident descent function

        :param i: user index of matrix
        :param j: item index of matrix
        :param rating: rating of (i,j)
        """

        # get error
        prediction = self.predict(i, j)
        error = rating - prediction

        # update biases
        self._bu[i] += self._lr * (error - self._regularization * self._bu[i])
        self._bi[j] += self._lr * (error - self._regularization * self._bi[j])

        # update latent feature
        dp, dq = self.gradient(error, i, j)
        self._P[i, :] += self._lr * dp
        self._Q[j, :] += self._lr * dq


    def predict(self, i: int, j: int) -> float:
        """
        get predicted rating: user_i, item_j
        :return: prediction of r_ij
        """
        return self._b + self._bu[i] + self._bi[j] + (csr_matrix(self._P[i, :]).dot(csr_matrix(self._Q[j, :].T).T)).toarray().reshape(-1)[0]


    def complete_matrix(self) -> np.ndarray:
        """
        computer complete matrix PXQ + P.bias + Q.bias + global bias

        - PXQ 행렬에 _bu[:, np.newaxis]를 더하는 것은 각 열마다 bias를 더해주는 것
        - _bi[np.newaxis:, ]를 더하는 것은 각 행마다 bias를 더해주는 것
        - b를 더하는 것은 각 element마다 bias를 더해주는 것

        - newaxis: 차원을 추가해줌. 1차원인 Latent들로 2차원의 R에 행/열 단위 연산을 해주기위해 차원을 추가하는 것.

        :return: complete matrix R^
        """
        return self._b + self._bu[:, np.newaxis] + self._bi[np.newaxis:, ] + (csr_matrix(self._P).dot(csr_matrix(self._Q.T))).toarray().reshape(-1)[0]


In [None]:
mf = MatrixFactorization(train, k=3, lr=0.01, regularization=0.01, epochs=10, verbose=True)
mf.fit()
mf.complete_matrix()

Iteration: 1 ; cost = 5.1086
Iteration: 2 ; cost = 4.9495
Iteration: 3 ; cost = 4.8493
Iteration: 4 ; cost = 4.7851
Iteration: 5 ; cost = 4.7446
Iteration: 6 ; cost = 4.7206
Iteration: 7 ; cost = 4.7089
Iteration: 8 ; cost = 4.7065
Iteration: 9 ; cost = 4.7116
Iteration: 10 ; cost = 4.7227


array([[ 9.24645479, 10.46988758,  9.75366269, ...,  9.49956015,
         9.89476271, 10.28093395],
       [ 8.50274883,  9.72618162,  9.00995673, ...,  8.75585419,
         9.15105675,  9.53722798],
       [ 9.16950315, 10.39293594,  9.67671104, ...,  9.42260851,
         9.81781106, 10.2039823 ],
       ...,
       [ 8.85479183, 10.07822462,  9.36199972, ...,  9.10789719,
         9.50309974,  9.88927098],
       [ 8.76483742,  9.98827021,  9.27204531, ...,  9.01794278,
         9.41314533,  9.79931657],
       [ 8.84594847, 10.06938126,  9.35315637, ...,  9.09905383,
         9.49425639,  9.88042763]])

In [None]:
print("mf : ", matrix_rmse(test, mf.complete_matrix()))

mf :  4.667692780977086


MF는 SVD에 비해 개선되었을 뿐 아니라 결과값이 좋은 것을 확인할 수 있습니다

## ALS

In [None]:
class AlternatingLeastSquares:
    def __init__(self, R: np.ndarray, k: int, regularization: float, epochs: int, verbose: bool =False) -> None:
        """
        :param R: rating matrix
        :param k: latent parameter
        :param regularization: regularization term for update
        :param epochs: training epochs
        :param verbose: print status
        """
        self._R = csr_matrix(R)
        self._ind, self._col = self._R.nonzero()
        self._n_users, self._n_items = R.shape
        self._k = k
        self._regularization = regularization
        self._epochs = epochs
        self._verbose = verbose


    def fit(self) -> None:
        # init latent features
        self._users = np.random.normal(size=(self._n_users, self._k))
        self._items = np.random.normal(size=(self._n_items, self._k))

        # train while epochs
        self._training_process = []
        self._user_error = 0; self._item_error = 0; 
        for epoch in range(self._epochs):
            for i, Ri in enumerate(self._R):
                self._users[i] = self.user_latent(Ri)

            for j, Rj in enumerate(self._R.T):
                self._items[j] = self.item_latent(Rj)

            cost = self.cost()
            self._training_process.append((epoch, cost))

            # print status
            if self._verbose == True and ((epoch + 1) % 1 == 0):
                print("Iteration: %d ; cost = %.4f" % (epoch + 1, cost))


    def cost(self) -> float:
        """
        compute root mean square error
        :return: rmse cost
        """
        cost = 0
        for x, y in zip(self._ind, self._col):
            cost += pow(self._R[x, y] - self.predict(x, y), 2)
        return np.sqrt(cost / len(self._ind))


    def user_latent(self, Ri: csr_matrix) -> np.ndarray:
        """
        :param i: user index
        :param Ri: Rating of user index i
        :return: convergence value of user latent of i index
        """
        du = linalg.spsolve((self._items.T @ (self._items)) + 
                            self._regularization * np.eye(self._k),
                            self._items.T @ (Ri.T)
                            ).T
        return du

    def item_latent(self, Rj: csr_matrix) -> np.ndarray:
        """
        :param j: item index
        :param Rj: Rating of item index j
        :return: convergence value of itemr latent of j index
        """

        di = linalg.spsolve((self._users.T @ self._users) + 
                            self._regularization * np.eye(self._k),
                            self._users.T @ (Rj.T)
                            ).T
        return di


    def predict(self, i: int, j: int) -> float:
        """
        get predicted rating: user_i, item_j
        :return: prediction of r_ij
        """
        return self._users[i, :].dot(self._items[j, :].T)


    def complete_matrix(self) -> np.ndarray:
        """
        :return: complete matrix R^
        """
        return self._users.dot(self._items.T)

In [None]:
als = AlternatingLeastSquares(train, k=3, regularization=0.01, epochs=10, verbose=True)
als.fit()
als.complete_matrix()

Iteration: 1 ; cost = 6.9775
Iteration: 2 ; cost = 6.8765
Iteration: 3 ; cost = 6.8666
Iteration: 4 ; cost = 6.8631
Iteration: 5 ; cost = 6.8607
Iteration: 6 ; cost = 6.8584
Iteration: 7 ; cost = 6.8558
Iteration: 8 ; cost = 6.8536
Iteration: 9 ; cost = 6.8520
Iteration: 10 ; cost = 6.8510


array([[-1.20888568e-07,  3.23376885e-03, -2.87718478e-05, ...,
         3.10664864e-07, -1.94117205e-08, -6.53690895e-06],
       [ 6.90859843e-07,  4.00761727e-03,  1.26352055e-04, ...,
         3.87271605e-07,  5.18761658e-08,  1.65226262e-05],
       [ 9.59983721e-07,  3.35240032e-03,  2.07935379e-04, ...,
         3.11629066e-07,  7.68408726e-08,  2.60015667e-05],
       ...,
       [ 9.60437509e-06,  2.42331446e-03,  2.12934670e-03, ...,
         1.27012248e-07,  8.50627666e-07,  2.88840462e-04],
       [ 3.40984848e-07, -1.93916546e-04,  6.85929780e-05, ...,
        -1.89221732e-08,  3.12133516e-08,  1.02555153e-05],
       [ 2.51252277e-06, -1.71299374e-03,  5.12436079e-04, ...,
        -1.69801061e-07,  2.30496562e-07,  7.60561280e-05]])

In [None]:
print("als : ", matrix_rmse(test, als.complete_matrix()))

als :  6.973062646056052


## FM

FM은 원래 제안된 모델은 0/1로 결과를 내지만 마지막 sigmoid 변환을 제거하고 시도해봤습니다

In [None]:
def train_fm(model: type, optimizer: torch.optim, data_loader: DataLoader, criterion: torch.nn, device: str, log_interval: int=100) -> None:
    model.train()
    total_loss = 0
    tk0 = tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0)
    for i, (fields, target) in enumerate(tk0):
        fields, target = fields.to(device), target.to(device)
        y = model(fields)
        loss = criterion(y, target.float())
        model.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if (i + 1) % log_interval == 0:
            tk0.set_postfix(loss=total_loss / log_interval)
            total_loss = 0

def test_fm(model: type, data_loader: DataLoader, device: str) -> float:
    model.eval()
    targets, predicts = list(), list()
    with torch.no_grad():
        for fields, target in tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0):
            fields, target = fields.to(device), target.to(device)
            y = model(fields)
            targets.extend(target.tolist())
            predicts.extend(y.tolist())
    # return roc_auc_score(targets, predicts)
    return rmse(targets, predicts)

In [None]:
class FactorizationMachine(nn.Module):

    def __init__(self, reduce_sum:bool=True):
        super().__init__()
        self.reduce_sum = reduce_sum

    def forward(self, x: torch.Tensor):
        """
        :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
        """
        square_of_sum = torch.sum(x, dim=1) ** 2
        sum_of_square = torch.sum(x ** 2, dim=1)
        ix = square_of_sum - sum_of_square
        if self.reduce_sum:
            ix = torch.sum(ix, dim=1, keepdim=True)
        return 0.5 * ix

In [None]:
class FeaturesEmbedding(nn.Module):

    def __init__(self, field_dims: np.ndarray, embed_dim: int):
        super().__init__()
        self.embedding = torch.nn.Embedding(sum(field_dims), embed_dim)
        self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
        torch.nn.init.xavier_uniform_(self.embedding.weight.data)

    def forward(self, x: torch.Tensor):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x = x + x.new_tensor(self.offsets).unsqueeze(0)
        return self.embedding(x)

In [None]:
class FeaturesLinear(nn.Module):

    def __init__(self, field_dims: np.ndarray, output_dim: int=1):
        super().__init__()
        self.fc = torch.nn.Embedding(sum(field_dims), output_dim)
        self.bias = torch.nn.Parameter(torch.zeros((output_dim,)))
        self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)

    def forward(self, x: torch.Tensor):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x = x + x.new_tensor(self.offsets).unsqueeze(0)
        return torch.sum(self.fc(x), dim=1) + self.bias

In [None]:
class FactorizationMachineModel(nn.Module):

    def __init__(self, field_dims: np.ndarray, embed_dim: int):
        super().__init__()
        self.embedding = FeaturesEmbedding(field_dims, embed_dim)
        self.linear = FeaturesLinear(field_dims)
        self.fm = FactorizationMachine(reduce_sum=True)
        self.output_linear = nn.Linear(1, 1, bias=False)

    def forward(self, x: torch.Tensor):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x = self.linear(x) + self.fm(self.embedding(x))
        # return torch.sigmoid(x.squeeze(1))
        return x.squeeze(1)

In [None]:
######## Hyperparameter ########

batch_size = 256
data_shuffle = True
embed_dim = 8
epochs = 100
learning_rate = 0.01
weight_decay=1e-6
gpu_idx = 0

torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

device = torch.device("cuda:{}".format(gpu_idx) if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [None]:
# PyTorch의 DataLoader에서 사용할 수 있도록 변환 
train_dataset = TensorDataset(torch.LongTensor(np.array(X_train_context)), torch.IntTensor(np.array(y_train)))
valid_dataset = TensorDataset(torch.LongTensor(np.array(X_valid_context)), torch.IntTensor(np.array(y_valid)))
test_dataset = TensorDataset(torch.LongTensor(np.array(X_test_context)), torch.IntTensor(np.array(y_test)))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=data_shuffle)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=data_shuffle)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=data_shuffle)

In [None]:
field_dims = np.array([len(user2idx), len(book2idx), 6, len(loc_city2idx), len(loc_state2idx), len(loc_country2idx), len(language2idx)], dtype=np.uint32)

# criterion = torch.nn.BCELoss()
criterion = torch.nn.MSELoss()
model = FactorizationMachineModel(field_dims, embed_dim).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, amsgrad=True, weight_decay=weight_decay)

for epoch in range(epochs):
    train_fm(model, optimizer, train_dataloader, criterion, device)
    rmse_score = test_fm(model, valid_dataloader, device)
    print('epoch:', epoch, 'validation: rmse:', rmse_score)

100%|██████████| 20/20 [00:00<00:00, 99.61it/s]
100%|██████████| 7/7 [00:00<00:00, 441.31it/s]


epoch: 0 validation: rmse: 5.407682254188486


100%|██████████| 20/20 [00:00<00:00, 228.57it/s]
100%|██████████| 7/7 [00:00<00:00, 426.29it/s]


epoch: 1 validation: rmse: 3.7913304863587247


100%|██████████| 20/20 [00:00<00:00, 230.79it/s]
100%|██████████| 7/7 [00:00<00:00, 453.13it/s]


epoch: 2 validation: rmse: 3.814326501128399


100%|██████████| 20/20 [00:00<00:00, 254.67it/s]
100%|██████████| 7/7 [00:00<00:00, 409.24it/s]


epoch: 3 validation: rmse: 3.6785202316028256


100%|██████████| 20/20 [00:00<00:00, 208.67it/s]
100%|██████████| 7/7 [00:00<00:00, 297.55it/s]


epoch: 4 validation: rmse: 3.730616125616463


100%|██████████| 20/20 [00:00<00:00, 214.62it/s]
100%|██████████| 7/7 [00:00<00:00, 476.32it/s]


epoch: 5 validation: rmse: 3.737225027399165


100%|██████████| 20/20 [00:00<00:00, 217.83it/s]
100%|██████████| 7/7 [00:00<00:00, 438.92it/s]


epoch: 6 validation: rmse: 3.758667189019526


100%|██████████| 20/20 [00:00<00:00, 202.38it/s]
100%|██████████| 7/7 [00:00<00:00, 380.47it/s]


epoch: 7 validation: rmse: 3.7763448044992205


100%|██████████| 20/20 [00:00<00:00, 226.33it/s]
100%|██████████| 7/7 [00:00<00:00, 416.54it/s]


epoch: 8 validation: rmse: 3.808201160594672


100%|██████████| 20/20 [00:00<00:00, 219.94it/s]
100%|██████████| 7/7 [00:00<00:00, 460.33it/s]


epoch: 9 validation: rmse: 3.8147862830588797


100%|██████████| 20/20 [00:00<00:00, 239.22it/s]
100%|██████████| 7/7 [00:00<00:00, 435.00it/s]


epoch: 10 validation: rmse: 3.831486181750756


100%|██████████| 20/20 [00:00<00:00, 234.97it/s]
100%|██████████| 7/7 [00:00<00:00, 462.63it/s]


epoch: 11 validation: rmse: 3.844232938701033


100%|██████████| 20/20 [00:00<00:00, 236.08it/s]
100%|██████████| 7/7 [00:00<00:00, 388.51it/s]


epoch: 12 validation: rmse: 3.855861070974356


100%|██████████| 20/20 [00:00<00:00, 212.58it/s]
100%|██████████| 7/7 [00:00<00:00, 406.64it/s]


epoch: 13 validation: rmse: 3.874597546496721


100%|██████████| 20/20 [00:00<00:00, 223.38it/s]
100%|██████████| 7/7 [00:00<00:00, 424.38it/s]


epoch: 14 validation: rmse: 3.879337572497283


100%|██████████| 20/20 [00:00<00:00, 227.35it/s]
100%|██████████| 7/7 [00:00<00:00, 413.46it/s]


epoch: 15 validation: rmse: 3.8931516131322397


100%|██████████| 20/20 [00:00<00:00, 239.53it/s]
100%|██████████| 7/7 [00:00<00:00, 384.30it/s]


epoch: 16 validation: rmse: 3.9037061924733596


100%|██████████| 20/20 [00:00<00:00, 245.67it/s]
100%|██████████| 7/7 [00:00<00:00, 382.10it/s]


epoch: 17 validation: rmse: 3.9127196646151807


100%|██████████| 20/20 [00:00<00:00, 237.73it/s]
100%|██████████| 7/7 [00:00<00:00, 405.71it/s]


epoch: 18 validation: rmse: 3.9202132339140903


100%|██████████| 20/20 [00:00<00:00, 231.47it/s]
100%|██████████| 7/7 [00:00<00:00, 426.42it/s]


epoch: 19 validation: rmse: 3.9318997759616217


100%|██████████| 20/20 [00:00<00:00, 215.02it/s]
100%|██████████| 7/7 [00:00<00:00, 415.27it/s]


epoch: 20 validation: rmse: 3.933137002624359


100%|██████████| 20/20 [00:00<00:00, 209.33it/s]
100%|██████████| 7/7 [00:00<00:00, 414.53it/s]


epoch: 21 validation: rmse: 3.947096330213175


100%|██████████| 20/20 [00:00<00:00, 234.38it/s]
100%|██████████| 7/7 [00:00<00:00, 472.45it/s]


epoch: 22 validation: rmse: 3.950094944823998


100%|██████████| 20/20 [00:00<00:00, 238.80it/s]
100%|██████████| 7/7 [00:00<00:00, 365.18it/s]


epoch: 23 validation: rmse: 3.959591896736254


100%|██████████| 20/20 [00:00<00:00, 214.15it/s]
100%|██████████| 7/7 [00:00<00:00, 410.99it/s]


epoch: 24 validation: rmse: 3.956025611053299


100%|██████████| 20/20 [00:00<00:00, 218.39it/s]
100%|██████████| 7/7 [00:00<00:00, 443.16it/s]


epoch: 25 validation: rmse: 3.9701516726644774


100%|██████████| 20/20 [00:00<00:00, 228.60it/s]
100%|██████████| 7/7 [00:00<00:00, 451.18it/s]


epoch: 26 validation: rmse: 3.969741181118648


100%|██████████| 20/20 [00:00<00:00, 244.69it/s]
100%|██████████| 7/7 [00:00<00:00, 433.74it/s]


epoch: 27 validation: rmse: 3.9762934811760355


100%|██████████| 20/20 [00:00<00:00, 120.25it/s]
100%|██████████| 7/7 [00:00<00:00, 430.21it/s]


epoch: 28 validation: rmse: 3.9812961056277647


100%|██████████| 20/20 [00:00<00:00, 201.67it/s]
100%|██████████| 7/7 [00:00<00:00, 441.91it/s]


epoch: 29 validation: rmse: 3.981729433970271


100%|██████████| 20/20 [00:00<00:00, 231.13it/s]
100%|██████████| 7/7 [00:00<00:00, 391.88it/s]


epoch: 30 validation: rmse: 3.9915833203120283


100%|██████████| 20/20 [00:00<00:00, 245.56it/s]
100%|██████████| 7/7 [00:00<00:00, 438.24it/s]


epoch: 31 validation: rmse: 3.990307060589058


100%|██████████| 20/20 [00:00<00:00, 234.84it/s]
100%|██████████| 7/7 [00:00<00:00, 408.18it/s]


epoch: 32 validation: rmse: 3.9975453187428744


100%|██████████| 20/20 [00:00<00:00, 161.61it/s]
100%|██████████| 7/7 [00:00<00:00, 386.03it/s]


epoch: 33 validation: rmse: 3.997993964695908


100%|██████████| 20/20 [00:00<00:00, 147.24it/s]
100%|██████████| 7/7 [00:00<00:00, 257.13it/s]


epoch: 34 validation: rmse: 4.00450535500737


100%|██████████| 20/20 [00:00<00:00, 143.73it/s]
100%|██████████| 7/7 [00:00<00:00, 331.12it/s]


epoch: 35 validation: rmse: 4.004693990817076


100%|██████████| 20/20 [00:00<00:00, 153.39it/s]
100%|██████████| 7/7 [00:00<00:00, 283.37it/s]


epoch: 36 validation: rmse: 4.003215036686865


100%|██████████| 20/20 [00:00<00:00, 131.90it/s]
100%|██████████| 7/7 [00:00<00:00, 132.18it/s]


epoch: 37 validation: rmse: 4.010793359717728


100%|██████████| 20/20 [00:00<00:00, 87.57it/s]
100%|██████████| 7/7 [00:00<00:00, 431.93it/s]


epoch: 38 validation: rmse: 4.008293328519693


100%|██████████| 20/20 [00:00<00:00, 222.79it/s]
100%|██████████| 7/7 [00:00<00:00, 405.94it/s]


epoch: 39 validation: rmse: 4.010600720556453


100%|██████████| 20/20 [00:00<00:00, 230.66it/s]
100%|██████████| 7/7 [00:00<00:00, 350.39it/s]


epoch: 40 validation: rmse: 4.010733425498682


100%|██████████| 20/20 [00:00<00:00, 202.20it/s]
100%|██████████| 7/7 [00:00<00:00, 275.40it/s]


epoch: 41 validation: rmse: 4.018150287407513


100%|██████████| 20/20 [00:00<00:00, 219.56it/s]
100%|██████████| 7/7 [00:00<00:00, 364.51it/s]


epoch: 42 validation: rmse: 4.018509893106595


100%|██████████| 20/20 [00:00<00:00, 223.83it/s]
100%|██████████| 7/7 [00:00<00:00, 370.59it/s]


epoch: 43 validation: rmse: 4.024425856026523


100%|██████████| 20/20 [00:00<00:00, 216.00it/s]
100%|██████████| 7/7 [00:00<00:00, 412.88it/s]


epoch: 44 validation: rmse: 4.025320284002436


100%|██████████| 20/20 [00:00<00:00, 252.25it/s]
100%|██████████| 7/7 [00:00<00:00, 402.13it/s]


epoch: 45 validation: rmse: 4.028153171467362


100%|██████████| 20/20 [00:00<00:00, 198.60it/s]
100%|██████████| 7/7 [00:00<00:00, 380.64it/s]


epoch: 46 validation: rmse: 4.030270689965807


100%|██████████| 20/20 [00:00<00:00, 219.31it/s]
100%|██████████| 7/7 [00:00<00:00, 416.90it/s]


epoch: 47 validation: rmse: 4.025623091014767


100%|██████████| 20/20 [00:00<00:00, 224.50it/s]
100%|██████████| 7/7 [00:00<00:00, 277.86it/s]


epoch: 48 validation: rmse: 4.032306791323745


100%|██████████| 20/20 [00:00<00:00, 198.29it/s]
100%|██████████| 7/7 [00:00<00:00, 378.34it/s]


epoch: 49 validation: rmse: 4.027745831411786


100%|██████████| 20/20 [00:00<00:00, 225.58it/s]
100%|██████████| 7/7 [00:00<00:00, 472.56it/s]


epoch: 50 validation: rmse: 4.040821627550177


100%|██████████| 20/20 [00:00<00:00, 223.74it/s]
100%|██████████| 7/7 [00:00<00:00, 349.64it/s]


epoch: 51 validation: rmse: 4.031630664070341


100%|██████████| 20/20 [00:00<00:00, 207.59it/s]
100%|██████████| 7/7 [00:00<00:00, 404.60it/s]


epoch: 52 validation: rmse: 4.036665041607955


100%|██████████| 20/20 [00:00<00:00, 233.17it/s]
100%|██████████| 7/7 [00:00<00:00, 374.88it/s]


epoch: 53 validation: rmse: 4.043917166573548


100%|██████████| 20/20 [00:00<00:00, 218.46it/s]
100%|██████████| 7/7 [00:00<00:00, 388.79it/s]


epoch: 54 validation: rmse: 4.0374814554638565


100%|██████████| 20/20 [00:00<00:00, 209.17it/s]
100%|██████████| 7/7 [00:00<00:00, 327.46it/s]


epoch: 55 validation: rmse: 4.033434111137018


100%|██████████| 20/20 [00:00<00:00, 181.82it/s]
100%|██████████| 7/7 [00:00<00:00, 435.21it/s]


epoch: 56 validation: rmse: 4.044663194208719


100%|██████████| 20/20 [00:00<00:00, 208.10it/s]
100%|██████████| 7/7 [00:00<00:00, 432.06it/s]


epoch: 57 validation: rmse: 4.039339802000822


100%|██████████| 20/20 [00:00<00:00, 110.93it/s]
100%|██████████| 7/7 [00:00<00:00, 392.73it/s]


epoch: 58 validation: rmse: 4.051677783960629


100%|██████████| 20/20 [00:00<00:00, 204.31it/s]
100%|██████████| 7/7 [00:00<00:00, 460.12it/s]


epoch: 59 validation: rmse: 4.043118174042672


100%|██████████| 20/20 [00:00<00:00, 199.10it/s]
100%|██████████| 7/7 [00:00<00:00, 381.54it/s]


epoch: 60 validation: rmse: 4.052873351810811


100%|██████████| 20/20 [00:00<00:00, 173.51it/s]
100%|██████████| 7/7 [00:00<00:00, 386.32it/s]


epoch: 61 validation: rmse: 4.042156495042388


100%|██████████| 20/20 [00:00<00:00, 182.24it/s]
100%|██████████| 7/7 [00:00<00:00, 469.56it/s]


epoch: 62 validation: rmse: 4.050772009817608


100%|██████████| 20/20 [00:00<00:00, 175.67it/s]
100%|██████████| 7/7 [00:00<00:00, 434.10it/s]


epoch: 63 validation: rmse: 4.0414178347657606


100%|██████████| 20/20 [00:00<00:00, 178.23it/s]
100%|██████████| 7/7 [00:00<00:00, 338.37it/s]


epoch: 64 validation: rmse: 4.053136213410811


100%|██████████| 20/20 [00:00<00:00, 176.80it/s]
100%|██████████| 7/7 [00:00<00:00, 362.45it/s]


epoch: 65 validation: rmse: 4.047714014294227


100%|██████████| 20/20 [00:00<00:00, 161.10it/s]
100%|██████████| 7/7 [00:00<00:00, 389.78it/s]


epoch: 66 validation: rmse: 4.053835526604043


100%|██████████| 20/20 [00:00<00:00, 162.96it/s]
100%|██████████| 7/7 [00:00<00:00, 445.81it/s]


epoch: 67 validation: rmse: 4.051007925449025


100%|██████████| 20/20 [00:00<00:00, 186.68it/s]
100%|██████████| 7/7 [00:00<00:00, 348.08it/s]


epoch: 68 validation: rmse: 4.052011341252866


100%|██████████| 20/20 [00:00<00:00, 163.80it/s]
100%|██████████| 7/7 [00:00<00:00, 283.00it/s]


epoch: 69 validation: rmse: 4.046813808679841


100%|██████████| 20/20 [00:00<00:00, 161.00it/s]
100%|██████████| 7/7 [00:00<00:00, 451.83it/s]


epoch: 70 validation: rmse: 4.056671993064103


100%|██████████| 20/20 [00:00<00:00, 179.31it/s]
100%|██████████| 7/7 [00:00<00:00, 406.45it/s]


epoch: 71 validation: rmse: 4.0538833632551246


100%|██████████| 20/20 [00:00<00:00, 176.72it/s]
100%|██████████| 7/7 [00:00<00:00, 378.32it/s]


epoch: 72 validation: rmse: 4.064279450714897


100%|██████████| 20/20 [00:00<00:00, 163.35it/s]
100%|██████████| 7/7 [00:00<00:00, 403.34it/s]


epoch: 73 validation: rmse: 4.054726708757726


100%|██████████| 20/20 [00:00<00:00, 171.25it/s]
100%|██████████| 7/7 [00:00<00:00, 458.02it/s]


epoch: 74 validation: rmse: 4.066062545098244


100%|██████████| 20/20 [00:00<00:00, 173.35it/s]
100%|██████████| 7/7 [00:00<00:00, 347.75it/s]


epoch: 75 validation: rmse: 4.057564454486017


100%|██████████| 20/20 [00:00<00:00, 174.22it/s]
100%|██████████| 7/7 [00:00<00:00, 347.31it/s]


epoch: 76 validation: rmse: 4.067679244272529


100%|██████████| 20/20 [00:00<00:00, 171.23it/s]
100%|██████████| 7/7 [00:00<00:00, 379.92it/s]


epoch: 77 validation: rmse: 4.0604911865983775


100%|██████████| 20/20 [00:00<00:00, 163.61it/s]
100%|██████████| 7/7 [00:00<00:00, 401.87it/s]


epoch: 78 validation: rmse: 4.059410238853197


100%|██████████| 20/20 [00:00<00:00, 166.34it/s]
100%|██████████| 7/7 [00:00<00:00, 400.96it/s]


epoch: 79 validation: rmse: 4.061712094134201


100%|██████████| 20/20 [00:00<00:00, 179.00it/s]
100%|██████████| 7/7 [00:00<00:00, 448.43it/s]


epoch: 80 validation: rmse: 4.063576069602544


100%|██████████| 20/20 [00:00<00:00, 175.33it/s]
100%|██████████| 7/7 [00:00<00:00, 439.46it/s]


epoch: 81 validation: rmse: 4.059309477876981


100%|██████████| 20/20 [00:00<00:00, 180.04it/s]
100%|██████████| 7/7 [00:00<00:00, 333.72it/s]


epoch: 82 validation: rmse: 4.0652341371052545


100%|██████████| 20/20 [00:00<00:00, 159.31it/s]
100%|██████████| 7/7 [00:00<00:00, 389.34it/s]


epoch: 83 validation: rmse: 4.060854671696189


100%|██████████| 20/20 [00:00<00:00, 169.90it/s]
100%|██████████| 7/7 [00:00<00:00, 427.05it/s]


epoch: 84 validation: rmse: 4.064048697854228


100%|██████████| 20/20 [00:00<00:00, 183.85it/s]
100%|██████████| 7/7 [00:00<00:00, 417.33it/s]


epoch: 85 validation: rmse: 4.064496228371447


100%|██████████| 20/20 [00:00<00:00, 185.12it/s]
100%|██████████| 7/7 [00:00<00:00, 383.30it/s]


epoch: 86 validation: rmse: 4.061823256752565


100%|██████████| 20/20 [00:00<00:00, 177.09it/s]
100%|██████████| 7/7 [00:00<00:00, 71.39it/s]


epoch: 87 validation: rmse: 4.062008498408222


100%|██████████| 20/20 [00:00<00:00, 136.36it/s]
100%|██████████| 7/7 [00:00<00:00, 432.78it/s]


epoch: 88 validation: rmse: 4.06477064800194


100%|██████████| 20/20 [00:00<00:00, 166.30it/s]
100%|██████████| 7/7 [00:00<00:00, 347.45it/s]


epoch: 89 validation: rmse: 4.060056425925819


100%|██████████| 20/20 [00:00<00:00, 190.74it/s]
100%|██████████| 7/7 [00:00<00:00, 404.94it/s]


epoch: 90 validation: rmse: 4.070256656257085


100%|██████████| 20/20 [00:00<00:00, 174.85it/s]
100%|██████████| 7/7 [00:00<00:00, 390.56it/s]


epoch: 91 validation: rmse: 4.058784924015037


100%|██████████| 20/20 [00:00<00:00, 179.18it/s]
100%|██████████| 7/7 [00:00<00:00, 376.92it/s]


epoch: 92 validation: rmse: 4.071768738759604


100%|██████████| 20/20 [00:00<00:00, 179.94it/s]
100%|██████████| 7/7 [00:00<00:00, 398.08it/s]


epoch: 93 validation: rmse: 4.066566876262739


100%|██████████| 20/20 [00:00<00:00, 187.17it/s]
100%|██████████| 7/7 [00:00<00:00, 350.08it/s]


epoch: 94 validation: rmse: 4.066058464670307


100%|██████████| 20/20 [00:00<00:00, 175.99it/s]
100%|██████████| 7/7 [00:00<00:00, 380.46it/s]


epoch: 95 validation: rmse: 4.064308722929516


100%|██████████| 20/20 [00:00<00:00, 162.73it/s]
100%|██████████| 7/7 [00:00<00:00, 383.55it/s]


epoch: 96 validation: rmse: 4.062928090954277


100%|██████████| 20/20 [00:00<00:00, 175.02it/s]
100%|██████████| 7/7 [00:00<00:00, 343.77it/s]


epoch: 97 validation: rmse: 4.065147316276448


100%|██████████| 20/20 [00:00<00:00, 177.42it/s]
100%|██████████| 7/7 [00:00<00:00, 392.98it/s]


epoch: 98 validation: rmse: 4.066384805605086


100%|██████████| 20/20 [00:00<00:00, 176.88it/s]
100%|██████████| 7/7 [00:00<00:00, 400.66it/s]

epoch: 99 validation: rmse: 4.064593470305713





In [None]:
test_fm(model, test_dataloader, device)

100%|██████████| 7/7 [00:00<00:00, 419.42it/s]


3.9932752760165418

FM은 좋은 성능을 보이지는 못했습니다

## FFM

FFM 역시 원래 제안된 모델은 0/1로 결과를 내지만 마지막 sigmoid 변환을 제거하고 시도해봤습니다

In [None]:
class FieldAwareFactorizationMachine(nn.Module):

    def __init__(self, field_dims: np.ndarray, embed_dim: int):
        super().__init__()
        self.num_fields = len(field_dims)
        self.embeddings = torch.nn.ModuleList([
            torch.nn.Embedding(sum(field_dims), embed_dim) for _ in range(self.num_fields)
        ])
        self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
        for embedding in self.embeddings:
            torch.nn.init.xavier_uniform_(embedding.weight.data)

    def forward(self, x: torch.Tensor):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x = x + x.new_tensor(self.offsets, dtype= np.long).unsqueeze(0)
        xs = [self.embeddings[i](x) for i in range(self.num_fields)]
        ix = list()
        for i in range(self.num_fields - 1):
            for j in range(i + 1, self.num_fields):
                ix.append(xs[j][:, i] * xs[i][:, j])
        ix = torch.stack(ix, dim=1)
        return ix

In [None]:
class FieldAwareFactorizationMachineModel(nn.Module):

    def __init__(self, field_dims: np.ndarray, embed_dim: int):
        super().__init__()
        self.linear = FeaturesLinear(field_dims)
        self.ffm = FieldAwareFactorizationMachine(field_dims, embed_dim)
        self.output_linear = nn.Linear(1, 1, bias=False)

    def forward(self, x: torch.Tensor):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        ffm_term = torch.sum(torch.sum(self.ffm(x), dim=1), dim=1, keepdim=True)
        x = self.linear(x) + ffm_term
        # return torch.sigmoid(x.squeeze(1))
        return x.squeeze(1)

In [None]:
device = torch.device("cuda:{}".format(gpu_idx) if torch.cuda.is_available() else "cpu")
print(device)
field_dims = np.array([len(user2idx), len(book2idx), 6, len(loc_city2idx), len(loc_state2idx), len(loc_country2idx), len(language2idx)], dtype=np.uint32)

# criterion = torch.nn.BCELoss()
criterion = torch.nn.MSELoss()
model = FieldAwareFactorizationMachineModel(field_dims, embed_dim).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, amsgrad=True, weight_decay=weight_decay)

for epoch in range(epochs):
    train_fm(model, optimizer, train_dataloader, criterion, device)
    rmse_score = test_fm(model, valid_dataloader, device)
    print('epoch:', epoch, 'validation: rmse:', rmse_score)

cpu
[    0  6764 15565 15571 18390 18760 18848]


100%|██████████| 39/39 [00:00<00:00, 62.67it/s]
100%|██████████| 13/13 [00:00<00:00, 314.39it/s]


epoch: 0 validation: rmse: 3.690340814707598


100%|██████████| 39/39 [00:00<00:00, 65.37it/s]
100%|██████████| 13/13 [00:00<00:00, 292.91it/s]


epoch: 1 validation: rmse: 3.406186943916107


100%|██████████| 39/39 [00:00<00:00, 65.14it/s]
100%|██████████| 13/13 [00:00<00:00, 311.55it/s]


epoch: 2 validation: rmse: 3.6194663689891304


100%|██████████| 39/39 [00:00<00:00, 64.68it/s]
100%|██████████| 13/13 [00:00<00:00, 297.18it/s]


epoch: 3 validation: rmse: 3.6626648405021918


100%|██████████| 39/39 [00:00<00:00, 65.15it/s]
100%|██████████| 13/13 [00:00<00:00, 288.53it/s]


epoch: 4 validation: rmse: 3.7090349221289785


100%|██████████| 39/39 [00:00<00:00, 65.04it/s]
100%|██████████| 13/13 [00:00<00:00, 294.44it/s]


epoch: 5 validation: rmse: 3.685722393025077


100%|██████████| 39/39 [00:00<00:00, 55.88it/s]
100%|██████████| 13/13 [00:00<00:00, 286.59it/s]


epoch: 6 validation: rmse: 3.7009752859561975


100%|██████████| 39/39 [00:00<00:00, 52.48it/s]
100%|██████████| 13/13 [00:00<00:00, 199.62it/s]


epoch: 7 validation: rmse: 3.675792244667678


100%|██████████| 39/39 [00:00<00:00, 57.32it/s]
100%|██████████| 13/13 [00:00<00:00, 266.27it/s]


epoch: 8 validation: rmse: 3.6826994510558886


100%|██████████| 39/39 [00:00<00:00, 69.53it/s]
100%|██████████| 13/13 [00:00<00:00, 329.67it/s]


epoch: 9 validation: rmse: 3.6678910704878884


100%|██████████| 39/39 [00:00<00:00, 68.46it/s]
100%|██████████| 13/13 [00:00<00:00, 315.79it/s]


epoch: 10 validation: rmse: 3.671831145208015


100%|██████████| 39/39 [00:00<00:00, 67.24it/s]
100%|██████████| 13/13 [00:00<00:00, 321.92it/s]


epoch: 11 validation: rmse: 3.6661280474866205


100%|██████████| 39/39 [00:00<00:00, 67.50it/s]
100%|██████████| 13/13 [00:00<00:00, 329.18it/s]


epoch: 12 validation: rmse: 3.660840363850697


100%|██████████| 39/39 [00:00<00:00, 67.79it/s]
100%|██████████| 13/13 [00:00<00:00, 309.28it/s]


epoch: 13 validation: rmse: 3.659725261413617


100%|██████████| 39/39 [00:00<00:00, 68.67it/s]
100%|██████████| 13/13 [00:00<00:00, 304.30it/s]


epoch: 14 validation: rmse: 3.6520712924045635


100%|██████████| 39/39 [00:00<00:00, 68.24it/s]
100%|██████████| 13/13 [00:00<00:00, 249.22it/s]


epoch: 15 validation: rmse: 3.6546791892855284


100%|██████████| 39/39 [00:00<00:00, 67.60it/s]
100%|██████████| 13/13 [00:00<00:00, 280.52it/s]


epoch: 16 validation: rmse: 3.6458487885713806


100%|██████████| 39/39 [00:00<00:00, 68.15it/s]
100%|██████████| 13/13 [00:00<00:00, 306.95it/s]


epoch: 17 validation: rmse: 3.6482365764401625


100%|██████████| 39/39 [00:00<00:00, 67.03it/s]
100%|██████████| 13/13 [00:00<00:00, 319.55it/s]


epoch: 18 validation: rmse: 3.6408689130172704


100%|██████████| 39/39 [00:00<00:00, 67.69it/s]
100%|██████████| 13/13 [00:00<00:00, 314.23it/s]


epoch: 19 validation: rmse: 3.6462101164284135


100%|██████████| 39/39 [00:00<00:00, 67.91it/s]
100%|██████████| 13/13 [00:00<00:00, 294.53it/s]


epoch: 20 validation: rmse: 3.6369251073405238


100%|██████████| 39/39 [00:00<00:00, 66.35it/s]
100%|██████████| 13/13 [00:00<00:00, 299.39it/s]


epoch: 21 validation: rmse: 3.6354874813015265


100%|██████████| 39/39 [00:00<00:00, 67.01it/s]
100%|██████████| 13/13 [00:00<00:00, 305.84it/s]


epoch: 22 validation: rmse: 3.6369034641826437


100%|██████████| 39/39 [00:00<00:00, 57.91it/s]
100%|██████████| 13/13 [00:00<00:00, 316.77it/s]


epoch: 23 validation: rmse: 3.6297044434051413


100%|██████████| 39/39 [00:00<00:00, 67.31it/s]
100%|██████████| 13/13 [00:00<00:00, 307.82it/s]


epoch: 24 validation: rmse: 3.628826129082836


100%|██████████| 39/39 [00:00<00:00, 68.16it/s]
100%|██████████| 13/13 [00:00<00:00, 296.13it/s]


epoch: 25 validation: rmse: 3.6232098634128227


100%|██████████| 39/39 [00:00<00:00, 66.88it/s]
100%|██████████| 13/13 [00:00<00:00, 315.18it/s]


epoch: 26 validation: rmse: 3.625894556915917


100%|██████████| 39/39 [00:00<00:00, 62.46it/s]
100%|██████████| 13/13 [00:00<00:00, 326.80it/s]


epoch: 27 validation: rmse: 3.625163034632328


100%|██████████| 39/39 [00:00<00:00, 53.36it/s]
100%|██████████| 13/13 [00:00<00:00, 301.81it/s]


epoch: 28 validation: rmse: 3.6171495312040483


100%|██████████| 39/39 [00:00<00:00, 41.37it/s]
100%|██████████| 13/13 [00:00<00:00, 276.29it/s]


epoch: 29 validation: rmse: 3.6241267191474384


100%|██████████| 39/39 [00:01<00:00, 37.44it/s]
100%|██████████| 13/13 [00:00<00:00, 306.12it/s]


epoch: 30 validation: rmse: 3.6139350499114586


100%|██████████| 39/39 [00:01<00:00, 34.17it/s]
100%|██████████| 13/13 [00:00<00:00, 345.87it/s]


epoch: 31 validation: rmse: 3.61223290821088


100%|██████████| 39/39 [00:01<00:00, 30.58it/s]
100%|██████████| 13/13 [00:00<00:00, 279.20it/s]


epoch: 32 validation: rmse: 3.616586015389627


100%|██████████| 39/39 [00:01<00:00, 30.04it/s]
100%|██████████| 13/13 [00:00<00:00, 305.61it/s]


epoch: 33 validation: rmse: 3.6158056120557776


100%|██████████| 39/39 [00:01<00:00, 27.49it/s]
100%|██████████| 13/13 [00:00<00:00, 309.65it/s]


epoch: 34 validation: rmse: 3.607499003860395


100%|██████████| 39/39 [00:01<00:00, 30.04it/s]
100%|██████████| 13/13 [00:00<00:00, 290.82it/s]


epoch: 35 validation: rmse: 3.6121682750844424


100%|██████████| 39/39 [00:01<00:00, 29.38it/s]
100%|██████████| 13/13 [00:00<00:00, 296.74it/s]


epoch: 36 validation: rmse: 3.6075447495679884


100%|██████████| 39/39 [00:01<00:00, 28.25it/s]
100%|██████████| 13/13 [00:00<00:00, 316.03it/s]


epoch: 37 validation: rmse: 3.611345896136915


100%|██████████| 39/39 [00:01<00:00, 32.34it/s]
100%|██████████| 13/13 [00:00<00:00, 299.37it/s]


epoch: 38 validation: rmse: 3.60643365709017


100%|██████████| 39/39 [00:01<00:00, 35.51it/s]
100%|██████████| 13/13 [00:00<00:00, 272.26it/s]


epoch: 39 validation: rmse: 3.6035101610516045


100%|██████████| 39/39 [00:01<00:00, 35.96it/s]
100%|██████████| 13/13 [00:00<00:00, 308.42it/s]


epoch: 40 validation: rmse: 3.6082227121395998


100%|██████████| 39/39 [00:01<00:00, 33.16it/s]
100%|██████████| 13/13 [00:00<00:00, 306.15it/s]


epoch: 41 validation: rmse: 3.5996454115720207


100%|██████████| 39/39 [00:01<00:00, 35.66it/s]
100%|██████████| 13/13 [00:00<00:00, 323.90it/s]


epoch: 42 validation: rmse: 3.609364619144304


100%|██████████| 39/39 [00:01<00:00, 35.54it/s]
100%|██████████| 13/13 [00:00<00:00, 307.33it/s]


epoch: 43 validation: rmse: 3.599863803609487


100%|██████████| 39/39 [00:01<00:00, 35.84it/s]
100%|██████████| 13/13 [00:00<00:00, 322.58it/s]


epoch: 44 validation: rmse: 3.6029764920235245


100%|██████████| 39/39 [00:01<00:00, 36.02it/s]
100%|██████████| 13/13 [00:00<00:00, 328.92it/s]


epoch: 45 validation: rmse: 3.5973001803697233


100%|██████████| 39/39 [00:01<00:00, 35.92it/s]
100%|██████████| 13/13 [00:00<00:00, 292.00it/s]


epoch: 46 validation: rmse: 3.602996914073656


100%|██████████| 39/39 [00:01<00:00, 35.93it/s]
100%|██████████| 13/13 [00:00<00:00, 292.70it/s]


epoch: 47 validation: rmse: 3.5968643139193133


100%|██████████| 39/39 [00:01<00:00, 36.42it/s]
100%|██████████| 13/13 [00:00<00:00, 302.04it/s]


epoch: 48 validation: rmse: 3.601199067790042


100%|██████████| 39/39 [00:01<00:00, 36.01it/s]
100%|██████████| 13/13 [00:00<00:00, 256.37it/s]


epoch: 49 validation: rmse: 3.5971006276765842


100%|██████████| 39/39 [00:01<00:00, 35.30it/s]
100%|██████████| 13/13 [00:00<00:00, 322.68it/s]


epoch: 50 validation: rmse: 3.6005540492777537


100%|██████████| 39/39 [00:01<00:00, 35.93it/s]
100%|██████████| 13/13 [00:00<00:00, 293.45it/s]


epoch: 51 validation: rmse: 3.595516230535131


100%|██████████| 39/39 [00:01<00:00, 35.94it/s]
100%|██████████| 13/13 [00:00<00:00, 285.83it/s]


epoch: 52 validation: rmse: 3.597676981637772


100%|██████████| 39/39 [00:01<00:00, 36.12it/s]
100%|██████████| 13/13 [00:00<00:00, 314.97it/s]


epoch: 53 validation: rmse: 3.5958662051829826


100%|██████████| 39/39 [00:01<00:00, 35.93it/s]
100%|██████████| 13/13 [00:00<00:00, 303.36it/s]


epoch: 54 validation: rmse: 3.597913913617555


100%|██████████| 39/39 [00:01<00:00, 36.07it/s]
100%|██████████| 13/13 [00:00<00:00, 317.80it/s]


epoch: 55 validation: rmse: 3.5955469713795503


100%|██████████| 39/39 [00:01<00:00, 36.06it/s]
100%|██████████| 13/13 [00:00<00:00, 286.34it/s]


epoch: 56 validation: rmse: 3.596195972187346


100%|██████████| 39/39 [00:01<00:00, 36.17it/s]
100%|██████████| 13/13 [00:00<00:00, 285.36it/s]


epoch: 57 validation: rmse: 3.5958882196808584


100%|██████████| 39/39 [00:01<00:00, 33.02it/s]
100%|██████████| 13/13 [00:00<00:00, 275.59it/s]


epoch: 58 validation: rmse: 3.593807222962438


100%|██████████| 39/39 [00:01<00:00, 35.97it/s]
100%|██████████| 13/13 [00:00<00:00, 325.22it/s]


epoch: 59 validation: rmse: 3.5923548477707863


100%|██████████| 39/39 [00:01<00:00, 36.04it/s]
100%|██████████| 13/13 [00:00<00:00, 303.00it/s]


epoch: 60 validation: rmse: 3.5928497577819316


100%|██████████| 39/39 [00:01<00:00, 35.45it/s]
100%|██████████| 13/13 [00:00<00:00, 278.94it/s]


epoch: 61 validation: rmse: 3.5918025936664497


100%|██████████| 39/39 [00:01<00:00, 35.29it/s]
100%|██████████| 13/13 [00:00<00:00, 300.90it/s]


epoch: 62 validation: rmse: 3.5928699898856538


100%|██████████| 39/39 [00:01<00:00, 35.71it/s]
100%|██████████| 13/13 [00:00<00:00, 261.15it/s]


epoch: 63 validation: rmse: 3.59489634472236


100%|██████████| 39/39 [00:01<00:00, 35.46it/s]
100%|██████████| 13/13 [00:00<00:00, 293.90it/s]


epoch: 64 validation: rmse: 3.5904114556862416


100%|██████████| 39/39 [00:01<00:00, 36.29it/s]
100%|██████████| 13/13 [00:00<00:00, 312.55it/s]


epoch: 65 validation: rmse: 3.592676021742504


100%|██████████| 39/39 [00:01<00:00, 35.84it/s]
100%|██████████| 13/13 [00:00<00:00, 277.18it/s]


epoch: 66 validation: rmse: 3.588157565217664


100%|██████████| 39/39 [00:01<00:00, 35.84it/s]
100%|██████████| 13/13 [00:00<00:00, 275.76it/s]


epoch: 67 validation: rmse: 3.5938016199984784


100%|██████████| 39/39 [00:01<00:00, 35.92it/s]
100%|██████████| 13/13 [00:00<00:00, 342.59it/s]


epoch: 68 validation: rmse: 3.5893928631722156


100%|██████████| 39/39 [00:01<00:00, 35.83it/s]
100%|██████████| 13/13 [00:00<00:00, 323.84it/s]


epoch: 69 validation: rmse: 3.5916197265097254


100%|██████████| 39/39 [00:01<00:00, 36.17it/s]
100%|██████████| 13/13 [00:00<00:00, 263.13it/s]


epoch: 70 validation: rmse: 3.5897169779887443


100%|██████████| 39/39 [00:01<00:00, 35.94it/s]
100%|██████████| 13/13 [00:00<00:00, 288.80it/s]


epoch: 71 validation: rmse: 3.5882750320913233


100%|██████████| 39/39 [00:01<00:00, 36.23it/s]
100%|██████████| 13/13 [00:00<00:00, 285.81it/s]


epoch: 72 validation: rmse: 3.5893120806157235


100%|██████████| 39/39 [00:01<00:00, 35.62it/s]
100%|██████████| 13/13 [00:00<00:00, 322.92it/s]


epoch: 73 validation: rmse: 3.5867694331787754


100%|██████████| 39/39 [00:01<00:00, 35.46it/s]
100%|██████████| 13/13 [00:00<00:00, 307.36it/s]


epoch: 74 validation: rmse: 3.590052104720728


100%|██████████| 39/39 [00:01<00:00, 32.82it/s]
100%|██████████| 13/13 [00:00<00:00, 313.75it/s]


epoch: 75 validation: rmse: 3.5863639989773017


100%|██████████| 39/39 [00:01<00:00, 35.66it/s]
100%|██████████| 13/13 [00:00<00:00, 315.18it/s]


epoch: 76 validation: rmse: 3.587519375674115


100%|██████████| 39/39 [00:01<00:00, 35.91it/s]
100%|██████████| 13/13 [00:00<00:00, 285.16it/s]


epoch: 77 validation: rmse: 3.585448489954782


100%|██████████| 39/39 [00:01<00:00, 35.89it/s]
100%|██████████| 13/13 [00:00<00:00, 310.30it/s]


epoch: 78 validation: rmse: 3.588230829717853


100%|██████████| 39/39 [00:01<00:00, 35.79it/s]
100%|██████████| 13/13 [00:00<00:00, 307.32it/s]


epoch: 79 validation: rmse: 3.5867151779270676


100%|██████████| 39/39 [00:01<00:00, 35.45it/s]
100%|██████████| 13/13 [00:00<00:00, 291.38it/s]


epoch: 80 validation: rmse: 3.5857707825186766


100%|██████████| 39/39 [00:01<00:00, 35.81it/s]
100%|██████████| 13/13 [00:00<00:00, 251.14it/s]


epoch: 81 validation: rmse: 3.5854240827399955


100%|██████████| 39/39 [00:01<00:00, 36.28it/s]
100%|██████████| 13/13 [00:00<00:00, 284.24it/s]


epoch: 82 validation: rmse: 3.586424346282


100%|██████████| 39/39 [00:01<00:00, 35.33it/s]
100%|██████████| 13/13 [00:00<00:00, 290.67it/s]


epoch: 83 validation: rmse: 3.5843546855801223


100%|██████████| 39/39 [00:01<00:00, 35.88it/s]
100%|██████████| 13/13 [00:00<00:00, 291.74it/s]


epoch: 84 validation: rmse: 3.5873097614381466


100%|██████████| 39/39 [00:01<00:00, 36.01it/s]
100%|██████████| 13/13 [00:00<00:00, 295.28it/s]


epoch: 85 validation: rmse: 3.5859570048208322


100%|██████████| 39/39 [00:01<00:00, 35.69it/s]
100%|██████████| 13/13 [00:00<00:00, 307.04it/s]


epoch: 86 validation: rmse: 3.5829374295041387


100%|██████████| 39/39 [00:01<00:00, 35.60it/s]
100%|██████████| 13/13 [00:00<00:00, 298.25it/s]


epoch: 87 validation: rmse: 3.5869689924631185


100%|██████████| 39/39 [00:01<00:00, 35.06it/s]
100%|██████████| 13/13 [00:00<00:00, 276.65it/s]


epoch: 88 validation: rmse: 3.5836495731150233


100%|██████████| 39/39 [00:01<00:00, 35.52it/s]
100%|██████████| 13/13 [00:00<00:00, 267.97it/s]


epoch: 89 validation: rmse: 3.5841934278698417


100%|██████████| 39/39 [00:01<00:00, 35.67it/s]
100%|██████████| 13/13 [00:00<00:00, 301.41it/s]


epoch: 90 validation: rmse: 3.5858565843198758


100%|██████████| 39/39 [00:01<00:00, 35.53it/s]
100%|██████████| 13/13 [00:00<00:00, 294.59it/s]


epoch: 91 validation: rmse: 3.5816186319813736


100%|██████████| 39/39 [00:01<00:00, 35.55it/s]
100%|██████████| 13/13 [00:00<00:00, 92.16it/s]


epoch: 92 validation: rmse: 3.584096141715655


100%|██████████| 39/39 [00:01<00:00, 35.73it/s]
100%|██████████| 13/13 [00:00<00:00, 296.94it/s]


epoch: 93 validation: rmse: 3.5864166254552545


100%|██████████| 39/39 [00:01<00:00, 35.57it/s]
100%|██████████| 13/13 [00:00<00:00, 289.98it/s]


epoch: 94 validation: rmse: 3.5838420341414223


100%|██████████| 39/39 [00:01<00:00, 35.52it/s]
100%|██████████| 13/13 [00:00<00:00, 253.35it/s]


epoch: 95 validation: rmse: 3.582898126626098


100%|██████████| 39/39 [00:01<00:00, 35.81it/s]
100%|██████████| 13/13 [00:00<00:00, 314.76it/s]


epoch: 96 validation: rmse: 3.5843909014443063


100%|██████████| 39/39 [00:01<00:00, 35.70it/s]
100%|██████████| 13/13 [00:00<00:00, 290.36it/s]


epoch: 97 validation: rmse: 3.5843751287543997


100%|██████████| 39/39 [00:01<00:00, 35.62it/s]
100%|██████████| 13/13 [00:00<00:00, 283.23it/s]


epoch: 98 validation: rmse: 3.5838290410838343


100%|██████████| 39/39 [00:01<00:00, 35.54it/s]
100%|██████████| 13/13 [00:00<00:00, 271.14it/s]

epoch: 99 validation: rmse: 3.5834300353380906





In [None]:
test_fm(model, test_dataloader, device)

100%|██████████| 25/25 [00:00<00:00, 321.07it/s]


3.566576112503476

FFM은 FM보다는 나은 성능을 보였고 빠르게 학습할 수 있었습니다

<font color='red'><b>**WARNING**</b></font> : **본 교육 콘텐츠의 지식재산권은 재단법인 네이버커넥트에 귀속됩니다. 본 콘텐츠를 어떠한 경로로든 외부로 유출 및 수정하는 행위를 엄격히 금합니다.** 다만, 비영리적 교육 및 연구활동에 한정되어 사용할 수 있으나 재단의 허락을 받아야 합니다. 이를 위반하는 경우, 관련 법률에 따라 책임을 질 수 있습니다.


