In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import Input, LSTM, Embedding, Dense
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
import pickle

# GPU Settings and Session Cleanup
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print("Error enabling GPU memory growth:", e)
tf.keras.backend.clear_session()

# Hyperparameters and Vocabulary Settings
latent_dim = 256         
embedding_dim = 128      
max_encoder_seq_length = 500   
max_decoder_seq_length = 100   
num_encoder_tokens = 20000     
num_decoder_tokens = 20000     

# Build the Citation Explanation Model

# Encoder: 
encoder_inputs = Input(shape=(None,), name="encoder_inputs")
encoder_embedding = Embedding(num_encoder_tokens, embedding_dim, mask_zero=True, name="encoder_embedding")(encoder_inputs)
encoder_lstm, state_h, state_c = LSTM(latent_dim, return_state=True, name="encoder_lstm")(encoder_embedding)
encoder_states = [state_h, state_c]

# Decoder: 
decoder_inputs = Input(shape=(None,), name="decoder_inputs")
decoder_embedding = Embedding(num_decoder_tokens, embedding_dim, mask_zero=True, name="decoder_embedding")(decoder_inputs)
decoder_lstm, _, _ = LSTM(latent_dim, return_sequences=True, return_state=True, name="decoder_lstm")(
    decoder_embedding, initial_state=encoder_states)
decoder_outputs = Dense(num_decoder_tokens, activation='softmax', name="decoder_dense")(decoder_lstm)

model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['sparse_categorical_accuracy'])
print("Model built and compiled.\n")

# Data Generator for Large Datasets
def data_generator(encoder_texts, decoder_texts, batch_size, encoder_tokenizer, decoder_tokenizer):
    num_samples = len(encoder_texts)
    while True:
        for offset in range(0, num_samples, batch_size):
            enc_batch = encoder_texts[offset:offset+batch_size]
            dec_batch = decoder_texts[offset:offset+batch_size]
            
            # Tokenize texts to sequences
            enc_sequences = encoder_tokenizer.texts_to_sequences(enc_batch)
            dec_sequences = decoder_tokenizer.texts_to_sequences(dec_batch)
            
            # Pad sequences to fixed lengths
            enc_sequences = pad_sequences(enc_sequences, maxlen=max_encoder_seq_length, padding='post')
            dec_sequences = pad_sequences(dec_sequences, maxlen=max_decoder_seq_length, padding='post')
            
            # Prepare decoder target data (shifted by one timestep)
            dec_target = np.zeros_like(dec_sequences)
            dec_target[:, :-1] = dec_sequences[:, 1:]
            
            # Yield as a tuple of tuples to match output_signature
            yield ((enc_sequences, dec_sequences), dec_target[..., np.newaxis])

