# ü•ô Âü∫‰∫éÈ£üË∞±Êï∞ÊçÆÁöÑ LSTM ÊñáÊú¨ÁîüÊàê

Êú¨ Notebook Â∞Ü‰∏ÄÊ≠•Ê≠•ÊºîÁ§∫Â¶Ç‰Ωï‰ΩøÁî® **PyTorch** Âú®È£üË∞±Êï∞ÊçÆÈõÜ‰∏äËÆ≠ÁªÉ‰∏Ä‰∏™ LSTMÔºåÂπ∂ÁîüÊàêÊñ∞ÁöÑËèúË∞±ÊñáÊú¨„ÄÇ

In [None]:
# %%
%load_ext autoreload
%autoreload 2

import json
import re
import string
import random
import numpy as np
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)


## 0. Parameters <a name="parameters"></a>

In [None]:
VOCAB_SIZE = 10000
MAX_LEN = 200
EMBEDDING_DIM = 100
N_UNITS = 128
VALIDATION_SPLIT = 0.2
SEED = 42
LOAD_MODEL = False
BATCH_SIZE = 32
EPOCHS = 25

## 1. Load the data <a name="load"></a>

In [None]:
with open("/app/data/epirecipes/full_format_recipes.json") as f:
    recipe_data = json.load(f)

In [None]:
# ËøáÊª§Êó†ÊïàÊ†∑Êú¨ÔºåÂπ∂ÊãºÊé• title + directions
filtered_data = [
    "Recipe for " + x["title"] + " | " + " ".join(x["directions"])
    for x in recipe_data
    if x.get("title") and x.get("directions")
]


In [None]:
n_recipes = len(filtered_data)
print(f"{n_recipes} recipes loaded")

In [None]:
example = filtered_data[9]
print(example)

## 2. ÊñáÊú¨tokenÂåñ

In [None]:
def pad_punctuation(s):
    s = re.sub(f"([{string.punctuation}])", r" \1 ", s)
    s = re.sub(" +", " ", s)
    return s.lower()

text_data = [pad_punctuation(x) for x in filtered_data]

In [None]:
example_data = text_data[9]
example_data

## 3.ÊûÑÂª∫ËØçË°®

In [None]:
# ÂàÜËØç
tokenized_texts = [x.split() for x in text_data]

# ÁªüËÆ°ËØçÈ¢ë
counter = Counter()
for tokens in tokenized_texts:
    counter.update(tokens)

# ‰øùÁïôÊúÄÂ∏∏ËßÅÁöÑ VOCAB_SIZE ‰∏™ËØç
vocab = ["<pad>", "<unk>"] + [w for w, _ in counter.most_common(VOCAB_SIZE - 2)]
word_to_index = {w: i for i, w in enumerate(vocab)}
index_to_word = {i: w for w, i in word_to_index.items()}

In [None]:
# Êü•ÁúãÂâçÂá†‰∏™ token
for i in range(10):
    print(f"{i}: {index_to_word[i]}")

In [None]:
# Â∞ÜÊñáÊú¨ËΩ¨Êàê token idÔºåÂπ∂ËøõË°å padding / truncation
def encode(tokens):
    ids = [word_to_index.get(t, 1) for t in tokens]
    if len(ids) >= MAX_LEN + 1:
        return ids[: MAX_LEN + 1]
    return ids + [0] * (MAX_LEN + 1 - len(ids))

encoded_texts = [encode(t) for t in tokenized_texts]

In [None]:
# ÊòæÁ§∫Âêå‰∏Ä‰∏™Á§∫‰æãÁöÑ token id
print(encoded_texts[9])

## 3. ÊûÑÂª∫ËÆ≠ÁªÉÈõÜ

In [None]:
class RecipeDataset(Dataset):
    def __init__(self, encoded_texts):
        self.data = encoded_texts

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq = torch.tensor(self.data[idx], dtype=torch.long)
        x = seq[:-1]
        y = seq[1:]
        return x, y

dataset = RecipeDataset(encoded_texts)
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

## 4. ÊûÑÂª∫ LSTM Ê®°Âûã <a name="build"></a>

In [None]:
class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.fc(x)
        return x

model = LSTMModel(VOCAB_SIZE, EMBEDDING_DIM, N_UNITS).to(device)
print(model)

In [None]:
if LOAD_MODEL:
    model.load_state_dict(torch.load("./models/lstm.pt"))

## 5. Train the LSTM <a name="train"></a>

In [None]:
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=0)

In [None]:
class TextGenerator:
    def __init__(self, model, index_to_word, word_to_index):
        self.model = model
        self.index_to_word = index_to_word
        self.word_to_index = word_to_index

    def sample_from(self, probs, temperature):
        probs = probs ** (1 / temperature)
        probs = probs / probs.sum()
        return np.random.choice(len(probs), p=probs), probs

    def generate(self, start_prompt, max_tokens=50, temperature=1.0):
        self.model.eval()
        tokens = [self.word_to_index.get(w, 1) for w in start_prompt.split()]
        info = []

        while len(tokens) < max_tokens and tokens[-1] != 0:
            x = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
            with torch.no_grad():
                logits = self.model(x)
            probs = F.softmax(logits[0, -1], dim=0).cpu().numpy()
            token, p = self.sample_from(probs, temperature)
            info.append({"prompt": start_prompt, "word_probs": p})
            tokens.append(token)
            start_prompt += " " + self.index_to_word[token]

        print(f"\nÁîüÊàêÊñáÊú¨:\n{start_prompt}\n")
        return info

text_generator = TextGenerator(model, index_to_word, word_to_index)

In [None]:
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits.view(-1, VOCAB_SIZE), y.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{EPOCHS}] Loss: {total_loss/len(train_loader):.4f}")
    text_generator.generate("recipe for", max_tokens=50, temperature=1.0)

In [None]:
# ‰øùÂ≠òÊ®°Âûã
torch.save(model.state_dict(), "./models/lstm.pt")

## 6. Generate text using the LSTM

In [None]:
def print_probs(info, vocab, top_k=5):
    for i in info:
        print(f"\nPROMPT: {i['prompt']}")
        probs = i["word_probs"]
        idx = np.argsort(probs)[::-1][:top_k]
        for j in idx:
            print(f"{vocab[j]}:\t{np.round(100 * probs[j], 2)}%")
        print("--------")

In [None]:
info = text_generator.generate(
    "recipe for roasted vegetables | chop 1 /",
    max_tokens=10,
    temperature=1.0
)
print_probs(info, vocab)

In [None]:
info = text_generator.generate(
    "recipe for chocolate ice cream |",
    max_tokens=7,
    temperature=0.2
)
print_probs(info, vocab)