# Train Classification Model for Ontology Generation
---

In [1]:
import pickle
import random
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

## Definition of Model
---

### Dataset Class
- 入力データを制御するためのクラス。
- データをvectorizerによりベクトル化
- ベクトルと教師ラベルをtorch.tensor型に変換

In [2]:
 class OntDataset(data.Dataset):
    def __init__(self, dataset, vectorizer, device='cpu'):
        self.dataset = dataset
        self.vectorizer = vectorizer
        self.device = device

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        a, b, l = self.dataset[idx]
        a = torch.from_numpy(self.vectorizer[a]).float().to(device)
        b = torch.from_numpy(self.vectorizer[b]).float().to(device)
        l = torch.tensor(l).to(device)
        return ((a, b), l)

In [3]:
class SimpleConcat(nn.Module):
    def __init__(self):
        super(SimpleConcat, self).__init__()

    def forward(self, x):
        a, b = x
        batch_size = a.size()[0]
        a, b = a.reshape(batch_size, -1), b.reshape(batch_size, -1)
        x = torch.cat([a, b], dim=1)
        return x

In [4]:
class Ont(nn.Module):
    def __init__(self, concat, x_size, h_size=300, y_size=4, drop_rate=0.5):
        super(Ont, self).__init__()
        self.concat = concat
        self.l1 = nn.Linear(x_size, h_size)
        self.l2 = nn.Linear(h_size, h_size)
        self.l3 = nn.Linear(h_size, y_size)
        self.drop = nn.Dropout(drop_rate)

    def forward(self, x):
        x = self.concat(x)
        h = self.drop(F.relu(self.l1(x)))
        h = self.drop(F.relu(self.l2(h)))
        t = self.l3(h)
        return t

## Training & Validation
---

In [5]:

max_epoch  = 50
batch_size = 1024

h_size = 512
drop_rate = 0.2

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [7]:
# Loading
vectorizer = pd.read_pickle('../dataset/vectorizer_w2v.pkl')
train_df   = pd.read_pickle('../dataset/wordnet_train.pkl')
valid_df   = pd.read_pickle('../dataset/wordnet_valid.pkl')

train_dataset = OntDataset(train_df, vectorizer, device=device)
valid_dataset = OntDataset(valid_df, vectorizer, device=device)

train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_loader = data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

# Calculate input size
max_seq_len, vec_size = vectorizer['example'].shape
x_size = 2 * max_seq_len * vec_size

In [8]:
model = Ont(SimpleConcat(), x_size=x_size, h_size=h_size, drop_rate=drop_rate).to(device)

loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

In [9]:
def train():
    model.train()
    
    epoch_loss = 0
    for x, y in train_loader:
        t = model(x)
        loss = loss_func(t, y)
        epoch_loss += loss.cpu().item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # if (i + 1) % 10 == 0: print(f'{i + 1:>4} : {loss.cpu().item():>6.3}')
    return epoch_loss

In [12]:
def valid():
    model.eval()
    
    epoch_accu = 0
    with torch.no_grad():
        for x, y in valid_loader:
            t = model(x)
            _, t = torch.max(t.data, 1)
            epoch_accu += sum(1 for t_i, y_i in zip(t, y) if t_i == y_i)
    
    return epoch_accu / len(valid_dataset)

In [13]:
max_epoch = 50

for epoch in range(max_epoch):
    loss = train()
    accu = valid()
    print(f'{epoch:0>2} | loss : {loss:>7.3f} | accu : {accu:.2%}')

00 | loss : 499.772 | accu : 79.79%
01 | loss : 448.687 | accu : 80.42%
02 | loss : 414.161 | accu : 81.08%
03 | loss : 387.717 | accu : 81.32%
04 | loss : 367.570 | accu : 81.52%
05 | loss : 350.533 | accu : 81.64%
06 | loss : 337.531 | accu : 81.71%
07 | loss : 326.037 | accu : 81.93%
08 | loss : 315.160 | accu : 81.91%
09 | loss : 305.910 | accu : 82.00%
10 | loss : 298.562 | accu : 82.13%
11 | loss : 291.519 | accu : 82.24%
12 | loss : 284.738 | accu : 82.15%
13 | loss : 278.571 | accu : 82.24%
14 | loss : 273.338 | accu : 82.38%
15 | loss : 268.360 | accu : 82.47%
16 | loss : 264.086 | accu : 82.52%
17 | loss : 259.080 | accu : 82.52%
18 | loss : 256.658 | accu : 82.53%
19 | loss : 251.655 | accu : 82.50%
20 | loss : 248.480 | accu : 82.47%
21 | loss : 244.744 | accu : 82.45%
22 | loss : 242.098 | accu : 82.50%
23 | loss : 239.434 | accu : 82.50%
24 | loss : 235.964 | accu : 82.51%
25 | loss : 233.702 | accu : 82.54%
26 | loss : 231.351 | accu : 82.60%
27 | loss : 229.207 | accu :