In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import json

In [30]:
# データセットの定義
class TextDataset(Dataset):
    def __init__(self, data, word_to_idx):
        self.data = data
        self.word_to_idx = word_to_idx

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]
        text_indices = [self.word_to_idx[word] for word in text]
        return torch.tensor(text_indices)

In [31]:
# LSTMモデルの定義
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output, hidden = self.lstm(embedded, hidden)
        output = self.fc(output.view(1, -1))
        return output, hidden

    def init_hidden(self):
        return torch.zeros(1, 1, self.hidden_size), torch.zeros(1, 1, self.hidden_size)

In [32]:
def generate_text(model, start_input, hidden, temperature, length):
    input = start_input
    generated_text = ""

    for _ in range(length):
        output, hidden = model(input, hidden)
        output_dist = output.squeeze().div(temperature).exp()
        predicted_word_idx = torch.multinomial(output_dist, 1)[0]
        predicted_word = idx_to_word[predicted_word_idx.item()]

        generated_text += predicted_word
        input = torch.tensor([[predicted_word_idx.item()]])
    
    return generated_text

In [92]:
# ハイパーパラメータの設定
hidden_size = 128  # 隠れ層のサイズ
num_epochs = 2
learning_rate = 0.01

# データの準備
json_files = ['./json-list-data/CaveCrawler-v0.json']
data = []
for json_file in json_files:
    with open(json_file, 'r') as f:
        json_data = json.load(f)
        json_data1 = json_data[4:]
    data.extend(json_data1)

words = [word for text in data for word in text]
vocab = list(set(words))

word_to_idx = {word: i for i, word in enumerate(vocab)}
idx_to_word = {i: word for i, word in enumerate(vocab)}

# データセットの作成
dataset = TextDataset(data, word_to_idx)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [93]:
input_size = len(dataset.word_to_idx)
output_size = len(dataset.word_to_idx)
model = LSTMModel(input_size, hidden_size, output_size)

# 損失関数と最適化手法の定義
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [94]:
# モデルの学習
for epoch in range(num_epochs):
    for batch_data in dataloader:
        optimizer.zero_grad()
        hidden = model.init_hidden()
        loss = 0
        for word in batch_data[0]:
            output, hidden = model(word, hidden)
            target = word.view(-1)
            loss += criterion(output, target)
        loss.backward()
        optimizer.step()
    print('Epoch: {}, Loss: {:.4f}'.format(epoch + 1, loss.item()))

Epoch: 1, Loss: 0.2396
Epoch: 2, Loss: 0.0169


In [123]:
# テキスト生成
hidden = model.init_hidden()
start_input = torch.tensor([[word_to_idx['-']]])  # 開始文字の指定
temperature = 1.6  # 温度パラメータの設定
X = 70
Y = 70
for i in range(1):
    generated_text = generate_text(model, start_input, hidden, temperature, length=980)
    result = [list(generated_text[i:i+X]) for i in range(0, len(generated_text), X)]
    result = [row[:Y] for row in result[:Y]]  # リストのサイズをX * Yに制限する
    result = [''.join(row) for row in result]
    print("##########################################################################")
    for m in result:
        print(m)

##########################################################################
-------------------HHHHSHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHSHHHHHHHHH
HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHSHHHHHHHHHHHHHHHHHHHHHHHH------
----------------------------------------------------------------------
----------------------------------------------------------------------
---------------------------------------------HHHHHHHHHHHHHHHHHHHHHHHHH
HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHS
HHSSHHHHHHHHHHHHHHHHHHHHHHHHHHSHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH
HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHSSHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH
HHHHHH--------------------------------------------HHHHHHHHHHHHHHHHHHHH
HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHSHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH
HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH-----
----------------------------------------------------------------------
----------------------------------------------------------------------
--