# Testing the generate function

## Import libraries

In [12]:
import torch
import gensim
from gensim import corpora
from torch.nn.utils.rnn import pad_sequence
from transformers import T5Tokenizer, T5ForConditionalGeneration
from topical_decoding.utils.newts_utils import read_train, read_test
from transformers import LogitsProcessor

## Import the model and tokenizer

In [13]:
T5_base_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
T5_base_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## Import the NEWTS dataset

In [14]:
# Load the NEWTS training set
newts_train = read_train("../../NEWTS/NEWTS_train_2400.csv")
newts_test = read_test("../../NEWTS/NEWTS_test_600.csv")

In [15]:
newts_train.head(2)

Unnamed: 0,AssignmentId,docId,article,tid1,tid2,words1,words2,phrases1,phrases2,sentences1,sentences2,summary1,summary2
0,3EG49X351WE8VLP4S0TIYZF3V476X2,094372190d52acbce61a73ec16b2217d1a60276f,The president of the World Bank on Saturday wa...,175,110,"house, committee, congress, senate, republican...","billion, figures, economy, global, growth, eco...","senate and congress, congressional pressure, y...","economic growth, global growth, billion dollar...","This topic is about the senate and congress, c...",This topic is about economic growth involving ...,The leader of the World Bank urged the US to t...,The US economy will be a driving factor in the...
1,3DOCMVPBTPGBQCHSPBSQ28AROFXNNI,bc733fb96fd73496e10fcff3c640ee11c4df3d7a,By . Nick Harris . Manchester City are the bes...,152,217,"united, manchester, liverpool, chelsea, league...","club, team, season, players, england, football...","Manchester United's manager, Premier League, t...","football league, the team's fans, football pla...",This topic is about Manchester United's manage...,This topic is about a football league having a...,Premier league is the most paying football lea...,Manchester city players earn the largest amoun...


In [16]:
print(newts_train.shape)

(2400, 13)


In [17]:
# Select an example article by its index.
example_article = newts_train.iloc[23]

# Print the article
print(example_article["article"])

A charge of making a false bomb threat has been dropped against a man who carried a backpack containing a rice cooker near a crowd marking the first anniversary of the Boston Marathon bombings in April, prosecutors said Wednesday. Investigators dropped the charge because they say the suspect, Kevin Edson, 25, did not communicate an "overt threat that an incendiary device would be detonated," Jake Wark, a spokesman for the Suffolk County district attorney, told CNN. Edson was arrested after carrying the backpack with a rice cooker near the finish line on Boylston Street in Boston while survivors of the 2013 bombing were commemorating its anniversary on April 15. In the 2013 attack, two pressure-cooker bombs exploded, killing three people and wounding at least 264 others. A barefoot Edson, carrying a backpack and wearing black clothes with a veil and hat covering his face, screamed and yelled near the end of the anniversary event on Boylston Street, drawing officers' attention, police sa

## Generate a summary for articles in the NEWTS training set

In [18]:
task_prefix = "summarize: "
# flan_prompt_prefix = "Briefly summarize this paragraph: "

min_idx = 0
max_idx = 2

# Selecting sentences and their corresponding topic IDs
articles = newts_test["article"][min_idx : max_idx + 1].tolist()
topic1_ids = newts_test["tid1"][min_idx : max_idx + 1].tolist()
topic2_ids = newts_test["tid2"][min_idx : max_idx + 1].tolist()

# Structure to hold both tokenized articles and their topics
encoded_data = []

# Define max length to pad
max_length = 0
for sentence in articles:
    max_length = max(max_length, len(T5_base_tokenizer.encode(task_prefix + sentence)))

for sentence, tid1, tid2 in zip(articles, topic1_ids, topic2_ids):
    # Tokenize the sentence with the task prefix
    encoded_article = T5_base_tokenizer.encode_plus(
        task_prefix + sentence,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )

    # Add the tokenized article and its topic id to the structure
    encoded_data.append(
        {
            "input_ids": encoded_article["input_ids"],
            "attention_mask": encoded_article["attention_mask"],
            "tid1": tid1,
            "tid2": tid2,
        }
    )

# Convert list of dictionaries to a single dictionary with batched tensors
batched_input_ids = torch.cat([item["input_ids"] for item in encoded_data], dim=0)
batched_attention_mask = torch.cat(
    [item["attention_mask"] for item in encoded_data], dim=0
)

Token indices sequence length is longer than the specified maximum sequence length for this model (829 > 512). Running this sequence through the model will result in indexing errors