# Load Dataset, Prepare Data, and Train Model
if __name__ == "__main__":
    data = pd.read_csv("explanationDataset.csv")
    print("explanationDataset loaded successfully.")
    
    all_encoder_texts = []
    all_decoder_texts = []
    for idx, row in data.iterrows():
        encoder_text = f"{row['context_before']} {row['context_after']} {row['cited_paper_title']} {row['cited_paper_abstract']}"
        all_encoder_texts.append(encoder_text)
        all_decoder_texts.append(row['citation_sentence'])
    
    train_enc, val_enc, train_dec, val_dec = train_test_split(
        all_encoder_texts, all_decoder_texts, test_size=0.2, random_state=42
    )
    print(f"Training samples: {len(train_enc)}, Validation samples: {len(val_enc)}\n")
    
    # Fit tokenizers on the training set
    encoder_tokenizer = Tokenizer(num_words=num_encoder_tokens, oov_token="<OOV>")
    decoder_tokenizer = Tokenizer(num_words=num_decoder_tokens, oov_token="<OOV>")
    encoder_tokenizer.fit_on_texts(train_enc)
    decoder_tokenizer.fit_on_texts(train_dec)
    print("Tokenizers fitted.\n")
    
    # Save tokenizers for later inference demo
    with open("encoder_tokenizer.pkl", "wb") as f:
        pickle.dump(encoder_tokenizer, f)
    with open("decoder_tokenizer.pkl", "wb") as f:
        pickle.dump(decoder_tokenizer, f)
    print("Tokenizers saved as 'encoder_tokenizer.pkl' and 'decoder_tokenizer.pkl'.\n")
    
    batch_size = 64  
    steps_per_epoch = len(train_enc) // batch_size
    validation_steps = len(val_enc) // batch_size
    
    train_dataset = tf.data.Dataset.from_generator(
        lambda: data_generator(train_enc, train_dec, batch_size, encoder_tokenizer, decoder_tokenizer),
        output_signature=(
            (
                tf.TensorSpec(shape=(None, max_encoder_seq_length), dtype=tf.int32),
                tf.TensorSpec(shape=(None, max_decoder_seq_length), dtype=tf.int32)
            ),
            tf.TensorSpec(shape=(None, max_decoder_seq_length, 1), dtype=tf.int32)
        )
    )
    
    val_dataset = tf.data.Dataset.from_generator(
        lambda: data_generator(val_enc, val_dec, batch_size, encoder_tokenizer, decoder_tokenizer),
        output_signature=(
            (
                tf.TensorSpec(shape=(None, max_encoder_seq_length), dtype=tf.int32),
                tf.TensorSpec(shape=(None, max_decoder_seq_length), dtype=tf.int32)
            ),
            tf.TensorSpec(shape=(None, max_decoder_seq_length, 1), dtype=tf.int32)
        )
    )
    print("Datasets created.\n")
    
    epochs = 10  
    print(f"Starting training for {epochs} epochs...")
    early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
    history = model.fit(
        train_dataset,
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
        validation_data=val_dataset,
        validation_steps=validation_steps,
        callbacks=[early_stopping]
    )
    print("Training complete.\n")
    
    # Save the model in the required .keras format
    model.save("citationExplanationModel.keras")
    print("Model saved as 'citationExplanationModel.keras'.")


2025-03-22 06:48:31.014317: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742640511.054019   23890 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742640511.073024   23890 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1742640511.127873   23890 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1742640511.127953   23890 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1742640511.127956   23890 computation_placer.cc:177] computation placer alr

Model built and compiled.

explanationDataset loaded successfully.
Training samples: 80000, Validation samples: 20000

Tokenizers fitted.

Tokenizers saved as 'encoder_tokenizer.pkl' and 'decoder_tokenizer.pkl'.

Datasets created.

Starting training for 10 epochs...
Epoch 1/10


I0000 00:00:1742640530.328889   23987 cuda_dnn.cc:529] Loaded cuDNN version 90800


[1m1250/1250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m293s[0m 231ms/step - loss: 6.8184 - sparse_categorical_accuracy: 0.4684 - val_loss: 5.7668 - val_sparse_categorical_accuracy: 0.8049
Epoch 2/10
[1m1250/1250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m287s[0m 230ms/step - loss: 5.6911 - sparse_categorical_accuracy: 0.8012 - val_loss: 5.3884 - val_sparse_categorical_accuracy: 0.8062
Epoch 3/10
[1m1250/1250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m289s[0m 231ms/step - loss: 5.3060 - sparse_categorical_accuracy: 0.8008 - val_loss: 5.1592 - val_sparse_categorical_accuracy: 0.7274
Epoch 4/10
[1m1250/1250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m290s[0m 232ms/step - loss: 5.0308 - sparse_categorical_accuracy: 0.7747 - val_loss: 4.9947 - val_sparse_categorical_accuracy: 0.7841
Epoch 5/10
[1m1250/1250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m287s[0m 230ms/step - loss: 4.8073 - sparse_categorical_accuracy: 0.7995 - val_loss: 4.8675 - val_sparse_categorica