
# 使用 LSTM 模型进行文本生成

使用循环神经网络 (RNN) 中的LSTM (Long Short-Term Memory)模型，完成一个图到文的序列建模任务。


# Import necessary modules

In [None]:
# Basic libraries
import os
import pickle
import re
import json
import zipfile
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import warnings
warnings.filterwarnings('ignore')
from math import ceil
from collections import defaultdict
from tqdm.notebook import tqdm        # Progress bar library for Jupyter Notebook

# Deep learning framework for building and training models
import tensorflow as tf
## Pre-trained model for image feature extraction
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.preprocessing.image import load_img, img_to_array

## Tokenizer class for captions tokenization
from tensorflow.keras.preprocessing.text import Tokenizer

## Function for padding sequences to a specific length
from tensorflow.keras.preprocessing.sequence import pad_sequences

## Class for defining Keras models
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical, plot_model
from tensorflow.keras.layers import Input, Dense, LSTM, Embedding, Dropout, concatenate, Bidirectional, Dot, Activation, RepeatVector, Multiply, Lambda
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, CSVLogger

# For checking score
from nltk.translate.bleu_score import corpus_bleu

: 

In [None]:
# Setting the input and output directory
default_input_candidates = [
    os.environ.get("FLICKR8K_DIR"),
    os.path.join(os.getcwd(), "data", "flickr8k"),
    "/kaggle/input/flickr8k"
 ]
INPUT_DIR = None
for candidate in default_input_candidates:
    if candidate and os.path.isdir(candidate):
        INPUT_DIR = candidate
        break
if INPUT_DIR is None:
    INPUT_DIR = os.path.join(os.getcwd(), "data", "flickr8k")
os.makedirs(INPUT_DIR, exist_ok=True)

OUTPUT_DIR = os.environ.get("A5_OUTPUT_DIR", os.path.join(os.getcwd(), "outputs"))
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "results"), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "models"), exist_ok=True)
print(f"使用数据目录: {INPUT_DIR}")
print(f"实验输出目录: {OUTPUT_DIR}")

In [None]:
# 如果本地没有数据则尝试自动下载Flickr8k
images_dir = os.path.join(INPUT_DIR, 'Images')
captions_file = os.path.join(INPUT_DIR, 'captions.txt')
if not (os.path.isdir(images_dir) and os.path.isfile(captions_file)):
    print("检测到本地缺少Flickr8k文件，尝试从公开镜像下载......")
    dataset_url = "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip"
    captions_url = "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip"
    dataset_zip = tf.keras.utils.get_file("Flickr8k_Dataset.zip", origin=dataset_url, extract=False)
    captions_zip = tf.keras.utils.get_file("Flickr8k_text.zip", origin=captions_url, extract=False)
    download_target = os.path.join(os.getcwd(), "downloads")
    os.makedirs(download_target, exist_ok=True)
    zip_paths = [dataset_zip, captions_zip]
    for zip_src in zip_paths:
        zip_name = os.path.basename(zip_src)
        local_path = os.path.join(download_target, zip_name)
        if not os.path.exists(local_path):
            os.replace(zip_src, local_path)
        else:
            os.remove(zip_src)
        with zipfile.ZipFile(local_path, 'r') as zip_ref:
            zip_ref.extractall(download_target)

    extracted_images = os.path.join(download_target, 'Flicker8k_Dataset')
    extracted_captions_dir = os.path.join(download_target, 'Flickr8k_text')
    extracted_captions = os.path.join(extracted_captions_dir, 'Flickr8k.token.txt')

    if os.path.isdir(extracted_images):
        os.makedirs(images_dir, exist_ok=True)
        for filename in os.listdir(extracted_images):
            src_path = os.path.join(extracted_images, filename)
            dst_path = os.path.join(images_dir, filename)
            if not os.path.exists(dst_path):
                os.replace(src_path, dst_path)
    if os.path.isfile(extracted_captions) and not os.path.isfile(captions_file):
        os.makedirs(os.path.dirname(captions_file), exist_ok=True)
        os.replace(extracted_captions, captions_file)
    print("下载完成，如未自动检测到文件，请重新运行本单元。")
else:
    print("检测到完整的 Images 与 captions.txt，跳过下载。")

# Image Features Extraction

In [None]:
# We are going to use pretraind vgg model
# Load the vgg16 model
model = VGG16()

# Restructuring the model to remove the last classification layer, this will give us access to the output features of the model
model = Model(inputs=model.inputs, outputs=model.layers[-2].output)

# Printing the model summary
print(model.summary())

