In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import pickle
import csv
import matplotlib.pyplot as plt
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import accuracy_score
from language_model import PhonemeLangModel
from phoneme_dataset import PhonemeDataset, collate_fn, dataset_spliter, MaskedPhonemeDataset

from torchvision import transforms
from trainer_phoneme import Trainer, MaskedTrainer
import json

In [2]:
BATCH_SIZE = 128
EPOCHS = 20
dict_path = 'data/label.pkl'
labels_path = 'data/phoneme.csv'

In [3]:
with open(dict_path, 'rb') as f:
    phone_dict = pickle.load(f)
with open(labels_path, 'r') as f:
    csv_reader = csv.reader(f)
    datas = list(csv_reader)

In [4]:
additional_file = './data/phoneme_data.txt'
with open(additional_file, 'r') as f:
    sentenses = f.readlines()

In [5]:
dict_path = 'data/phoneme_dict.json'
with open(dict_path, 'r') as f:
    phone_dict = json.load(f)

In [6]:
datas = [s.replace('\n', '').split() for s in sentenses]

In [7]:
datas_train, datas_valid, datas_test = dataset_spliter(datas)
dataset_train = MaskedPhonemeDataset(phone_dict, datas_train)
dataset_valid = MaskedPhonemeDataset(phone_dict, datas_valid)
dataset_test  = MaskedPhonemeDataset(phone_dict, datas_test)

train: 730554 valid: 243518 test: 243518


In [8]:
dataloader_train = DataLoader(dataset_train,
                              shuffle=True,
                              batch_size=BATCH_SIZE,
                              collate_fn=collate_fn,
                              drop_last=True)
dataloader_valid = DataLoader(dataset_valid,
                              batch_size=BATCH_SIZE,
                              collate_fn=collate_fn,
                              drop_last=True)
dataloader_test  = DataLoader(dataset_test,
                              batch_size=BATCH_SIZE,
                              collate_fn=collate_fn,
                              drop_last=True)

In [9]:
model = PhonemeLangModel().to('cuda:0')

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
trainer = MaskedTrainer(model, dataloader_train, dataloader_valid, criterion, optimizer)
trainer.train(EPOCHS, BATCH_SIZE)

  0%|          | 0/20 [00:00<?, ?it/s]

train: 0


  0%|          | 0/5707 [00:00<?, ?it/s]

In [None]:
save_path = './lstm_lang_model.pth'
torch.save(model.to('cpu').state_dict(), save_path)

In [None]:
import pandas as pd
loss_dict = {'train_loss': trainer.train_loss_list,
             'valid_loss': trainer.valid_loss_list,
             'train_acc': trainer.train_acc_list,
             'valid_acc': trainer.valid_acc_list}
df = pd.DataFrame(loss_dict)
df.to_csv('./lstm_lang_model_losses.csv')

In [None]:
df = pd.read_csv('./lstm_lang_model_losses.csv')
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
fig, ax1 = plt.subplots()
x = np.array([i for i in range(EPOCHS)])
ax1.set_xlabel('epochs')
ax1.plot(x, df['train_loss'], label='train_loss')
ax1.plot(x, df['valid_loss'], label='valid_loss')
ax1.legend()

ax2 = ax1.twinx()
ax2.plot(x, df['train_acc'], label='train_acc')
ax2.plot(x, df['valid_acc'], label='valid_acc')
ax2.legend()

ax1.set_ylabel('loss')
ax2.set_ylabel('acc')
plt.savefig('./graph.png')
plt.show()

In [None]:
save_path = './lstm_lang_model.pth'
model = PhonemeLangModel()
model.load_state_dict(torch.load(save_path))