## Normal setup using generate function from transformers library

In [19]:
output_sequences = T5_base_model.generate(
    input_ids=batched_input_ids,
    attention_mask=batched_attention_mask,
    do_sample=False,  # disable sampling to test if batching affects output
    max_length=100,
    min_length=30,
    # early_stopping=False,
    num_beams=1,
    # no_repeat_ngram_size=3,
    num_return_sequences=1,
    # top_k=0,
    # bos_token_id=model.config.bos_token_id,
    # eos_token_id=model.config.eos_token_id,
    # length_penalty=2.0,
)

In [20]:
print("Output:\n" + 100 * "-")
for idx, output_sequence in enumerate(output_sequences):
    output = T5_base_tokenizer.decode(output_sequence, skip_special_tokens=True)
    print("Input: {}".format(articles[idx]))
    print("Generated summary {}: {}".format(idx + 1, output))
    print(100 * "-")

Output:
----------------------------------------------------------------------------------------------------
Input: An American tourist has spent the night stranded in the Blue Mountains, west of Sydney, after she fell 15 metres off a cliff while bushwalking. The 25-year-old from the US state of Wisconsin was walking near Pulpit Rock, Mount Victoria with a group of friends on Friday when she slipped from a track. She fell about 15 metres and rolled a further 20 metres down a steep slope, police say. Rescue teams escort a 25-year-old US tourist after she spent the night stranded in the Blue Mountains after falling 15 metres off a cliff . Rescue crews found the woman suffering a possible broken ankle and broken ribs. She remained with an ambulance team overnight due to low light and foggy weather conditions. Blue Mountains Police Rescue Sergeant Dallas Atkinson told ABC a helicopter was deployed to finish the rescue this morning. 'After she fell yesterday she was accessed a short time la

## take topic into account during generation
#### Option 1: 
discard/block words based on their topic affiliation (expect this to be bad)
cannot write coherent/grammatical sentences if blocking too many words
block words from second topic
topics could overlap (could exclude the overlapping)
binary mask over the vocabulary to multiply probability vector with

#### Option 2:
temperature scaling if needed (to make the factor important)
multiply tokens that belong to the target topic with a factor > 1 
rescale/normalize  after multiplying
Alternative
select top k tokens and then over-sample tokens from the topic. 
Start out with uniform distribution over top-k tokens. Multiply each token belonging to the target topic with factor (e.g. 2-5) and then rescale entire probability distribution accordingly.


## Import LDA model

In [21]:
def load_lda_model(model_address: str):
    # Loads the LDA model and dictionary from the specified address.
    try:
        lda = gensim.models.ldamodel.LdaModel.load(
            model_address + "lda.model", mmap="r"
        )
        dictionary = corpora.Dictionary.load(model_address + "dictionary.dic", mmap="r")
        return lda, dictionary
    except Exception as e:
        print(f"Error loading model or dictionary: {e}")
        return None, None

In [22]:
model_address = "../../LDA_250/"
lda, dictionary = load_lda_model(model_address)
# Warning "WARNING:root:random_state not set so using default value" is inconsequential for inference



In [23]:
def get_top_topic_words(lda, topic_id, num_words=100):
    """
    Returns the top words for a given topic from the LDA model.

    :param lda: The LDA model.
    :param topic_id: The topic number to get the top words for.
    :param num_words: The number of top words to return.
    :return: A list of top words for the specified topic.
    """
    try:
        # Get the specified topic. Note: num_words here limits the number of words returned for the topic.
        topic_words = lda.show_topic(topic_id, num_words)

        # Extract just the words
        top_words = [word for word, prob in topic_words]
        return top_words
    except Exception as e:
        print(f"Error in getting top topic words: {e}")
        return []

In [24]:
def get_topic_tokens(lda, topic_id, num_words, tokenizer):
    # Get the top words for the specified topic
    top_words = get_top_topic_words(lda, topic_id, num_words)

    # Initialize a set to store unique token IDs
    token_ids_set = set()

    # Tokenize each word and add its token IDs to the set
    for word in top_words:
        # Tokenize the word
        tokens = tokenizer.tokenize(word)

        # Convert tokens to their IDs and add to the set
        token_ids_set.update(tokenizer.convert_tokens_to_ids(tokens))

    return list(token_ids_set)

## Create custom model based on logits_process.py