In [None]:
# Initialize an empty dictionary to store image features
features_path = os.path.join(OUTPUT_DIR, 'img_features.pkl')
image_features = {}

# Define the directory path where images are located
img_dir = os.path.join(INPUT_DIR, 'Images')

if os.path.isfile(features_path):
    print(f"检测到缓存特征: {features_path}")
    with open(features_path, 'rb') as feature_file:
        image_features = pickle.load(feature_file)
else:
    print("首次提取VGG16特征，耗时较长，请耐心等待……")
    # Loop through each image in the directory
    for img_name in tqdm(os.listdir(img_dir)):
        # Load the image from file
        img_path = os.path.join(img_dir, img_name)
        image = load_img(img_path, target_size=(224, 224))
        # Convert image pixels to a numpy array
        image = img_to_array(image)
        # Reshape the data for the model
        image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
        # Preprocess the image for ResNet50
        image = preprocess_input(image)
        # Extract features using the pre-trained ResNet50 model
        image_feature = model.predict(image, verbose=0)
        # Get the image ID by removing the file extension
        image_id = img_name.split('.')[0]
        # Store the extracted feature in the dictionary with the image ID as the key
        image_features[image_id] = image_feature
    with open(features_path, 'wb') as feature_file:
        pickle.dump(image_features, feature_file)
    print(f"图像特征已缓存至: {features_path}")

In [None]:
# Store the image features in pickle
features_path = os.path.join(OUTPUT_DIR, 'img_features.pkl')
if not os.path.isfile(features_path):
    with open(features_path, 'wb') as feature_file:
        pickle.dump(image_features, feature_file)
print(f"当前特征文件: {features_path}")

In [None]:
# Load features from pickle file
pickle_file_path = os.path.join(OUTPUT_DIR, 'img_features.pkl')
with open(pickle_file_path, 'rb') as file:
    loaded_features = pickle.load(file)

# Loading Caption Data

In [None]:
with open(os.path.join(INPUT_DIR, 'captions.txt'), 'r') as file:
    next(file)
    captions_doc = file.read()

In [None]:
# Create mapping of image to captions
image_to_captions_mapping = defaultdict(list)

# Process lines from captions_doc
for line in tqdm(captions_doc.split('\n')):
    line = line.strip()
    if not line:
        continue
    if ',' in line:
        image_part, caption_part = line.split(',', 1)
    else:
        # 兼容 Flickr8k.token.txt 形式，例如 image.jpg#0\tcaption
        token_parts = line.split('\t')
        image_part = token_parts[0]
        caption_part = token_parts[1] if len(token_parts) > 1 else ''
    image_id = image_part.split('.')[0].split('#')[0]
    caption = caption_part.strip()
    image_to_captions_mapping[image_id].append(caption)

# Print the total number of captions
total_captions = sum(len(captions) for captions in image_to_captions_mapping.values())
print("Total number of captions:", total_captions)

how many captions does each image have?

# Preprocessing Captions



In [None]:
# Function for processing the captions
def clean(mapping):
    for key, captions in mapping.items():
        for i in range(len(captions)):
            # Take one caption at a time
            caption = captions[i]

            # TODO caption预处理
            # Convert to lowercase
            caption = caption.lower()

            # Remove non-alphabetical characters
            caption = re.sub(r"[^a-z\s]", " ", caption)

            # Remove extra spaces
            caption = re.sub(r"\s+", " ", caption).strip()

            # Add unique start and end tokens to the caption
            caption = 'startseq ' + ' '.join([word for word in caption.split() if len(word) > 1]) + ' endseq'
            captions[i] = caption

In [None]:
# before preprocess of text
image_to_captions_mapping['1026685415_0431cbf574']

In [None]:
# preprocess the text
clean(image_to_captions_mapping)

In [None]:
# after preprocess of text
image_to_captions_mapping['1026685415_0431cbf574']

In [None]:
# Creating a List of All Captions
all_captions = [caption for captions in image_to_captions_mapping.values() for caption in captions]

In [None]:
all_captions[:10]

In [None]:
# Tokenizing the Text
tokenizer = Tokenizer()
tokenizer.fit_on_texts(all_captions)

In [None]:
# Save the tokenizer
with open('tokenizer.pkl', 'wb') as tokenizer_file:
    pickle.dump(tokenizer, tokenizer_file)

# Load the tokenizer
with open('tokenizer.pkl', 'rb') as tokenizer_file:
    tokenizer = pickle.load(tokenizer_file)

