# 1) Citation Recommendation Demo

In [None]:
import tensorflow as tf
import numpy as np
from transformers import BertTokenizer, TFBertForSequenceClassification
import os
import sys

# Suppress TF logs for clarity
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

print("Loading the citation recommendation model for inference...")

try:
    # Load the saved recommendation model from the local folder.
    model = TFBertForSequenceClassification.from_pretrained("citationRecommendationModel", local_files_only=True)
    print("Model loaded successfully.\n")
except Exception as e:
    print("Error loading model:", e)
    sys.exit(1)

print("Initializing BERT tokenizer...")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
print("Tokenizer initialized.\n")

# Sample input for inference
print("Preparing sample input for inference...")
citing_sentence = "The study provides novel insights into the effects of climate change on urban areas."
cited_paper_abstract = "This paper explores the impact of environmental changes on city infrastructure using innovative methodologies."
print("Sample citing sentence:")
print("  ", citing_sentence)
print("Sample cited paper abstract:")
print("  ", cited_paper_abstract, "\n")

print("Encoding sample input...")

# Encode the input pair
encoded_inputs = tokenizer(
    citing_sentence,
    cited_paper_abstract,
    add_special_tokens=True,
    max_length=128,
    padding='max_length',
    truncation=True,
    return_tensors='tf'
)
print("Input encoded:")
print(" - input_ids shape:", encoded_inputs['input_ids'].shape)
print(" - attention_mask shape:", encoded_inputs['attention_mask'].shape, "\n")

# Model prediction
print("Running model prediction on the sample input...")
outputs = model(encoded_inputs)
logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
probs = tf.nn.softmax(logits, axis=-1).numpy()
print("Predicted probabilities:", probs)
predicted_label = np.argmax(probs, axis=1)
print("Predicted label:", predicted_label)

# Explain label 
if predicted_label[0] == 1:
    print("Interpretation: Label 1 means that the reference is recommended for citation.")
else:
    print("Interpretation: Label 0 means that the reference is not recommended for citation.")

Loading the citation recommendation model for inference...


Some layers from the model checkpoint at citationRecommendationModel were not used when initializing TFBertForSequenceClassification: ['dropout_37']
- This IS expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFBertForSequenceClassification were initialized from the model checkpoint at citationRecommendationModel.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForSequenceClassification for predictions without further training.


Model loaded successfully.

Initializing BERT tokenizer...
Tokenizer initialized.

Preparing sample input for inference...
Sample citing sentence:
   The study provides novel insights into the effects of climate change on urban areas.
Sample cited paper abstract:
   This paper explores the impact of environmental changes on city infrastructure using innovative methodologies. 

Encoding sample input...
Input encoded:
 - input_ids shape: (1, 128)
 - attention_mask shape: (1, 128) 

Running model prediction on the sample input...
Predicted probabilities: [[0.50324184 0.49675813]]
Predicted label: [0]
Interpretation: Label 0 means that the reference is not recommended for citation.


# 2) Citation Explaination Demo

In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.preprocessing.sequence import pad_sequences
import pickle

# Constants 
latent_dim = 256
max_encoder_seq_length = 500
max_decoder_seq_length = 100

# Load citation explanation model
print("Loading the citation explanation model for inference...")
model = tf.keras.models.load_model("citationExplanationModel.keras")
print("Model loaded successfully.\n")

# Rebuild the Inference Encoder Model
print("Building the inference encoder model...")
encoder_inputs = model.input[0]
encoder_embedding = model.get_layer("encoder_embedding")(encoder_inputs)
encoder_outputs, state_h, state_c = model.get_layer("encoder_lstm")(encoder_embedding)
encoder_states = [state_h, state_c]
encoder_model = tf.keras.Model(encoder_inputs, encoder_states)
print("Inference encoder model built.\n")

