# T5 Fine-tuned model inference

In [37]:
pip -qqq install --upgrade accelerate transformers auto-gptq optimum rouge_score bert-score datasets torch torchvision

In [38]:
# General libraries
import os
import shutil
import random
import re
import time

# Data handling libraries
import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict

# Transformers libraries
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Import torch
import torch

# NLP and evaluation libraries
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.meteor_score import single_meteor_score
from bert_score import BERTScorer
import spacy
nlp = spacy.load("en_core_web_sm")
import nltk
nltk.download('wordnet')

# Logging for the pipeline
import logging

# Google drive
from google.colab import drive

# Language libraries
import nltk
from nltk.tokenize import sent_tokenize
nltk.download('punkt')
nltk.download('punkt_tab')

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

# Load the dataset

Connect to google drive to access kaggle.json

In [39]:
drive.mount('/content/drive')

# change this to your own kaggle.json path
kaggle_file_path = '/content/drive/My Drive/kaggle.json'

os.makedirs('/root/.kaggle', exist_ok=True)
shutil.copy(kaggle_file_path, '/root/.kaggle/kaggle.json')
os.chmod('/root/.kaggle/kaggle.json', 600)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [40]:
def download_dataset(dataset_name):
    command = f"kaggle datasets download -d bruncikristian/{dataset_name}-preprocessed-dataset"
    os.system(command)

In [41]:
def load_dataset_from_zip(dataset_name):
    dataset_dict = load_dataset(
        "csv",
        data_files=f"{dataset_name}-preprocessed-dataset.zip"
    )

    dataset = dataset_dict['train']
    return dataset

In [42]:
def split_dataset(dataset, test_size=0.1, validation_size=0.1, seed=42):
    datasets_train_test = dataset.train_test_split(test_size=test_size)
    datasets_train_validation = datasets_train_test["train"].train_test_split(test_size=validation_size)
    dataset_split = DatasetDict({
        "train": datasets_train_validation["train"],
        "validation": datasets_train_validation["test"],
        "test": datasets_train_test["test"]
    })

    return dataset_split["test"]

In [43]:
def add_prompt_to_dataset(dataset, prompt_template="summarize: {article}"):
    def add_summarization_prompt(example):
        example['article'] = prompt_template.format(article=example['article'])
        return example

    dataset = dataset.map(add_summarization_prompt)
    return dataset


# Load the tokenizer and model

In [44]:
def load_model_and_tokenizer(model_checkpoint, device="cpu", hf_token=None):
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, trust_remote_code=True, use_auth_token=hf_token)

    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_checkpoint,
        trust_remote_code=True,
        torch_dtype=torch.float32,
        device_map=None if device == "cpu" else "auto",
        use_auth_token=hf_token
    )

    tokenizer.pad_token = tokenizer.eos_token
    return tokenizer, model

# Inference

In [45]:
def summarize_text(prompt, tokenizer, model, device="cpu", max_new_tokens=200, temperature=0.7, top_p=0.9):
    print(prompt)
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device)

    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        do_sample=True,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        no_repeat_ngram_size=3,
        repetition_penalty=1.5,
        min_length=30,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id
    )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)


In [46]:
def smart_trim(summary):
    sentences = sent_tokenize(summary.strip())

    clean_sents = [s.strip() for s in sentences if s.strip().endswith(('.', '!', '?'))]

    if clean_sents:
        return " ".join(clean_sents)

    match = re.search(r"(.*?[\.!?])\s", summary)
    if match:
        return match.group(1)

    return summary.strip()

# Evaluate the model

In [47]:
def rouge_score(response, summary):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    scores = scorer.score(response, summary)
    f1_scores = [score.fmeasure for score in scores.values()]
    return f1_scores[0], f1_scores[1], f1_scores[2]


In [48]:
def bleu_score(response, summary):
    doc_abstract = nlp(summary)
    doc_summary = nlp(response)

    sentences_abstract = [sent.text.split() for sent in doc_abstract.sents]
    sentences_summary = [token.text for token in doc_summary]

    weights = (0.5, 0.5, 0.25, 0.25)

    score = sentence_bleu(sentences_abstract, sentences_summary, weights=weights)
    return score

In [49]:
def meteor_score(response, summary):
    summary_tokens = summary.split()
    response_tokens = response.split()

    score = round(single_meteor_score(response_tokens, summary_tokens), 4)
    return score

In [50]:
def bert_score(response, summary):
    scorer = BERTScorer(model_type='bert-base-uncased')
    P, R, F1 = scorer.score([response], [summary])
    return F1.item()


