In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import pickle
import csv
import matplotlib.pyplot as plt
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import accuracy_score
from language_model import PhonemeLangModel
from phoneme_dataset import PhonemeDataset, collate_fn, dataset_spliter, MaskedPhonemeDataset, PAD_NUM

from torchvision import transforms
from trainer_phoneme import Trainer, MaskedTrainer
import json
from tqdm import tqdm_notebook as tqdm
from sklearn.metrics import accuracy_score

In [17]:
BATCH_SIZE = 128
EPOCHS = 1000
data_path = './data/ROHAN_labels.txt'

In [5]:
with open(data_path, 'r') as f:
    sentenses = f.readlines()
datas = [s.replace('\n', '').split()[1:] for s in sentenses]
print(datas[0])

In [6]:
dict_path = 'data/phoneme_dict.json'
with open(dict_path, 'r') as f:
    phone_dict = json.load(f)

In [8]:
datas_train, datas_valid, datas_test = dataset_spliter(datas)
dataset_train = MaskedPhonemeDataset(phone_dict, datas_train)
dataset_valid = MaskedPhonemeDataset(phone_dict, datas_valid)
dataset_test  = MaskedPhonemeDataset(phone_dict, datas_test)

In [9]:
dataloader_train = DataLoader(dataset_train,
                              shuffle=True,
                              batch_size=BATCH_SIZE,
                              collate_fn=collate_fn,
                              drop_last=True)
dataloader_valid = DataLoader(dataset_valid,
                              batch_size=BATCH_SIZE,
                              collate_fn=collate_fn,
                              drop_last=True)
dataloader_test  = DataLoader(dataset_test,
                              batch_size=BATCH_SIZE,
                              collate_fn=collate_fn,
                              drop_last=True)

In [18]:
save_path = './checkpoints_aoyama/lstm_lang_model.pth'
model = PhonemeLangModel()
model.load_state_dict(torch.load(save_path))
model = model.to('cuda:0')

In [19]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
trainer = MaskedTrainer(model, dataloader_train, dataloader_valid, criterion, optimizer)
trainer.train(EPOCHS, BATCH_SIZE)

In [20]:
save_path = './checkpoints_aoyama/lstm_lang_model_ROHAN.pth'
torch.save(model.to('cpu').state_dict(), save_path)

In [21]:
import pandas as pd
loss_dict = {'train_loss': trainer.train_loss_list,
             'valid_loss': trainer.valid_loss_list,
             'train_acc': trainer.train_acc_list,
             'valid_acc': trainer.valid_acc_list}
df = pd.DataFrame(loss_dict)
df.to_csv('./checkpoints_aoyama/lstm_lang_model_losses_ROHAN_fine.csv')

In [22]:
df = pd.read_csv('./checkpoints_aoyama/lstm_lang_model_losses_ROHAN_fine.csv')
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
fig, ax1 = plt.subplots()
x = np.array([i for i in range(EPOCHS)])
ax1.set_xlabel('epochs')
ax1.plot(x, df['train_loss'], label='train_loss')
ax1.plot(x, df['valid_loss'], label='valid_loss')
ax1.legend()

ax2 = ax1.twinx()
ax2.plot(x, df['train_acc'], label='train_acc')
ax2.plot(x, df['valid_acc'], label='valid_acc')
ax2.legend()

ax1.set_ylabel('loss')
ax2.set_ylabel('acc')
plt.savefig('./graph.png')
plt.show()

In [24]:
save_path = './checkpoints_aoyama/lstm_lang_model_ROHAN.pth'
model = PhonemeLangModel()
model.load_state_dict(torch.load(save_path))
model = model.to('cuda:0')

In [27]:
def generate_sentence(model, batch):
    model.eval()
    datas, labels = batch
    datas = datas.to(model.device)
    labels = labels.to(model.device)
    softmax = nn.Softmax(dim=1)
    acc = 0.0
    with torch.no_grad():
        states = (torch.zeros(1, BATCH_SIZE, model.hidden_dim_encoder).to(model.device),
                  torch.zeros(1, BATCH_SIZE, model.hidden_dim_encoder).to(model.device))
        outputs, _ = model(datas, states)
        for output, label in zip(outputs, labels):
            output = torch.argmax(output, dim=1).tolist()
            acc += accuracy_score(label.tolist(), output)
            # print(acc)
            pad_index = output.index(PAD_NUM) if PAD_NUM in output else False
            if pad_index:
                output = torch.Tensor(output[:pad_index]).to(device=model.device, dtype=torch.long)
            # print(f'output: {output}')
            # print(f'label : {label}')
    return acc

In [28]:
for batch in tqdm(dataloader_test, total=len(dataloader_train)):
    acc = generate_sentence(model, batch)
    print(acc/len(batch))