# Rebuild the Inference Decoder Model
print("Building the inference decoder model...")
decoder_embedding_layer = model.get_layer("decoder_embedding")
decoder_lstm_layer = model.get_layer("decoder_lstm")
decoder_dense_layer = model.get_layer("decoder_dense")

decoder_state_input_h = tf.keras.Input(shape=(latent_dim,), name="decoder_state_input_h")
decoder_state_input_c = tf.keras.Input(shape=(latent_dim,), name="decoder_state_input_c")
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]

decoder_input_single = tf.keras.Input(shape=(None,), name="decoder_input_single")
decoder_embedded = decoder_embedding_layer(decoder_input_single)
decoder_outputs, state_h, state_c = decoder_lstm_layer(decoder_embedded, initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense_layer(decoder_outputs)

decoder_model = tf.keras.Model(
    [decoder_input_single] + decoder_states_inputs,
    [decoder_outputs] + decoder_states
)
print("Inference decoder model built.\n")

# Load Tokenizers
print("Loading tokenizers...")
with open("encoder_tokenizer.pkl", "rb") as f:
    encoder_tokenizer = pickle.load(f)
with open("decoder_tokenizer.pkl", "rb") as f:
    decoder_tokenizer = pickle.load(f)
print("Tokenizers loaded.\n")

# Decoding Function 
def decode_sequence(input_seq, max_decoder_seq_length=100, temperature=0.7):
    states_value = encoder_model.predict(input_seq)
    start_token = decoder_tokenizer.word_index.get("<start>")
    if start_token is None:
        print("Warning: '<start>' token not found. Using fallback value 1.")
        start_token = 1
    
    # Initialize target sequence 
    target_seq = np.array([[start_token]])
    decoded_sentence = ""
    
    for i in range(max_decoder_seq_length):
        output_tokens, h, c = decoder_model.predict([target_seq] + states_value)
        logits = output_tokens[0, -1, :]
        logits = logits / temperature
        exp_logits = np.exp(logits - np.max(logits))
        probs = exp_logits / np.sum(exp_logits)
        sampled_token_index = np.random.choice(len(probs), p=probs)
        sampled_word = decoder_tokenizer.index_word.get(sampled_token_index, "")
        decoded_sentence += " " + sampled_word
        
        if sampled_word == "<end>":
            break
        
        target_seq = np.array([[sampled_token_index]])
        states_value = [h, c]
    else:
        print("Warning: Maximum decoder iterations reached. The generated sequence may be incomplete.")
    
    return decoded_sentence.strip()

# Run Inference
print("Preparing sample input for citation explanation inference...")
sample_input_text = (
    "The initial observations indicate a gradual increase in star brightness. "
    "Subsequent measurements confirmed periodic variations. "
    "Discovery of Exoplanet X. "
    "The paper details spectroscopic analysis and methods used for identifying exoplanets."
)
print("Sample input text:")
print(" ", sample_input_text, "\n")

# Convert sample text
encoder_seq = encoder_tokenizer.texts_to_sequences([sample_input_text])
encoder_seq = pad_sequences(encoder_seq, maxlen=max_encoder_seq_length, padding='post')

print("Decoding citation explanation for the sample input using temperature-based sampling...")
decoded_explanation = decode_sequence(encoder_seq, max_decoder_seq_length=100, temperature=0.7)
print("\nDecoded citation explanation:")
print(" ", decoded_explanation)

Loading the citation explanation model for inference...
Model loaded successfully.

Building the inference encoder model...
Inference encoder model built.

Building the inference decoder model...
Inference decoder model built.

Loading tokenizers...
Tokenizers loaded.

Preparing sample input for citation explanation inference...
Sample input text:
  The initial observations indicate a gradual increase in star brightness. Subsequent measurements confirmed periodic variations. Discovery of Exoplanet X. The paper details spectroscopic analysis and methods used for identifying exoplanets. 

Decoding citation explanation for the sample input using temperature-based sampling...
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 72ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 71ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━