In [51]:
def compute_metrics(response, summary):
  rouge1, rouge2, rougel = rouge_score(response, summary)
  bleu = bleu_score(response, summary)
  meteor = meteor_score(response, summary)
  bert = bert_score(response, summary)

  return {
      "rouge1": rouge1,
      "rouge2": rouge2,
      "rougel": rougel,
      "bleu": bleu,
      "meteor": meteor,
      "bert": bert
  }

In [52]:
def save_metrics_to_csv(metrics_list, output_folder, dataset_name, model_checkpoint):
    metrics_df = pd.DataFrame(metrics_list, columns=["ROUGE-1", "ROUGE-2", "ROUGE-L", "BLEU", "METEOR", "BERT", "time"])
    metrics_df.to_csv(f"{output_folder}/{dataset_name}_google-t5-small_fine_tuned_metrics.csv", index=False)
    # metrics_df.to_csv(f"{output_folder}/{dataset_name}_{model_checkpoint}_metrics.csv", index=False)

In [53]:
def pipeline(dataset_name, model_checkpoint, output_folder, num_samples=100):
    download_dataset(dataset_name)
    dataset = load_dataset_from_zip(dataset_name)
    dataset_split = split_dataset(dataset)

    dataset = add_prompt_to_dataset(dataset)
    tokenizer, model = load_model_and_tokenizer(model_checkpoint)

    test_samples_list = list(dataset)

    random.seed(42)
    test_samples = random.sample(test_samples_list, num_samples)

    metrics_list = []
    for sample in test_samples:
        text = sample["article"]
        summary = sample["summary"]

        start_time = time.time()

        response = summarize_text(text, tokenizer, model)

        end_time = time.time()
        elapsed_time = end_time - start_time

        response = smart_trim(response)
        print("response: ", response)

        metrics = compute_metrics(response, summary)
        metrics_list.append([
            metrics["rouge1"],
            metrics["rouge2"],
            metrics["rougel"],
            metrics["bleu"],
            metrics["meteor"],
            metrics["bert"],
            elapsed_time
        ])

    save_metrics_to_csv(metrics_list, output_folder, dataset_name, model_checkpoint)

In [54]:
datasets = ["bbc", "gigaword", "nips", "cord"]


# TODO: Replace with your own Hugging Face API token
huggingface_token = "YOUR_HUGGINGFACE_TOKEN"


for dataset_name in datasets:
    model_checkpoint = f'BRUNOKRISTI/t5_small_fine_tuned_{dataset_name}'
    # model_checkpoint = f't5-small'
    metrics_df = pipeline(dataset_name, model_checkpoint, output_folder="/content/drive/My Drive/Results/T5", num_samples=10)


summarize: Citizenship ceremonies could be introduced for people celebrating their ##th birthday Charles Clarke has said. The idea will be tried as part of an overhaul of the way government approaches "inclusive citizenship" particularly for ethnic minorities. A pilot scheme based on ceremonies in Australia will start in October. Mr Clarke said it would be a way of recognising young people reaching their voting age when they also gain greater independence from parents. Britain's young black and Asian people are to be encouraged to learn about the nation's heritage as part of the government's new race strategy which will also target specific issues within different ethnic minority groups. Officials say the home secretary wants young people to feel they belong and to understand their "other cultural identities" alongside being British. The launch follows a row about the role of faith schools in Britain. On Monday school inspection chief David Bell accused some Islamic schools of failing 

The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


summarize: The world's biggest confectionery firm Cadbury Schweppes has reported a modest rise in profits after the weak dollar took a bite out of its results. Underlying pre-tax profits rose #% to £###m ($#.##bn) in #### but would have been #% higher if currency movements were stripped out. The owner of brands such as Dairy Milk Dr Pepper and Snapple generates more than ##% of its sales outside the UK. Cadbury said it was confident it would hit its targets for ####. "While the external commercial environment remains competitive we are confident that we have the strategy brands and people to deliver within our goal ranges in #### " said chief executive Todd Stitzer. The modest profit rise had been expected by analysts after the company said in December that the poor summer weather had hit soft drink sales in Europe. Cadbury said its underlying sales were up by #% in ####. Growth was helped by its confectionery brands - including Cadbury Trident and Halls - which enjoyed a "successful" 

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/6864 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/20.8k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.47k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/142 [00:00<?, ?B/s]

summarize: Former us president bill clinton announced on monday that his foundation will provide ## million us dollars for aids treatment for ## ### children in ## developing countries in ####.
response:  Clinton foundation to provide ## ### aids treatment for children.-# #### charitable foundation to help ## indiana kids in developing countries.


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


