In [4]:
import wandb
from wandb.keras import WandbMetricsLogger

import tensorflow as tf

from datasets import load_dataset
from datasets import load_metric
from transformers import AutoTokenizer
from transformers import DataCollatorForSeq2Seq
from transformers import TFAutoModelForSeq2SeqLM

from tqdm import tqdm

## Model

The reward model is pretrained model for summarization task with a randomly initialized linear head on top. So in our T5 model, let's try to add a linear head. 

In [3]:
# Pretrained supervised T5 model
model = TFAutoModelForSeq2SeqLM.from_pretrained("t5-small")
model.summary()

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at t5-small.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


Model: "tft5_for_conditional_generation_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 shared (Embedding)          multiple                  16449536  
                                                                 
 encoder (TFT5MainLayer)     multiple                  35330816  
                                                                 
 decoder (TFT5MainLayer)     multiple                  41625344  
                                                                 
Total params: 60,506,624
Trainable params: 60,506,624
Non-trainable params: 0
_________________________________________________________________


In [10]:
model.layers

[<keras.layers.core.embedding.Embedding at 0x7ff7a0112190>,
 <transformers.models.t5.modeling_tf_t5.TFT5MainLayer at 0x7ff7a0112890>,
 <transformers.models.t5.modeling_tf_t5.TFT5MainLayer at 0x7ff7a01333d0>]

In [11]:
text = "summarize: The Inflation Reduction Act lowers prescription drug costs, health care costs, and energy costs. It's the most aggressive action on tackling the climate crisis in American history, which will lift up American workers and create good-paying, union jobs across the country. It'll lower the deficit and ask the ultra-wealthy and corporations to pay their fair share. And no one making under $400,000 per year will pay a penny more in taxes."

In [61]:
tokenizer = AutoTokenizer.from_pretrained("t5-small")

In [62]:
inputs = tokenizer([text], return_tensors="tf", padding=True, truncation=True)
inputs.input_ids

<tf.Tensor: shape=(1, 101), dtype=int32, numpy=
array([[21603,    10,    37,    86,    89,  6105,   419,  8291,  1983,
         1364,     7,  7744,  2672,  1358,     6,   533,   124,  1358,
            6,    11,   827,  1358,     5,    94,    31,     7,     8,
          167,  8299,  1041,    30,     3, 26074,     8,  3298,  5362,
           16,   797,   892,     6,    84,    56,  5656,    95,   797,
         2765,    11,   482,   207,    18,  8832,    53,     6,  7021,
         2476,   640,     8,   684,     5,    94,    31,   195,  1364,
            8, 11724,    11,   987,     8,  6173,    18,  1123,   138,
          189,    63,    11, 11711,    12,   726,    70,  2725,   698,
            5,   275,   150,    80,   492,   365,  1514, 31471,   399,
          215,    56,   726,     3,     9, 23925,    72,    16,  5161,
            5,     1]], dtype=int32)>

In [63]:
model

<transformers.models.t5.modeling_tf_t5.TFT5ForConditionalGeneration at 0x7ff7a0122d10>

In [64]:
from transformers.models.t5.configuration_t5 import T5Config
from transformers.models.t5.modeling_tf_t5 import TFT5Model

In [65]:
config = T5Config()

model_without_head = TFT5Model(config)

In [69]:
# decoder_input_ids should be tokenized input ids of summary
output = model_without_head(inputs.input_ids, decoder_input_ids=inputs.input_ids, return_dict=True)
output

TFSeq2SeqModelOutput(last_hidden_state=<tf.Tensor: shape=(1, 101, 512), dtype=float32, numpy=
array([[[ 1.1848021e+00,  1.3647668e+00,  4.8886967e-01, ...,
         -5.5582839e-01,  1.8024219e-02,  1.4010403e-02],
        [ 1.1753862e+00,  1.3805727e+00,  4.7043261e-01, ...,
         -5.7651693e-01, -6.2417691e-03,  7.3773619e-03],
        [ 1.1477405e+00,  1.4045787e+00,  4.5675668e-01, ...,
         -5.5629939e-01,  3.3542234e-02,  8.8066403e-03],
        ...,
        [ 1.1489153e+00,  1.3862922e+00,  4.6429896e-01, ...,
         -5.6210953e-01, -5.2479161e-03,  5.0726469e-04],
        [ 1.1495028e+00,  1.3614743e+00,  5.0725353e-01, ...,
         -5.7479459e-01,  1.5330562e-02, -2.9693369e-02],
        [ 1.1464672e+00,  1.3709490e+00,  4.8977739e-01, ...,
         -5.5022699e-01,  5.7469425e-04,  2.5579433e-03]]], dtype=float32)>, past_key_values=((<tf.Tensor: shape=(1, 8, 101, 64), dtype=float32, numpy=
array([[[[ 1.71381962e+00, -1.40151024e-01,  2.07584411e-01, ...,
           7.

In [72]:
output.last_hidden_state[:, 0, :]

<tf.Tensor: shape=(1, 512), dtype=float32, numpy=
array([[ 1.18480206e+00,  1.36476684e+00,  4.88869667e-01,
        -5.92741013e-01,  9.77886140e-01,  6.48904443e-01,
        -4.58916843e-01, -1.06635794e-01,  9.10232902e-01,
        -2.84296930e-01,  1.15429807e+00,  4.14846182e-01,
         9.78748798e-01, -9.64390934e-01,  6.46315038e-01,
        -1.05247533e+00, -8.82582843e-01, -4.80766922e-01,
         4.24455911e-01,  7.64136463e-02,  1.36931747e-01,
         1.97333825e+00,  1.25592545e-01,  1.20324838e+00,
         4.49753731e-01,  6.72076344e-02,  8.51777673e-01,
         3.69537145e-01, -1.07756913e+00,  1.31219435e+00,
         3.60225961e-02,  6.46743059e-01,  2.80902147e-01,
         1.70210075e+00, -7.61029646e-02, -4.23340619e-01,
         6.58759415e-01,  9.64850262e-02,  6.70779943e-01,
         1.24985623e+00, -1.11185944e+00, -3.08721364e-01,
         5.44326007e-01,  7.83310413e-01, -5.00096142e-01,
        -5.42051613e-01,  5.19697189e-01, -2.48681545e+00,
      

In [129]:
class T5RewardModel(tf.keras.Model):
    def __init__(self):
        super(T5RewardModel, self).__init__()
        self.t5_without_lm_head = TFT5Model(config)
        self.reward_head = tf.keras.layers.Dense(1, use_bias=False)

    def call(self, inputs, training=False):
        input_ids, decoder_input_ids = inputs
        sequence_output = self.t5_without_lm_head(
            input_ids,
            decoder_input_ids=decoder_input_ids,
            training=training).last_hidden_state

        reward_output = self.reward_head(sequence_output[:,0,:]) ## extract the 1st token's embeddings

        return reward_output

In [130]:
reward_model = T5RewardModel()

reward_model((inputs.input_ids, inputs.input_ids))

<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[1.2724683]], dtype=float32)>

In [131]:
reward_model.summaryr

Model: "t5_reward_model_21"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 tft5_model_23 (TFT5Model)   multiple                  60506624  
                                                                 
 dense_20 (Dense)            multiple                  512       
                                                                 
Total params: 60,507,136
Trainable params: 60,507,136
Non-trainable params: 0
_________________________________________________________________


In [132]:
reward_model.compile(optimizer='adam', loss='binary_crossentropy')