In [25]:
class TopicTokenExclusionProcessor(LogitsProcessor):
    def __init__(self, excluded_token_ids, factor: float):
        self.excluded_token_ids = set(excluded_token_ids)
        self.factor = factor

    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        # Create a mask for tokens to be excluded
        exclusion_mask = torch.tensor(
            [
                token_id in self.excluded_token_ids
                for token_id in range(scores.shape[1])
            ],
            dtype=torch.bool,
        )
        # Apply the factor to the scores of the excluded tokens
        scores[:, exclusion_mask] *= self.factor
        return scores

In [26]:
class CustomT5Model(T5ForConditionalGeneration):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def generate(
        self,
        input_ids,
        attention_mask=None,
        excluded_token_ids=None,
        exclusion_factor=1.0,
        **kwargs
    ):
        # Create an instance of TopicTokenExclusionProcessor
        topic_processor = TopicTokenExclusionProcessor(
            excluded_token_ids=excluded_token_ids, factor=exclusion_factor
        )

        # Call the original generate method with the custom logits processor
        output_sequences = super().generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            logits_processor=[topic_processor],
            **kwargs
        )

        return output_sequences

In [27]:
num_words = 100
tid1 = 218
tid2 = 62

In [28]:
# Assuming you have an LDA model, a tokenizer, and a topic ID (tid1)
topic_tokens = get_topic_tokens(
    lda, tid1, num_words, T5_base_tokenizer
) + get_topic_tokens(lda, tid2, num_words, T5_base_tokenizer)

# Aggregate and pad input_ids and attention_mask from encoded_data
input_ids = pad_sequence(
    [item["input_ids"].squeeze(0) for item in encoded_data],
    batch_first=True,
    padding_value=T5_base_tokenizer.pad_token_id,
)
attention_mask = pad_sequence(
    [item["attention_mask"].squeeze(0) for item in encoded_data],
    batch_first=True,
    padding_value=0,
)

# Use the custom model for generation
custom_model = CustomT5Model.from_pretrained("google/flan-t5-base")
output_sequences = custom_model.generate(
    input_ids=input_ids,
    attention_mask=attention_mask,
    excluded_token_ids=topic_tokens,
    exclusion_factor=0.2,  # Factor to reduce the likelihood of topic tokens
    max_length=100,
    min_length=30,
    num_beams=1,
    num_return_sequences=1,
)

In [29]:
print("Output:\n" + 100 * "-")
for idx, output_sequence in enumerate(output_sequences):
    output = T5_base_tokenizer.decode(output_sequence, skip_special_tokens=True)
    print("Input: {}".format(articles[idx]))
    print("Generated summary {}: {}".format(idx + 1, output))
    print(100 * "-")

Output:
----------------------------------------------------------------------------------------------------
Input: An American tourist has spent the night stranded in the Blue Mountains, west of Sydney, after she fell 15 metres off a cliff while bushwalking. The 25-year-old from the US state of Wisconsin was walking near Pulpit Rock, Mount Victoria with a group of friends on Friday when she slipped from a track. She fell about 15 metres and rolled a further 20 metres down a steep slope, police say. Rescue teams escort a 25-year-old US tourist after she spent the night stranded in the Blue Mountains after falling 15 metres off a cliff . Rescue crews found the woman suffering a possible broken ankle and broken ribs. She remained with an ambulance team overnight due to low light and foggy weather conditions. Blue Mountains Police Rescue Sergeant Dallas Atkinson told ABC a helicopter was deployed to finish the rescue this morning. 'After she fell yesterday she was accessed a short time la

In [30]:
print(type(encoded_data))
print(type(encoded_data[0]))
print(type(encoded_data[0]["input_ids"]))
print(encoded_data[0]["input_ids"].shape)
print(type(encoded_data[0]["attention_mask"]))
print(type(encoded_data[0]["tid1"]))
print(type(encoded_data[0]["tid2"]))

<class 'list'>
<class 'dict'>
<class 'torch.Tensor'>
torch.Size([1, 829])
<class 'torch.Tensor'>
<class 'int'>
<class 'int'>


In [31]:
output_sequences_iterative = []
for idx, item in enumerate(encoded_data):
    # Get the input_ids and attention_mask for the current item
    input_ids = item["input_ids"]
    attention_mask = item["attention_mask"]
    topic_tokens = get_topic_tokens(
        lda, item["tid1"], num_words, T5_base_tokenizer
    ) + get_topic_tokens(lda, item["tid2"], num_words, T5_base_tokenizer)

    # Use the custom model for generation
    output_sequences = custom_model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        excluded_token_ids=topic_tokens,
        exclusion_factor=3.0,  # Factor to reduce the likelihood of topic tokens
        max_length=100,
        min_length=30,
        num_beams=1,
        num_return_sequences=1,
    )

    # Append the output sequence to the list
    output_sequences_iterative.append(output_sequences)

