In [72]:
import torch
import h5py
import numpy as np
from torch import nn
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import pickle
from sklearn.preprocessing import normalize
from scipy.sparse import lil_matrix, csr_matrix, hstack
from tqdm import tqdm
import math
import os
import csv

In [73]:
# 변수 선언 block
rec_file_list = ["./test_recs/CF_rec_clf_dim_64.pickle",
                "./test_recs/Graph_rec_clf_1_8_depth_3.pickle",
                "./test_recs/Graph_rec_clf_1_8_depth_1.pickle",]
state_dict_path = "./ensemble_model/ensemble_model_best_clf.pt"
id_cuisine_dict_path = os.path.join("./container", 'id_cuisine_dict.pickle')
save_path = './ensemble_model/test_clf.csv'

In [74]:
class RecDataset(Dataset):
    def __init__(self, recs_list, query_num, item_num, transform=None, target_transform=None):
        # rec_matrix = [query num, model_num, item_num]
        self.rec_matrix = []
        for i in range(query_num):
            self.rec_matrix.append(lil_matrix((len(rec_file_list), item_num)))
        for i, recs in enumerate(recs_list):
            for query in tqdm(recs.keys()):
                rec = recs[query]
                rec_items, rec_scores = [rec_ for rec_, score in rec], [score for rec_, score in rec]
                rec_scores = normalize(np.array(rec_scores)[:,np.newaxis], axis=0).ravel()
                for item, score in zip(rec_items, rec_scores):
                    self.rec_matrix[query][i, item] = score
        self.transform = transform
        self.target_transform = target_transform
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    def __len__(self):
        return len(self.rec_matrix)

    def __getitem__(self, idx):
        rec_matrix = self.rec_matrix[idx].toarray()
        if self.transform:
            rec_matrix = self.transform(rec_matrix).to(self.device)
        return rec_matrix.to(self.device)

In [75]:
class Network(nn.Module):
    def __init__(self, model_len, k=10):
        super(Network, self).__init__()
        self.w1 = torch.nn.Parameter(torch.randn(k, model_len))
        self.w2 = torch.nn.Parameter(torch.randn(1, k))
        
    def forward(self, x):
        #import ipdb; ipdb.set_trace()
        x = x.float()
        x = torch.einsum('nm, bmp -> bnp', self.w1, x)
        x = torch.einsum('nm, bmp -> bnp', self.w2, x).squeeze(1)
        return x

In [76]:
recs_list = []
for rec_file in rec_file_list:
    with open(rec_file, 'rb') as f:
        recs = pickle.load(f)
        recs_list.append(recs)

query_num = len(recs_list[0])
item_num = 20

test_data = RecDataset(recs_list, query_num, item_num, transform=torch.Tensor, target_transform=torch.tensor)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)

100%|███████████████████████████████████████████████████████████████████████████| 3924/3924 [00:00<00:00, 14164.59it/s]
100%|███████████████████████████████████████████████████████████████████████████| 3924/3924 [00:00<00:00, 13768.47it/s]
100%|███████████████████████████████████████████████████████████████████████████| 3924/3924 [00:00<00:00, 13916.33it/s]


In [77]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Network(len(rec_file_list), k=10).to(device)
model_state_dict = torch.load(state_dict_path, map_location=device)
model.load_state_dict(model_state_dict)


def inference(dataloader, model):
    rec_lst = []
    with torch.no_grad():
        for batch, X in tqdm(enumerate(dataloader),total=len(dataloader)):
            pred = model(X)
            pred = pred.cpu().numpy()
            top_recommends = list(np.argmax(pred, axis=1))
            rec_lst.extend(top_recommends)
    return rec_lst
            

infer = inference(test_dataloader, model)

100%|█████████████████████████████████████████████████████████████████████████████████| 62/62 [00:00<00:00, 174.16it/s]


In [78]:
with open(id_cuisine_dict_path, 'rb') as fr:
    cuisine_dict = pickle.load(fr)
infer_name = [[cuisine_dict[i]] for i in infer]
with open(save_path, 'w+', newline ='') as f:
    write = csv.writer(f)
    write.writerows(infer_name)