In [1]:
import torch
import h5py
import numpy as np
from torch import nn
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import pickle
from sklearn.preprocessing import normalize

In [2]:
class RecDataset(Dataset):
    def __init__(self, recs_list, answer_dict, query_num, item_num, transform=None, target_transform=None):
        # rec_matrix = [query num, model_num, item_num]
        self.rec_matrix = np.zeros((query_num, len(rec_file_list), item_num))
        for i, recs in enumerate(recs_list):
            for query in recs.keys():
                rec = recs[query]
                rec_items, rec_scores = [rec_ for rec_, score in rec], [score for rec_, score in rec]
                rec_scores = normalize(np.array(rec_scores)[:,np.newaxis], axis=0).ravel()
                for item, score in zip(rec_items, rec_scores):
                    self.rec_matrix[query, i, item] = score
        self.labels = answer_dict
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return self.rec_matrix.shape[0]

    def __getitem__(self, idx):
        rec_matrix = self.rec_matrix[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return rec_matrix, label

In [10]:
class Network(nn.Module):
    def __init__(self, model_len, k=10):
        super(Network, self).__init__()
        self.w1 = torch.nn.Parameter(torch.randn(k, model_len))
        self.w2 = torch.nn.Parameter(torch.randn(1, k))
        
    def forward(self, x):
        #import ipdb; ipdb.set_trace()
        x = x.float()
        x = torch.einsum('nm, bmp -> bnp', self.w1, x)
        x = torch.einsum('nm, bmp -> bnp', self.w2, x).squeeze(1)
        return x

In [16]:
rec_file_list = ["./valid_recs/CF_rec_cpl_dim_64.pickle",
                "./valid_recs/Graph_rec_cpl_1_2_depth_2.pickle",
                "./valid_recs/Graph_rec_cpl_1_4_depth_2.pickle",]
recs_list = []
for rec_file in rec_file_list:
    with open(rec_file, 'rb') as f:
        recs = pickle.load(f)
        recs_list.append(recs)

query_num = len(recs_list[0])
item_num = 6714

h5f_valid = h5py.File('./Container/valid_cpl', 'r')
answer = h5f_valid['labels_id'][:].astype(np.int64)
h5f_valid.close()

answer_dict = {}
for i, ans in enumerate(answer):
    answer_dict[i] = ans

train_data = RecDataset(recs_list, answer_dict, query_num, item_num)
dataloader = DataLoader(train_data, batch_size=64, shuffle=True)

model = Network(len(rec_file_list), k=10)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    total_loss = 0
    for batch, (X, y) in enumerate(dataloader):
        pred = model(X)
        loss = loss_fn(pred, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / (batch + 1)
    print(f"loss: {loss:>7f}")
            
""" 
###구현할 거###
1. cpl용 train set 만들기
2. train set에서 CF/Graph 예측 생성
3. 모델 학습마다 앙상블된 결과 제작 -> metric 측정
3. train set / valid set 따로 앙상블
4. WanDB 적용
"""
        
epochs = 1000
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(dataloader, model, loss_fn, optimizer)
print("Done!")

Epoch 1
-------------------------------
loss: 8.279675
Epoch 2
-------------------------------
loss: 8.315391
Epoch 3
-------------------------------
loss: 8.161734
Epoch 4
-------------------------------
loss: 7.997823
Epoch 5
-------------------------------
loss: 8.122506
Epoch 6
-------------------------------
loss: 8.081966
Epoch 7
-------------------------------
loss: 7.616101
Epoch 8
-------------------------------
loss: 7.372782
Epoch 9
-------------------------------
loss: 7.774245
Epoch 10
-------------------------------
loss: 7.819954
Epoch 11
-------------------------------
loss: 7.339439
Epoch 12
-------------------------------
loss: 7.725379
Epoch 13
-------------------------------
loss: 7.523771
Epoch 14
-------------------------------
loss: 7.651724
Epoch 15
-------------------------------
loss: 7.964770
Epoch 16
-------------------------------
loss: 8.259868
Epoch 17
-------------------------------
loss: 7.793155
Epoch 18
-------------------------------
loss: 7.574823
E

loss: 7.782719
Epoch 147
-------------------------------
loss: 7.374236
Epoch 148
-------------------------------
loss: 7.930694
Epoch 149
-------------------------------
loss: 7.297251
Epoch 150
-------------------------------
loss: 7.602061
Epoch 151
-------------------------------
loss: 7.446518
Epoch 152
-------------------------------
loss: 7.739115
Epoch 153
-------------------------------
loss: 7.593728
Epoch 154
-------------------------------
loss: 7.868934
Epoch 155
-------------------------------
loss: 7.144675
Epoch 156
-------------------------------
loss: 7.409850
Epoch 157
-------------------------------
loss: 7.503049
Epoch 158
-------------------------------
loss: 7.363141
Epoch 159
-------------------------------
loss: 7.210325
Epoch 160
-------------------------------
loss: 8.122210
Epoch 161
-------------------------------
loss: 7.937881
Epoch 162
-------------------------------
loss: 7.195582
Epoch 163
-------------------------------
loss: 7.816548
Epoch 164
------

loss: 7.456145
Epoch 291
-------------------------------
loss: 7.755116
Epoch 292
-------------------------------
loss: 7.938689
Epoch 293
-------------------------------
loss: 7.417361
Epoch 294
-------------------------------
loss: 7.188723
Epoch 295
-------------------------------
loss: 7.246391
Epoch 296
-------------------------------
loss: 7.415021
Epoch 297
-------------------------------
loss: 7.786751
Epoch 298
-------------------------------
loss: 7.009239
Epoch 299
-------------------------------
loss: 7.923261
Epoch 300
-------------------------------
loss: 6.782679
Epoch 301
-------------------------------
loss: 8.361791
Epoch 302
-------------------------------
loss: 6.953040
Epoch 303
-------------------------------
loss: 7.412927
Epoch 304
-------------------------------
loss: 7.661375
Epoch 305
-------------------------------
loss: 7.819381
Epoch 306
-------------------------------
loss: 7.665544
Epoch 307
-------------------------------
loss: 7.693293
Epoch 308
------

KeyboardInterrupt: 

In [15]:
print(model.parameters())

<generator object Module.parameters at 0x0000021908C45580>


In [None]:
nm, bmp -> bnp (10x3) X 64x(3x6714) = 64x(10x6714)