In [None]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from collections import Counter
from torch.utils.data import Dataset, DataLoader
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")

torch.manual_seed(42)
np.random.seed(42)

DATA_PATH = '/kaggle/input/short-jokes/shortjokes.csv'

if os.path.exists(DATA_PATH):
    print(f"成功找到数据文件: {DATA_PATH}")
else:
    print(f"错误：文件不存在 {DATA_PATH}")
    print("请检查 /kaggle/input/short-jokes/ 目录下的文件:")
    print(os.listdir('/kaggle/input/short-jokes/'))


使用设备: cuda
成功找到数据文件: /kaggle/input/short-jokes/shortjokes.csv


In [None]:
class MyDataset(Dataset):
    def __init__(self):
        self.listOfWords = self.loadWords()
        self.listOfUniqueWords = self.obtainUniqueWords()
        self.id2word = {i: w for i, w in enumerate(self.listOfUniqueWords)}
        self.word2id = {w: i for i, w in enumerate(self.listOfUniqueWords)}

        unk_id = self.word2id['<unk>']
        self.listOfIds = [self.word2id.get(w, unk_id) for w in self.listOfWords]

    def loadWords(self):
        try:
            csvData = pd.read_csv(DATA_PATH)
            joke_column = None
            for col in csvData.columns:
                if 'joke' in col.lower():
                    joke_column = col
                    break
            
            if joke_column is None:
                joke_column = csvData.columns[0]
                print(f"警告：未找到包含'joke'的列，默认使用第一列: '{joke_column}'")
            else:
                print(f"成功找到笑话列: '{joke_column}'")

            text = csvData[joke_column].str.cat(sep=' ').lower()
            return text.split(' ')
            
        except Exception as e:
            print(f"加载数据时出错: {e}")
            return []

    def obtainUniqueWords(self):
        wordCounts = Counter(self.listOfWords)
        min_count = 2
        filtered_words = {word: count for word, count in wordCounts.items() if count >= min_count}
        print(f"词汇过滤前: {len(wordCounts)}, 过滤后(出现>=2次): {len(filtered_words)}")
        sorted_words = sorted(filtered_words, key=filtered_words.get, reverse=True)

        return ['<unk>'] + sorted_words

    def __len__(self):
        return len(self.listOfIds) - 4

    def __getitem__(self, index):
        return (
            torch.tensor(self.listOfIds[index:index+4]),
            torch.tensor(self.listOfIds[index+1:index+4+1])
        )
    
print("\n--- 创建数据集 (修正后) ---")
dataset = MyDataset()
VOCAB_SIZE = len(dataset.listOfUniqueWords)
print(f"总词数: {len(dataset.listOfWords)}")
print(f"最终词汇表大小 (包含<unk>): {VOCAB_SIZE}")
print(f"可训练样本数: {len(dataset)}")



--- 创建数据集 (修正后) ---
成功找到笑话列: 'Joke'
词汇过滤前: 191284, 过滤后(出现>=2次): 77741
总词数: 4071141
最终词汇表大小 (包含<unk>): 77742
可训练样本数: 4071137


In [None]:
class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super(LSTMModel, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)

        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)

        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.hidden_dim = hidden_dim

    def forward(self, x, hidden=None):
        embedded = self.embedding(x)  

        lstm_out, hidden = self.lstm(embedded, hidden)  

        output = self.fc(lstm_out) 
        return output, hidden

    def init_hidden(self, batch_size):
        h0 = torch.zeros(1, batch_size, self.hidden_dim).to(device)
        c0 = torch.zeros(1, batch_size, self.hidden_dim).to(device)
        return (h0, c0)

print("模型类定义完成！")


模型类定义完成！


In [None]:
BATCH_SIZE = 64      
EMBED_DIM = 128      
HIDDEN_DIM = 256     
EPOCHS = 10          
LEARNING_RATE = 0.001

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

model = LSTMModel(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(f"\n模型参数数量: {sum(p.numel() for p in model.parameters()):,}")

print("\n--- 开始训练 ---")
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        
        outputs, _ = model(inputs)
 
        loss = criterion(outputs.view(-1, VOCAB_SIZE), targets.view(-1))

        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
    avg_loss = total_loss / len(dataloader)
    print(f'Epoch {epoch+1}/{EPOCHS}, 平均损失: {avg_loss:.4f}')

print("--- 训练完成 ---")

MODEL_SAVE_PATH = 'lstm_joke_generator.pth'
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"模型已保存到: {MODEL_SAVE_PATH}")



模型参数数量: 30,325,934

--- 开始训练 ---
Epoch 1/10, 平均损失: 5.5893
Epoch 2/10, 平均损失: 5.0876
Epoch 3/10, 平均损失: 4.9231
Epoch 4/10, 平均损失: 4.8245
Epoch 5/10, 平均损失: 4.7674
Epoch 6/10, 平均损失: 4.7304
Epoch 7/10, 平均损失: 4.6956
Epoch 8/10, 平均损失: 4.6733
Epoch 9/10, 平均损失: 4.6648
Epoch 10/10, 平均损失: 4.6544
--- 训练完成 ---
模型已保存到: lstm_joke_generator.pth


In [None]:
loaded_model = LSTMModel(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM).to(device)
loaded_model.load_state_dict(torch.load(MODEL_SAVE_PATH))
loaded_model.eval()

print("\n模型加载成功！")
def generate_text(model, start_text, max_len=10):
    model.eval()
    words = start_text.lower().split()

    hidden = model.init_hidden(1)
    
    with torch.no_grad():
        for i in range(max_len):
            input_seq = words[-4:]
            input_ids = torch.tensor([[dataset.word2id.get(w, 0) for w in input_seq]]).to(device)
            output, hidden = model(input_ids, hidden)

            last_word_logits = output[0, -1, :]

            p = torch.nn.functional.softmax(last_word_logits, dim=0).cpu().numpy()

            word_index = np.random.choice(len(p), p=p)

            words.append(dataset.id2word[word_index])
            
    return ' '.join(words)

input_text = "If life gives you melons"
generated_text = generate_text(loaded_model, input_text, max_len=15)

print(f"\n输入: '{input_text}'")
print(f"生成: '{generated_text}'")

input_text_2 = "why did the chicken"
generated_text_2 = generate_text(loaded_model, input_text_2, max_len=15)
print(f"\n输入: '{input_text_2}'")
print(f"生成: '{generated_text_2}'")

test_inputs = [
    "What do you call",
    "I went to the",
    "There was a",
    "My doctor told me"
]

print("\n--- 更多生成示例 ---")
for text in test_inputs:
    generated = generate_text(loaded_model, text, max_len=12)
    print(f"输入: '{text}'")
    print(f"生成: '{generated}'")
    print("-" * 50)



模型加载成功！

输入: 'If life gives you melons'
生成: 'if life gives you melons you bees ever eat a job i tattooed on his <unk> network. tonight and it'

输入: 'why did the chicken'
生成: 'why did the chicken cross the road? it got them to have to ever knows it still hasn't got'

--- 更多生成示例 ---
输入: 'What do you call'
生成: 'what do you call a kid on 14. teacher promoted to <unk> jared fogle in the'
--------------------------------------------------
输入: 'I went to the'
生成: 'i went to the doctor away... for gas in her pocket and nine advice sign. voting'
--------------------------------------------------
输入: 'There was a'
生成: 'there was a christian bale a woman to say it was dating my money for'
--------------------------------------------------
输入: 'My doctor told me'
生成: 'my doctor told me i knew a mexican, apology to <unk> yay looks man &amp; a'
--------------------------------------------------
