In [11]:
import wandb
from wandb.keras import WandbMetricsLogger

import tensorflow as tf

from datasets import load_dataset
from datasets import load_metric
from transformers import AutoTokenizer
from transformers import DataCollatorForSeq2Seq
from transformers import TFAutoModelForSeq2SeqLM

from tqdm import tqdm

## Model

The reward model is pretrained model for summarization task with a randomly initialized linear head on top. So in our T5 model, let's try to add a linear head. 

In [4]:
# Pretrained supervised T5 model
tf.keras.backend.clear_session()
model = TFAutoModelForSeq2SeqLM.from_pretrained("t5-small")
model.summary()

2023-01-06 19:58:34.341758: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-01-06 19:58:34.355569: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-01-06 19:58:34.356016: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-01-06 19:58:34.356981: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorF

Model: "tft5_for_conditional_generation"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 shared (Embedding)          multiple                  16449536  
                                                                 
 encoder (TFT5MainLayer)     multiple                  35330816  
                                                                 
 decoder (TFT5MainLayer)     multiple                  41625344  
                                                                 
Total params: 60,506,624
Trainable params: 60,506,624
Non-trainable params: 0
_________________________________________________________________


In [81]:
model

<transformers.models.t5.modeling_tf_t5.TFT5ForConditionalGeneration at 0x7fdd97951910>

In [4]:
model.layers

[<keras.layers.core.embedding.Embedding at 0x7f9684020150>,
 <transformers.models.t5.modeling_tf_t5.TFT5MainLayer at 0x7f9684020650>,
 <transformers.models.t5.modeling_tf_t5.TFT5MainLayer at 0x7f9684020310>]

In [5]:
text = "summarize: The Inflation Reduction Act lowers prescription drug costs, health care costs, and energy costs. It's the most aggressive action on tackling the climate crisis in American history, which will lift up American workers and create good-paying, union jobs across the country. It'll lower the deficit and ask the ultra-wealthy and corporations to pay their fair share. And no one making under $400,000 per year will pay a penny more in taxes."

In [6]:
tokenizer = AutoTokenizer.from_pretrained("t5-small")

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [7]:
inputs = tokenizer([text], return_tensors="tf", padding=True, truncation=True)
inputs.input_ids

<tf.Tensor: shape=(1, 101), dtype=int32, numpy=
array([[21603,    10,    37,    86,    89,  6105,   419,  8291,  1983,
         1364,     7,  7744,  2672,  1358,     6,   533,   124,  1358,
            6,    11,   827,  1358,     5,    94,    31,     7,     8,
          167,  8299,  1041,    30,     3, 26074,     8,  3298,  5362,
           16,   797,   892,     6,    84,    56,  5656,    95,   797,
         2765,    11,   482,   207,    18,  8832,    53,     6,  7021,
         2476,   640,     8,   684,     5,    94,    31,   195,  1364,
            8, 11724,    11,   987,     8,  6173,    18,  1123,   138,
          189,    63,    11, 11711,    12,   726,    70,  2725,   698,
            5,   275,   150,    80,   492,   365,  1514, 31471,   399,
          215,    56,   726,     3,     9, 23925,    72,    16,  5161,
            5,     1]], dtype=int32)>

In [8]:
model

<transformers.models.t5.modeling_tf_t5.TFT5ForConditionalGeneration at 0x7f97207bae90>

In [2]:
from transformers.models.t5.configuration_t5 import T5Config
from transformers.models.t5.modeling_tf_t5 import TFT5Model

In [4]:
config = T5Config()

model_without_head = TFT5Model(config).from_pretrained("t5-small")

All model checkpoint layers were used when initializing TFT5Model.

All the layers of TFT5Model were initialized from the model checkpoint at t5-small.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5Model for predictions without further training.


In [11]:
# decoder_input_ids should be tokenized input ids of summary
output = model_without_head(inputs.input_ids, decoder_input_ids=inputs.input_ids, return_dict=True)
output

