# Testing the generate function

## Import libraries

In [1]:
import torch
import transformers
import gensim
from gensim import corpora
from topical_decoding.utils.newts_utils import read_train, read_test
from transformers import T5Tokenizer, T5ForConditionalGeneration

## Import the model and tokenizer

In [2]:
T5_base_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
T5_base_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base")

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## Import the NEWTS dataset

In [8]:
# Load the NEWTS training set
newts_test = read_test("/data/NEWTS/NEWTS_test_600.csv")
newts_train = read_train("/data/NEWTS/NEWTS_train_2400.csv")

FileNotFoundError: [Errno 2] No such file or directory: '/data/NEWTS/NEWTS_test_600.csv'

In [4]:
newts_train.head(2)

Unnamed: 0,AssignmentId,docId,article,tid1,tid2,words1,words2,phrases1,phrases2,sentences1,sentences2,summary1,summary2
0,3EG49X351WE8VLP4S0TIYZF3V476X2,094372190d52acbce61a73ec16b2217d1a60276f,The president of the World Bank on Saturday wa...,175,110,"house, committee, congress, senate, republican...","billion, figures, economy, global, growth, eco...","senate and congress, congressional pressure, y...","economic growth, global growth, billion dollar...","This topic is about the senate and congress, c...",This topic is about economic growth involving ...,The leader of the World Bank urged the US to t...,The US economy will be a driving factor in the...
1,3DOCMVPBTPGBQCHSPBSQ28AROFXNNI,bc733fb96fd73496e10fcff3c640ee11c4df3d7a,By . Nick Harris . Manchester City are the bes...,152,217,"united, manchester, liverpool, chelsea, league...","club, team, season, players, england, football...","Manchester United's manager, Premier League, t...","football league, the team's fans, football pla...",This topic is about Manchester United's manage...,This topic is about a football league having a...,Premier league is the most paying football lea...,Manchester city players earn the largest amoun...


In [5]:
print(newts_train.shape)

(2400, 13)


In [6]:
# Select an example article by its index.
example_article = newts_train.iloc[23]

# Print the article
print(example_article["article"])

A charge of making a false bomb threat has been dropped against a man who carried a backpack containing a rice cooker near a crowd marking the first anniversary of the Boston Marathon bombings in April, prosecutors said Wednesday. Investigators dropped the charge because they say the suspect, Kevin Edson, 25, did not communicate an "overt threat that an incendiary device would be detonated," Jake Wark, a spokesman for the Suffolk County district attorney, told CNN. Edson was arrested after carrying the backpack with a rice cooker near the finish line on Boylston Street in Boston while survivors of the 2013 bombing were commemorating its anniversary on April 15. In the 2013 attack, two pressure-cooker bombs exploded, killing three people and wounding at least 264 others. A barefoot Edson, carrying a backpack and wearing black clothes with a veil and hat covering his face, screamed and yelled near the end of the anniversary event on Boylston Street, drawing officers' attention, police sa

## Generate a summary for articles in the NEWTS training set

In [7]:
task_prefix = "summarize: "
# flan_prompt_prefix = "Briefly summarize this paragraph: "

min_idx = 0
max_idx = 3
sentences = newts_test["article"][min_idx:max_idx].tolist()
inputs = T5_base_tokenizer(
    [task_prefix + sentence for sentence in sentences],
    return_tensors="pt",
    padding=True,
)

Token indices sequence length is longer than the specified maximum sequence length for this model (829 > 512). Running this sequence through the model will result in indexing errors


## Normal setup using generate function from transformers library

In [8]:
output_sequences = T5_base_model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    do_sample=False,  # disable sampling to test if batching affects output
    max_length=100,
    min_length=30,
    # early_stopping=False,
    num_beams=1,
    # no_repeat_ngram_size=3,
    num_return_sequences=1,
    # top_k=0,
    # bos_token_id=model.config.bos_token_id,
    # eos_token_id=model.config.eos_token_id,
    # length_penalty=2.0,
)

In [9]:
print("Output:\n" + 100 * "-")
for idx, output_sequence in enumerate(output_sequences):
    output = T5_base_tokenizer.decode(output_sequence, skip_special_tokens=True)
    print("Input: {}".format(sentences[idx]))
    print("Generated summary {}: {}".format(idx + 1, output))
    print(100 * "-")

