## [0] 실습 소개

이번 실습에서 해볼 내용
- Context-Aware Recommender System (CARS)
  - FM (Factorization Machine)
  - FFM (Field-Aware Factorization Machine)
- Deep Learning Recommender System
  - NCF (Neural Collaborative Filtering)
  - AutoRec
  - WDN (Wide & Deep Network)
  - DCN (Deep & Cross Network)

RecSys 기초 대회 강의에서는 Book Crossing 데이터를 사용하여, 모든 실습 및 미션, 대회를 진행합니다. [Kaggle Book-Crossing](https://www.kaggle.com/datasets/ruchi798/bookcrossing-dataset) 을 출처로 하며, 데이터는 재구성되어 제공되었습니다. 해당 데이터는 CC0: Public Domain 라이센스임을 밝힙니다.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torch.autograd import Variable
import numpy as np
import math
import tqdm
import pdb
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
import warnings

warnings.filterwarnings(action='ignore')

# [1] 데이터 불러오기
book data를 이용합니다. 4강 실습용으로 샘플링된 데이터를 이용하며, 아래 코드를 통해 불러올 수 있습니다.

In [None]:
!wget --load-cookies ~/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies ~/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1jdLFf4JyfWo1406LJ2f8no67YPf15X3f' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1jdLFf4JyfWo1406LJ2f8no67YPf15X3f" -O users.csv && rm -rf ~/cookies.txt
!wget --load-cookies ~/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies ~/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1-EZ2fFCA5RNoqlyM69NeN-L4Y6qTLnQN' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1-EZ2fFCA5RNoqlyM69NeN-L4Y6qTLnQN" -O books.csv && rm -rf ~/cookies.txt
!wget --load-cookies ~/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies ~/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1-I3YKtaJb5IPvOOFqkonQj5ikJQHfoUC' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1-I3YKtaJb5IPvOOFqkonQj5ikJQHfoUC" -O ratings.csv && rm -rf ~/cookies.txt

--2022-10-12 14:49:31--  https://docs.google.com/uc?export=download&confirm=&id=1jdLFf4JyfWo1406LJ2f8no67YPf15X3f
Resolving docs.google.com (docs.google.com)... 142.251.31.113, 142.251.31.138, 142.251.31.101, ...
Connecting to docs.google.com (docs.google.com)|142.251.31.113|:443... connected.
HTTP request sent, awaiting response... 303 See Other
Location: https://doc-00-94-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/ls1nnfsjjjfitbnm8f98qok0trmnrcsi/1665586125000/15943349948421319767/*/1jdLFf4JyfWo1406LJ2f8no67YPf15X3f?e=download&uuid=a32042a3-947f-4255-848f-ff4ce906bf06 [following]
--2022-10-12 14:49:32--  https://doc-00-94-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/ls1nnfsjjjfitbnm8f98qok0trmnrcsi/1665586125000/15943349948421319767/*/1jdLFf4JyfWo1406LJ2f8no67YPf15X3f?e=download&uuid=a32042a3-947f-4255-848f-ff4ce906bf06
Resolving doc-00-94-docs.googleusercontent.com (doc-00-94-docs.googleusercontent.com)... 74.125.143.132, 2

In [None]:
books = pd.read_csv('./books.csv')
users = pd.read_csv('./users.csv')
ratings = pd.read_csv('./ratings.csv')

해당 데이터셋에서는 isbn을 item_id로 볼 수 있으며, user_id가 users_id로 볼 수 있습니다.

implicit 형태로 0/1 문제로 변경하여 진행합니다

In [None]:
ratings['rating'] = [int(i/7) for i in ratings['rating']]

하이퍼 파라미터를 포함한 미리 정의할 내용들을 미리 선언하겠습니다

In [None]:
######## Hyperparameter ########
batch_size = 256 # 배치 사이즈
data_shuffle = True
embed_dim = 8 # embed feature의 dimension
epochs = 100 # epoch 돌릴 횟수
learning_rate = 0.001 # 학습이 반영되는 정도를 나타내는 파라미터
weight_decay=1e-6 # 정규화를 위한 파라미터
gpu_idx = 0

seed=42

torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

device = torch.device("cuda:{}".format(gpu_idx) if torch.cuda.is_available() else "cpu")
print(device)

cpu


## train, test 데이터로 split 해줍니다

FM/FFM에 필요한 데이터셋

In [None]:
def age_map(x: int) -> int:
    x = int(x)
    if x < 20:
        return 1
    elif x >= 20 and x < 30:
        return 2
    elif x >= 30 and x < 40:
        return 3
    elif x >= 40 and x < 50:
        return 4
    elif x >= 50 and x < 60:
        return 5
    else:
        return 6

user_id와 isbn 정보에 대해 값을 직접 사용하는 대신 인덱스화 하여 인덱스된 값을 사용합니다

In [None]:
user2idx = {v:k for k,v in enumerate(ratings['user_id'].unique())}
book2idx = {v:k for k,v in enumerate(ratings['isbn'].unique())}

In [None]:
users_context = users.copy()
books_context = books.copy()

# 인덱싱 처리된 데이터 조인
context_df = ratings.merge(users_context, on='user_id', how='inner').merge(books_context[['isbn', 'category', 'publisher', 'language']], on='isbn', how='inner')

# 인덱싱 처리합니다
context_df['user_id'] = context_df['user_id'].map(user2idx)
context_df['isbn'] = context_df['isbn'].map(book2idx)

# user 파트 인덱싱
loc_city2idx = {v:k for k,v in enumerate(context_df['location_city'].unique())}
loc_state2idx = {v:k for k,v in enumerate(context_df['location_state'].unique())}
loc_country2idx = {v:k for k,v in enumerate(context_df['location_country'].unique())}
context_df['location_city'] = context_df['location_city'].map(loc_city2idx)
context_df['location_state'] = context_df['location_state'].map(loc_state2idx)
context_df['location_country'] = context_df['location_country'].map(loc_country2idx)
context_df['age'] = context_df['age'].apply(age_map)

# book 파트 인덱싱
category2idx = {v:k for k,v in enumerate(context_df['category'].unique())}
publisher2idx = {v:k for k,v in enumerate(context_df['publisher'].unique())}
language2idx = {v:k for k,v in enumerate(context_df['language'].unique())}
context_df['category'] = context_df['category'].map(category2idx)
context_df['publisher'] = context_df['publisher'].map(publisher2idx)
context_df['language'] = context_df['language'].map(language2idx)

context_df

Unnamed: 0,user_id,isbn,rating,age,location_country,location_state,location_city,category,publisher,language
0,0,0,1,2,0,0,0,0,0,0
1,61,0,1,4,0,1,1,0,0,0
2,76,0,1,3,0,2,2,0,0,0
3,103,0,1,3,0,3,3,0,0,0
4,130,0,0,3,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...
16760,246,12507,0,3,0,11,135,377,39,0
16761,246,12508,1,3,0,11,135,0,643,0
16762,246,12509,0,3,0,11,135,4,560,0
16763,246,12510,1,3,0,11,135,17,730,0


In [None]:
# fm/ffm에 사용할 데이터셋
X_train_context, X_test_context, y_train_context, y_test_context = train_test_split(
  context_df.drop(['rating'], axis=1), context_df['rating'], test_size=0.2, random_state=seed)

# PyTorch의 DataLoader에서 사용할 수 있도록 변환 
train_dataset_context = TensorDataset(torch.IntTensor(X_train_context.values), torch.IntTensor(y_train_context.values))
test_dataset_context = TensorDataset(torch.IntTensor(X_test_context.values), torch.IntTensor(y_test_context.values))

# DataLoader로 변환
train_dataloader_context = DataLoader(train_dataset_context, batch_size=batch_size, shuffle=data_shuffle)
test_dataloader_context = DataLoader(test_dataset_context, batch_size=batch_size, shuffle=False)

In [None]:
# user_id, isbn, rating만 있는 데이터
ratings['user_id'] = ratings['user_id'].map(user2idx)
ratings['isbn'] = ratings['isbn'].map(book2idx)

r_train_X, r_test_X, r_train_y, r_test_y = train_test_split(
    ratings[['user_id', 'isbn']], ratings['rating'], test_size=0.2, random_state=seed)
print('학습 데이터 크기:', r_train_X.shape, r_train_y.shape)
print('테스트 데이터 크기:', r_test_X.shape, r_test_y.shape)

# PyTorch의 DataLoader에서 사용할 수 있도록 변환 
r_train_dataset = TensorDataset(torch.IntTensor(r_train_X.values), torch.IntTensor(r_train_y.values))
r_test_dataset = TensorDataset(torch.IntTensor(r_test_X.values), torch.IntTensor(r_test_y.values))

# DataLoader로 변환
r_train_dataloader = DataLoader(r_train_dataset, batch_size=batch_size, shuffle=data_shuffle)
r_test_dataloader = DataLoader(r_test_dataset, batch_size=batch_size, shuffle=False)

학습 데이터 크기: (13412, 2) (13412,)
테스트 데이터 크기: (3353, 2) (3353,)


In [None]:
# autoencoder에 사용할 X, y 구분 없이 나누는 데이터
ae_train, ae_test = train_test_split(
    ratings[['user_id', 'isbn', 'rating']], test_size=0.2, random_state=seed)

print('학습 데이터 크기:', ae_train.shape)
print('테스트 데이터 크기:', ae_test.shape)

학습 데이터 크기: (13412, 3)
테스트 데이터 크기: (3353, 3)


# [2] 학습 함수

본 실습에서는 rmse를 사용하지 않지만 아래 코드를 활용하여 필요시 사용하시면 되겠습니다

In [None]:
from sklearn.metrics import mean_squared_error
def rmse(true: list, pred: list) -> float:
    return np.sqrt(mean_squared_error(true, pred))

In [None]:
def train(model: type, optimizer: torch.optim, data_loader: DataLoader, criterion: torch.nn, device: str, log_interval: int=100) -> None:
    model.train()
    total_loss = 0
    tk0 = tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0)
    for i, (fields, target) in enumerate(tk0):
        fields, target = fields.to(device), target.to(device)
        y = model(fields)
        loss = criterion(y, target.float())
        model.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if (i + 1) % log_interval == 0:
            tk0.set_postfix(loss=total_loss / log_interval)
            total_loss = 0

def test(model: type, data_loader: DataLoader, device: str) -> float:
    model.eval()
    targets, predicts = list(), list()
    with torch.no_grad():
        for fields, target in tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0):
            fields, target = fields.to(device), target.to(device)
            y = model(fields)
            targets.extend(target.tolist())
            predicts.extend(y.tolist())
    return roc_auc_score(targets, predicts)

# [3] Context-Aware Recommender System (CARS)

In [None]:
class FactorizationMachine(nn.Module):

    def __init__(self, reduce_sum:bool=True):
        super().__init__()
        self.reduce_sum = reduce_sum

    def forward(self, x: torch.Tensor):
        """
        :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
        """
        square_of_sum = torch.sum(x, dim=1) ** 2
        sum_of_square = torch.sum(x ** 2, dim=1)
        ix = square_of_sum - sum_of_square
        if self.reduce_sum:
            ix = torch.sum(ix, dim=1, keepdim=True)
        return 0.5 * ix

In [None]:
class FeaturesEmbedding(nn.Module):

    def __init__(self, field_dims: np.ndarray, embed_dim: int):
        super().__init__()
        self.embedding = torch.nn.Embedding(sum(field_dims), embed_dim)
        self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
        torch.nn.init.xavier_uniform_(self.embedding.weight.data)

    def forward(self, x: torch.Tensor):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x = x + x.new_tensor(self.offsets).unsqueeze(0)
        return self.embedding(x)

In [None]:
class FeaturesLinear(nn.Module):

    def __init__(self, field_dims: np.ndarray, output_dim: int=1):
        super().__init__()
        self.fc = torch.nn.Embedding(sum(field_dims), output_dim)
        self.bias = torch.nn.Parameter(torch.zeros((output_dim,)))
        self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)

    def forward(self, x: torch.Tensor):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x = x + x.new_tensor(self.offsets).unsqueeze(0)
        return torch.sum(self.fc(x), dim=1) + self.bias

## FM

In [None]:
class FactorizationMachineModel(nn.Module):

    def __init__(self, field_dims: np.ndarray, embed_dim: int):
        super().__init__()
        self.embedding = FeaturesEmbedding(field_dims, embed_dim)
        self.linear = FeaturesLinear(field_dims)
        self.fm = FactorizationMachine(reduce_sum=True)

    def forward(self, x: torch.Tensor):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x = self.linear(x) + self.fm(self.embedding(x))
        return torch.sigmoid(x.squeeze(1))

In [None]:
fm_field_dims = np.array([len(user2idx), len(book2idx), 6, len(loc_city2idx), len(loc_state2idx), len(loc_country2idx), 
                          len(category2idx), len(publisher2idx), len(language2idx)], dtype=np.uint32)

In [None]:
criterion = torch.nn.BCELoss()
model = FactorizationMachineModel(fm_field_dims, embed_dim).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, amsgrad=True, weight_decay=weight_decay)

for epoch in range(epochs):
    train(model, optimizer, train_dataloader_context, criterion, device)
    auc_score = test(model, train_dataloader_context, device)
    print('epoch:', epoch, 'validation: roc_auc_score:', auc_score)

100%|██████████| 53/53 [00:00<00:00, 177.04it/s]
100%|██████████| 53/53 [00:00<00:00, 301.77it/s]


epoch: 0 validation: roc_auc_score: 0.5129181284386842


100%|██████████| 53/53 [00:00<00:00, 287.96it/s]
100%|██████████| 53/53 [00:00<00:00, 549.01it/s]


epoch: 1 validation: roc_auc_score: 0.6025077477322605


100%|██████████| 53/53 [00:00<00:00, 264.55it/s]
100%|██████████| 53/53 [00:00<00:00, 532.90it/s]


epoch: 2 validation: roc_auc_score: 0.6862601144540433


100%|██████████| 53/53 [00:00<00:00, 165.62it/s]
100%|██████████| 53/53 [00:00<00:00, 534.38it/s]


epoch: 3 validation: roc_auc_score: 0.7363705337058742


100%|██████████| 53/53 [00:00<00:00, 256.35it/s]
100%|██████████| 53/53 [00:00<00:00, 521.54it/s]


epoch: 4 validation: roc_auc_score: 0.7697201846986792


100%|██████████| 53/53 [00:00<00:00, 158.15it/s]
100%|██████████| 53/53 [00:00<00:00, 546.01it/s]


epoch: 5 validation: roc_auc_score: 0.7979880548788224


100%|██████████| 53/53 [00:00<00:00, 280.79it/s]
100%|██████████| 53/53 [00:00<00:00, 547.31it/s]


epoch: 6 validation: roc_auc_score: 0.8203885188139386


100%|██████████| 53/53 [00:00<00:00, 252.53it/s]
100%|██████████| 53/53 [00:00<00:00, 480.04it/s]


epoch: 7 validation: roc_auc_score: 0.840886353082687


100%|██████████| 53/53 [00:00<00:00, 254.65it/s]
100%|██████████| 53/53 [00:00<00:00, 271.79it/s]


epoch: 8 validation: roc_auc_score: 0.8600595470328629


100%|██████████| 53/53 [00:00<00:00, 269.12it/s]
100%|██████████| 53/53 [00:00<00:00, 525.93it/s]


epoch: 9 validation: roc_auc_score: 0.8747169287970863


100%|██████████| 53/53 [00:00<00:00, 258.15it/s]
100%|██████████| 53/53 [00:00<00:00, 521.07it/s]


epoch: 10 validation: roc_auc_score: 0.8866949241899981


100%|██████████| 53/53 [00:00<00:00, 245.69it/s]
100%|██████████| 53/53 [00:00<00:00, 508.77it/s]


epoch: 11 validation: roc_auc_score: 0.9006999750980514


100%|██████████| 53/53 [00:00<00:00, 260.82it/s]
100%|██████████| 53/53 [00:00<00:00, 503.00it/s]


