In [2]:
from datasets import load_dataset
import nltk
from nltk.tokenize import word_tokenize
from nltk.stem import PorterStemmer
from tensorflow.keras.layers import Input, Embedding, LSTM, Dense, Concatenate, Attention
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np
import wandb
from wandb.keras import WandbCallback
from nltk.translate.bleu_score import sentence_bleu
from rouge import Rouge
from sklearn.model_selection import train_test_split
import os
from tensorflow.keras.callbacks import ModelCheckpoint

# Download the Xsum dataset
dataset = load_dataset('xsum')


  from .autonotebook import tqdm as notebook_tqdm
Using custom data configuration default
Reusing dataset xsum (/home/guanhua_wu/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934)
100%|██████████| 3/3 [00:00<00:00, 212.33it/s]


In [5]:
# Initialize wandb
wandb.init(project="text_summarization", name="seq2seq_summarization_2")


In [6]:
# Preprocess the data
nltk.download('punkt')
stemmer = PorterStemmer()

def preprocess(text):
    # Tokenize the text
    tokens = word_tokenize(text)
    
    # Stem the tokens
    stemmed_tokens = [stemmer.stem(token) for token in tokens]
    
    # Join the stemmed tokens back into a single string
    preprocessed_text = ' '.join(stemmed_tokens)
    
    return preprocessed_text

preprocessed_data = []
for example in dataset['train']:
    article = example['document']
    summary = example['summary']
    preprocessed_article = preprocess(article)
    preprocessed_summary = preprocess(summary)
    preprocessed_data.append((preprocessed_article, preprocessed_summary))


[nltk_data] Downloading package punkt to /home/guanhua_wu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [7]:
# Reduce dataset size
train_data, _ = train_test_split(preprocessed_data, test_size=0.5, random_state=42)

In [8]:
# Create tokenizer and fit on texts
encoder_inputs_train = [example[0] for example in train_data]
decoder_inputs_train = ['<start> ' + example[1] for example in train_data]
decoder_outputs_train = [example[1] + ' <end>' for example in train_data]

tokenizer = Tokenizer(filters='', lower=False, split=' ')
tokenizer.fit_on_texts(encoder_inputs_train + decoder_inputs_train + decoder_outputs_train)

encoder_inputs_train = tokenizer.texts_to_sequences(encoder_inputs_train)
decoder_inputs_train = tokenizer.texts_to_sequences(decoder_inputs_train)
decoder_outputs_train = tokenizer.texts_to_sequences(decoder_outputs_train)

vocab_size = len(tokenizer.word_index) + 1
input_vocab_size = vocab_size
target_vocab_size = vocab_size

# Set the model configuration
config = wandb.config
config.input_vocab_size = input_vocab_size
config.target_vocab_size = target_vocab_size
config.embedding_dim = 128
config.lstm_units = 256
config.batch_size = 16
config.epochs = 10

# Set the maximum length for the encoder and decoder inputs
max_encoder_length = 150
max_decoder_length = 150

# Pad or truncate the sequences to the desired length
encoder_inputs_train = pad_sequences(encoder_inputs_train, maxlen=max_encoder_length, padding='post', truncating='post')
decoder_inputs_train = pad_sequences(decoder_inputs_train, maxlen=max_decoder_length, padding='post', truncating='post')
decoder_outputs_train = pad_sequences(decoder_outputs_train, maxlen=max_decoder_length, padding='post', truncating='post')

# Reshape the decoder_outputs_train
decoder_outputs_train = np.expand_dims(decoder_outputs_train, -1)

# Define the layers
encoder_inputs = Input(shape=(None,))
decoder_inputs = Input(shape=(None,))
encoder_embedding = Embedding(input_dim=input_vocab_size, output_dim=128)(encoder_inputs)
decoder_embedding = Embedding(input_dim=target_vocab_size, output_dim=128)(decoder_inputs)

encoder_outputs, state_h, state_c = LSTM(256, return_sequences=True, return_state=True)(encoder_embedding)
encoder_states = [state_h, state_c]

decoder_outputs, _, _ = LSTM(256, return_sequences=True, return_state=True)(decoder_embedding, initial_state=encoder_states)

attention = Attention()([encoder_outputs, decoder_outputs])
x = Concatenate(axis=2)([decoder_outputs, attention])

decoder_dense = Dense(target_vocab_size, activation='softmax')
decoder_outputs = decoder_dense(x)

model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

print("encoder_inputs_train shape:", encoder_inputs_train.shape)
print("decoder_inputs_train shape:", decoder_inputs_train.shape)
print("decoder_outputs_train shape:", decoder_outputs_train.shape)


