In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as scheduler
from torch.utils.data import DataLoader, Dataset
import json

import Models.Models

In [None]:
class WinRateDataset(Dataset):
    """
    data : user_vector, item_vector
    label : win_rate
    """
    def __init__(self, user_path, item_path, label_path):
        self.user_encoder = AutoEncoder(143, 32)
        self.user_encoder.load_state_dict(torch.load('user_encoder_path'))
        self.item_encoder = AutoEncoder(143, 32)
        self.item_encoder.load_state_dict(torch.load('item_encoder_path'))
        
        # get user_vector
        with open(user_path, 'r') as up:
            self.user = json.load(fp)
        
        # get item_vector
        with open(item_path, 'r') as ip:
            self.item = json.load(ip)
        
        # get label
        with open(label_path, 'r') as lp:
            self.label = json.load(lp)
            
        # build dataset
        self.data = []
        for key, item in self.label.items()
            user_vec = self.user[key]
            for champ in self.label['user_name']['champion_history'].values():
                if champ['play_count'] >= 5:
                    item_vec = self.item[champ['champion_id']]
                    self.data.append((user_vec, item_vec, champ['win_rate']))
                
    def __getitem__(self, index):
        user_data = self.user[]
        item_data = self.item[]
        return (user_vec, item_vec), label
                                     
    def __len__(self, index):
        length = len(self.data)
        return length
        

In [None]:
user_path = './datasets/user_vector'
item_path = './datasets/item_vector'
label_path = './datasets/labels'
dataset = WinRateDataset(user_path, item_path, label_path)


valid_list = []
for i in range(len(dataset)):
    if i%10 == 0:
        valid_list.append(i)
        
train_list = [x for x in range(len(dataset))]
train_list = list(set(train_list)-set(valid_list))

train_set = Subset(dataset, train_list)
valid_set = Subset(dataset, valid_list)

In [None]:
num_epochs = 800
batch_size = 64
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=15)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, num_workers=15)

model = Predictor(user_len=32, item_len=32, hidden_unit=32).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
scheduler = scheduler.MultiStepLR(optimizer, [500, 1000, 1500], gamma=0.5)

loss = []

for epoch in range(num_epochs):
    train_loss = 0.0
    count = 0
    for (user_vec, item_vec), label in train_loader:
        scheduler.step()
        user_vec = user_vec.to(device)
        item_vec = item_vec.to(device)
        label = label.to(device)
        # ===================forward=====================
        output = model(user_vec, item_vec)
        loss = criterion(output, data)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        count += label.size(0)
    # ===================log========================
    print('train epoch [{}/{}], loss:{:.8f}'
          .format(epoch + 1, num_epochs, train_loss/count))

    valid_loss = 0.0
    count = 0
    
    if epoch % 10 == 0:
        for (user_vec, item_vec), label in valid_loader:
            scheduler.step()
            user_vec = user_vec.to(device)
            item_vec = item_vec.to(device)
            label = label.to(device)
            # ===================forward=====================
            output = model(user_vec, item_vec)
            loss = criterion(output, data)
            # ===================backward====================
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            count += label.size(0)
        # ===================log========================
        print('valid epoch [{}/{}], loss:{:.8f}'
              .format(epoch + 1, num_epochs, train_loss/count))