summarize: Macau s monetary authority on tuesday denied a news report that several north korean bank accounts in the territory had been unfrozen after being put on hold over united states allegations of financial irregularities.
response:  Macau monetary authority denies north korean bank accounts unfrozen.-)T Preparatory financial information from the company is contained in a report.
summarize: The leaders of jordan egypt and the palestinians today reiterated the necessity of promoting arab efforts aiming at maintaining the peace process and ensuring the implementation of the agreements which have been signed.
response:  Egypt palestinians reiterate importance of promoting arab efforts.----sunteti Forum urges more effort on peace process with arabs to ensure implementation of agreements signed.


The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


summarize: United states senate candidate rand paul is stirring up controversy again this time by saying he opposes citizenship for children born in the united states to parents who are illegal immigrants.
response:  Senate candidate senate candidates raise controversy over children born in us.-#.dx.de Pageant criticizes paul on immigration reform.
summarize: A list of some ### alleged homosexuals which has been circulating on the internet has sparked panic among gays in predominantly catholic croatia an activist said monday.
response:  ### alleged homosexuals circulate on internet.dh.-dhoaning list sparks panic among gay people in predominantly catholic croatia.
summarize: Serbia-montenegro s army friday opened up its top secret underground shelter in belgrade for years rumored to be a hideout for war crimes suspects state television said.
response:  Englishsunteti German army offers free standing for war crimes suspects spy hideout claims defense agency.
summarize: City officials and

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/5102 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/20.8k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.47k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/142 [00:00<?, ?B/s]

summarize: An ideal observer model for identifying the reference frame of objects Joseph L. Austerweil Department of Psychology University of California Berkeley Berkeley CA ##### Department of Computer Science and Engineering Abram L. Friesen University of Washington Seattle WA ##### Joseph.Austerweil@gmail.com afriesen@cs.washington.edu Thomas L. Grifﬁths Department of Psychology University of California Berkeley Berkeley CA ##### Tom Griffiths@berkeley.edu Abstract The object people perceive in an image can depend on its orientation relative to the scene it is in (its reference frame). For example the images of the symbols × and + differ by a ## degree rotation. Although real scenes have multiple im- ages and reference frames psychologists have focused on scenes with only one reference frame. We propose an ideal observer model based on nonparametric Bayesian statistics for inferring the number of reference frames in a scene and their parameters. When an ambiguous image could be assi

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/7909 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/20.8k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.47k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/142 [00:00<?, ?B/s]

summarize: Suicidality which encompasses suicidal ideation plans and attempts is a clinical concern and has an adverse impact on individuals families and society. Individuals with suicidal ideation [#] and those who have planned for a suicide attempt [#] are at greater risk of attempting suicide with those who have demonstrated planning being at a higher risk than those with ideation [#] . A suicide attempt is the most potent risk factor for completed suicide [#] . Research has been conducted to identify the determinants of a suicide attempt and factors identified include depressive [#] and anxiety disorders [#] poor social adjustment [#] unemployment [#] and medical illness [#] . Singapore is a city-state in Southeast Asia and is a multi-ethnic nation. The Singapore population as of #### stands at #.# million of whom #.# million are residents (Citizens and Permanent Residents) [#] . The suicide rate in Singapore between #### and #### stood at #.## deaths per ### ### residents with a s

The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


summarize: Swine vesicular disease (SVD) is a highly contagious viral disease of pigs. Symptoms are clinically indistinguishable from those caused by other vesicular disease viruses such as foot and mouth disease (FMD) virus vesicular stomatitis (VS) virus and vesiviruses (which include vesicular exanthema of swine virus (VESV)) so SVD is classified as a list A disease by the Office International des Epizooties (OIE) [#] . The causative agent swine vesicular disease virus (SVDV) is a member of the genus Enterovirus within the family Picornaviridae. It is a non-enveloped virus containing a single-strand RNA genome of positive sense which is approximately #.# kb nucleotides in length [# #] with a poly (A) tail at the #′ end and can act directly as messenger RNA in host cells. SVDV is both antigenically and genetically closely related to the human pathogen coxsackievirus B# (CV-B#) [#] [#] [#] [#] although pigs inoculated with CV-B# do not show overt clinical signs of SVD [#] . It is poss

The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


summarize: Environmental economic and other major disruptions that create even brief periods of social instability are referred to as "Big Events " ( Friedman et al. #### ) . They have a variety of direct and indirect influences such as population displacement economic disruption health service shortages and upticks in violence which affect normative behaviors and in some cases lead to social upheaval ( Friedman et al. #### ) . Big Events are noted as having especially seriously impacts "on marginalized groups of people whose social precarity leaves them more vulnerable to the harms engendered by major disruptions " ( Zolopa et al. #### ) . Research on the repercussions of Big Events and on how interacting causal pathways are experienced by specific groups can inform interven-tions aimed at preventing or mitigating harm ( Friedman et al. #### ) . In their recent review article Zolopa et al. identify risk pathways resulting from "Big Events " on health and service delivery for people wh