epoch: 12 validation: roc_auc_score: 0.911020441290139


100%|██████████| 53/53 [00:00<00:00, 281.30it/s]
100%|██████████| 53/53 [00:00<00:00, 539.70it/s]


epoch: 13 validation: roc_auc_score: 0.9220426275623222


100%|██████████| 53/53 [00:00<00:00, 239.06it/s]
100%|██████████| 53/53 [00:00<00:00, 507.53it/s]


epoch: 14 validation: roc_auc_score: 0.9306081643539144


100%|██████████| 53/53 [00:00<00:00, 266.17it/s]
100%|██████████| 53/53 [00:00<00:00, 507.64it/s]


epoch: 15 validation: roc_auc_score: 0.9394423516777846


100%|██████████| 53/53 [00:00<00:00, 189.26it/s]
100%|██████████| 53/53 [00:00<00:00, 518.52it/s]


epoch: 16 validation: roc_auc_score: 0.9463106297666564


100%|██████████| 53/53 [00:00<00:00, 267.00it/s]
100%|██████████| 53/53 [00:00<00:00, 526.55it/s]


epoch: 17 validation: roc_auc_score: 0.95248599657902


100%|██████████| 53/53 [00:00<00:00, 256.68it/s]
100%|██████████| 53/53 [00:00<00:00, 539.73it/s]


epoch: 18 validation: roc_auc_score: 0.9584664431443741


100%|██████████| 53/53 [00:00<00:00, 252.94it/s]
100%|██████████| 53/53 [00:00<00:00, 500.50it/s]


epoch: 19 validation: roc_auc_score: 0.9632222622699862


100%|██████████| 53/53 [00:00<00:00, 246.62it/s]
100%|██████████| 53/53 [00:00<00:00, 502.59it/s]


epoch: 20 validation: roc_auc_score: 0.9669359699692426


100%|██████████| 53/53 [00:00<00:00, 255.13it/s]
100%|██████████| 53/53 [00:00<00:00, 541.63it/s]


epoch: 21 validation: roc_auc_score: 0.9711603246117229


100%|██████████| 53/53 [00:00<00:00, 266.10it/s]
100%|██████████| 53/53 [00:00<00:00, 529.61it/s]


epoch: 22 validation: roc_auc_score: 0.9741528010305557


100%|██████████| 53/53 [00:00<00:00, 261.58it/s]
100%|██████████| 53/53 [00:00<00:00, 305.32it/s]


epoch: 23 validation: roc_auc_score: 0.9767916189662962


100%|██████████| 53/53 [00:00<00:00, 247.71it/s]
100%|██████████| 53/53 [00:00<00:00, 575.77it/s]


epoch: 24 validation: roc_auc_score: 0.9793253664516224


100%|██████████| 53/53 [00:00<00:00, 274.72it/s]
100%|██████████| 53/53 [00:00<00:00, 546.00it/s]


epoch: 25 validation: roc_auc_score: 0.981097769642092


100%|██████████| 53/53 [00:00<00:00, 256.54it/s]
100%|██████████| 53/53 [00:00<00:00, 452.74it/s]


epoch: 26 validation: roc_auc_score: 0.9830956422501094


100%|██████████| 53/53 [00:00<00:00, 205.61it/s]
100%|██████████| 53/53 [00:00<00:00, 552.72it/s]


epoch: 27 validation: roc_auc_score: 0.9846435228196072


100%|██████████| 53/53 [00:00<00:00, 264.13it/s]
100%|██████████| 53/53 [00:00<00:00, 565.69it/s]


epoch: 28 validation: roc_auc_score: 0.985966796874766


100%|██████████| 53/53 [00:00<00:00, 287.42it/s]
100%|██████████| 53/53 [00:00<00:00, 558.40it/s]


epoch: 29 validation: roc_auc_score: 0.9871234347325303


100%|██████████| 53/53 [00:00<00:00, 287.40it/s]
100%|██████████| 53/53 [00:00<00:00, 545.27it/s]


epoch: 30 validation: roc_auc_score: 0.9880008155394102


100%|██████████| 53/53 [00:00<00:00, 185.25it/s]
100%|██████████| 53/53 [00:00<00:00, 547.41it/s]


epoch: 31 validation: roc_auc_score: 0.9887923494618013


100%|██████████| 53/53 [00:00<00:00, 283.63it/s]
100%|██████████| 53/53 [00:00<00:00, 563.78it/s]


epoch: 32 validation: roc_auc_score: 0.9894593976115846


100%|██████████| 53/53 [00:00<00:00, 268.39it/s]
100%|██████████| 53/53 [00:00<00:00, 559.84it/s]


epoch: 33 validation: roc_auc_score: 0.9901454775710242


100%|██████████| 53/53 [00:00<00:00, 293.68it/s]
100%|██████████| 53/53 [00:00<00:00, 596.01it/s]


epoch: 34 validation: roc_auc_score: 0.9906718724978053


100%|██████████| 53/53 [00:00<00:00, 266.32it/s]
100%|██████████| 53/53 [00:00<00:00, 538.82it/s]


epoch: 35 validation: roc_auc_score: 0.9911526054631314


100%|██████████| 53/53 [00:00<00:00, 270.96it/s]
100%|██████████| 53/53 [00:00<00:00, 492.33it/s]


epoch: 36 validation: roc_auc_score: 0.9915211849808725


100%|██████████| 53/53 [00:00<00:00, 198.52it/s]
100%|██████████| 53/53 [00:00<00:00, 559.12it/s]


epoch: 37 validation: roc_auc_score: 0.99192919438134


100%|██████████| 53/53 [00:00<00:00, 265.19it/s]
100%|██████████| 53/53 [00:00<00:00, 561.23it/s]


epoch: 38 validation: roc_auc_score: 0.9922131854151045


100%|██████████| 53/53 [00:00<00:00, 135.46it/s]
100%|██████████| 53/53 [00:00<00:00, 563.62it/s]


epoch: 39 validation: roc_auc_score: 0.9924743814224417


100%|██████████| 53/53 [00:00<00:00, 258.95it/s]
100%|██████████| 53/53 [00:00<00:00, 580.46it/s]


epoch: 40 validation: roc_auc_score: 0.992711919500143


100%|██████████| 53/53 [00:00<00:00, 161.36it/s]
100%|██████████| 53/53 [00:00<00:00, 529.48it/s]


epoch: 41 validation: roc_auc_score: 0.992892889478614


100%|██████████| 53/53 [00:00<00:00, 263.17it/s]
100%|██████████| 53/53 [00:00<00:00, 517.45it/s]


epoch: 42 validation: roc_auc_score: 0.9931229011227738


100%|██████████| 53/53 [00:00<00:00, 261.41it/s]
100%|██████████| 53/53 [00:00<00:00, 571.29it/s]


epoch: 43 validation: roc_auc_score: 0.9933429893800344


100%|██████████| 53/53 [00:00<00:00, 286.59it/s]
100%|██████████| 53/53 [00:00<00:00, 548.03it/s]


epoch: 44 validation: roc_auc_score: 0.9935082233597151


100%|██████████| 53/53 [00:00<00:00, 263.75it/s]
100%|██████████| 53/53 [00:00<00:00, 507.93it/s]


epoch: 45 validation: roc_auc_score: 0.9936464556598265


100%|██████████| 53/53 [00:00<00:00, 254.70it/s]
100%|██████████| 53/53 [00:00<00:00, 323.35it/s]


epoch: 46 validation: roc_auc_score: 0.9937852033049098


100%|██████████| 53/53 [00:00<00:00, 258.94it/s]
100%|██████████| 53/53 [00:00<00:00, 502.84it/s]


epoch: 47 validation: roc_auc_score: 0.9939110912952314


100%|██████████| 53/53 [00:00<00:00, 256.82it/s]
100%|██████████| 53/53 [00:00<00:00, 514.63it/s]


epoch: 48 validation: roc_auc_score: 0.9940341508805915


100%|██████████| 53/53 [00:00<00:00, 243.80it/s]
100%|██████████| 53/53 [00:00<00:00, 518.40it/s]


epoch: 49 validation: roc_auc_score: 0.9941474308962541


100%|██████████| 53/53 [00:00<00:00, 266.85it/s]
100%|██████████| 53/53 [00:00<00:00, 533.42it/s]


epoch: 50 validation: roc_auc_score: 0.994244004147017


100%|██████████| 53/53 [00:00<00:00, 301.51it/s]
100%|██████████| 53/53 [00:00<00:00, 564.28it/s]


epoch: 51 validation: roc_auc_score: 0.9943513397183537


100%|██████████| 53/53 [00:00<00:00, 252.45it/s]
100%|██████████| 53/53 [00:00<00:00, 521.53it/s]


epoch: 52 validation: roc_auc_score: 0.9944203959445754


100%|██████████| 53/53 [00:00<00:00, 242.45it/s]
100%|██████████| 53/53 [00:00<00:00, 516.87it/s]


epoch: 53 validation: roc_auc_score: 0.9945275876987107


100%|██████████| 53/53 [00:00<00:00, 186.99it/s]
100%|██████████| 53/53 [00:00<00:00, 492.08it/s]


epoch: 54 validation: roc_auc_score: 0.9945785229575518


100%|██████████| 53/53 [00:00<00:00, 243.50it/s]
100%|██████████| 53/53 [00:00<00:00, 561.83it/s]


epoch: 55 validation: roc_auc_score: 0.994694080078904


100%|██████████| 53/53 [00:00<00:00, 260.52it/s]
100%|██████████| 53/53 [00:00<00:00, 554.44it/s]


epoch: 56 validation: roc_auc_score: 0.9947670912781651


100%|██████████| 53/53 [00:00<00:00, 252.87it/s]
100%|██████████| 53/53 [00:00<00:00, 533.85it/s]


epoch: 57 validation: roc_auc_score: 0.994830250999128


100%|██████████| 53/53 [00:00<00:00, 251.37it/s]
100%|██████████| 53/53 [00:00<00:00, 553.87it/s]


epoch: 58 validation: roc_auc_score: 0.9948849974138071


100%|██████████| 53/53 [00:00<00:00, 285.18it/s]
100%|██████████| 53/53 [00:00<00:00, 509.33it/s]


epoch: 59 validation: roc_auc_score: 0.9949527832547496


100%|██████████| 53/53 [00:00<00:00, 261.40it/s]
100%|██████████| 53/53 [00:00<00:00, 575.11it/s]


epoch: 60 validation: roc_auc_score: 0.9950241405561944


100%|██████████| 53/53 [00:00<00:00, 260.29it/s]
100%|██████████| 53/53 [00:00<00:00, 297.67it/s]


epoch: 61 validation: roc_auc_score: 0.9950678609854302


100%|██████████| 53/53 [00:00<00:00, 283.40it/s]
100%|██████████| 53/53 [00:00<00:00, 499.73it/s]


epoch: 62 validation: roc_auc_score: 0.9951311165845275


100%|██████████| 53/53 [00:00<00:00, 277.46it/s]
100%|██████████| 53/53 [00:00<00:00, 542.64it/s]


epoch: 63 validation: roc_auc_score: 0.9952041038142551


100%|██████████| 53/53 [00:00<00:00, 260.28it/s]
100%|██████████| 53/53 [00:00<00:00, 491.42it/s]


epoch: 64 validation: roc_auc_score: 0.9952328432850083


100%|██████████| 53/53 [00:00<00:00, 257.91it/s]
100%|██████████| 53/53 [00:00<00:00, 504.45it/s]


epoch: 65 validation: roc_auc_score: 0.9953062140272732


100%|██████████| 53/53 [00:00<00:00, 273.35it/s]
100%|██████████| 53/53 [00:00<00:00, 525.12it/s]


epoch: 66 validation: roc_auc_score: 0.9953588990620651


100%|██████████| 53/53 [00:00<00:00, 250.11it/s]
100%|██████████| 53/53 [00:00<00:00, 525.80it/s]


epoch: 67 validation: roc_auc_score: 0.9953873389136485


100%|██████████| 53/53 [00:00<00:00, 271.29it/s]
100%|██████████| 53/53 [00:00<00:00, 508.86it/s]


epoch: 68 validation: roc_auc_score: 0.9954294413993682


100%|██████████| 53/53 [00:00<00:00, 260.61it/s]
100%|██████████| 53/53 [00:00<00:00, 308.74it/s]


epoch: 69 validation: roc_auc_score: 0.9954934160844725


100%|██████████| 53/53 [00:00<00:00, 245.66it/s]
100%|██████████| 53/53 [00:00<00:00, 503.33it/s]


epoch: 70 validation: roc_auc_score: 0.9955345238345492


100%|██████████| 53/53 [00:00<00:00, 266.59it/s]
100%|██████████| 53/53 [00:00<00:00, 544.42it/s]


epoch: 71 validation: roc_auc_score: 0.9955757754018271


100%|██████████| 53/53 [00:00<00:00, 254.61it/s]
100%|██████████| 53/53 [00:00<00:00, 553.75it/s]


epoch: 72 validation: roc_auc_score: 0.9956396422240306


100%|██████████| 53/53 [00:00<00:00, 271.35it/s]
100%|██████████| 53/53 [00:00<00:00, 593.82it/s]


epoch: 73 validation: roc_auc_score: 0.995682871277828


100%|██████████| 53/53 [00:00<00:00, 263.91it/s]
100%|██████████| 53/53 [00:00<00:00, 559.48it/s]


epoch: 74 validation: roc_auc_score: 0.9957380731136452


100%|██████████| 53/53 [00:00<00:00, 264.31it/s]
100%|██████████| 53/53 [00:00<00:00, 535.28it/s]


epoch: 75 validation: roc_auc_score: 0.9957823088878528


100%|██████████| 53/53 [00:00<00:00, 261.29it/s]
100%|██████████| 53/53 [00:00<00:00, 570.71it/s]


epoch: 76 validation: roc_auc_score: 0.9958212713646745


100%|██████████| 53/53 [00:00<00:00, 198.48it/s]
100%|██████████| 53/53 [00:00<00:00, 527.04it/s]


epoch: 77 validation: roc_auc_score: 0.9958561470360222


100%|██████████| 53/53 [00:00<00:00, 257.32it/s]
100%|██████████| 53/53 [00:00<00:00, 541.62it/s]


epoch: 78 validation: roc_auc_score: 0.9959165502606243


100%|██████████| 53/53 [00:00<00:00, 257.36it/s]
100%|██████████| 53/53 [00:00<00:00, 563.52it/s]


epoch: 79 validation: roc_auc_score: 0.9959521929570463


100%|██████████| 53/53 [00:00<00:00, 265.64it/s]
100%|██████████| 53/53 [00:00<00:00, 559.08it/s]


epoch: 80 validation: roc_auc_score: 0.9960073708233297


100%|██████████| 53/53 [00:00<00:00, 276.31it/s]
100%|██████████| 53/53 [00:00<00:00, 550.48it/s]


epoch: 81 validation: roc_auc_score: 0.9960434210018225


100%|██████████| 53/53 [00:00<00:00, 274.26it/s]
100%|██████████| 53/53 [00:00<00:00, 555.89it/s]


epoch: 82 validation: roc_auc_score: 0.9960975921476959


100%|██████████| 53/53 [00:00<00:00, 319.14it/s]
100%|██████████| 53/53 [00:00<00:00, 596.08it/s]


epoch: 83 validation: roc_auc_score: 0.9961444525858297


100%|██████████| 53/53 [00:00<00:00, 306.28it/s]
100%|██████████| 53/53 [00:00<00:00, 318.15it/s]


epoch: 84 validation: roc_auc_score: 0.9961850569757011


100%|██████████| 53/53 [00:00<00:00, 266.24it/s]
100%|██████████| 53/53 [00:00<00:00, 573.22it/s]


epoch: 85 validation: roc_auc_score: 0.9962333795553828


100%|██████████| 53/53 [00:00<00:00, 275.18it/s]
100%|██████████| 53/53 [00:00<00:00, 557.38it/s]


epoch: 86 validation: roc_auc_score: 0.996278490217566