Output:
----------------------------------------------------------------------------------------------------
Input: An American tourist has spent the night stranded in the Blue Mountains, west of Sydney, after she fell 15 metres off a cliff while bushwalking. The 25-year-old from the US state of Wisconsin was walking near Pulpit Rock, Mount Victoria with a group of friends on Friday when she slipped from a track. She fell about 15 metres and rolled a further 20 metres down a steep slope, police say. Rescue teams escort a 25-year-old US tourist after she spent the night stranded in the Blue Mountains after falling 15 metres off a cliff . Rescue crews found the woman suffering a possible broken ankle and broken ribs. She remained with an ambulance team overnight due to low light and foggy weather conditions. Blue Mountains Police Rescue Sergeant Dallas Atkinson told ABC a helicopter was deployed to finish the rescue this morning. 'After she fell yesterday she was accessed a short time la

## Seperate generate function via Subclassing and Overriding

In [10]:
class CustomT5Model(T5ForConditionalGeneration):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Initialize additional attributes if necessary

    def generate(self, input_ids, attention_mask=None, **kwargs):
        # Dummy implementation: Call the original generate method
        output_sequences = super().generate(
            input_ids=input_ids, attention_mask=attention_mask, **kwargs
        )

        # Dummy token reweighting (multiplying by 1, which does nothing)
        # This is where you'd integrate the LDA model logic in the future

        return output_sequences

In [11]:
# Example Usage
custom_model = CustomT5Model.from_pretrained("google/flan-t5-base")

# Assuming 'inputs' is a dictionary with 'input_ids' and 'attention_mask'
output_sequences = custom_model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    max_length=100,
    min_length=30,
    num_beams=1,
    num_return_sequences=1,
)

In [12]:
print("Output from custom model:\n" + 100 * "-")
for idx, output_sequence in enumerate(output_sequences):
    output = T5_base_tokenizer.decode(output_sequence, skip_special_tokens=True)
    print("Input: {}".format(sentences[idx]))
    print("Generated summary {}: {}".format(idx + 1, output))
    print(100 * "-")

Output from custom model:
----------------------------------------------------------------------------------------------------
Input: An American tourist has spent the night stranded in the Blue Mountains, west of Sydney, after she fell 15 metres off a cliff while bushwalking. The 25-year-old from the US state of Wisconsin was walking near Pulpit Rock, Mount Victoria with a group of friends on Friday when she slipped from a track. She fell about 15 metres and rolled a further 20 metres down a steep slope, police say. Rescue teams escort a 25-year-old US tourist after she spent the night stranded in the Blue Mountains after falling 15 metres off a cliff . Rescue crews found the woman suffering a possible broken ankle and broken ribs. She remained with an ambulance team overnight due to low light and foggy weather conditions. Blue Mountains Police Rescue Sergeant Dallas Atkinson told ABC a helicopter was deployed to finish the rescue this morning. 'After she fell yesterday she was access

Works perfectly fine

## Alter generate function and take topic into account during generation
#### Option 1: 
discard/block words based on their topic affiliation (expect this to be bad)
cannot write coherent/grammatical sentences if blocking too many words
block words from second topic
topics could overlap (could exclude the overlapping)
binary mask over the vocabulary to multiply probability vector with

#### Option 2:
temperature scaling if needed (to make the factor important)
multiply tokens that belong to the target topic with a factor > 1 
rescale/normalize  after multiplying
Alternative
select top k tokens and then over-sample tokens from the topic. 
Start out with uniform distribution over top-k tokens. Multiply each token belonging to the target topic with factor (e.g. 2-5) and then rescale entire probability distribution accordingly.


## Import LDA model

In [13]:
def load_lda_model(model_address: str):
    # Loads the LDA model and dictionary from the specified address.
    try:
        lda = gensim.models.ldamodel.LdaModel.load(
            model_address + "lda.model", mmap="r"
        )
        dictionary = corpora.Dictionary.load(model_address + "dictionary.dic", mmap="r")
        return lda, dictionary
    except Exception as e:
        print(f"Error loading model or dictionary: {e}")
        return None, None

In [15]:
model_address = "../../LDA_250/"
lda, dictionary = load_lda_model(model_address)
# Warning "WARNING:root:random_state not set so using default value" is inconsequential
# for inference



In [16]:
def get_top_topic_words(lda, topic_number, num_words=100):
    """
    Returns the top words for a given topic from the LDA model.

    :param lda: The LDA model.
    :param topic_number: The topic number to get the top words for.
    :param num_words: The number of top words to return.
    :return: A list of top words for the specified topic.
    """
    try:
        lda_topics = lda.show_topics(formatted=True)
        _, topic_words = lda_topics[topic_number]
        words_with_probs = topic_words.split(" + ")

        # Extracting just the words
        top_words = [
            word_prob.split("*")[1].strip()
            for word_prob in words_with_probs[:num_words]
        ]
        return top_words
    except Exception as e:
        print(f"Error in getting top topic words: {e}")
        return []

