In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader, Dataset, Subset
from champ_id_remap import champ_id_remap
import json
import math

from Models.Models import AutoEncoder, Predictor

In [2]:
import global_win_rate

global_win_rate = global_win_rate.global_win_rate()

In [3]:
class WinRateDataset(Dataset):
    """
    data : user_vector, item_vector
    label : win_rate
    """
    def __init__(self, user_path, item_path, label_path, global_win_rate, is_valid=False):
        self.user_encoder = AutoEncoder(143, 12)
        self.user_encoder.load_state_dict(torch.load('./trained_model/user_encoder_baseline.pth'))
        self.item_encoder = AutoEncoder(143, 10)
        self.item_encoder.load_state_dict(torch.load('./trained_model/item_encoder.pth'))
        self.champ_id_remap = champ_id_remap()
        
        # get user_vector
        with open(user_path, 'r') as up:
            self.user = json.load(up)
        
        # get item_vector
        with open(item_path, 'r') as ip:
            self.item = json.load(ip)
        
        # get label
        with open(label_path, 'r') as lp:
            self.label = json.load(lp)
            
        # build dataset
        self.data = []
        with torch.no_grad():
            if not is_valid:              
                for i, (key, item) in enumerate(self.label.items()):
                    if i % 12 != 0:
                        user_vec = self.user[key]
                        user_vec = torch.Tensor(user_vec[:143])
                        user_vec = self.user_encoder.encoder(user_vec)
                        user_winrate = torch.Tensor([item['win_rate']])
                        for champ in item['champion_history']:
                            if champ['play_count'] >= 10:
                                original_key = champ['champion_key']
                                hashed_key = self.champ_id_remap[original_key]
                                champion_global = global_win_rate[hashed_key]
                                global_win = torch.Tensor([champion_global])
                                item_vec = self.item[str(hashed_key)]
                                item_vec = torch.Tensor(item_vec)
                                item_vec = self.item_encoder.encoder(item_vec)
                                self.data.append(((user_vec, item_vec, user_winrate), global_win, champ['win_rate']))
            
            else:              
                for i, (key, item) in enumerate(self.label.items()):
                    if i % 12 == 0:
                        user_vec = self.user[key]
                        user_vec = torch.Tensor(user_vec[:143])
                        user_vec = self.user_encoder.encoder(user_vec)
                        user_winrate = torch.Tensor([item['win_rate']])
                        for champ in item['champion_history']:
                            if champ['play_count'] >= 10:
                                original_key = champ['champion_key']
                                hashed_key = self.champ_id_remap[original_key]
                                champion_global = global_win_rate[hashed_key]
                                global_win = torch.Tensor([champion_global])
                                item_vec = self.item[str(hashed_key)]
                                item_vec = torch.Tensor(item_vec)
                                item_vec = self.item_encoder.encoder(item_vec)
                                self.data.append(((user_vec, item_vec, user_winrate), global_win, champ['win_rate']))   
                                
                                
    def __getitem__(self, index):
        user_vec = self.data[index][0][0]
        item_vec = self.data[index][0][1]
        user_winrate = self.data[index][0][2]
        global_win = self.data[index][1]
        label = self.data[index][2]
        return (user_vec, item_vec, user_winrate), global_win, label
                                     
    def __len__(self):
        length = len(self.data)
        return length
        

In [4]:
user_path = './datasets/user_vectors_tf_idf.json'
item_path = './datasets/item_vectors_tf_idf.json'
label_path = './data_batch/userbatch.json'

#train_set = Subset(dataset, train_list)
#valid_set = Subset(dataset, valid_list)

train_set = WinRateDataset(user_path, item_path, label_path, global_win_rate, is_valid=False)
valid_set = WinRateDataset(user_path, item_path, label_path, global_win_rate, is_valid=True)

In [5]:
num_epochs = 20
batch_size = 32
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
learning_rate = 0.1

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=15)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, num_workers=4)

model = Predictor(user_len=12, item_len=10, hidden_unit=22).to(device)
criterion = nn.SmoothL1Loss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, weight_decay=1e-5)
scheduler = MultiStepLR(optimizer, [5, 10, 15], gamma=0.1)

loss = []
best_model_wts = None
best_loss = 100

for epoch in range(num_epochs):
    train_loss = 0.0
    count = 0
    for (user_vec, item_vec, user_winrate), global_win, label in train_loader:
        scheduler.step()
        user_vec = user_vec.to(device)
        item_vec = item_vec.to(device)
        global_win = global_win.to(device)
        user_winrate = user_winrate.to(device)
        label = label.float().to(device)
        # ===================forward=====================
        output = model(user_vec, item_vec, user_winrate, global_win)
        loss = criterion(output, label)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        count += label.size(0)
        #print(loss.item()/user_vec.size(0))
    # ===================log========================
    print('train epoch [{}/{}], loss:{:.8f}'
          .format(epoch + 1, num_epochs, train_loss/count))

    valid_loss = 0.0
    count = 0
    
    if (epoch+1) % 4 == 0 and epoch > 1:
        print('---------------------------')
        for (user_vec, item_vec, user_winrate), global_win, label in valid_loader:
            scheduler.step()
            user_vec = user_vec.to(device)
            item_vec = item_vec.to(device)
            global_win = global_win.to(device)
            user_winrate = user_winrate.to(device)
            label = label.float().to(device)
            # ===================forward=====================
            output = model(user_vec, item_vec, user_winrate, global_win)
            loss = criterion(output, label)
            # ===================backward====================
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            valid_loss += loss.item()
            count += label.size(0)
            if valid_loss < best_loss:
                print('best model so far!')
                best_loss = valid_loss
                best_model_wts = model.state_dict()
        # ===================log========================
        print('valid epoch, loss:{:.8f}'
              .format(valid_loss/count))
        print('----------------------------')

train epoch [1/20], loss:0.00066558
train epoch [2/20], loss:0.00046225
train epoch [3/20], loss:0.00035849
train epoch [4/20], loss:0.00030242
---------------------------
best model so far!
valid epoch, loss:0.00028052
----------------------------
train epoch [5/20], loss:0.00027128
train epoch [6/20], loss:0.00025681
train epoch [7/20], loss:0.00024900
train epoch [8/20], loss:0.00024523
---------------------------
best model so far!
valid epoch, loss:0.00023966
----------------------------
train epoch [9/20], loss:0.00024272
train epoch [10/20], loss:0.00024193
train epoch [11/20], loss:0.00024143
train epoch [12/20], loss:0.00024098
---------------------------
best model so far!
valid epoch, loss:0.00023661
----------------------------
train epoch [13/20], loss:0.00024082
train epoch [14/20], loss:0.00024076
train epoch [15/20], loss:0.00024072
train epoch [16/20], loss:0.00024058
---------------------------
best model so far!
valid epoch, loss:0.00023635
--------------------------

In [6]:
torch.save(model.state_dict(), './trained_model/predictor_baseline.pth')