100%|██████████| 53/53 [00:00<00:00, 311.77it/s]
100%|██████████| 53/53 [00:00<00:00, 616.87it/s]


epoch: 87 validation: roc_auc_score: 0.9963403316141826


100%|██████████| 53/53 [00:00<00:00, 270.51it/s]
100%|██████████| 53/53 [00:00<00:00, 534.30it/s]


epoch: 88 validation: roc_auc_score: 0.9963733256771447


100%|██████████| 53/53 [00:00<00:00, 279.04it/s]
100%|██████████| 53/53 [00:00<00:00, 526.03it/s]


epoch: 89 validation: roc_auc_score: 0.9964112095249557


100%|██████████| 53/53 [00:00<00:00, 269.56it/s]
100%|██████████| 53/53 [00:00<00:00, 570.58it/s]


epoch: 90 validation: roc_auc_score: 0.9964784920056932


100%|██████████| 53/53 [00:00<00:00, 293.83it/s]
100%|██████████| 53/53 [00:00<00:00, 563.09it/s]


epoch: 91 validation: roc_auc_score: 0.9965248490836218


100%|██████████| 53/53 [00:00<00:00, 204.39it/s]
100%|██████████| 53/53 [00:00<00:00, 574.50it/s]


epoch: 92 validation: roc_auc_score: 0.9965722368514944


100%|██████████| 53/53 [00:00<00:00, 281.05it/s]
100%|██████████| 53/53 [00:00<00:00, 573.67it/s]


epoch: 93 validation: roc_auc_score: 0.9966133446015708


100%|██████████| 53/53 [00:00<00:00, 264.94it/s]
100%|██████████| 53/53 [00:00<00:00, 566.08it/s]


epoch: 94 validation: roc_auc_score: 0.9966452720202892


100%|██████████| 53/53 [00:00<00:00, 268.61it/s]
100%|██████████| 53/53 [00:00<00:00, 512.20it/s]


epoch: 95 validation: roc_auc_score: 0.9966953683454552


100%|██████████| 53/53 [00:00<00:00, 262.45it/s]
100%|██████████| 53/53 [00:00<00:00, 530.85it/s]


epoch: 96 validation: roc_auc_score: 0.9967620036487862


100%|██████████| 53/53 [00:00<00:00, 269.29it/s]
100%|██████████| 53/53 [00:00<00:00, 570.22it/s]


epoch: 97 validation: roc_auc_score: 0.9968108775277401


100%|██████████| 53/53 [00:00<00:00, 253.48it/s]
100%|██████████| 53/53 [00:00<00:00, 572.41it/s]


epoch: 98 validation: roc_auc_score: 0.9968634307300974


100%|██████████| 53/53 [00:00<00:00, 266.65it/s]
100%|██████████| 53/53 [00:00<00:00, 536.02it/s]

epoch: 99 validation: roc_auc_score: 0.9968954300574162





In [None]:
test(model, test_dataloader_context, device)

100%|██████████| 14/14 [00:00<00:00, 143.90it/s]


0.7134581442838162

## FFM

In [None]:
class FieldAwareFactorizationMachine(nn.Module):

    def __init__(self, field_dims: np.ndarray, embed_dim: int):
        super().__init__()
        self.num_fields = len(field_dims)
        self.embeddings = torch.nn.ModuleList([
            torch.nn.Embedding(sum(field_dims), embed_dim) for _ in range(self.num_fields)
        ])
        self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
        for embedding in self.embeddings:
            torch.nn.init.xavier_uniform_(embedding.weight.data)

    def forward(self, x: torch.Tensor):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x = x + x.new_tensor(self.offsets).unsqueeze(0)
        xs = [self.embeddings[i](x) for i in range(self.num_fields)]
        ix = list()
        for i in range(self.num_fields - 1):
            for j in range(i + 1, self.num_fields):
                ix.append(xs[j][:, i] * xs[i][:, j])
        ix = torch.stack(ix, dim=1)
        return ix

In [None]:
class FieldAwareFactorizationMachineModel(nn.Module):

    def __init__(self, field_dims: np.ndarray, embed_dim: int):
        super().__init__()
        self.linear = FeaturesLinear(field_dims)
        self.ffm = FieldAwareFactorizationMachine(field_dims, embed_dim)

    def forward(self, x: torch.Tensor):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        ffm_term = torch.sum(torch.sum(self.ffm(x), dim=1), dim=1, keepdim=True)
        x = self.linear(x) + ffm_term
        return torch.sigmoid(x.squeeze(1))

각 필드의 아이템 수만큼 field_dims 선언해주겠습니다

In [None]:
ffm_field_dims = np.array([len(user2idx), len(book2idx), 6, len(loc_country2idx), len(loc_state2idx), len(loc_city2idx), 
                      len(category2idx), len(publisher2idx), len(language2idx)], 
                      dtype=np.uint32)

In [None]:
criterion = torch.nn.BCELoss()
model = FieldAwareFactorizationMachineModel(ffm_field_dims, embed_dim).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, amsgrad=True, weight_decay=weight_decay)

for epoch in range(epochs):
    train(model, optimizer, train_dataloader_context, criterion, device)
    auc_score = test(model, train_dataloader_context, device)
    print('epoch:', epoch, 'validation: roc_auc_score:', auc_score)

100%|██████████| 53/53 [00:00<00:00, 72.40it/s]
100%|██████████| 53/53 [00:00<00:00, 330.43it/s]


epoch: 0 validation: roc_auc_score: 0.557266163956835


100%|██████████| 53/53 [00:00<00:00, 77.45it/s]
100%|██████████| 53/53 [00:00<00:00, 345.55it/s]


epoch: 1 validation: roc_auc_score: 0.6594477822153109


100%|██████████| 53/53 [00:00<00:00, 74.79it/s]
100%|██████████| 53/53 [00:00<00:00, 335.79it/s]


epoch: 2 validation: roc_auc_score: 0.7526527382423625


100%|██████████| 53/53 [00:00<00:00, 77.37it/s]
100%|██████████| 53/53 [00:00<00:00, 323.93it/s]


epoch: 3 validation: roc_auc_score: 0.8292837685245844


100%|██████████| 53/53 [00:00<00:00, 78.96it/s]
100%|██████████| 53/53 [00:00<00:00, 322.48it/s]


epoch: 4 validation: roc_auc_score: 0.8989592596309646


100%|██████████| 53/53 [00:00<00:00, 79.06it/s]
100%|██████████| 53/53 [00:00<00:00, 340.13it/s]


epoch: 5 validation: roc_auc_score: 0.9502039531657365


100%|██████████| 53/53 [00:00<00:00, 75.88it/s]
100%|██████████| 53/53 [00:00<00:00, 356.18it/s]


epoch: 6 validation: roc_auc_score: 0.9796748901901735


100%|██████████| 53/53 [00:00<00:00, 71.31it/s]
100%|██████████| 53/53 [00:00<00:00, 323.83it/s]


epoch: 7 validation: roc_auc_score: 0.9924191076780241


100%|██████████| 53/53 [00:00<00:00, 77.23it/s]
100%|██████████| 53/53 [00:00<00:00, 338.64it/s]


epoch: 8 validation: roc_auc_score: 0.9971211152007664


100%|██████████| 53/53 [00:00<00:00, 75.90it/s]
100%|██████████| 53/53 [00:00<00:00, 316.47it/s]


epoch: 9 validation: roc_auc_score: 0.9988817972893189


100%|██████████| 53/53 [00:00<00:00, 76.85it/s]
100%|██████████| 53/53 [00:00<00:00, 341.27it/s]


epoch: 10 validation: roc_auc_score: 0.9995443391667908


100%|██████████| 53/53 [00:00<00:00, 79.06it/s]
100%|██████████| 53/53 [00:00<00:00, 331.46it/s]


epoch: 11 validation: roc_auc_score: 0.9997805589201457


100%|██████████| 53/53 [00:00<00:00, 77.44it/s]
100%|██████████| 53/53 [00:00<00:00, 343.00it/s]


epoch: 12 validation: roc_auc_score: 0.9998990882636607


100%|██████████| 53/53 [00:00<00:00, 76.99it/s]
100%|██████████| 53/53 [00:00<00:00, 366.99it/s]


epoch: 13 validation: roc_auc_score: 0.9999473868738089


100%|██████████| 53/53 [00:00<00:00, 77.75it/s]
100%|██████████| 53/53 [00:00<00:00, 334.30it/s]


epoch: 14 validation: roc_auc_score: 0.9999705654127732


100%|██████████| 53/53 [00:00<00:00, 68.37it/s]
100%|██████████| 53/53 [00:00<00:00, 331.50it/s]


epoch: 15 validation: roc_auc_score: 0.9999876317206767


100%|██████████| 53/53 [00:00<00:00, 78.29it/s]
100%|██████████| 53/53 [00:00<00:00, 365.19it/s]


epoch: 16 validation: roc_auc_score: 0.9999937439517377


100%|██████████| 53/53 [00:00<00:00, 77.48it/s]
100%|██████████| 53/53 [00:00<00:00, 343.83it/s]


epoch: 17 validation: roc_auc_score: 0.9999973633513071


100%|██████████| 53/53 [00:00<00:00, 79.76it/s]
100%|██████████| 53/53 [00:00<00:00, 331.77it/s]


epoch: 18 validation: roc_auc_score: 0.9999988734319222


100%|██████████| 53/53 [00:00<00:00, 78.58it/s]
100%|██████████| 53/53 [00:00<00:00, 324.75it/s]


epoch: 19 validation: roc_auc_score: 0.99999932885306


100%|██████████| 53/53 [00:00<00:00, 78.50it/s]
100%|██████████| 53/53 [00:00<00:00, 364.51it/s]


epoch: 20 validation: roc_auc_score: 0.9999998801523322


100%|██████████| 53/53 [00:00<00:00, 76.21it/s]
100%|██████████| 53/53 [00:00<00:00, 363.02it/s]


epoch: 21 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 75.95it/s]
100%|██████████| 53/53 [00:00<00:00, 226.22it/s]


epoch: 22 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 77.72it/s]
100%|██████████| 53/53 [00:00<00:00, 330.20it/s]


epoch: 23 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 74.63it/s]
100%|██████████| 53/53 [00:00<00:00, 327.88it/s]


epoch: 24 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 75.84it/s]
100%|██████████| 53/53 [00:00<00:00, 322.17it/s]


epoch: 25 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 75.25it/s]
100%|██████████| 53/53 [00:00<00:00, 320.84it/s]


epoch: 26 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 77.51it/s]
100%|██████████| 53/53 [00:00<00:00, 348.37it/s]


epoch: 27 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 75.81it/s]
100%|██████████| 53/53 [00:00<00:00, 364.23it/s]


epoch: 28 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 78.71it/s]
100%|██████████| 53/53 [00:00<00:00, 361.29it/s]


epoch: 29 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 67.41it/s]
100%|██████████| 53/53 [00:00<00:00, 353.40it/s]


epoch: 30 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 77.36it/s]
100%|██████████| 53/53 [00:00<00:00, 319.20it/s]


epoch: 31 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 74.87it/s]
100%|██████████| 53/53 [00:00<00:00, 333.35it/s]


epoch: 32 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 78.26it/s]
100%|██████████| 53/53 [00:00<00:00, 351.10it/s]


epoch: 33 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 77.39it/s]
100%|██████████| 53/53 [00:00<00:00, 298.95it/s]


epoch: 34 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 76.65it/s]
100%|██████████| 53/53 [00:00<00:00, 339.98it/s]


epoch: 35 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 75.29it/s]
100%|██████████| 53/53 [00:00<00:00, 329.89it/s]


epoch: 36 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 76.66it/s]
100%|██████████| 53/53 [00:00<00:00, 221.77it/s]


epoch: 37 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 75.46it/s]
100%|██████████| 53/53 [00:00<00:00, 317.76it/s]


epoch: 38 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 75.93it/s]
100%|██████████| 53/53 [00:00<00:00, 363.90it/s]


epoch: 39 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 76.24it/s]
100%|██████████| 53/53 [00:00<00:00, 354.71it/s]


epoch: 40 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 78.24it/s]
100%|██████████| 53/53 [00:00<00:00, 338.78it/s]


epoch: 41 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 76.67it/s]
100%|██████████| 53/53 [00:00<00:00, 318.25it/s]


epoch: 42 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 77.83it/s]
100%|██████████| 53/53 [00:00<00:00, 323.76it/s]


epoch: 43 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 77.12it/s]
100%|██████████| 53/53 [00:00<00:00, 331.56it/s]


epoch: 44 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 74.99it/s]
100%|██████████| 53/53 [00:00<00:00, 235.43it/s]


epoch: 45 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 74.85it/s]
100%|██████████| 53/53 [00:00<00:00, 311.24it/s]


epoch: 46 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 74.69it/s]
100%|██████████| 53/53 [00:00<00:00, 325.71it/s]


epoch: 47 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 71.99it/s]
100%|██████████| 53/53 [00:00<00:00, 350.78it/s]


epoch: 48 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 77.63it/s]
100%|██████████| 53/53 [00:00<00:00, 318.84it/s]


epoch: 49 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 76.00it/s]
100%|██████████| 53/53 [00:00<00:00, 334.41it/s]


epoch: 50 validation: roc_auc_score: 0.9999999999999999


100%|██████████| 53/53 [00:00<00:00, 74.63it/s]
100%|██████████| 53/53 [00:00<00:00, 326.69it/s]


epoch: 51 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 75.75it/s]
100%|██████████| 53/53 [00:00<00:00, 310.54it/s]


epoch: 52 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 68.14it/s]
100%|██████████| 53/53 [00:00<00:00, 348.76it/s]


epoch: 53 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 74.07it/s]
100%|██████████| 53/53 [00:00<00:00, 306.54it/s]


epoch: 54 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 73.27it/s]
100%|██████████| 53/53 [00:00<00:00, 328.20it/s]


epoch: 55 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 74.45it/s]
100%|██████████| 53/53 [00:00<00:00, 343.30it/s]


epoch: 56 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 74.82it/s]
100%|██████████| 53/53 [00:00<00:00, 308.56it/s]


epoch: 57 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 73.50it/s]
100%|██████████| 53/53 [00:00<00:00, 318.94it/s]


epoch: 58 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 73.55it/s]
100%|██████████| 53/53 [00:00<00:00, 308.76it/s]


epoch: 59 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 74.33it/s]
100%|██████████| 53/53 [00:00<00:00, 215.04it/s]


epoch: 60 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 76.99it/s]
100%|██████████| 53/53 [00:00<00:00, 316.54it/s]


epoch: 61 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 75.79it/s]
100%|██████████| 53/53 [00:00<00:00, 326.94it/s]


epoch: 62 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 74.35it/s]
100%|██████████| 53/53 [00:00<00:00, 310.51it/s]


epoch: 63 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 74.15it/s]
100%|██████████| 53/53 [00:00<00:00, 314.87it/s]


epoch: 64 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 74.14it/s]
100%|██████████| 53/53 [00:00<00:00, 323.78it/s]


epoch: 65 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 72.96it/s]
100%|██████████| 53/53 [00:00<00:00, 342.87it/s]


epoch: 66 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 75.17it/s]
100%|██████████| 53/53 [00:00<00:00, 356.62it/s]


epoch: 67 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 68.58it/s]
100%|██████████| 53/53 [00:00<00:00, 345.32it/s]


epoch: 68 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 78.53it/s]
100%|██████████| 53/53 [00:00<00:00, 326.85it/s]


epoch: 69 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 75.96it/s]
100%|██████████| 53/53 [00:00<00:00, 323.16it/s]


epoch: 70 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 77.24it/s]
100%|██████████| 53/53 [00:00<00:00, 330.37it/s]


epoch: 71 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 76.54it/s]
100%|██████████| 53/53 [00:00<00:00, 349.22it/s]


epoch: 72 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 78.22it/s]
100%|██████████| 53/53 [00:00<00:00, 350.45it/s]


epoch: 73 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 75.74it/s]
100%|██████████| 53/53 [00:00<00:00, 319.03it/s]