In [None]:
# Calculate maximum caption length
max_caption_length = max(len(tokenizer.texts_to_sequences([caption])[0]) for caption in all_captions)
vocab_size = len(tokenizer.word_index) + 1

# Print the results
print("Vocabulary Size:", vocab_size)
print("Maximum Caption Length:", max_caption_length)

# Train Test Split

In [None]:
# Creating a List of Image IDs
image_ids = list(image_to_captions_mapping.keys())
# Splitting into Training and Test Sets
split = int(len(image_ids) * 0.90)
train = image_ids[:split]
test = image_ids[split:]

In [None]:
# Data generator function
def data_generator(data_keys, image_to_captions_mapping, features, tokenizer, max_caption_length, vocab_size, batch_size):
    # Lists to store batch data
    X1_batch, X2_batch, y_batch = [], [], []
    # Counter for the current batch size
    batch_count = 0

    while True:
        # Loop through each image in the current batch
        for image_id in data_keys: 
            # Get the captions associated with the current image
            captions = image_to_captions_mapping[image_id]

            # Loop through each caption for the current image
            for caption in captions:
                # Convert the caption to a sequence of token IDs
                caption_seq = tokenizer.texts_to_sequences([caption])[0]

                # Loop through the tokens in the caption sequence
                for i in range(1, len(caption_seq)):
                    # Split the sequence into input and output pairs
                    in_seq, out_seq = caption_seq[:i], caption_seq[i]

                    # Pad the input sequence to the specified maximum caption length
                    in_seq = pad_sequences([in_seq], maxlen=max_caption_length)[0]

                    # Convert the output sequence to one-hot encoded format
                    out_seq = to_categorical([out_seq], num_classes=vocab_size)[0]

                    # Append data to batch lists
                    X1_batch.append(features[image_id][0])  # Image features
                    X2_batch.append(in_seq)  # Input sequence
                    y_batch.append(out_seq)  # Output sequence

                    # Increase the batch counter
                    batch_count += 1

                    # If the batch is complete, yield the batch and reset lists and counter
                    if batch_count == batch_size:
                        X1_batch, X2_batch, y_batch = np.array(X1_batch), np.array(X2_batch), np.array(y_batch)
                        yield [X1_batch, X2_batch], y_batch
                        X1_batch, X2_batch, y_batch = [], [], []
                        batch_count = 0

# LSTM Model Training
在下方代码中补全模型结构
模型包括：
1. 图像特征编码（Encoder）部分：  
   - 输入层、Dropout、全连接层、RepeatVector  
   - 双向 LSTM 将特征投射到序列空间  

2. 文本序列（Caption）输入部分：  
   - Embedding 层  
   - Dropout + 双向 LSTM  

3. 注意力机制（Attention）（可选）：  
   - 使用 `Dot` 计算注意力得分  
   - 使用 `Activation('softmax')` 进行归一化  
   - 用 `Lambda` 实现加权求和（`tf.einsum` 或 `tf.matmul`）  

4. 解码（Decoder）：  
   - 将上下文向量与编码特征拼接  
   - Dense 层输出词汇分布  
![LSTM-Architecture](https://raw.githubusercontent.com/yunjey/pytorch-tutorial/master/tutorials/03-advanced/image_captioning/png/model.png)

In [None]:
# Encoder model
embedding_dim = 300
lstm_units = 256
inputs1 = Input(shape=(4096,), name='image_features')
fe1 = Dropout(0.5)(inputs1)
fe2 = Dense(512, activation='relu')(fe1)
fe3 = RepeatVector(max_caption_length)(fe2)
fe4 = Bidirectional(LSTM(lstm_units, return_sequences=True))(fe3)

# Sequence feature layers
inputs2 = Input(shape=(max_caption_length,), name='caption_input')
se1 = Embedding(vocab_size, embedding_dim, mask_zero=True)(inputs2)
se2 = Dropout(0.5)(se1)
se3 = Bidirectional(LSTM(lstm_units, return_sequences=True))(se2)

# Attention mechanism
attention_scores = Dot(axes=[2, 2], name='attention_scores')([se3, fe4])
attention_weights = Activation('softmax', name='attention_weights')(attention_scores)
context = Lambda(lambda tensors: tf.matmul(tensors[0], tensors[1]), name='context_vector')([attention_weights, fe4])

# Decoder部分
decoder_input = concatenate([context, se3], name='decoder_concat')
decoder_lstm = LSTM(lstm_units * 2)(decoder_input)
decoder_dropout = Dropout(0.5)(decoder_lstm)
decoder_dense = Dense(512, activation='relu')(decoder_dropout)

outputs = Dense(vocab_size, activation='softmax', name='word_predictions')(decoder_dense)

# Create the model
model = Model(inputs=[inputs1, inputs2], outputs=outputs) 
model.compile(loss='categorical_crossentropy', optimizer='adam')

plot_model(model, show_shapes=True)

In [None]:
# Set the number of epochs, batch size
epochs = 50
batch_size = 32

# Calculate the steps_per_epoch based on the number of batches in one epoch
steps_per_epoch = ceil(len(train) / batch_size)
validation_steps = ceil(len(test) / batch_size)  # Calculate the steps for validation data

train_generator = data_generator(train, image_to_captions_mapping, loaded_features, tokenizer, max_caption_length, vocab_size, batch_size)
val_generator = data_generator(test, image_to_captions_mapping, loaded_features, tokenizer, max_caption_length, vocab_size, batch_size)

checkpoint_path = os.path.join(OUTPUT_DIR, "models", "best_cnn_lstm.h5")
csv_log_path = os.path.join(OUTPUT_DIR, "results", "training_log.csv")

callbacks = [
    ModelCheckpoint(checkpoint_path, monitor='val_loss', save_best_only=True, save_weights_only=False, verbose=1),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, verbose=1),
    EarlyStopping(monitor='val_loss', patience=6, restore_best_weights=True, verbose=1),
    CSVLogger(csv_log_path, append=False)
 ]