encoder_inputs_train shape: (102022, 150)
decoder_inputs_train shape: (102022, 150)
decoder_outputs_train shape: (102022, 150, 1)


In [9]:
# Define the checkpoint directory
checkpoint_dir = 'checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Create a checkpoint callback
# Create a checkpoint callback
checkpoint_callback = ModelCheckpoint(
    filepath=os.path.join(checkpoint_dir, 'checkpoint_epoch_{epoch:03d}.hdf5'),
    save_weights_only=True,
    save_best_only=False,
    verbose=1,
    save_freq='epoch'  # Save after each epoch
)


In [None]:
# Train the model starting from epoch 3
model.fit([encoder_inputs_train, decoder_inputs_train],
          decoder_outputs_train,
          batch_size=config.batch_size,
          initial_epoch=2,
          epochs=10,
          callbacks=[WandbCallback(), checkpoint_callback])

In [10]:
# Load the checkpoint
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint_epoch_001.hdf5")
model.load_weights(checkpoint_path)

In [None]:
# Define the encoder model
encoder_model = Model(encoder_inputs, [encoder_outputs] + encoder_states)

# Define the decoder model
decoder_state_input_h = Input(shape=(256,))
decoder_state_input_c = Input(shape=(256,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]

decoder_outputs, state_h, state_c = LSTM(256, return_sequences=True, return_state=True)(
    decoder_embedding, initial_state=decoder_states_inputs
)
decoder_states = [state_h, state_c]

attention_outputs = attention([encoder_outputs, decoder_outputs])
x = Concatenate(axis=2)([decoder_outputs, attention_outputs])
decoder_outputs = decoder_dense(x)

decoder_model = Model(
    [decoder_inputs] + decoder_states_inputs + [encoder_outputs],
    [decoder_outputs] + decoder_states
)

def generate_summary_seq2seq(article):
    # Tokenize and pad the input article
    encoder_input = tokenizer.texts_to_sequences([article])
    encoder_input = pad_sequences(encoder_input, maxlen=max_encoder_length, padding='post', truncating='post')

    # Get the initial states from the encoder
    encoder_outputs, state_h, state_c = encoder_model.predict(encoder_input)
    encoder_states = [state_h, state_c]

    # Initialize the decoder input with the <start> token
    decoder_input = np.zeros((1, 1))
    decoder_input[0, 0] = tokenizer.word_index['<start>']

    stop_condition = False
    decoded_sentence = []

    while not stop_condition:
        # Get the next word probabilities and update the states
        decoder_outputs, state_h, state_c = decoder_model.predict([decoder_input] + encoder_states + [encoder_outputs])
        encoder_states = [state_h, state_c]

        # Choose the word with the highest probability
        next_word_index = np.argmax(decoder_outputs[0, -1, :])

        # Check if we've reached the <end> token or the maximum summary length
        if (next_word_index == tokenizer.word_index['<end>']) or (len(decoded_sentence) >= max_decoder_length):
            stop_condition = True
        else:
            decoded_sentence.append(tokenizer.index_word[next_word_index])

            # Update the decoder input with the chosen word
            decoder_input[0, 0] = next_word_index

    return ' '.join(decoded_sentence)

def evaluate_model(generate_summary_function):
    rouge = Rouge()
    bleu_score = 0
    dataset = load_dataset('xsum', split='test')
    predictions = []
    references = []

    for example in dataset:
        article = example["document"]
        summary = generate_summary_function(article)
        predictions.append(summary)
        references.append(example["summary"])

    rouge_scores = rouge.compute(predictions=predictions, references=references, rouge_types=["rouge1", "rouge2", "rougeL"])

    # Compute BLEU score
    for pred, ref in zip(predictions, references):
        pred_tokens = pred.split()
        ref_tokens = ref.split()
        bleu_score += sentence_bleu([ref_tokens], pred_tokens, smoothing_function=SmoothingFunction().method1)

    bleu_score = bleu_score / len(predictions)
    return rouge_scores, bleu_score

# Evaluate the LSTM-based model
seq2seq_rouge_scores, seq2seq_bleu_scores,= evaluate_model(generate_summary_seq2seq)
print("Seq2Seq Rouge Scores:", seq2seq_rouge_scores)
print("Seq2Seq BLEU Scores:", seq2seq_bleu_scores)

# Log Rouge and BLEU scores to WandB
wandb.log({"rouge1": seq2seq_rouge_scores['rouge1'].mid.fmeasure,
           "rouge2": seq2seq_rouge_scores['rouge2'].mid.fmeasure,
           "rougeL": seq2seq_rouge_scores['rougeL'].mid.fmeasure,
           "avg_bleu": seq2seq_bleu_scores})

# Finish the run
wandb.finish()