epoch: 74 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 75.22it/s]
100%|██████████| 53/53 [00:00<00:00, 310.10it/s]


epoch: 75 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 67.76it/s]
100%|██████████| 53/53 [00:00<00:00, 305.19it/s]


epoch: 76 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 74.08it/s]
100%|██████████| 53/53 [00:00<00:00, 332.04it/s]


epoch: 77 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 79.12it/s]
100%|██████████| 53/53 [00:00<00:00, 342.40it/s]


epoch: 78 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 79.17it/s]
100%|██████████| 53/53 [00:00<00:00, 355.27it/s]


epoch: 79 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 77.12it/s]
100%|██████████| 53/53 [00:00<00:00, 326.21it/s]


epoch: 80 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 75.03it/s]
100%|██████████| 53/53 [00:00<00:00, 318.45it/s]


epoch: 81 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 74.26it/s]
100%|██████████| 53/53 [00:00<00:00, 325.61it/s]


epoch: 82 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 74.94it/s]
100%|██████████| 53/53 [00:00<00:00, 219.73it/s]


epoch: 83 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 76.61it/s]
100%|██████████| 53/53 [00:00<00:00, 324.80it/s]


epoch: 84 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 75.08it/s]
100%|██████████| 53/53 [00:00<00:00, 310.53it/s]


epoch: 85 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 77.38it/s]
100%|██████████| 53/53 [00:00<00:00, 328.14it/s]


epoch: 86 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 76.35it/s]
100%|██████████| 53/53 [00:00<00:00, 334.93it/s]


epoch: 87 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 78.47it/s]
100%|██████████| 53/53 [00:00<00:00, 320.31it/s]


epoch: 88 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 75.80it/s]
100%|██████████| 53/53 [00:00<00:00, 332.48it/s]


epoch: 89 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 74.43it/s]
100%|██████████| 53/53 [00:00<00:00, 355.65it/s]


epoch: 90 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 70.68it/s]
100%|██████████| 53/53 [00:00<00:00, 309.04it/s]


epoch: 91 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 76.49it/s]
100%|██████████| 53/53 [00:00<00:00, 307.50it/s]


epoch: 92 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 75.31it/s]
100%|██████████| 53/53 [00:00<00:00, 325.90it/s]


epoch: 93 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 76.75it/s]
100%|██████████| 53/53 [00:00<00:00, 339.45it/s]


epoch: 94 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 78.70it/s]
100%|██████████| 53/53 [00:00<00:00, 343.08it/s]


epoch: 95 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 77.96it/s]
100%|██████████| 53/53 [00:00<00:00, 347.39it/s]


epoch: 96 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 75.16it/s]
100%|██████████| 53/53 [00:00<00:00, 327.75it/s]


epoch: 97 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 77.43it/s]
100%|██████████| 53/53 [00:00<00:00, 228.63it/s]


epoch: 98 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 73.48it/s]
100%|██████████| 53/53 [00:00<00:00, 315.13it/s]

epoch: 99 validation: roc_auc_score: 1.0





In [None]:
test(model, test_dataloader_context, device)

100%|██████████| 14/14 [00:00<00:00, 300.55it/s]


0.7365658121934642

# [4] Deep Learning Recommender System

## NCF

In [None]:
class MultiLayerPerceptron(nn.Module):

    def __init__(self, input_dim, embed_dims, dropout, output_layer=True):
        super().__init__()
        layers = list()
        for embed_dim in embed_dims:
            layers.append(torch.nn.Linear(input_dim, embed_dim))
            layers.append(torch.nn.BatchNorm1d(embed_dim))
            layers.append(torch.nn.ReLU())
            layers.append(torch.nn.Dropout(p=dropout))
            input_dim = embed_dim
        if output_layer:
            layers.append(torch.nn.Linear(input_dim, 1))
        self.mlp = torch.nn.Sequential(*layers)

    def forward(self, x):
        """
        :param x: Float tensor of size ``(batch_size, embed_dim)``
        """
        return self.mlp(x)

