<a href="https://colab.research.google.com/github/GuifuLiu/co-occur_dm/blob/main/finetune_t5_for_dm_pred.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [20]:
import torch
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration
from datasets import load_dataset
import re

In [33]:
!git clone https://github.com/GuifuLiu/co-occur_dm.git
!pip install -U datasets

fatal: destination path 'co-occur_dm' already exists and is not an empty directory.
Collecting datasets
  Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-4.0.0-py3-none-any.whl (494 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m494.8/494.8 kB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (193 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fsspec, datasets
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2025.3.2
    Uninstalling fsspec-2025.3.2:
      Successfully uninstalled fsspec-2025.3.2
  Attempting uninstall: datasets
    Found existing installation: datasets 2.14.4
    Uninstalling datasets-2

In [22]:
T5_PATH = 't5-base' # "t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # My envirnment uses CPU

t5_tokenizer = T5Tokenizer.from_pretrained(T5_PATH)
t5_config = T5Config.from_pretrained(T5_PATH)
t5_mlm = T5ForConditionalGeneration.from_pretrained(T5_PATH, config=t5_config).to(DEVICE)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


# Inference

In [3]:
# extracts the word sequence between <extra_id_0> and <extra_id_1> from the output
def _filter(output, end_token='<extra_id_1>'):
    # The first token is <unk> (index at 0) and the second token is <extra_id_0> (indexed at 32099. So output starts at index 2
    generated_text = tokenizer.decode(output[2:], skip_special_tokens=False, clean_up_tokenization_spaces=False)
    if end_token in generated_text:
        _end_token_index = generated_text.index(end_token)
        return generated_text[:_end_token_index]
    else:
        return generated_text

# Same inference function for zero-shot and finetuned model
def generate_dm_with_logp (text, tokenizer, model, num_return_sequences=50, max_length=7):
  input_ids = tokenizer.encode(text, add_special_tokens=True, return_tensors="pt").to(device)

  outputs = model.generate(input_ids=input_ids,
                            num_beams=200, num_return_sequences=num_return_sequences, return_dict_in_generate=True,
                            max_length=max_length, output_scores=True, output_logits=True)

  _0_index = text.index('<extra_id_0>')
  _result_prefix = text[:_0_index]
  _result_suffix = text[_0_index+len("<extra_id_0>"):]

  results = list(map(_filter, outputs["sequences"]))
  return results, outputs["sequences_scores"]

In [28]:
# text = 'Here\'s my abridgment. <extra_id_0>, however, it seems worth while at least to mention the most serious defect in the story, which is this.'
text = 'Here\'s my abridgment. In the mean time, <extra_id_0>, it seems worth while at least to mention the most serious defect in the story, which is this.'
generate_dm_with_logp(text)

(['though ',
  "if you're",
  'though ',
  'though ',
  'though ',
  'though, ',
  'though ',
  'though, ',
  'though ',
  'though ',
  'though ',
  'though ',
  'though ',
  'though ',
  'though ',
  'however ',
  'though ',
  'though ',
  'though ',
  'if nothing else ',
  'though ',
  'though ',
  'at the very least ',
  'though, ',
  'though ',
  'though ',
  'though ',
  'though ',
  'though ',
  "if you've",
  'though ',
  'though, ',
  'though, ',
  '. . ',
  'though ',
  "if I'm",
  'though, ',
  "while I'm at",
  'though ',
  'though ',
  'for the sake of ',
  'though ',
  'though ',
  'though ',
  'however, ',
  'though ',
  'if you are reading',
  'though ',
  'though ',
  'however, '],
 tensor([-0.8475, -0.9212, -0.9604, -1.0377, -1.0557, -1.0690, -1.0721, -1.0741,
         -1.0951, -1.0956, -1.1229, -1.1285, -1.1324, -1.1489, -1.1504, -1.1584,
         -1.1728, -1.1863, -1.1872, -1.1888, -1.1891, -1.1958, -1.1970, -1.2039,
         -1.2091, -1.2190, -1.2201, -1.2238, -1.23

# Finetune

In [15]:
dir = "co-occur_dm/dataset/explicit-explicit/discovery/dm1_other"

dataset = load_dataset("csv", data_files={
    "train": f"{dir}/train.csv",
    "validation": f"{dir}/validation.csv",
    "test": f"{dir}/test.csv"
})

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [27]:
dataset["train"][:4]

{'id': [7, 1303, 613, 1840],
 'sentence1': ['For ideologically grounded conservatives and libertarians, it was infuriating; the undecided swing vote could be swayed and Democrats prospered.',
  'would you go back and buy another ticket?',
  'In fact, he probably could have used a righty to face Cruz the batter before, but after that walk, it was certainly time for a new arm.',
  'Would it be better for him to drink half a Snapple?'],
 'sentence2': ['things had begun to change.',
  'it would hurt a lot more.',
  'valentine remained in the dugout, leaving morales in the game to pitch to craig gentry.',
  "i'm not sure how that's your business."],
 'dm1': ['already,', 'maybe,', 'amazingly,', 'maybe,'],
 'dm2': ['however,', 'but', 'though,', 'but']}

In [None]:
def preprocess_function(example, after = True):
    sentence1, sentence2, dm1, dm2 = example["sentence1"], example["sentence2"], example["dm1"], example["dm2"]
    should_capitalize = bool(re.search(r'[.!?]\s*["\']?\s*$', sentence1))
    dm1 = dm1.capitalize() if should_capitalize else dm1
    dm1 = dm1.lower() if not after else dm1

    input = sentence1.rstrip() + " " + dm1.rstrip() + " <extra_id_0> " + sentence2
    targets = "<extra_id_0> " + example["dm2"].rstrip(", ")

    model_input = t5_tokenizer(input, max_length=512, truncation=True, padding="max_length")
    labels = t5_tokenizer(targets, max_length=128, truncation=True, padding="max_length")

    model_input["labels"] = labels["input_ids"]
    return model_input

tokenized_dataset = dataset.map(preprocess_function, batched=False)
data_collator = DataCollatorForSeq2Seq(t5_tokenizer, model=t5_mlm)

In [None]:
# Define training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./t5-finetuned-fill-mask",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),  # Use mixed precision if available

    # Add these parameters for automatic saving every epoch
    save_strategy="epoch",           # Save at every epoch
    save_steps=None,                 # Don't save based on steps when using epoch strategy
    logging_strategy="epoch",        # Optional: log at every epoch too
    evaluation_strategy="epoch",     # Optional: evaluate at every epoch
)

# Create the trainer
trainer = Seq2SeqTrainer(
    model=t5_mlm,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
    tokenizer=t5_tokenizer,
)

trainer.train()

data_split = "discovery/dm1_other"
trainer.save_model(f"{data_split}_{T5_PATH}")

# Inference from Model File

In [None]:
# Load your fine-tuned model and tokenizer
model_path = "./t5-finetuned-fill-mask"  # Path to your saved model
tokenizer = T5Tokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)


# Example usage
test_text = "Here's my abridgment. <extra_id_0>, however, it seems worth while at least to mention the most serious defect in the story, which is this."
predicted_dms, probs = generate_dm_with_logp(test_text, tokenizer, model)