## Create custom model

In [17]:
class CustomT5Model(T5ForConditionalGeneration):
    def __init__(self, tokenizer, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.tokenizer = tokenizer

    def adjust_logits_during_generation(
        self, logits, block_token_ids, cur_len, max_length
    ):
        # Modify this method to adjust logits according to your needs
        if cur_len < max_length:
            for token_id in block_token_ids:
                logits[:, token_id] = -float("inf")
        return logits

    def generate(
        self,
        input_ids,
        attention_mask=None,
        block_topic=-1,
        lda_model=None,
        dictionary=None,
        **kwargs
    ):
        # Checking if a topic should be blocked
        block_token_ids = []
        if block_topic >= 0 and lda_model is not None and dictionary is not None:
            top_words = get_top_topic_words(lda_model, block_topic)
            block_token_ids = [
                self.tokenizer.encode(word, add_special_tokens=False)[0]
                for word in top_words
            ]

            # Modify logits_processor to block the selected words
            logits_processor = transformers.generation_logits_process.LogitsProcessorList(
                [
                    transformers.generation_logits_process.HammingDiversityLogitsProcessor(
                        group_size=1
                    ),
                    transformers.generation_logits_process.InfNanRemoveLogitsProcessor(),
                    transformers.generation_logits_process.RepetitionPenaltyLogitsProcessor(
                        penalty=1.2
                    ),
                    transformers.generation_logits_process.ForbiddenTokensLogitsProcessor(
                        forbidden_tokens=block_token_ids
                    ),
                ]
            )
            kwargs["logits_processor"] = logits_processor

        output_sequences = super().generate(
            input_ids=input_ids, attention_mask=attention_mask, **kwargs
        )

        return output_sequences

In [18]:
class CustomT5Model(T5ForConditionalGeneration):
    def __init__(self, tokenizer, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.tokenizer = tokenizer

    def generate(
        self,
        input_ids,
        attention_mask=None,
        block_topic=-1,
        lda_model=None,
        dictionary=None,
        **kwargs
    ):
        # If a topic should be blocked, prepare the block_token_ids
        block_token_ids = []
        if block_topic >= 0 and lda_model is not None and dictionary is not None:
            top_words = get_top_topic_words(lda_model, block_topic)
            block_token_ids = [
                self.tokenizer.encode(word, add_special_tokens=False)[0]
                for word in top_words
            ]

        # Custom logits manipulation
        def custom_logits_processor(logits, input_ids):
            if block_token_ids:
                logits[:, block_token_ids] = -float("inf")
            return logits

        # Override the standard logits_processor in the generation loop
        original_logits_processor = self._get_logits_processor(**kwargs)

        def combined_logits_processor(logits, input_ids):
            logits = original_logits_processor(logits, input_ids)
            return custom_logits_processor(logits, input_ids)

        self._get_logits_processor = lambda **kwargs: combined_logits_processor

        # Generate output sequences
        output_sequences = super().generate(
            input_ids=input_ids, attention_mask=attention_mask, **kwargs
        )

        # Reset logits processor to its original state
        self._get_logits_processor = lambda **kwargs: original_logits_processor

        return output_sequences


# Usage remains the same

In [19]:
T5_base_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
custom_model = CustomT5Model(
    tokenizer=T5_base_tokenizer,
    config=T5ForConditionalGeneration.from_pretrained("google/flan-t5-base").config,
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [20]:
# Assuming 'inputs' is a dictionary with 'input_ids' and 'attention_mask'
custom_output_sequences = custom_model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    block_topic=0,  # Topic to block
    lda_model=lda,  # Your LDA model
    dictionary=dictionary,  # Your dictionary
    max_length=100,
    min_length=30,
    num_beams=1,
    num_return_sequences=1,
)

TypeError: GenerationMixin._get_logits_processor() got an unexpected keyword argument 'max_length'

In [None]:
print("Output from custom model:\n" + 100 * "-")
for idx, output_sequence in enumerate(custom_output_sequences):
    output = T5_base_tokenizer.decode(output_sequence, skip_special_tokens=True)
    print("Input: {}".format(sentences[idx]))
    print("Generated summary {}: {}".format(idx + 1, output))
    print(100 * "-")

In [None]:
original_summaries = output_sequences.tolist()

In [None]:
print(original_summaries)

In [None]:
print("Output:\n" + 100 * "-")
for idx, output_sequence in enumerate(original_summaries):
    output = T5_base_tokenizer.decode(output_sequence, skip_special_tokens=True)
    print("Input: {}".format(sentences[idx]))
    print("Generated summary {}: {}".format(idx + 1, output))
    print(100 * "-")