history = model.fit(
    train_generator,
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_generator,
    validation_steps=validation_steps,
    callbacks=callbacks,
    verbose=1
)

history_records = {key: [float(v) for v in values] for key, values in history.history.items()}
history_path = os.path.join(OUTPUT_DIR, "results", "train_history.json")
with open(history_path, 'w', encoding='utf-8') as history_file:
    json.dump(history_records, history_file, indent=2, ensure_ascii=False)

In [None]:
# Plot and save training curves
loss_curve_path = os.path.join(OUTPUT_DIR, 'results', 'train_val_loss_curve.png')
plt.figure(figsize=(8, 5))
plt.plot(history_records.get('loss', []), label='Train Loss')
plt.plot(history_records.get('val_loss', []), label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training vs Validation Loss')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.4)
plt.tight_layout()
plt.savefig(loss_curve_path, dpi=200)
plt.show()
print(f"Loss 曲线已保存: {loss_curve_path}")

In [None]:
# Save the model
final_model_path = os.path.join(OUTPUT_DIR, 'models', 'final_cnn_lstm.h5')
model.save(final_model_path)
print(f"最终模型已保存到: {final_model_path}")
print(f"最优模型（按val_loss）保存在: {checkpoint_path}")

# Captions Generation
完成自回归生成caption部分

In [None]:
def get_word_from_index(index, tokenizer):
    return next((word for word, idx in tokenizer.word_index.items() if idx == index), None)

In [None]:
def predict_caption(model, image_features, tokenizer, max_caption_length):
    # Initialize the caption sequence
    caption = 'startseq'

    for _ in range(max_caption_length):
        sequence = tokenizer.texts_to_sequences([caption])[0]
        sequence = pad_sequences([sequence], maxlen=max_caption_length)
        yhat = model.predict([image_features, sequence], verbose=0)
        next_index = int(np.argmax(yhat[0]))
        predicted_word = get_word_from_index(next_index, tokenizer)
        if predicted_word is None:
            break
        caption += " " + predicted_word
        if predicted_word == 'endseq':
            break

    tokens = caption.split()
    if len(tokens) <= 2:
        return ''
    if tokens[-1] == 'endseq':
        tokens = tokens[1:-1]
    else:
        tokens = tokens[1:]
    return ' '.join(tokens)

In [None]:
# Initialize lists to store actual and predicted captions
actual_captions_list = []
predicted_captions_list = []

# Loop through the test data
for key in tqdm(test):
    # Get actual captions for the current image
    actual_captions = image_to_captions_mapping[key]
    # Predict the caption for the image using the model
    predicted_caption = predict_caption(model, loaded_features[key], tokenizer, max_caption_length)
    
    # Split actual captions into words (remove start/end tokens)
    actual_captions_words = [
        caption.replace('startseq ', '').replace(' endseq', '').split()
        for caption in actual_captions
    ]
    # Split predicted caption into words
    predicted_caption_words = predicted_caption.split()
    
    # Append to the lists
    actual_captions_list.append(actual_captions_words)
    predicted_captions_list.append(predicted_caption_words)

# Calculate BLEU score
print("BLEU-1: %f" % corpus_bleu(actual_captions_list, predicted_captions_list, weights=(1.0, 0, 0, 0)))
print("BLEU-2: %f" % corpus_bleu(actual_captions_list, predicted_captions_list, weights=(0.5, 0.5, 0, 0)))
print("BLEU-4: %f" % corpus_bleu(actual_captions_list, predicted_captions_list, weights=(0.25, 0.25, 0.25, 0.25)))

In [None]:
# 保存评测指标
metrics = {
    'BLEU-1': corpus_bleu(actual_captions_list, predicted_captions_list, weights=(1.0, 0, 0, 0)),
    'BLEU-2': corpus_bleu(actual_captions_list, predicted_captions_list, weights=(0.5, 0.5, 0, 0)),
    'BLEU-4': corpus_bleu(actual_captions_list, predicted_captions_list, weights=(0.25, 0.25, 0.25, 0.25))
}
metrics_path = os.path.join(OUTPUT_DIR, 'results', 'bleu_metrics.json')
with open(metrics_path, 'w', encoding='utf-8') as metric_file:
    json.dump(metrics, metric_file, indent=2)
print(f"BLEU 指标已保存到: {metrics_path}")
metrics

# Predicting captions for Images

In [None]:
def find_image_path(image_identifier):
    image_id = image_identifier.split('.')[0]
    search_candidates = [f"{image_id}.jpg", f"{image_id}.jpeg", f"{image_id}.png"]
    for candidate in search_candidates:
        candidate_path = os.path.join(INPUT_DIR, "Images", candidate)
        if os.path.exists(candidate_path):
            return candidate_path
    raise FileNotFoundError(f"未找到图像 {image_identifier}，请检查文件是否存在")

In [None]:
# Function for generating caption
def generate_caption(image_identifier):
    image_id = image_identifier.split('.')[0]
    img_path = find_image_path(image_identifier)
    image = Image.open(img_path)
    captions = image_to_captions_mapping[image_id]
    print('---------------------Actual---------------------')
    for caption in captions:
        print(caption.replace('startseq ', '').replace(' endseq', ''))
    # predict the caption
    y_pred = predict_caption(model, loaded_features[image_id], tokenizer, max_caption_length)
    print('--------------------Predicted--------------------')
    print(y_pred)
    plt.imshow(image)
    plt.axis('off')
    return y_pred

In [None]:
# 生成至少10张图像的定性结果并保存
sample_image_ids = test[:10]
qualitative_records = []
cols = 2
rows = ceil(len(sample_image_ids) / cols)
plt.figure(figsize=(12, rows * 5))
for idx, image_id in enumerate(sample_image_ids):
    plt.subplot(rows, cols, idx + 1)
    image_name = f"{image_id}.jpg"
    img_path = find_image_path(image_name)
    image = Image.open(img_path)
    predicted_caption = predict_caption(model, loaded_features[image_id], tokenizer, max_caption_length)
    gt_caption = image_to_captions_mapping[image_id][0].replace('startseq ', '').replace(' endseq', '')
    plt.imshow(image)
    plt.axis('off')
    plt.title(f"Pred: {predicted_caption}\nGT: {gt_caption}", fontsize=9)
    qualitative_records.append({
        'image_id': image_id,
        'predicted': predicted_caption,
        'ground_truth': gt_caption
    })
plt.tight_layout()
qualitative_img_path = os.path.join(OUTPUT_DIR, 'results', 'qualitative_examples.png')
plt.savefig(qualitative_img_path, dpi=200)
plt.show()
print(f"定性结果图已保存到: {qualitative_img_path}")

qualitative_json_path = os.path.join(OUTPUT_DIR, 'results', 'qualitative_examples.json')
with open(qualitative_json_path, 'w', encoding='utf-8') as q_file:
    json.dump(qualitative_records, q_file, indent=2, ensure_ascii=False)
print(f"定性结果数据已保存到: {qualitative_json_path}")

In [None]:
generate_caption("101669240_b2d3e7f17b.jpg")

In [None]:
generate_caption("1077546505_a4f6c4daa9.jpg")

In [None]:
generate_caption("1002674143_1b742ab4b8.jpg")

In [None]:
generate_caption("1032460886_4a598ed535.jpg")

In [None]:
generate_caption("1032122270_ea6f0beedb.jpg")

## 拓展实验（可选）

1. 尝试将LSTM替换为 GRU/Transformer 模型，比较结果。
2. 查找其他captioning任务的评测指标并实现，对结果进行评测。
   - 本实验额外记录了 BLEU-4 指标，后续可继续引入 METEOR、CIDEr 等指标以获得更全面的评估。