In [None]:
class NeuralCollaborativeFiltering(nn.Module):

    def __init__(self, field_dims, user_field_idx, item_field_idx, embed_dim, mlp_dims, dropout):
        super().__init__()
        self.user_field_idx = user_field_idx
        self.item_field_idx = item_field_idx
        self.embedding = FeaturesEmbedding(field_dims, embed_dim)
        self.embed_output_dim = len(field_dims) * embed_dim
        self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout, output_layer=False)
        self.fc = torch.nn.Linear(mlp_dims[-1] + embed_dim, 1)

    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_user_fields)``
        """
        x = self.embedding(x)
        user_x = x[:, self.user_field_idx].squeeze(1)
        item_x = x[:, self.item_field_idx].squeeze(1)
        gmf = user_x * item_x
        x = self.mlp(x.view(-1, self.embed_output_dim))
        x = torch.cat([gmf, x], dim=1)
        x = self.fc(x).squeeze(1)
        return torch.sigmoid(x)

In [None]:
field_dims = np.array([len(user2idx), len(book2idx)], dtype=np.uint32)

In [None]:
criterion = torch.nn.BCELoss()
model = NeuralCollaborativeFiltering(field_dims=field_dims, 
                                    user_field_idx=np.array((0, ), dtype=np.long), item_field_idx=np.array((1, ), dtype=np.long), 
                                    embed_dim=embed_dim, mlp_dims=(16, 16), dropout=0.2).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, amsgrad=True, weight_decay=weight_decay)

for epoch in range(epochs):
    train(model, optimizer, r_train_dataloader, criterion, device)
    auc_score = test(model, r_train_dataloader, device)
    print('epoch:', epoch, 'validation: roc_auc_score:', auc_score)

100%|██████████| 53/53 [00:00<00:00, 202.65it/s]
100%|██████████| 53/53 [00:00<00:00, 540.15it/s]


epoch: 0 validation: roc_auc_score: 0.8140286883653547


100%|██████████| 53/53 [00:00<00:00, 227.83it/s]
100%|██████████| 53/53 [00:00<00:00, 541.48it/s]


epoch: 1 validation: roc_auc_score: 0.9395223867979836


100%|██████████| 53/53 [00:00<00:00, 229.72it/s]
100%|██████████| 53/53 [00:00<00:00, 519.93it/s]


epoch: 2 validation: roc_auc_score: 0.9839747566081632


100%|██████████| 53/53 [00:00<00:00, 230.12it/s]
100%|██████████| 53/53 [00:00<00:00, 566.54it/s]


epoch: 3 validation: roc_auc_score: 0.9921353507889173


100%|██████████| 53/53 [00:00<00:00, 249.39it/s]
100%|██████████| 53/53 [00:00<00:00, 526.46it/s]


epoch: 4 validation: roc_auc_score: 0.9947879400056625


100%|██████████| 53/53 [00:00<00:00, 241.95it/s]
100%|██████████| 53/53 [00:00<00:00, 315.33it/s]


epoch: 5 validation: roc_auc_score: 0.996450391458089


100%|██████████| 53/53 [00:00<00:00, 242.73it/s]
100%|██████████| 53/53 [00:00<00:00, 532.83it/s]


epoch: 6 validation: roc_auc_score: 0.9977929511385215


100%|██████████| 53/53 [00:00<00:00, 229.72it/s]
100%|██████████| 53/53 [00:00<00:00, 555.63it/s]


epoch: 7 validation: roc_auc_score: 0.9983755845327892


100%|██████████| 53/53 [00:00<00:00, 234.59it/s]
100%|██████████| 53/53 [00:00<00:00, 539.30it/s]


epoch: 8 validation: roc_auc_score: 0.998843244424338


100%|██████████| 53/53 [00:00<00:00, 226.62it/s]
100%|██████████| 53/53 [00:00<00:00, 530.21it/s]


epoch: 9 validation: roc_auc_score: 0.999203360092202


100%|██████████| 53/53 [00:00<00:00, 234.13it/s]
100%|██████████| 53/53 [00:00<00:00, 485.74it/s]


epoch: 10 validation: roc_auc_score: 0.99944316471707


100%|██████████| 53/53 [00:00<00:00, 228.38it/s]
100%|██████████| 53/53 [00:00<00:00, 580.82it/s]


epoch: 11 validation: roc_auc_score: 0.9995279642871432


100%|██████████| 53/53 [00:00<00:00, 259.91it/s]
100%|██████████| 53/53 [00:00<00:00, 579.85it/s]


epoch: 12 validation: roc_auc_score: 0.9996742657852705


100%|██████████| 53/53 [00:00<00:00, 253.94it/s]
100%|██████████| 53/53 [00:00<00:00, 321.16it/s]


epoch: 13 validation: roc_auc_score: 0.9997183365570153


100%|██████████| 53/53 [00:00<00:00, 239.82it/s]
100%|██████████| 53/53 [00:00<00:00, 503.53it/s]


epoch: 14 validation: roc_auc_score: 0.9997883256549691


100%|██████████| 53/53 [00:00<00:00, 248.69it/s]
100%|██████████| 53/53 [00:00<00:00, 534.91it/s]


epoch: 15 validation: roc_auc_score: 0.9998528329548007


100%|██████████| 53/53 [00:00<00:00, 250.11it/s]
100%|██████████| 53/53 [00:00<00:00, 411.35it/s]


epoch: 16 validation: roc_auc_score: 0.9998862767319404


100%|██████████| 53/53 [00:00<00:00, 224.45it/s]
100%|██████████| 53/53 [00:00<00:00, 557.20it/s]


epoch: 17 validation: roc_auc_score: 0.999887767396342


100%|██████████| 53/53 [00:00<00:00, 234.18it/s]
100%|██████████| 53/53 [00:00<00:00, 552.69it/s]


epoch: 18 validation: roc_auc_score: 0.9999337135200761


100%|██████████| 53/53 [00:00<00:00, 219.29it/s]
100%|██████████| 53/53 [00:00<00:00, 518.48it/s]


epoch: 19 validation: roc_auc_score: 0.9999389789314302


100%|██████████| 53/53 [00:00<00:00, 236.80it/s]
100%|██████████| 53/53 [00:00<00:00, 468.89it/s]


epoch: 20 validation: roc_auc_score: 0.9999422006899757


100%|██████████| 53/53 [00:00<00:00, 149.88it/s]
100%|██████████| 53/53 [00:00<00:00, 376.17it/s]


epoch: 21 validation: roc_auc_score: 0.9999723505796473


100%|██████████| 53/53 [00:00<00:00, 160.36it/s]
100%|██████████| 53/53 [00:00<00:00, 509.87it/s]


epoch: 22 validation: roc_auc_score: 0.9999710041730909


100%|██████████| 53/53 [00:00<00:00, 233.42it/s]
100%|██████████| 53/53 [00:00<00:00, 534.95it/s]


epoch: 23 validation: roc_auc_score: 0.9999721822788277


100%|██████████| 53/53 [00:00<00:00, 211.53it/s]
100%|██████████| 53/53 [00:00<00:00, 519.70it/s]


epoch: 24 validation: roc_auc_score: 0.9999759329828061


100%|██████████| 53/53 [00:00<00:00, 226.48it/s]
100%|██████████| 53/53 [00:00<00:00, 546.07it/s]


epoch: 25 validation: roc_auc_score: 0.999972759310209


100%|██████████| 53/53 [00:00<00:00, 216.24it/s]
100%|██████████| 53/53 [00:00<00:00, 550.38it/s]


epoch: 26 validation: roc_auc_score: 0.9999749712638373


100%|██████████| 53/53 [00:00<00:00, 208.22it/s]
100%|██████████| 53/53 [00:00<00:00, 489.43it/s]


epoch: 27 validation: roc_auc_score: 0.9999772553463881


100%|██████████| 53/53 [00:00<00:00, 232.29it/s]
100%|██████████| 53/53 [00:00<00:00, 328.73it/s]


epoch: 28 validation: roc_auc_score: 0.999986992750947


100%|██████████| 53/53 [00:00<00:00, 229.31it/s]
100%|██████████| 53/53 [00:00<00:00, 507.50it/s]


epoch: 29 validation: roc_auc_score: 0.9999830497031751


100%|██████████| 53/53 [00:00<00:00, 240.55it/s]
100%|██████████| 53/53 [00:00<00:00, 553.10it/s]


epoch: 30 validation: roc_auc_score: 0.9999928111507084


100%|██████████| 53/53 [00:00<00:00, 235.13it/s]
100%|██████████| 53/53 [00:00<00:00, 524.67it/s]


epoch: 31 validation: roc_auc_score: 0.9999905511111317


100%|██████████| 53/53 [00:00<00:00, 231.63it/s]
100%|██████████| 53/53 [00:00<00:00, 531.29it/s]


epoch: 32 validation: roc_auc_score: 0.999990478982209


100%|██████████| 53/53 [00:00<00:00, 218.80it/s]
100%|██████████| 53/53 [00:00<00:00, 516.27it/s]


epoch: 33 validation: roc_auc_score: 0.9999874976534058


100%|██████████| 53/53 [00:00<00:00, 223.65it/s]
100%|██████████| 53/53 [00:00<00:00, 541.97it/s]


epoch: 34 validation: roc_auc_score: 0.9999863916765915


100%|██████████| 53/53 [00:00<00:00, 249.16it/s]
100%|██████████| 53/53 [00:00<00:00, 530.92it/s]


epoch: 35 validation: roc_auc_score: 0.9999807656206243


100%|██████████| 53/53 [00:00<00:00, 238.46it/s]
100%|██████████| 53/53 [00:00<00:00, 317.11it/s]


epoch: 36 validation: roc_auc_score: 0.9999904068532863


100%|██████████| 53/53 [00:00<00:00, 231.02it/s]
100%|██████████| 53/53 [00:00<00:00, 525.79it/s]


epoch: 37 validation: roc_auc_score: 0.9999867042352565


100%|██████████| 53/53 [00:00<00:00, 242.37it/s]
100%|██████████| 53/53 [00:00<00:00, 524.94it/s]


epoch: 38 validation: roc_auc_score: 0.9999944941589036


100%|██████████| 53/53 [00:00<00:00, 237.15it/s]
100%|██████████| 53/53 [00:00<00:00, 541.22it/s]


epoch: 39 validation: roc_auc_score: 0.9999943258580841


100%|██████████| 53/53 [00:00<00:00, 235.27it/s]
100%|██████████| 53/53 [00:00<00:00, 529.84it/s]


epoch: 40 validation: roc_auc_score: 0.9999940613853677


100%|██████████| 53/53 [00:00<00:00, 217.69it/s]
100%|██████████| 53/53 [00:00<00:00, 489.94it/s]


epoch: 41 validation: roc_auc_score: 0.9999924264631207


100%|██████████| 53/53 [00:00<00:00, 232.88it/s]
100%|██████████| 53/53 [00:00<00:00, 546.34it/s]


epoch: 42 validation: roc_auc_score: 0.9999957684365373


100%|██████████| 53/53 [00:00<00:00, 239.26it/s]
100%|██████████| 53/53 [00:00<00:00, 499.02it/s]


epoch: 43 validation: roc_auc_score: 0.9999941816002388


100%|██████████| 53/53 [00:00<00:00, 173.41it/s]
100%|██████████| 53/53 [00:00<00:00, 523.15it/s]


epoch: 44 validation: roc_auc_score: 0.9999958886514084


100%|██████████| 53/53 [00:00<00:00, 254.68it/s]
100%|██████████| 53/53 [00:00<00:00, 511.33it/s]


epoch: 45 validation: roc_auc_score: 0.999995407791924


100%|██████████| 53/53 [00:00<00:00, 224.44it/s]
100%|██████████| 53/53 [00:00<00:00, 518.90it/s]


epoch: 46 validation: roc_auc_score: 0.9999869206220245


100%|██████████| 53/53 [00:00<00:00, 232.08it/s]
100%|██████████| 53/53 [00:00<00:00, 521.63it/s]


epoch: 47 validation: roc_auc_score: 0.9999941575572646


100%|██████████| 53/53 [00:00<00:00, 234.91it/s]
100%|██████████| 53/53 [00:00<00:00, 499.00it/s]


epoch: 48 validation: roc_auc_score: 0.9999946624597232


100%|██████████| 53/53 [00:00<00:00, 230.80it/s]
100%|██████████| 53/53 [00:00<00:00, 506.56it/s]


epoch: 49 validation: roc_auc_score: 0.9999948067175686


100%|██████████| 53/53 [00:00<00:00, 220.85it/s]
100%|██████████| 53/53 [00:00<00:00, 563.69it/s]


epoch: 50 validation: roc_auc_score: 0.999987641911251


100%|██████████| 53/53 [00:00<00:00, 232.25it/s]
100%|██████████| 53/53 [00:00<00:00, 530.08it/s]


epoch: 51 validation: roc_auc_score: 0.999991464744152


100%|██████████| 53/53 [00:00<00:00, 170.96it/s]
100%|██████████| 53/53 [00:00<00:00, 543.12it/s]


epoch: 52 validation: roc_auc_score: 0.9999855261295197


100%|██████████| 53/53 [00:00<00:00, 242.18it/s]
100%|██████████| 53/53 [00:00<00:00, 564.93it/s]


epoch: 53 validation: roc_auc_score: 0.999986992750947


100%|██████████| 53/53 [00:00<00:00, 232.01it/s]
100%|██████████| 53/53 [00:00<00:00, 531.35it/s]


epoch: 54 validation: roc_auc_score: 0.9999950231043364


100%|██████████| 53/53 [00:00<00:00, 235.82it/s]
100%|██████████| 53/53 [00:00<00:00, 550.46it/s]


epoch: 55 validation: roc_auc_score: 0.9999951192762334


100%|██████████| 53/53 [00:00<00:00, 242.89it/s]
100%|██████████| 53/53 [00:00<00:00, 545.00it/s]


epoch: 56 validation: roc_auc_score: 0.9999915609160489


100%|██████████| 53/53 [00:00<00:00, 234.07it/s]
100%|██████████| 53/53 [00:00<00:00, 527.05it/s]


epoch: 57 validation: roc_auc_score: 0.9999926428498888


100%|██████████| 53/53 [00:00<00:00, 220.16it/s]
100%|██████████| 53/53 [00:00<00:00, 535.79it/s]


epoch: 58 validation: roc_auc_score: 0.9999945422448521


100%|██████████| 53/53 [00:00<00:00, 237.05it/s]
100%|██████████| 53/53 [00:00<00:00, 318.58it/s]


epoch: 59 validation: roc_auc_score: 0.9999847086683963


100%|██████████| 53/53 [00:00<00:00, 242.38it/s]
100%|██████████| 53/53 [00:00<00:00, 516.77it/s]


epoch: 60 validation: roc_auc_score: 0.9999953356630014


100%|██████████| 53/53 [00:00<00:00, 251.08it/s]
100%|██████████| 53/53 [00:00<00:00, 541.62it/s]


epoch: 61 validation: roc_auc_score: 0.999996393553867


100%|██████████| 53/53 [00:00<00:00, 217.04it/s]
100%|██████████| 53/53 [00:00<00:00, 543.65it/s]


epoch: 62 validation: roc_auc_score: 0.999997042714171


100%|██████████| 53/53 [00:00<00:00, 227.58it/s]
100%|██████████| 53/53 [00:00<00:00, 515.78it/s]


epoch: 63 validation: roc_auc_score: 0.9999949509754139


100%|██████████| 53/53 [00:00<00:00, 225.92it/s]
100%|██████████| 53/53 [00:00<00:00, 549.37it/s]


epoch: 64 validation: roc_auc_score: 0.9999960088662795


100%|██████████| 53/53 [00:00<00:00, 238.70it/s]
100%|██████████| 53/53 [00:00<00:00, 551.96it/s]


epoch: 65 validation: roc_auc_score: 0.9999954318348981


100%|██████████| 53/53 [00:00<00:00, 252.75it/s]
100%|██████████| 53/53 [00:00<00:00, 518.97it/s]


epoch: 66 validation: roc_auc_score: 0.9999949990613622


100%|██████████| 53/53 [00:00<00:00, 186.06it/s]
100%|██████████| 53/53 [00:00<00:00, 542.81it/s]


epoch: 67 validation: roc_auc_score: 0.9999969946282224


100%|██████████| 53/53 [00:00<00:00, 242.53it/s]
100%|██████████| 53/53 [00:00<00:00, 546.03it/s]


epoch: 68 validation: roc_auc_score: 0.9999971869720163


100%|██████████| 53/53 [00:00<00:00, 228.26it/s]
100%|██████████| 53/53 [00:00<00:00, 570.64it/s]


epoch: 69 validation: roc_auc_score: 0.9999949509754139


100%|██████████| 53/53 [00:00<00:00, 249.28it/s]
100%|██████████| 53/53 [00:00<00:00, 541.09it/s]


epoch: 70 validation: roc_auc_score: 0.9999960329092538


100%|██████████| 53/53 [00:00<00:00, 219.62it/s]
100%|██████████| 53/53 [00:00<00:00, 501.14it/s]


epoch: 71 validation: roc_auc_score: 0.9999958886514084


100%|██████████| 53/53 [00:00<00:00, 244.12it/s]
100%|██████████| 53/53 [00:00<00:00, 529.16it/s]


epoch: 72 validation: roc_auc_score: 0.9999969224993


100%|██████████| 53/53 [00:00<00:00, 243.84it/s]
100%|██████████| 53/53 [00:00<00:00, 562.59it/s]


epoch: 73 validation: roc_auc_score: 0.9999972591009388


100%|██████████| 53/53 [00:00<00:00, 241.37it/s]
100%|██████████| 53/53 [00:00<00:00, 530.58it/s]


epoch: 74 validation: roc_auc_score: 0.9999968503703772


100%|██████████| 53/53 [00:00<00:00, 182.41it/s]
100%|██████████| 53/53 [00:00<00:00, 521.58it/s]


epoch: 75 validation: roc_auc_score: 0.99999737931581


100%|██████████| 53/53 [00:00<00:00, 226.50it/s]
100%|██████████| 53/53 [00:00<00:00, 515.86it/s]


epoch: 76 validation: roc_auc_score: 0.9999973552728358


100%|██████████| 53/53 [00:00<00:00, 248.59it/s]
100%|██████████| 53/53 [00:00<00:00, 556.79it/s]


epoch: 77 validation: roc_auc_score: 0.9999937728696771


100%|██████████| 53/53 [00:00<00:00, 233.26it/s]
100%|██████████| 53/53 [00:00<00:00, 534.45it/s]


epoch: 78 validation: roc_auc_score: 0.9999892047045754


100%|██████████| 53/53 [00:00<00:00, 250.24it/s]
100%|██████████| 53/53 [00:00<00:00, 521.39it/s]


epoch: 79 validation: roc_auc_score: 0.9999956001357178


100%|██████████| 53/53 [00:00<00:00, 243.28it/s]
100%|██████████| 53/53 [00:00<00:00, 540.70it/s]


epoch: 80 validation: roc_auc_score: 0.9999956001357178


100%|██████████| 53/53 [00:00<00:00, 251.95it/s]
100%|██████████| 53/53 [00:00<00:00, 549.00it/s]


epoch: 81 validation: roc_auc_score: 0.9999974274017585


100%|██████████| 53/53 [00:00<00:00, 237.37it/s]
100%|██████████| 53/53 [00:00<00:00, 320.19it/s]


epoch: 82 validation: roc_auc_score: 0.9999908396268222


100%|██████████| 53/53 [00:00<00:00, 224.80it/s]
100%|██████████| 53/53 [00:00<00:00, 478.36it/s]


epoch: 83 validation: roc_auc_score: 0.9999935564829091


100%|██████████| 53/53 [00:00<00:00, 222.80it/s]
100%|██████████| 53/53 [00:00<00:00, 541.60it/s]


epoch: 84 validation: roc_auc_score: 0.9999945662878263


100%|██████████| 53/53 [00:00<00:00, 243.31it/s]
100%|██████████| 53/53 [00:00<00:00, 524.57it/s]


epoch: 85 validation: roc_auc_score: 0.9999972350579647


100%|██████████| 53/53 [00:00<00:00, 223.41it/s]
100%|██████████| 53/53 [00:00<00:00, 538.07it/s]


epoch: 86 validation: roc_auc_score: 0.9999964656827897


100%|██████████| 53/53 [00:00<00:00, 229.99it/s]
100%|██████████| 53/53 [00:00<00:00, 496.61it/s]


epoch: 87 validation: roc_auc_score: 0.9999967061125319


100%|██████████| 53/53 [00:00<00:00, 234.18it/s]
100%|██████████| 53/53 [00:00<00:00, 532.55it/s]


epoch: 88 validation: roc_auc_score: 0.9999972350579647


100%|██████████| 53/53 [00:00<00:00, 227.96it/s]
100%|██████████| 53/53 [00:00<00:00, 542.60it/s]


epoch: 89 validation: roc_auc_score: 0.9999952875770529


100%|██████████| 53/53 [00:00<00:00, 183.86it/s]
100%|██████████| 53/53 [00:00<00:00, 500.31it/s]


epoch: 90 validation: roc_auc_score: 0.9999970427141709


100%|██████████| 53/53 [00:00<00:00, 245.45it/s]
100%|██████████| 53/53 [00:00<00:00, 525.92it/s]


epoch: 91 validation: roc_auc_score: 0.9999964175968412


100%|██████████| 53/53 [00:00<00:00, 224.47it/s]
100%|██████████| 53/53 [00:00<00:00, 442.55it/s]


epoch: 92 validation: roc_auc_score: 0.9999969946282224


100%|██████████| 53/53 [00:00<00:00, 230.63it/s]
100%|██████████| 53/53 [00:00<00:00, 536.16it/s]


epoch: 93 validation: roc_auc_score: 0.9999965858976608


100%|██████████| 53/53 [00:00<00:00, 232.73it/s]
100%|██████████| 53/53 [00:00<00:00, 484.18it/s]


epoch: 94 validation: roc_auc_score: 0.9999971869720163


100%|██████████| 53/53 [00:00<00:00, 241.19it/s]
100%|██████████| 53/53 [00:00<00:00, 570.44it/s]


epoch: 95 validation: roc_auc_score: 0.9999976918744748


100%|██████████| 53/53 [00:00<00:00, 233.90it/s]
100%|██████████| 53/53 [00:00<00:00, 533.67it/s]


epoch: 96 validation: roc_auc_score: 0.9999957924795115


100%|██████████| 53/53 [00:00<00:00, 249.09it/s]
100%|██████████| 53/53 [00:00<00:00, 525.08it/s]


epoch: 97 validation: roc_auc_score: 0.999997475487707


100%|██████████| 53/53 [00:00<00:00, 181.42it/s]
100%|██████████| 53/53 [00:00<00:00, 572.42it/s]


epoch: 98 validation: roc_auc_score: 0.9999977399604232


100%|██████████| 53/53 [00:00<00:00, 245.14it/s]
100%|██████████| 53/53 [00:00<00:00, 541.93it/s]

epoch: 99 validation: roc_auc_score: 0.999998028476114





In [None]:
test(model, r_test_dataloader, device)

100%|██████████| 14/14 [00:00<00:00, 485.36it/s]


0.7097445971603072

## AutoRec

In [None]:
# AutoRec에 맞는 형태로 데이터 변형해주기 위한 클래스
class RecDataset(Dataset):
    def __init__(self, n_user, n_item, list_user, list_item, data):
        self.n_user = n_user
        self.n_item = n_item
        self.list_user = list_user
        self.list_item = list_item
        
        rating = torch.tensor(data.rating.values)
        user_id = data.user_id
        item_id = data.isbn
        indices = torch.tensor(list(zip(user_id, item_id))).t()

        self.X = (
            torch.sparse_coo_tensor(indices, rating, (n_user, n_item))
            .to_dense()
            .to(dtype=torch.float)
        )

        ones = torch.ones_like(rating)
        self.mask = (
            torch.sparse_coo_tensor(indices, ones, (n_user, n_item))
            .to_dense()
            .to(dtype=torch.float)
        )


    def __getitem__(self, index):
        return self.X[index], self.mask[index]

    def __len__(self):
        return self.n_user

    def get_mat(self):
        return self.X, self.mask

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, d, k, dropout=0.1):
        super(AutoEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(d, k),
            nn.Dropout(dropout),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(k, d),
            nn.Dropout(dropout),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.decoder(self.encoder(x))

In [None]:
class AutoRec(nn.Module):
    def __init__(self, d, k, learning_rate, batch_size):
        super(AutoRec, self).__init__()
        self.batch_size = batch_size
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = AutoEncoder(d, k).to(self.device)
        self.optimizer = torch.optim.SGD(self.model.parameters(), learning_rate, momentum=0.9, weight_decay=1e-4)
        self.feature_size = d # n_user/n_item

    def run(self, trainset, testlist, num_epoch):
        for epoch in range(1, num_epoch + 1):
            train_loader = DataLoader(trainset, self.batch_size, shuffle=True, pin_memory=True)
            self.train(train_loader, epoch)
            self.test(trainset, testlist)
            
    def train(self, train_loader, epoch):
        self.model.train()
        features = Variable(torch.FloatTensor(self.batch_size, self.feature_size))
        masks = Variable(torch.FloatTensor(self.batch_size, self.feature_size))

        for bid, (feature, mask) in enumerate(train_loader):
            if mask.shape[0] == self.batch_size:
                features.data.copy_(feature)
                masks.data.copy_(mask)
            else:
                features = Variable(feature)
                masks = Variable(mask)
            features, masks = features.to(self.device), masks.to(self.device)
            self.optimizer.zero_grad()
            output = self.model(features)
            loss = F.mse_loss(output * masks, features * masks)
            loss.backward()
            self.optimizer.step()

        print ("Epoch %d, train end." % epoch)

    def test(self, trainset, testlist):
        self.model.eval()
        with torch.no_grad():
            x_mat, mask = trainset.get_mat()
            features = Variable(x_mat).to(self.device)
            xc = self.model(features)
            xc = xc.cpu().data.numpy()

            rmse = 0.0
            for idx, (i, j, r) in testlist.iterrows():
                i, j = int(i), int(j)
                rmse += (xc[i][j]-r)*(xc[i][j]-r)
            rmse = math.sqrt(rmse / len(testlist))

        print (" Test RMSE = %f" % rmse)

In [None]:
train_set = RecDataset(len(user2idx), len(book2idx), ratings['user_id'].unique(), ratings['isbn'].unique(), ae_train)

In [None]:
d = len(book2idx)
mod = AutoRec(d=d, k=embed_dim,
          learning_rate = learning_rate,
          batch_size=batch_size)

mod.run(train_set, ae_test, num_epoch=epochs)

Epoch 1, train end.
 Test RMSE = 0.504176
Epoch 2, train end.
 Test RMSE = 0.504176
Epoch 3, train end.
 Test RMSE = 0.504176
Epoch 4, train end.
 Test RMSE = 0.504176
Epoch 5, train end.
 Test RMSE = 0.504176
Epoch 6, train end.
 Test RMSE = 0.504176
Epoch 7, train end.
 Test RMSE = 0.504176
Epoch 8, train end.
 Test RMSE = 0.504176
Epoch 9, train end.
 Test RMSE = 0.504176
Epoch 10, train end.
 Test RMSE = 0.504176
Epoch 11, train end.
 Test RMSE = 0.504176
Epoch 12, train end.
 Test RMSE = 0.504176
Epoch 13, train end.
 Test RMSE = 0.504176
Epoch 14, train end.
 Test RMSE = 0.504176
Epoch 15, train end.
 Test RMSE = 0.504176
Epoch 16, train end.
 Test RMSE = 0.504176
Epoch 17, train end.
 Test RMSE = 0.504176
Epoch 18, train end.
 Test RMSE = 0.504176
Epoch 19, train end.
 Test RMSE = 0.504176
Epoch 20, train end.
 Test RMSE = 0.504176
Epoch 21, train end.
 Test RMSE = 0.504176
Epoch 22, train end.
 Test RMSE = 0.504176
Epoch 23, train end.
 Test RMSE = 0.504176
Epoch 24, train end.

## WDN

In [None]:
class WideAndDeepModel(nn.Module):

    def __init__(self, field_dims: np.ndarray, embed_dim: int, mlp_dims: tuple, dropout: float):
        super().__init__()
        self.linear = FeaturesLinear(field_dims)
        self.embedding = FeaturesEmbedding(field_dims, embed_dim)
        self.embed_output_dim = len(field_dims) * embed_dim
        self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout)

    def forward(self, x: torch.Tensor):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        embed_x = self.embedding(x)
        x = self.linear(x) + self.mlp(embed_x.view(-1, self.embed_output_dim))
        return torch.sigmoid(x.squeeze(1))

In [None]:
field_dims = np.array([len(user2idx), len(book2idx)], dtype=np.uint32)

In [None]:
criterion = torch.nn.BCELoss()
model = WideAndDeepModel(field_dims, embed_dim, mlp_dims=(16, 16), dropout=0.2).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, amsgrad=True, weight_decay=weight_decay)

for epoch in range(epochs):
    train(model, optimizer, r_train_dataloader, criterion, device)
    auc_score = test(model, r_train_dataloader, device)
    print('epoch:', epoch, 'validation: roc_auc_score:', auc_score)

100%|██████████| 53/53 [00:00<00:00, 227.84it/s]
100%|██████████| 53/53 [00:00<00:00, 523.17it/s]


epoch: 0 validation: roc_auc_score: 0.5627054832791694


100%|██████████| 53/53 [00:00<00:00, 225.66it/s]
100%|██████████| 53/53 [00:00<00:00, 467.24it/s]


epoch: 1 validation: roc_auc_score: 0.8867796508536988


100%|██████████| 53/53 [00:00<00:00, 218.96it/s]
100%|██████████| 53/53 [00:00<00:00, 509.78it/s]


epoch: 2 validation: roc_auc_score: 0.9799546871682069


100%|██████████| 53/53 [00:00<00:00, 210.09it/s]
100%|██████████| 53/53 [00:00<00:00, 310.95it/s]


epoch: 3 validation: roc_auc_score: 0.9918841257512949


100%|██████████| 53/53 [00:00<00:00, 216.15it/s]
100%|██████████| 53/53 [00:00<00:00, 505.47it/s]


epoch: 4 validation: roc_auc_score: 0.9956686101083588


100%|██████████| 53/53 [00:00<00:00, 246.08it/s]
100%|██████████| 53/53 [00:00<00:00, 499.89it/s]


epoch: 5 validation: roc_auc_score: 0.997316587690305


100%|██████████| 53/53 [00:00<00:00, 231.67it/s]
100%|██████████| 53/53 [00:00<00:00, 565.83it/s]


epoch: 6 validation: roc_auc_score: 0.998158404346662


100%|██████████| 53/53 [00:00<00:00, 246.82it/s]
100%|██████████| 53/53 [00:00<00:00, 556.65it/s]


epoch: 7 validation: roc_auc_score: 0.9985812481342652


100%|██████████| 53/53 [00:00<00:00, 224.27it/s]
100%|██████████| 53/53 [00:00<00:00, 473.08it/s]


epoch: 8 validation: roc_auc_score: 0.9990423923798004


100%|██████████| 53/53 [00:00<00:00, 231.46it/s]
100%|██████████| 53/53 [00:00<00:00, 534.08it/s]


epoch: 9 validation: roc_auc_score: 0.9991933341719522


100%|██████████| 53/53 [00:00<00:00, 232.17it/s]
100%|██████████| 53/53 [00:00<00:00, 478.19it/s]


epoch: 10 validation: roc_auc_score: 0.999398973730454


100%|██████████| 53/53 [00:00<00:00, 211.19it/s]
100%|██████████| 53/53 [00:00<00:00, 295.38it/s]


epoch: 11 validation: roc_auc_score: 0.9996233668088471


100%|██████████| 53/53 [00:00<00:00, 215.31it/s]
100%|██████████| 53/53 [00:00<00:00, 506.02it/s]


epoch: 12 validation: roc_auc_score: 0.9997358157992731


100%|██████████| 53/53 [00:00<00:00, 223.98it/s]
100%|██████████| 53/53 [00:00<00:00, 505.47it/s]


epoch: 13 validation: roc_auc_score: 0.9997284105632134


100%|██████████| 53/53 [00:00<00:00, 200.60it/s]
100%|██████████| 53/53 [00:00<00:00, 531.68it/s]


epoch: 14 validation: roc_auc_score: 0.9997822187395172


100%|██████████| 53/53 [00:00<00:00, 224.29it/s]
100%|██████████| 53/53 [00:00<00:00, 559.54it/s]


epoch: 15 validation: roc_auc_score: 0.9998168887083422


100%|██████████| 53/53 [00:00<00:00, 223.30it/s]
100%|██████████| 53/53 [00:00<00:00, 528.32it/s]


epoch: 16 validation: roc_auc_score: 0.9999188790049824


100%|██████████| 53/53 [00:00<00:00, 225.28it/s]
100%|██████████| 53/53 [00:00<00:00, 484.89it/s]


epoch: 17 validation: roc_auc_score: 0.999909742674779


100%|██████████| 53/53 [00:00<00:00, 218.34it/s]
100%|██████████| 53/53 [00:00<00:00, 523.11it/s]


epoch: 18 validation: roc_auc_score: 0.9999235433419811


100%|██████████| 53/53 [00:00<00:00, 167.63it/s]
100%|██████████| 53/53 [00:00<00:00, 522.32it/s]


epoch: 19 validation: roc_auc_score: 0.9999265487137586


100%|██████████| 53/53 [00:00<00:00, 221.36it/s]
100%|██████████| 53/53 [00:00<00:00, 510.98it/s]


epoch: 20 validation: roc_auc_score: 0.999932727758133


100%|██████████| 53/53 [00:00<00:00, 230.05it/s]
100%|██████████| 53/53 [00:00<00:00, 493.33it/s]


epoch: 21 validation: roc_auc_score: 0.9999431864519186


100%|██████████| 53/53 [00:00<00:00, 213.33it/s]
100%|██████████| 53/53 [00:00<00:00, 535.14it/s]


epoch: 22 validation: roc_auc_score: 0.9999626372180624


100%|██████████| 53/53 [00:00<00:00, 219.76it/s]
100%|██████████| 53/53 [00:00<00:00, 528.42it/s]


epoch: 23 validation: roc_auc_score: 0.9999588384281357


100%|██████████| 53/53 [00:00<00:00, 225.64it/s]
100%|██████████| 53/53 [00:00<00:00, 493.42it/s]


epoch: 24 validation: roc_auc_score: 0.9999591269438264


100%|██████████| 53/53 [00:00<00:00, 224.98it/s]
100%|██████████| 53/53 [00:00<00:00, 502.40it/s]


epoch: 25 validation: roc_auc_score: 0.9999685037037721


100%|██████████| 53/53 [00:00<00:00, 240.15it/s]
100%|██████████| 53/53 [00:00<00:00, 514.96it/s]


epoch: 26 validation: roc_auc_score: 0.9999817513825672


100%|██████████| 53/53 [00:00<00:00, 177.79it/s]
100%|██████████| 53/53 [00:00<00:00, 508.93it/s]


epoch: 27 validation: roc_auc_score: 0.9999823764998969


100%|██████████| 53/53 [00:00<00:00, 226.12it/s]
100%|██████████| 53/53 [00:00<00:00, 516.70it/s]


epoch: 28 validation: roc_auc_score: 0.9999713648177042


100%|██████████| 53/53 [00:00<00:00, 205.83it/s]
100%|██████████| 53/53 [00:00<00:00, 507.49it/s]


epoch: 29 validation: roc_auc_score: 0.9999788902686351


100%|██████████| 53/53 [00:00<00:00, 228.61it/s]
100%|██████████| 53/53 [00:00<00:00, 545.07it/s]


epoch: 30 validation: roc_auc_score: 0.9999872812666377


100%|██████████| 53/53 [00:00<00:00, 211.39it/s]
100%|██████████| 53/53 [00:00<00:00, 563.89it/s]


epoch: 31 validation: roc_auc_score: 0.9999730959118481


100%|██████████| 53/53 [00:00<00:00, 233.62it/s]
100%|██████████| 53/53 [00:00<00:00, 531.78it/s]


epoch: 32 validation: roc_auc_score: 0.9999890364037558


100%|██████████| 53/53 [00:00<00:00, 206.48it/s]
100%|██████████| 53/53 [00:00<00:00, 336.60it/s]


epoch: 33 validation: roc_auc_score: 0.9999789864405318


100%|██████████| 53/53 [00:00<00:00, 109.28it/s]
100%|██████████| 53/53 [00:00<00:00, 310.56it/s]


epoch: 34 validation: roc_auc_score: 0.9999884593723745


100%|██████████| 53/53 [00:00<00:00, 145.78it/s]
100%|██████████| 53/53 [00:00<00:00, 496.48it/s]


epoch: 35 validation: roc_auc_score: 0.9999943979870068


100%|██████████| 53/53 [00:00<00:00, 192.23it/s]
100%|██████████| 53/53 [00:00<00:00, 511.57it/s]


epoch: 36 validation: roc_auc_score: 0.9999795153859647


100%|██████████| 53/53 [00:00<00:00, 135.25it/s]
100%|██████████| 53/53 [00:00<00:00, 518.58it/s]


epoch: 37 validation: roc_auc_score: 0.999994638416749


100%|██████████| 53/53 [00:00<00:00, 232.37it/s]
100%|██████████| 53/53 [00:00<00:00, 558.68it/s]


epoch: 38 validation: roc_auc_score: 0.9999905270681574


100%|██████████| 53/53 [00:00<00:00, 208.30it/s]
100%|██████████| 53/53 [00:00<00:00, 534.52it/s]


epoch: 39 validation: roc_auc_score: 0.9999943979870068


100%|██████████| 53/53 [00:00<00:00, 213.71it/s]
100%|██████████| 53/53 [00:00<00:00, 496.22it/s]


epoch: 40 validation: roc_auc_score: 0.9999892527905238


100%|██████████| 53/53 [00:00<00:00, 211.22it/s]
100%|██████████| 53/53 [00:00<00:00, 487.84it/s]


epoch: 41 validation: roc_auc_score: 0.9999918975176879


100%|██████████| 53/53 [00:00<00:00, 224.42it/s]
100%|██████████| 53/53 [00:00<00:00, 311.32it/s]


epoch: 42 validation: roc_auc_score: 0.9999805492338563


100%|██████████| 53/53 [00:00<00:00, 247.89it/s]
100%|██████████| 53/53 [00:00<00:00, 546.32it/s]


epoch: 43 validation: roc_auc_score: 0.9999883632004776


100%|██████████| 53/53 [00:00<00:00, 221.44it/s]
100%|██████████| 53/53 [00:00<00:00, 514.60it/s]


epoch: 44 validation: roc_auc_score: 0.9999991104099538


100%|██████████| 53/53 [00:00<00:00, 219.24it/s]
100%|██████████| 53/53 [00:00<00:00, 530.76it/s]


epoch: 45 validation: roc_auc_score: 0.9999631902064695


100%|██████████| 53/53 [00:00<00:00, 231.59it/s]
100%|██████████| 53/53 [00:00<00:00, 512.17it/s]


epoch: 46 validation: roc_auc_score: 0.9999970908001193


100%|██████████| 53/53 [00:00<00:00, 225.61it/s]
100%|██████████| 53/53 [00:00<00:00, 529.70it/s]


epoch: 47 validation: roc_auc_score: 0.999997259100939


100%|██████████| 53/53 [00:00<00:00, 240.23it/s]
100%|██████████| 53/53 [00:00<00:00, 559.42it/s]


epoch: 48 validation: roc_auc_score: 0.9999982929488304


100%|██████████| 53/53 [00:00<00:00, 216.44it/s]
100%|██████████| 53/53 [00:00<00:00, 445.00it/s]


epoch: 49 validation: roc_auc_score: 0.9999980525190881


100%|██████████| 53/53 [00:00<00:00, 167.79it/s]
100%|██████████| 53/53 [00:00<00:00, 545.72it/s]


epoch: 50 validation: roc_auc_score: 0.9999953597059756


100%|██████████| 53/53 [00:00<00:00, 218.67it/s]
100%|██████████| 53/53 [00:00<00:00, 556.54it/s]


epoch: 51 validation: roc_auc_score: 0.9999982689058562


100%|██████████| 53/53 [00:00<00:00, 215.90it/s]
100%|██████████| 53/53 [00:00<00:00, 529.64it/s]


epoch: 52 validation: roc_auc_score: 0.9999979563471914


100%|██████████| 53/53 [00:00<00:00, 245.95it/s]
100%|██████████| 53/53 [00:00<00:00, 517.75it/s]


epoch: 53 validation: roc_auc_score: 0.9999911762284613


100%|██████████| 53/53 [00:00<00:00, 215.16it/s]
100%|██████████| 53/53 [00:00<00:00, 514.77it/s]


epoch: 54 validation: roc_auc_score: 0.9999971388860679


100%|██████████| 53/53 [00:00<00:00, 232.34it/s]
100%|██████████| 53/53 [00:00<00:00, 546.86it/s]


epoch: 55 validation: roc_auc_score: 0.9999994229686189


100%|██████████| 53/53 [00:00<00:00, 221.23it/s]
100%|██████████| 53/53 [00:00<00:00, 514.05it/s]


epoch: 56 validation: roc_auc_score: 0.9999974514447327


100%|██████████| 53/53 [00:00<00:00, 225.22it/s]
100%|██████████| 53/53 [00:00<00:00, 298.69it/s]


epoch: 57 validation: roc_auc_score: 0.9999884834153486


100%|██████████| 53/53 [00:00<00:00, 224.73it/s]
100%|██████████| 53/53 [00:00<00:00, 501.66it/s]


epoch: 58 validation: roc_auc_score: 0.999999567226464


100%|██████████| 53/53 [00:00<00:00, 219.18it/s]
100%|██████████| 53/53 [00:00<00:00, 537.54it/s]


epoch: 59 validation: roc_auc_score: 0.9999974754877069


100%|██████████| 53/53 [00:00<00:00, 229.52it/s]
100%|██████████| 53/53 [00:00<00:00, 516.01it/s]


epoch: 60 validation: roc_auc_score: 0.9999929554085536


100%|██████████| 53/53 [00:00<00:00, 223.73it/s]
100%|██████████| 53/53 [00:00<00:00, 522.63it/s]


epoch: 61 validation: roc_auc_score: 0.9999962973819702


100%|██████████| 53/53 [00:00<00:00, 232.82it/s]
100%|██████████| 53/53 [00:00<00:00, 537.99it/s]


epoch: 62 validation: roc_auc_score: 0.999997812089346


100%|██████████| 53/53 [00:00<00:00, 233.53it/s]
100%|██████████| 53/53 [00:00<00:00, 495.46it/s]


epoch: 63 validation: roc_auc_score: 0.9999944941589036


100%|██████████| 53/53 [00:00<00:00, 215.11it/s]
100%|██████████| 53/53 [00:00<00:00, 546.42it/s]


epoch: 64 validation: roc_auc_score: 0.9999978842182687


100%|██████████| 53/53 [00:00<00:00, 231.48it/s]
100%|██████████| 53/53 [00:00<00:00, 312.31it/s]


epoch: 65 validation: roc_auc_score: 0.9999997595702578


100%|██████████| 53/53 [00:00<00:00, 244.81it/s]
100%|██████████| 53/53 [00:00<00:00, 541.78it/s]


epoch: 66 validation: roc_auc_score: 0.999999687441335


100%|██████████| 53/53 [00:00<00:00, 232.11it/s]
100%|██████████| 53/53 [00:00<00:00, 527.99it/s]


epoch: 67 validation: roc_auc_score: 0.9999899019508277


100%|██████████| 53/53 [00:00<00:00, 216.57it/s]
100%|██████████| 53/53 [00:00<00:00, 541.92it/s]


epoch: 68 validation: roc_auc_score: 0.9999998797851288


100%|██████████| 53/53 [00:00<00:00, 227.46it/s]
100%|██████████| 53/53 [00:00<00:00, 522.24it/s]


epoch: 69 validation: roc_auc_score: 0.9999996393553868


100%|██████████| 53/53 [00:00<00:00, 228.92it/s]
100%|██████████| 53/53 [00:00<00:00, 560.33it/s]


epoch: 70 validation: roc_auc_score: 0.9999999519140516


100%|██████████| 53/53 [00:00<00:00, 220.59it/s]
100%|██████████| 53/53 [00:00<00:00, 548.59it/s]


epoch: 71 validation: roc_auc_score: 0.9999973552728358


100%|██████████| 53/53 [00:00<00:00, 227.43it/s]
100%|██████████| 53/53 [00:00<00:00, 556.99it/s]


epoch: 72 validation: roc_auc_score: 0.9999986776364178


100%|██████████| 53/53 [00:00<00:00, 178.44it/s]
100%|██████████| 53/53 [00:00<00:00, 520.63it/s]


epoch: 73 validation: roc_auc_score: 0.9999867763641791


100%|██████████| 53/53 [00:00<00:00, 246.25it/s]
100%|██████████| 53/53 [00:00<00:00, 575.35it/s]


epoch: 74 validation: roc_auc_score: 0.999999687441335


100%|██████████| 53/53 [00:00<00:00, 228.62it/s]
100%|██████████| 53/53 [00:00<00:00, 531.66it/s]


epoch: 75 validation: roc_auc_score: 0.9999993027537477


100%|██████████| 53/53 [00:00<00:00, 218.13it/s]
100%|██████████| 53/53 [00:00<00:00, 506.31it/s]


epoch: 76 validation: roc_auc_score: 0.9999870167939213


100%|██████████| 53/53 [00:00<00:00, 215.55it/s]
100%|██████████| 53/53 [00:00<00:00, 514.41it/s]


epoch: 77 validation: roc_auc_score: 0.9999992546677992


100%|██████████| 53/53 [00:00<00:00, 213.30it/s]
100%|██████████| 53/53 [00:00<00:00, 539.40it/s]


epoch: 78 validation: roc_auc_score: 0.9999999759570257


100%|██████████| 53/53 [00:00<00:00, 230.22it/s]
100%|██████████| 53/53 [00:00<00:00, 497.72it/s]


epoch: 79 validation: roc_auc_score: 0.9999958646084343


100%|██████████| 53/53 [00:00<00:00, 232.20it/s]
100%|██████████| 53/53 [00:00<00:00, 495.33it/s]


epoch: 80 validation: roc_auc_score: 0.9999998076562062


100%|██████████| 53/53 [00:00<00:00, 164.56it/s]
100%|██████████| 53/53 [00:00<00:00, 531.94it/s]


epoch: 81 validation: roc_auc_score: 0.9999948307605427


100%|██████████| 53/53 [00:00<00:00, 226.43it/s]
100%|██████████| 53/53 [00:00<00:00, 542.55it/s]


epoch: 82 validation: roc_auc_score: 0.9999784094091506


100%|██████████| 53/53 [00:00<00:00, 218.29it/s]
100%|██████████| 53/53 [00:00<00:00, 508.12it/s]


epoch: 83 validation: roc_auc_score: 0.9999997355272836


100%|██████████| 53/53 [00:00<00:00, 240.73it/s]
100%|██████████| 53/53 [00:00<00:00, 552.61it/s]


epoch: 84 validation: roc_auc_score: 0.9999999278710773


100%|██████████| 53/53 [00:00<00:00, 235.32it/s]
100%|██████████| 53/53 [00:00<00:00, 535.79it/s]


epoch: 85 validation: roc_auc_score: 0.9999998797851289


100%|██████████| 53/53 [00:00<00:00, 234.10it/s]
100%|██████████| 53/53 [00:00<00:00, 565.06it/s]


epoch: 86 validation: roc_auc_score: 0.9999998076562062


100%|██████████| 53/53 [00:00<00:00, 221.16it/s]
100%|██████████| 53/53 [00:00<00:00, 522.22it/s]


epoch: 87 validation: roc_auc_score: 0.9999979803901655


100%|██████████| 53/53 [00:00<00:00, 224.71it/s]
100%|██████████| 53/53 [00:00<00:00, 303.06it/s]


epoch: 88 validation: roc_auc_score: 0.9999998797851288


100%|██████████| 53/53 [00:00<00:00, 230.79it/s]
100%|██████████| 53/53 [00:00<00:00, 521.87it/s]


epoch: 89 validation: roc_auc_score: 0.9999996153124125


100%|██████████| 53/53 [00:00<00:00, 216.99it/s]
100%|██████████| 53/53 [00:00<00:00, 556.77it/s]


epoch: 90 validation: roc_auc_score: 0.9999999519140516


100%|██████████| 53/53 [00:00<00:00, 232.63it/s]
100%|██████████| 53/53 [00:00<00:00, 522.01it/s]


epoch: 91 validation: roc_auc_score: 1.0


100%|██████████| 53/53 [00:00<00:00, 215.84it/s]
100%|██████████| 53/53 [00:00<00:00, 530.61it/s]


epoch: 92 validation: roc_auc_score: 0.9999994710545672


100%|██████████| 53/53 [00:00<00:00, 241.46it/s]
100%|██████████| 53/53 [00:00<00:00, 493.83it/s]


epoch: 93 validation: roc_auc_score: 0.9999999038281031


100%|██████████| 53/53 [00:00<00:00, 217.06it/s]
100%|██████████| 53/53 [00:00<00:00, 522.80it/s]


epoch: 94 validation: roc_auc_score: 0.9999999759570257


100%|██████████| 53/53 [00:00<00:00, 226.98it/s]
100%|██████████| 53/53 [00:00<00:00, 535.43it/s]


epoch: 95 validation: roc_auc_score: 0.9999998557421547


100%|██████████| 53/53 [00:00<00:00, 246.90it/s]
100%|██████████| 53/53 [00:00<00:00, 316.06it/s]


epoch: 96 validation: roc_auc_score: 0.999999350839696


100%|██████████| 53/53 [00:00<00:00, 229.44it/s]
100%|██████████| 53/53 [00:00<00:00, 481.26it/s]


epoch: 97 validation: roc_auc_score: 0.9999986535934438


100%|██████████| 53/53 [00:00<00:00, 213.82it/s]
100%|██████████| 53/53 [00:00<00:00, 535.46it/s]


epoch: 98 validation: roc_auc_score: 0.9999998316991805


100%|██████████| 53/53 [00:00<00:00, 231.57it/s]
100%|██████████| 53/53 [00:00<00:00, 578.72it/s]

epoch: 99 validation: roc_auc_score: 0.9999998076562062





In [None]:
test(model, r_test_dataloader, device)

100%|██████████| 14/14 [00:00<00:00, 560.21it/s]


0.7172484886275323

## DCN

In [None]:
class CrossNetwork(nn.Module):

    def __init__(self, input_dim: int, num_layers: int):
        super().__init__()
        self.num_layers = num_layers
        self.w = torch.nn.ModuleList([
            torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers)
        ])
        self.b = torch.nn.ParameterList([
            torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)
        ])

    def forward(self, x: torch.Tensor):
        """
        :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
        """
        x0 = x
        for i in range(self.num_layers):
            xw = self.w[i](x)
            x = x0 * xw + self.b[i] + x
        return x

Stacked DCN 모델 구조입니다

In [None]:
class StackedDeepCrossNetworkModel(nn.Module):
  
    def __init__(self, field_dims: np.ndarray, embed_dim: int, num_layers: int, mlp_dims: tuple, dropout: float):
        super().__init__()
        self.embedding = FeaturesEmbedding(field_dims, embed_dim)
        self.embed_output_dim = len(field_dims) * embed_dim
        self.cn = CrossNetwork(self.embed_output_dim, num_layers)
        self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout, output_layer=False)
        self.linear = torch.nn.Linear(self.embed_output_dim, 1, bias=False)

    def forward(self, x: torch.Tensor):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        embed_x = self.embedding(x).view(-1, self.embed_output_dim)
        x_l1 = self.cn(embed_x)
        x_out = self.mlp(x_l1)
        p = self.linear(x_out)
        return torch.sigmoid(p.squeeze(1))

Parallel DCN 구조입니다

In [None]:
class ParallelDeepCrossNetworkModel(nn.Module):

    def __init__(self, field_dims: np.ndarray, embed_dim: int, num_layers: int, mlp_dims: tuple, dropout: float):
        super().__init__()
        self.embedding = FeaturesEmbedding(field_dims, embed_dim)
        self.embed_output_dim = len(field_dims) * embed_dim
        self.cn = CrossNetwork(self.embed_output_dim, num_layers)
        self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout, output_layer=False)
        self.linear = torch.nn.Linear(mlp_dims[-1] + self.embed_output_dim, 1)

    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        embed_x = self.embedding(x).view(-1, self.embed_output_dim)
        x_l1 = self.cn(embed_x)
        h_l2 = self.mlp(embed_x)
        x_stack = torch.cat([x_l1, h_l2], dim=1)
        p = self.linear(x_stack)
        return torch.sigmoid(p.squeeze(1))

In [None]:
field_dims = np.array([len(user2idx), len(book2idx)], dtype=np.uint32)

In [None]:
criterion = torch.nn.BCELoss()
model = StackedDeepCrossNetworkModel(field_dims, embed_dim, num_layers=3, mlp_dims=(16, 16), dropout=0.2).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, amsgrad=True, weight_decay=weight_decay)

for epoch in range(epochs):
    train(model, optimizer, r_train_dataloader, criterion, device)
    auc_score = test(model, r_train_dataloader, device)
    print('epoch:', epoch, 'validation: roc_auc_score:', auc_score)

100%|██████████| 53/53 [00:00<00:00, 192.59it/s]
100%|██████████| 53/53 [00:00<00:00, 489.58it/s]


epoch: 0 validation: roc_auc_score: 0.7537203737663069


100%|██████████| 53/53 [00:00<00:00, 196.11it/s]
100%|██████████| 53/53 [00:00<00:00, 466.85it/s]


epoch: 1 validation: roc_auc_score: 0.8838325472242482


100%|██████████| 53/53 [00:00<00:00, 196.07it/s]
100%|██████████| 53/53 [00:00<00:00, 447.68it/s]


epoch: 2 validation: roc_auc_score: 0.941011993308744


100%|██████████| 53/53 [00:00<00:00, 188.92it/s]
100%|██████████| 53/53 [00:00<00:00, 482.99it/s]


epoch: 3 validation: roc_auc_score: 0.9485112494191217


100%|██████████| 53/53 [00:00<00:00, 188.23it/s]
100%|██████████| 53/53 [00:00<00:00, 499.05it/s]


epoch: 4 validation: roc_auc_score: 0.9535469181330958


100%|██████████| 53/53 [00:00<00:00, 182.48it/s]
100%|██████████| 53/53 [00:00<00:00, 454.50it/s]


epoch: 5 validation: roc_auc_score: 0.9498224450396843


100%|██████████| 53/53 [00:00<00:00, 146.67it/s]
100%|██████████| 53/53 [00:00<00:00, 462.75it/s]


epoch: 6 validation: roc_auc_score: 0.920663654851372


100%|██████████| 53/53 [00:00<00:00, 185.95it/s]
100%|██████████| 53/53 [00:00<00:00, 491.06it/s]


epoch: 7 validation: roc_auc_score: 0.9632434977218799


100%|██████████| 53/53 [00:00<00:00, 192.97it/s]
100%|██████████| 53/53 [00:00<00:00, 470.99it/s]


epoch: 8 validation: roc_auc_score: 0.9885386420605099


100%|██████████| 53/53 [00:00<00:00, 193.29it/s]
100%|██████████| 53/53 [00:00<00:00, 447.79it/s]


epoch: 9 validation: roc_auc_score: 0.8804176154024295


100%|██████████| 53/53 [00:00<00:00, 187.77it/s]
100%|██████████| 53/53 [00:00<00:00, 492.24it/s]


epoch: 10 validation: roc_auc_score: 0.9400714201357793


100%|██████████| 53/53 [00:00<00:00, 197.21it/s]
100%|██████████| 53/53 [00:00<00:00, 483.00it/s]


epoch: 11 validation: roc_auc_score: 0.9000836863803667


100%|██████████| 53/53 [00:00<00:00, 178.89it/s]
100%|██████████| 53/53 [00:00<00:00, 471.04it/s]


epoch: 12 validation: roc_auc_score: 0.9326806579465684


100%|██████████| 53/53 [00:00<00:00, 189.39it/s]
100%|██████████| 53/53 [00:00<00:00, 270.79it/s]


epoch: 13 validation: roc_auc_score: 0.9446809992606303


100%|██████████| 53/53 [00:00<00:00, 184.29it/s]
100%|██████████| 53/53 [00:00<00:00, 478.30it/s]


epoch: 14 validation: roc_auc_score: 0.941896113578241


100%|██████████| 53/53 [00:00<00:00, 195.14it/s]
100%|██████████| 53/53 [00:00<00:00, 436.45it/s]


epoch: 15 validation: roc_auc_score: 0.9850383937446722


100%|██████████| 53/53 [00:00<00:00, 195.59it/s]
100%|██████████| 53/53 [00:00<00:00, 461.79it/s]


epoch: 16 validation: roc_auc_score: 0.9901906949265863


100%|██████████| 53/53 [00:00<00:00, 189.22it/s]
100%|██████████| 53/53 [00:00<00:00, 493.28it/s]


epoch: 17 validation: roc_auc_score: 0.9138350414423937


100%|██████████| 53/53 [00:00<00:00, 189.85it/s]
100%|██████████| 53/53 [00:00<00:00, 192.01it/s]


epoch: 18 validation: roc_auc_score: 0.976070532661515


100%|██████████| 53/53 [00:00<00:00, 188.16it/s]
100%|██████████| 53/53 [00:00<00:00, 481.15it/s]


epoch: 19 validation: roc_auc_score: 0.9872957405082186


100%|██████████| 53/53 [00:00<00:00, 111.04it/s]
100%|██████████| 53/53 [00:00<00:00, 452.00it/s]


epoch: 20 validation: roc_auc_score: 0.9874212568551328


100%|██████████| 53/53 [00:00<00:00, 121.86it/s]
100%|██████████| 53/53 [00:00<00:00, 266.42it/s]


epoch: 21 validation: roc_auc_score: 0.9888732601542137


100%|██████████| 53/53 [00:00<00:00, 165.81it/s]
100%|██████████| 53/53 [00:00<00:00, 473.62it/s]


epoch: 22 validation: roc_auc_score: 0.9942794431224016


100%|██████████| 53/53 [00:00<00:00, 171.07it/s]
100%|██████████| 53/53 [00:00<00:00, 467.96it/s]


epoch: 23 validation: roc_auc_score: 0.9882770304580245


100%|██████████| 53/53 [00:00<00:00, 164.36it/s]
100%|██████████| 53/53 [00:00<00:00, 463.69it/s]


epoch: 24 validation: roc_auc_score: 0.9881299595847219


100%|██████████| 53/53 [00:00<00:00, 165.26it/s]
100%|██████████| 53/53 [00:00<00:00, 467.29it/s]


epoch: 25 validation: roc_auc_score: 0.9716457358150298


100%|██████████| 53/53 [00:00<00:00, 163.70it/s]
100%|██████████| 53/53 [00:00<00:00, 476.77it/s]


epoch: 26 validation: roc_auc_score: 0.9568515888751428


100%|██████████| 53/53 [00:00<00:00, 159.40it/s]
100%|██████████| 53/53 [00:00<00:00, 468.51it/s]


epoch: 27 validation: roc_auc_score: 0.9871623500872471


100%|██████████| 53/53 [00:00<00:00, 173.54it/s]
100%|██████████| 53/53 [00:00<00:00, 449.32it/s]


epoch: 28 validation: roc_auc_score: 0.9934857484789452


100%|██████████| 53/53 [00:00<00:00, 136.76it/s]
100%|██████████| 53/53 [00:00<00:00, 475.71it/s]


epoch: 29 validation: roc_auc_score: 0.9923969022839672


100%|██████████| 53/53 [00:00<00:00, 157.03it/s]
100%|██████████| 53/53 [00:00<00:00, 467.58it/s]


epoch: 30 validation: roc_auc_score: 0.9937591892247468


100%|██████████| 53/53 [00:00<00:00, 164.91it/s]
100%|██████████| 53/53 [00:00<00:00, 468.23it/s]


epoch: 31 validation: roc_auc_score: 0.996088712996901


100%|██████████| 53/53 [00:00<00:00, 164.88it/s]
100%|██████████| 53/53 [00:00<00:00, 457.52it/s]


epoch: 32 validation: roc_auc_score: 0.9933339531612088


100%|██████████| 53/53 [00:00<00:00, 168.78it/s]
100%|██████████| 53/53 [00:00<00:00, 468.32it/s]


epoch: 33 validation: roc_auc_score: 0.9727031698641898


100%|██████████| 53/53 [00:00<00:00, 165.81it/s]
100%|██████████| 53/53 [00:00<00:00, 468.62it/s]


epoch: 34 validation: roc_auc_score: 0.9821146118002149


100%|██████████| 53/53 [00:00<00:00, 167.47it/s]
100%|██████████| 53/53 [00:00<00:00, 448.75it/s]


epoch: 35 validation: roc_auc_score: 0.999564052791447


100%|██████████| 53/53 [00:00<00:00, 171.31it/s]
100%|██████████| 53/53 [00:00<00:00, 291.34it/s]


epoch: 36 validation: roc_auc_score: 0.999144731299567


100%|██████████| 53/53 [00:00<00:00, 164.68it/s]
100%|██████████| 53/53 [00:00<00:00, 476.75it/s]


epoch: 37 validation: roc_auc_score: 0.9960415046170203


100%|██████████| 53/53 [00:00<00:00, 170.79it/s]
100%|██████████| 53/53 [00:00<00:00, 473.70it/s]


epoch: 38 validation: roc_auc_score: 0.984828017720249


100%|██████████| 53/53 [00:00<00:00, 173.97it/s]
100%|██████████| 53/53 [00:00<00:00, 471.01it/s]


epoch: 39 validation: roc_auc_score: 0.9982618612647298


100%|██████████| 53/53 [00:00<00:00, 172.75it/s]
100%|██████████| 53/53 [00:00<00:00, 483.85it/s]


epoch: 40 validation: roc_auc_score: 0.8912649518448079


100%|██████████| 53/53 [00:00<00:00, 168.35it/s]
100%|██████████| 53/53 [00:00<00:00, 480.67it/s]


epoch: 41 validation: roc_auc_score: 0.9987243519168212


100%|██████████| 53/53 [00:00<00:00, 164.42it/s]
100%|██████████| 53/53 [00:00<00:00, 459.82it/s]


epoch: 42 validation: roc_auc_score: 0.9986985177410221


100%|██████████| 53/53 [00:00<00:00, 166.69it/s]
100%|██████████| 53/53 [00:00<00:00, 478.80it/s]


epoch: 43 validation: roc_auc_score: 0.9983385222880294


100%|██████████| 53/53 [00:00<00:00, 135.65it/s]
100%|██████████| 53/53 [00:00<00:00, 439.03it/s]


epoch: 44 validation: roc_auc_score: 0.9692587252915162


100%|██████████| 53/53 [00:00<00:00, 166.25it/s]
100%|██████████| 53/53 [00:00<00:00, 457.86it/s]


epoch: 45 validation: roc_auc_score: 0.990803826833652


100%|██████████| 53/53 [00:00<00:00, 168.41it/s]
100%|██████████| 53/53 [00:00<00:00, 489.46it/s]


epoch: 46 validation: roc_auc_score: 0.9870611171442948


100%|██████████| 53/53 [00:00<00:00, 170.91it/s]
100%|██████████| 53/53 [00:00<00:00, 469.48it/s]


epoch: 47 validation: roc_auc_score: 0.9910511689309378


100%|██████████| 53/53 [00:00<00:00, 167.36it/s]
100%|██████████| 53/53 [00:00<00:00, 480.10it/s]


epoch: 48 validation: roc_auc_score: 0.9997479094153057


100%|██████████| 53/53 [00:00<00:00, 170.86it/s]
100%|██████████| 53/53 [00:00<00:00, 476.33it/s]


epoch: 49 validation: roc_auc_score: 0.9978430927612568


100%|██████████| 53/53 [00:00<00:00, 166.71it/s]
100%|██████████| 53/53 [00:00<00:00, 458.84it/s]


epoch: 50 validation: roc_auc_score: 0.9954788148698679


100%|██████████| 53/53 [00:00<00:00, 164.20it/s]
100%|██████████| 53/53 [00:00<00:00, 277.41it/s]


epoch: 51 validation: roc_auc_score: 0.9560943313591166


100%|██████████| 53/53 [00:00<00:00, 166.88it/s]
100%|██████████| 53/53 [00:00<00:00, 478.88it/s]


epoch: 52 validation: roc_auc_score: 0.967352213607785


100%|██████████| 53/53 [00:00<00:00, 165.86it/s]
100%|██████████| 53/53 [00:00<00:00, 459.89it/s]


epoch: 53 validation: roc_auc_score: 0.9926639836630875


100%|██████████| 53/53 [00:00<00:00, 163.34it/s]
100%|██████████| 53/53 [00:00<00:00, 471.97it/s]


epoch: 54 validation: roc_auc_score: 0.9996182456553385


100%|██████████| 53/53 [00:00<00:00, 98.97it/s]
100%|██████████| 53/53 [00:00<00:00, 483.06it/s]


epoch: 55 validation: roc_auc_score: 0.9920679583321792


100%|██████████| 53/53 [00:00<00:00, 110.78it/s]
100%|██████████| 53/53 [00:00<00:00, 363.28it/s]


epoch: 56 validation: roc_auc_score: 0.9815694734242427


100%|██████████| 53/53 [00:00<00:00, 159.05it/s]
100%|██████████| 53/53 [00:00<00:00, 472.78it/s]


epoch: 57 validation: roc_auc_score: 0.9981227245729198


100%|██████████| 53/53 [00:00<00:00, 107.84it/s]
100%|██████████| 53/53 [00:00<00:00, 479.74it/s]


epoch: 58 validation: roc_auc_score: 0.9954757974766034


100%|██████████| 53/53 [00:00<00:00, 165.49it/s]
100%|██████████| 53/53 [00:00<00:00, 276.54it/s]


epoch: 59 validation: roc_auc_score: 0.9958038638598321


100%|██████████| 53/53 [00:00<00:00, 161.87it/s]
100%|██████████| 53/53 [00:00<00:00, 475.50it/s]


epoch: 60 validation: roc_auc_score: 0.9793361455919418


100%|██████████| 53/53 [00:00<00:00, 164.04it/s]
100%|██████████| 53/53 [00:00<00:00, 445.08it/s]


epoch: 61 validation: roc_auc_score: 0.9783699426084587


100%|██████████| 53/53 [00:00<00:00, 161.04it/s]
100%|██████████| 53/53 [00:00<00:00, 469.32it/s]


epoch: 62 validation: roc_auc_score: 0.9999017844503123


100%|██████████| 53/53 [00:00<00:00, 161.64it/s]
100%|██████████| 53/53 [00:00<00:00, 467.95it/s]


epoch: 63 validation: roc_auc_score: 0.9780702589562965


100%|██████████| 53/53 [00:00<00:00, 164.16it/s]
100%|██████████| 53/53 [00:00<00:00, 477.92it/s]


epoch: 64 validation: roc_auc_score: 0.990482660784024


100%|██████████| 53/53 [00:00<00:00, 158.99it/s]
100%|██████████| 53/53 [00:00<00:00, 457.75it/s]


epoch: 65 validation: roc_auc_score: 0.9970721788358738


100%|██████████| 53/53 [00:00<00:00, 165.89it/s]
100%|██████████| 53/53 [00:00<00:00, 430.78it/s]


epoch: 66 validation: roc_auc_score: 0.9989010918203108


100%|██████████| 53/53 [00:00<00:00, 128.67it/s]
100%|██████████| 53/53 [00:00<00:00, 463.65it/s]


epoch: 67 validation: roc_auc_score: 0.9991317000075398


100%|██████████| 53/53 [00:00<00:00, 165.35it/s]
100%|██████████| 53/53 [00:00<00:00, 453.84it/s]


epoch: 68 validation: roc_auc_score: 0.999213866871936


100%|██████████| 53/53 [00:00<00:00, 161.50it/s]
100%|██████████| 53/53 [00:00<00:00, 461.70it/s]


epoch: 69 validation: roc_auc_score: 0.9994475886243265


100%|██████████| 53/53 [00:00<00:00, 161.68it/s]
100%|██████████| 53/53 [00:00<00:00, 452.03it/s]


epoch: 70 validation: roc_auc_score: 0.9997160765174388


100%|██████████| 53/53 [00:00<00:00, 162.47it/s]
100%|██████████| 53/53 [00:00<00:00, 461.89it/s]


epoch: 71 validation: roc_auc_score: 0.995742265759881


100%|██████████| 53/53 [00:00<00:00, 164.03it/s]
100%|██████████| 53/53 [00:00<00:00, 489.80it/s]


epoch: 72 validation: roc_auc_score: 0.999751335539132


100%|██████████| 53/53 [00:00<00:00, 162.68it/s]
100%|██████████| 53/53 [00:00<00:00, 457.12it/s]


epoch: 73 validation: roc_auc_score: 0.9977718414071565


100%|██████████| 53/53 [00:00<00:00, 168.44it/s]
100%|██████████| 53/53 [00:00<00:00, 275.63it/s]


epoch: 74 validation: roc_auc_score: 0.9696320405522267


100%|██████████| 53/53 [00:00<00:00, 163.91it/s]
100%|██████████| 53/53 [00:00<00:00, 458.01it/s]


epoch: 75 validation: roc_auc_score: 0.99741065582694


100%|██████████| 53/53 [00:00<00:00, 167.52it/s]
100%|██████████| 53/53 [00:00<00:00, 475.77it/s]


epoch: 76 validation: roc_auc_score: 0.9920421481993543


100%|██████████| 53/53 [00:00<00:00, 161.36it/s]
100%|██████████| 53/53 [00:00<00:00, 447.99it/s]


epoch: 77 validation: roc_auc_score: 0.9982102890850284


100%|██████████| 53/53 [00:00<00:00, 162.97it/s]
100%|██████████| 53/53 [00:00<00:00, 478.58it/s]


epoch: 78 validation: roc_auc_score: 0.9989206868442999


100%|██████████| 53/53 [00:00<00:00, 155.64it/s]
100%|██████████| 53/53 [00:00<00:00, 478.60it/s]


epoch: 79 validation: roc_auc_score: 0.9983019048382927


100%|██████████| 53/53 [00:00<00:00, 166.00it/s]
100%|██████████| 53/53 [00:00<00:00, 453.62it/s]


epoch: 80 validation: roc_auc_score: 0.9998017296130967


100%|██████████| 53/53 [00:00<00:00, 169.23it/s]
100%|██████████| 53/53 [00:00<00:00, 434.06it/s]


epoch: 81 validation: roc_auc_score: 0.999828573593813


100%|██████████| 53/53 [00:00<00:00, 135.62it/s]
100%|██████████| 53/53 [00:00<00:00, 458.65it/s]


epoch: 82 validation: roc_auc_score: 0.988434295552396


100%|██████████| 53/53 [00:00<00:00, 166.23it/s]
100%|██████████| 53/53 [00:00<00:00, 420.53it/s]


epoch: 83 validation: roc_auc_score: 0.9998651429576013


100%|██████████| 53/53 [00:00<00:00, 165.88it/s]
100%|██████████| 53/53 [00:00<00:00, 466.26it/s]


epoch: 84 validation: roc_auc_score: 0.9980922741460705


100%|██████████| 53/53 [00:00<00:00, 171.85it/s]
100%|██████████| 53/53 [00:00<00:00, 421.05it/s]


epoch: 85 validation: roc_auc_score: 0.9986190196467645


100%|██████████| 53/53 [00:00<00:00, 167.36it/s]
100%|██████████| 53/53 [00:00<00:00, 481.02it/s]


epoch: 86 validation: roc_auc_score: 0.9997080221210749


100%|██████████| 53/53 [00:00<00:00, 164.20it/s]
100%|██████████| 53/53 [00:00<00:00, 466.16it/s]


epoch: 87 validation: roc_auc_score: 0.9745454387208062


100%|██████████| 53/53 [00:00<00:00, 161.87it/s]
100%|██████████| 53/53 [00:00<00:00, 478.06it/s]


epoch: 88 validation: roc_auc_score: 0.9990501702819606


100%|██████████| 53/53 [00:00<00:00, 166.50it/s]
100%|██████████| 53/53 [00:00<00:00, 482.52it/s]


epoch: 89 validation: roc_auc_score: 0.9999806213627789


100%|██████████| 53/53 [00:00<00:00, 127.80it/s]
100%|██████████| 53/53 [00:00<00:00, 468.79it/s]


epoch: 90 validation: roc_auc_score: 0.9887265018395759


100%|██████████| 53/53 [00:00<00:00, 165.58it/s]
100%|██████████| 53/53 [00:00<00:00, 464.88it/s]


epoch: 91 validation: roc_auc_score: 0.9903598252287353


100%|██████████| 53/53 [00:00<00:00, 156.25it/s]
100%|██████████| 53/53 [00:00<00:00, 472.31it/s]


epoch: 92 validation: roc_auc_score: 0.9939580005785701


100%|██████████| 53/53 [00:00<00:00, 165.40it/s]
100%|██████████| 53/53 [00:00<00:00, 489.65it/s]


epoch: 93 validation: roc_auc_score: 0.9996857583269475


100%|██████████| 53/53 [00:00<00:00, 159.30it/s]
100%|██████████| 53/53 [00:00<00:00, 473.68it/s]


epoch: 94 validation: roc_auc_score: 0.9998680040715334


100%|██████████| 53/53 [00:00<00:00, 166.35it/s]
100%|██████████| 53/53 [00:00<00:00, 478.23it/s]


epoch: 95 validation: roc_auc_score: 0.9989634833384113


100%|██████████| 53/53 [00:00<00:00, 162.31it/s]
100%|██████████| 53/53 [00:00<00:00, 466.36it/s]


epoch: 96 validation: roc_auc_score: 0.998141225641582


100%|██████████| 53/53 [00:00<00:00, 167.60it/s]
100%|██████████| 53/53 [00:00<00:00, 281.47it/s]


epoch: 97 validation: roc_auc_score: 0.9996545745893844


100%|██████████| 53/53 [00:00<00:00, 163.03it/s]
100%|██████████| 53/53 [00:00<00:00, 473.01it/s]


epoch: 98 validation: roc_auc_score: 0.9999129884762987


100%|██████████| 53/53 [00:00<00:00, 164.81it/s]
100%|██████████| 53/53 [00:00<00:00, 469.64it/s]

epoch: 99 validation: roc_auc_score: 0.9996495736507468





In [None]:
test(model, r_test_dataloader, device)

100%|██████████| 14/14 [00:00<00:00, 458.59it/s]


0.7146135661428352

<font color='red'><b>**WARNING**</b></font> : **본 교육 콘텐츠의 지식재산권은 재단법인 네이버커넥트에 귀속됩니다. 본 콘텐츠를 어떠한 경로로든 외부로 유출 및 수정하는 행위를 엄격히 금합니다.** 다만, 비영리적 교육 및 연구활동에 한정되어 사용할 수 있으나 재단의 허락을 받아야 합니다. 이를 위반하는 경우, 관련 법률에 따라 책임을 질 수 있습니다.