TFSeq2SeqModelOutput(last_hidden_state=<tf.Tensor: shape=(1, 101, 512), dtype=float32, numpy=
array([[[-0.46276543,  1.0959268 ,  1.0243176 , ...,  2.6208034 ,
          1.530184  , -0.0862919 ],
        [-0.47934574,  1.0995437 ,  1.0164989 , ...,  2.6417081 ,
          1.5272915 , -0.13780792],
        [-0.4515319 ,  1.071627  ,  1.0325024 , ...,  2.6534188 ,
          1.5200558 , -0.11399219],
        ...,
        [-0.46786645,  1.0388124 ,  1.0414091 , ...,  2.6877446 ,
          1.5250008 , -0.11594465],
        [-0.48268214,  1.0611473 ,  1.0236026 , ...,  2.6838536 ,
          1.5431606 , -0.11632806],
        [-0.4720666 ,  1.0709583 ,  1.0015554 , ...,  2.674243  ,
          1.5484025 , -0.12812552]]], dtype=float32)>, past_key_values=((<tf.Tensor: shape=(1, 8, 101, 64), dtype=float32, numpy=
array([[[[-0.12296635, -0.26068586,  0.02107891, ..., -0.4433809 ,
          -0.7091241 , -1.2599609 ],
         [-0.14442179, -0.32069674,  0.13090545, ..., -0.43308276,
          -0.755

In [12]:
output.last_hidden_state[:, 0, :]

<tf.Tensor: shape=(1, 512), dtype=float32, numpy=
array([[-4.62765425e-01,  1.09592676e+00,  1.02431762e+00,
        -8.61518919e-01,  3.01193267e-01, -4.57092613e-01,
         4.04883653e-01,  1.63444293e+00,  1.41569674e-01,
        -1.80862434e-02,  7.85361111e-01, -1.26193881e+00,
         1.41255713e+00, -7.88933933e-01,  4.29367214e-01,
         1.48956284e-01,  1.18693542e+00, -1.07305080e-01,
        -4.79683399e-01, -3.45987558e-01, -7.54406571e-01,
         7.18635321e-01,  7.34846413e-01,  1.04856646e+00,
         8.97109747e-01,  6.79533064e-01, -1.21847844e+00,
         9.55356181e-01,  1.27251789e-01, -1.45679367e+00,
        -2.14301810e-01,  1.42446673e+00, -8.95364225e-01,
         2.72548199e-01,  6.73023999e-01,  7.32028544e-01,
         2.12593651e+00,  1.57341826e+00, -6.84934914e-01,
        -1.95827281e+00,  8.12734306e-01,  4.32506680e-01,
         1.16787004e+00,  5.57159185e-01,  3.39249402e-01,
         1.93310523e+00,  7.38496780e-01, -1.02255034e+00,
      

In [5]:
class T5RewardModel(tf.keras.Model):
    def __init__(self):
        super(T5RewardModel, self).__init__()
        self.t5_without_lm_head = TFT5Model(config).from_pretrained("t5-small")
        self.reward_head = tf.keras.layers.Dense(1, use_bias=False, activation="sigmoid")

    def call(self, inputs, training=False):
        input_ids, decoder_input_ids = inputs
        sequence_output = self.t5_without_lm_head(
            input_ids,
            decoder_input_ids=decoder_input_ids,
            training=training).last_hidden_state

        reward_output = self.reward_head(sequence_output[:,0,:]) ## extract the 1st token's embeddings

        return reward_output

In [20]:
tf.keras.backend.clear_session()
reward_model = T5RewardModel()

reward_model((inputs.input_ids, inputs.input_ids)) # (input token id, target token id)

<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.8869067]], dtype=float32)>

In [131]:
reward_model.summaryr

Model: "t5_reward_model_21"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 tft5_model_23 (TFT5Model)   multiple                  60506624  
                                                                 
 dense_20 (Dense)            multiple                  512       
                                                                 
Total params: 60,507,136
Trainable params: 60,507,136
Non-trainable params: 0
_________________________________________________________________


In [132]:
reward_model.compile(optimizer='adam', loss='binary_crossentropy')

# Dataloader

In [6]:
DATA_PATH = "../../comparison_splits"

train_dataset = load_dataset("json", data_files=f"{DATA_PATH}/train.jsonl")["train"]
valid_dataset = load_dataset("json", data_files=f"{DATA_PATH}/valid.jsonl")["train"]

Using custom data configuration default-146c1daaf6dd79a6
Found cached dataset json (/home/ayushthakur/.cache/huggingface/datasets/json/default-146c1daaf6dd79a6/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-8dfa26f346d28263
Found cached dataset json (/home/ayushthakur/.cache/huggingface/datasets/json/default-8dfa26f346d28263/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

In [8]:
train_dataset[0]["post"] + "\n" + train_dataset[0]["summary"]

'Previous: \n\nGuys, I think I\'m "cured". It was a strange event but what the heck, it made me realize something.\n\nI was studying late at night in my room a few days ago. I have this shelf in my room with a bunch of zelda collectibles and a really expensive Zelda figurine underneath it ($400+).\n\nWell, guess what. As my luck would have it, Ikea shelf gave in, all my collectibles fell to the ground and the shelf knocked down my figure and destroyed it.\n\nObviously a distressing moment for me but it was also at that time I realised I needed to perhaps chill with this hobby and that the hobby can still be mine without having to necessarily share it with a significant other.\n\nOf course I would PREFER if she liked Zelda too and it would definitely be a huge plus in my book but if I fall in love with a girl who isn\'t into Zelda, I guess that\'s just how it\'s going to be. I\'m honestly not going to worry too much about this to be honest, I have enough on my plate with studies as it i

In [103]:
PREFIX = "binary classification: "

In [104]:
inputs1 = PREFIX + train_dataset[0]["post"] + "\n" + train_dataset[0]["summary"]
inputs1

'binary classification: Previous: \n\nGuys, I think I\'m "cured". It was a strange event but what the heck, it made me realize something.\n\nI was studying late at night in my room a few days ago. I have this shelf in my room with a bunch of zelda collectibles and a really expensive Zelda figurine underneath it ($400+).\n\nWell, guess what. As my luck would have it, Ikea shelf gave in, all my collectibles fell to the ground and the shelf knocked down my figure and destroyed it.\n\nObviously a distressing moment for me but it was also at that time I realised I needed to perhaps chill with this hobby and that the hobby can still be mine without having to necessarily share it with a significant other.\n\nOf course I would PREFER if she liked Zelda too and it would definitely be a huge plus in my book but if I fall in love with a girl who isn\'t into Zelda, I guess that\'s just how it\'s going to be. I\'m honestly not going to worry too much about this to be honest, I have enough on my pla

In [105]:
tokenizer = AutoTokenizer.from_pretrained("t5-small")

In [106]:
input1 = tokenizer(inputs1, return_tensors="tf")
input1

{'input_ids': <tf.Tensor: shape=(1, 428), dtype=int32, numpy=
array([[14865, 13774,    10, 10232,    10,  9070,     6,    27,   317,
           27,    31,    51,    96, 26867,  1280,    94,    47,     3,
            9,  6765,   605,    68,   125,     8, 24783,     6,    34,
          263,   140,  3384,   424,     5,    27,    47,  6908,  1480,
           44,   706,    16,    82,   562,     3,     9,   360,   477,
          977,     5,    27,    43,    48,  8625,    16,    82,   562,
           28,     3,     9,  7292,    13,     3,  4650,    26,     9,
         2868,  2317,     7,    11,     3,     9,   310,  2881,  1027,
         8804,     9, 31193, 13483,    34,  8785,  5548,  1220,   137,
         1548,     6,  3382,   125,     5,   282,    82,  5851,   133,
           43,    34,     6, 25907,  8625,  1891,    16,     6,    66,
           82,  2868,  2317,     7,  4728,    12,     8,  1591,    11,
            8,  8625,  7673,    15,    26,   323,    82,  2320,    11,
        10932, 

In [107]:
inputs2 = [PREFIX + doc for doc in train_dataset[0]["post"] + "\n" + train_dataset[0]["summary"]]
inputs2

['binary classification: P',
 'binary classification: r',
 'binary classification: e',
 'binary classification: v',
 'binary classification: i',
 'binary classification: o',
 'binary classification: u',
 'binary classification: s',
 'binary classification: :',
 'binary classification:  ',
 'binary classification: \n',
 'binary classification: \n',
 'binary classification: G',
 'binary classification: u',
 'binary classification: y',
 'binary classification: s',
 'binary classification: ,',
 'binary classification:  ',
 'binary classification: I',
 'binary classification:  ',
 'binary classification: t',
 'binary classification: h',
 'binary classification: i',
 'binary classification: n',
 'binary classification: k',
 'binary classification:  ',
 'binary classification: I',
 "binary classification: '",
 'binary classification: m',
 'binary classification:  ',
 'binary classification: "',
 'binary classification: c',
 'binary classification: u',
 'binary classification: r',
 'binary cla

In [108]:
input2 = tokenizer(inputs2)
input2

{'input_ids': [[14865, 13774, 10, 276, 1], [14865, 13774, 10, 3, 52, 1], [14865, 13774, 10, 3, 15, 1], [14865, 13774, 10, 3, 208, 1], [14865, 13774, 10, 3, 23, 1], [14865, 13774, 10, 3, 32, 1], [14865, 13774, 10, 3, 76, 1], [14865, 13774, 10, 3, 7, 1], [14865, 13774, 10, 3, 10, 1], [14865, 13774, 10, 1], [14865, 13774, 10, 1], [14865, 13774, 10, 1], [14865, 13774, 10, 350, 1], [14865, 13774, 10, 3, 76, 1], [14865, 13774, 10, 3, 63, 1], [14865, 13774, 10, 3, 7, 1], [14865, 13774, 10, 3, 6, 1], [14865, 13774, 10, 1], [14865, 13774, 10, 27, 1], [14865, 13774, 10, 1], [14865, 13774, 10, 3, 17, 1], [14865, 13774, 10, 3, 107, 1], [14865, 13774, 10, 3, 23, 1], [14865, 13774, 10, 3, 29, 1], [14865, 13774, 10, 3, 157, 1], [14865, 13774, 10, 1], [14865, 13774, 10, 27, 1], [14865, 13774, 10, 3, 31, 1], [14865, 13774, 10, 3, 51, 1], [14865, 13774, 10, 1], [14865, 13774, 10, 96, 1], [14865, 13774, 10, 3, 75, 1], [14865, 13774, 10, 3, 76, 1], [14865, 13774, 10, 3, 52, 1], [14865, 13774, 10, 3, 15, 1

In [109]:
label = "1"

with tokenizer.as_target_tokenizer():
    label = tokenizer(label, max_length=3)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [110]:
label

{'input_ids': [209, 1], 'attention_mask': [1, 1]}

In [116]:
tokenizer.pad_token_id

0

In [112]:
tokenizer.decode(output[0], skip_special_tokens=True)

': Previous: Guys, I think I\'m "cured" Previous: Guys, I'

In [1]:
DATA_PATH = "../../comparison_splits"
PREFIX = "binary classifation: "
BATCH_SIZE = 2

In [5]:
# Get the dataset
train_dataset = load_dataset("json", data_files=f"{DATA_PATH}/train.jsonl")["train"]
valid_dataset = load_dataset("json", data_files=f"{DATA_PATH}/valid.jsonl")["train"]

Using custom data configuration default-b1ddb1517d68537e
Found cached dataset json (/home/ayushthakur/.cache/huggingface/datasets/json/default-b1ddb1517d68537e/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-88724b549b3e0a85
Found cached dataset json (/home/ayushthakur/.cache/huggingface/datasets/json/default-88724b549b3e0a85/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

In [10]:
train_dataset[0]

{'id': 't3_3pppzr',
 'post': 'Previous: \n\nGuys, I think I\'m "cured". It was a strange event but what the heck, it made me realize something.\n\nI was studying late at night in my room a few days ago. I have this shelf in my room with a bunch of zelda collectibles and a really expensive Zelda figurine underneath it ($400+).\n\nWell, guess what. As my luck would have it, Ikea shelf gave in, all my collectibles fell to the ground and the shelf knocked down my figure and destroyed it.\n\nObviously a distressing moment for me but it was also at that time I realised I needed to perhaps chill with this hobby and that the hobby can still be mine without having to necessarily share it with a significant other.\n\nOf course I would PREFER if she liked Zelda too and it would definitely be a huge plus in my book but if I fall in love with a girl who isn\'t into Zelda, I guess that\'s just how it\'s going to be. I\'m honestly not going to worry too much about this to be honest, I have enough on 

In [9]:
tokenizer = AutoTokenizer.from_pretrained("t5-small")

In [None]:
def preprocess_function(examples):
    chosen_inputs = [post + "\n" + summary for post, summary in zip(examples["post"], examples["chosen"])]
    rejected_inputs = [post + "\n" + summary for post, summary in zip(examples["post"], examples["rejected"])]
    
    inputs = [PREFIX + doc for doc in examples["post"]]
    model_inputs = tokenizer(inputs, max_length=550, truncation=True)

    # Setup the tokenizer for targets
    labels = examples["label"]
    labels = [str(label) for label in labels]
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(text_target=labels, max_length=3, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    model_inputs["choice"] = examples["label"]
    return model_inputs

train_dataset = train_dataset.map(preprocess_function, batched=True)
valid_dataset = valid_dataset.map(preprocess_function, batched=True)

In [48]:
DATA_PATH = "../../comparison_splits"
PREFIX = "binary classifation: "
BATCH_SIZE = 2

# Get the dataset
train_dataset = load_dataset("json", data_files=f"{DATA_PATH}/train.jsonl")["train"]
valid_dataset = load_dataset("json", data_files=f"{DATA_PATH}/valid.jsonl")["train"]

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("t5-small")

def preprocess_function(examples):
    inputs = [post + "\n" + summary for post, summary in zip(examples["post"], examples["summary"])]
    inputs = [PREFIX + doc for doc in examples["post"]]
    model_inputs = tokenizer(inputs, max_length=550, truncation=True)

    # Setup the tokenizer for targets
    labels = examples["label"]
    labels = [str(label) for label in labels]
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(text_target=labels, max_length=3, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    model_inputs["choice"] = examples["label"]
    return model_inputs

train_dataset = train_dataset.map(preprocess_function, batched=True)
valid_dataset = valid_dataset.map(preprocess_function, batched=True)

# # Remove unwanted columns
train_dataset = train_dataset.remove_columns(['id', 'post', 'summary', 'label'])
valid_dataset = valid_dataset.remove_columns(['id', 'post', 'summary', 'label'])

# data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, return_tensors="tf")
# tf_train_set = model.prepare_tf_dataset(
#     train_dataset,
#     shuffle=True,
#     batch_size=BATCH_SIZE,
#     collate_fn=data_collator,
# )

# # tf_valid_set = model.prepare_tf_dataset(
# #     valid_dataset,
# #     shuffle=False,
# #     batch_size=BATCH_SIZE,
# #     collate_fn=data_collator,
# # )

Using custom data configuration default-146c1daaf6dd79a6
Found cached dataset json (/home/ayushthakur/.cache/huggingface/datasets/json/default-146c1daaf6dd79a6/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-8dfa26f346d28263
Found cached dataset json (/home/ayushthakur/.cache/huggingface/datasets/json/default-8dfa26f346d28263/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Loading cached processed dataset at /home/ayushthakur/.cache/huggingface/datasets/json/default-146c1daaf6dd79a6/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-f4ac9ac0bfa98b29.arrow
Loading cached processed dataset at /home/ayushthakur/.cache/huggingface/datasets/json/default-8dfa26f346d28263/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-5821162798f4af06.arrow


In [79]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, return_tensors="tf")
trainloader = train_dataset.to_tf_dataset(
#             columns=["choice", "input_ids", "attention_mask"],
#             label_cols=["labels"],
            batch_size=16,
            shuffle=True,
            collate_fn=data_collator,
            prefetch=False
        )

In [80]:
from transformers.tf_utils import shape_list

In [81]:
def _shift_right(input_ids):
    decoder_start_token_id = 0
    pad_token_id = 0

    assert decoder_start_token_id is not None, (
        "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the"
        " pad_token_id. See T5 docs for more information"
    )

    start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
    start_tokens = tf.cast(start_tokens, input_ids.dtype)  # Ensure compatible dtypes for concatenation
    shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)

    assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids = tf.where(
        shifted_input_ids == -100,
        tf.cast(tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids.dtype),
        shifted_input_ids,
    )

    # "Verify that `labels` has only positive values and -100"
    assert_gte0 = tf.debugging.assert_greater_equal(
        shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype)
    )

    # Make sure the assertion op is called by wrapping the result in an identity no-op
    with tf.control_dependencies([assert_gte0]):
        shifted_input_ids = tf.identity(shifted_input_ids)

    return shifted_input_ids

In [82]:
trainloader

<MapDataset element_spec={'input_ids': TensorSpec(shape=(None, None), dtype=tf.int64, name=None), 'attention_mask': TensorSpec(shape=(None, None), dtype=tf.int64, name=None), 'labels': TensorSpec(shape=(None, None), dtype=tf.int64, name=None), 'choice': TensorSpec(shape=(None,), dtype=tf.int64, name=None)}>

In [83]:
def parse_data(inputs):
    # This will be used as decoder_input_ids
    shift_right_labels = _shift_right(inputs["labels"])
    inputs["decoder_input_ids"] = shift_right_labels
    inputs.pop("labels")

    # This label is for calculating the accuracy of the reward model
    labels = inputs["choice"]
    inputs.pop("choice")

    return inputs, labels

In [84]:
trainloader = trainloader.map(parse_data, num_parallel_calls=tf.data.AUTOTUNE)
trainloader = trainloader.prefetch(tf.data.AUTOTUNE)

In [94]:
sample = next(iter(trainloader))
sample

({'input_ids': <tf.Tensor: shape=(16, 506), dtype=int64, numpy=
  array([[14865,   853,    99, ...,     0,     0,     0],
         [14865,   853,    99, ...,     0,     0,     0],
         [14865,   853,    99, ...,     0,     0,     0],
         ...,
         [14865,   853,    99, ...,     0,     0,     0],
         [14865,   853,    99, ...,     0,     0,     0],
         [14865,   853,    99, ...,     0,     0,     0]])>,
  'attention_mask': <tf.Tensor: shape=(16, 506), dtype=int64, numpy=
  array([[1, 1, 1, ..., 0, 0, 0],
         [1, 1, 1, ..., 0, 0, 0],
         [1, 1, 1, ..., 0, 0, 0],
         ...,
         [1, 1, 1, ..., 0, 0, 0],
         [1, 1, 1, ..., 0, 0, 0],
         [1, 1, 1, ..., 0, 0, 0]])>,
  'decoder_input_ids': <tf.Tensor: shape=(16, 3), dtype=int64, numpy=
  array([[  0, 209,   1],
         [  0, 209,   1],
         [  0,   3, 632],
         [  0, 209,   1],
         [  0,   3, 632],
         [  0, 209,   1],
         [  0, 209,   1],
         [  0,   3, 632],
   

In [96]:
class T5RewardModel(tf.keras.Model):
    def __init__(self):
        super(T5RewardModel, self).__init__()
        config = T5Config()
        self.t5_without_lm_head = TFT5Model(config)
        self.reward_head = tf.keras.layers.Dense(1, use_bias=False, activation="sigmoid")

    def call(self, inputs, training=False):
        sequence_output = self.t5_without_lm_head(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            decoder_input_ids=inputs["decoder_input_ids"],
            training=training).last_hidden_state

        reward_output = self.reward_head(sequence_output[:,0,:]) ## extract the 1st token's embeddings

        return reward_output

In [102]:
reward_model = T5RewardModel()

In [98]:
reward_model(sample[0])

<tf.Tensor: shape=(16, 1), dtype=float32, numpy=
array([[0.19765256],
       [0.19703884],
       [0.19661598],
       [0.1970882 ],
       [0.19656728],
       [0.196947  ],
       [0.19677263],
       [0.197112  ],
       [0.19760194],
       [0.19729137],
       [0.19713765],
       [0.19740523],
       [0.19686896],
       [0.19787773],
       [0.19810855],
       [0.19713028]], dtype=float32)>

In [103]:
reward_model.compile("adam", loss="binary_crossentropy", metrics=["accuracy"])

In [105]:
reward_model.fit(trainloader, epochs=1)