In [32]:
# print(output_sequences_iterative)
print(len(output_sequences_iterative))
print(type(output_sequences_iterative))

3
<class 'list'>


In [33]:
print("Output:\n" + 100 * "-")
for idx, output_tensor in enumerate(output_sequences_iterative):
    # Ensure output_tensor is squeezed to remove any extra dimensions
    output_tensor = output_tensor.squeeze()

    # Check if the tensor is two-dimensional (batch_size, sequence_length)
    if output_tensor.dim() == 2:
        # Decode each sequence in the tensor
        for sequence in output_tensor:
            output = T5_base_tokenizer.decode(sequence, skip_special_tokens=True)
            print("Input: {}".format(articles[idx]))
            print("Generated summary {}: {}".format(idx + 1, output))
            print(100 * "-")
    else:
        # Decode the single sequence
        output = T5_base_tokenizer.decode(output_tensor, skip_special_tokens=True)
        print("Input: {}".format(articles[idx]))
        print("Generated summary {}: {}".format(idx + 1, output))
        print(100 * "-")

Output:
----------------------------------------------------------------------------------------------------
Input: An American tourist has spent the night stranded in the Blue Mountains, west of Sydney, after she fell 15 metres off a cliff while bushwalking. The 25-year-old from the US state of Wisconsin was walking near Pulpit Rock, Mount Victoria with a group of friends on Friday when she slipped from a track. She fell about 15 metres and rolled a further 20 metres down a steep slope, police say. Rescue teams escort a 25-year-old US tourist after she spent the night stranded in the Blue Mountains after falling 15 metres off a cliff . Rescue crews found the woman suffering a possible broken ankle and broken ribs. She remained with an ambulance team overnight due to low light and foggy weather conditions. Blue Mountains Police Rescue Sergeant Dallas Atkinson told ABC a helicopter was deployed to finish the rescue this morning. 'After she fell yesterday she was accessed a short time la

In [34]:
num_words = 100
max_length = 100
num_beams = 1

for i, (sentence, tid1, tid2) in enumerate(zip(articles, topic1_ids, topic2_ids)):
    # Check for valid topic IDs
    if tid1 is None or tid2 is None:
        raise ValueError(f"Invalid topic IDs for article index {i}")

    # Tokenize the sentence with the task prefix
    encoded_article = T5_base_tokenizer.encode_plus(
        task_prefix + sentence,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )

    # Get topic tokens for tid1 and tid2
    topic_tokens = get_topic_tokens(
        lda, tid1, num_words, T5_base_tokenizer
    ) + get_topic_tokens(lda, tid2, num_words, T5_base_tokenizer)

    # Generate output with topic token exclusion
    output_sequences = custom_model.generate(
        input_ids=encoded_article["input_ids"],
        attention_mask=encoded_article["attention_mask"],
        excluded_token_ids=topic_tokens,
        exclusion_factor=1,  # Factor to reduce the likelihood of topic tokens
        max_length=max_length,
        min_length=30,
        num_beams=num_beams,
        num_return_sequences=1,
    )

In [35]:
def print_output_sequences(tokenizer, output_sequences):
    """
    Decodes and prints the output sequences generated by the model.

    :param tokenizer: The tokenizer used to encode the input sequences.
    :param output_sequences: A tensor containing the generated output sequences.
    """
    # Check if output_sequences is a tensor
    if not isinstance(output_sequences, torch.Tensor):
        raise ValueError("output_sequences must be a PyTorch Tensor.")

    # Loop through each sequence in the output
    for i, sequence in enumerate(output_sequences):
        # Decode the tensor to a string
        decoded_sequence = tokenizer.decode(
            sequence, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )

        # Print the decoded string
        print(f"Output Sequence {i + 1}: {decoded_sequence}\n")

In [36]:
# Example usage
# Assuming output_sequences is the tensor output from the model
print_output_sequences(T5_base_tokenizer, output_sequences)

Output Sequence 1: ISIS will be destroyed," he said. "If our Imam, our Supreme Leader orders us, we will be destroyed." The Basij is a militia that has been fighting ISIS for years.

