In [2]:
import torch
from torch import nn
from transformers import LEDTokenizer, LEDModel

tokenizer = LEDTokenizer.from_pretrained("allenai/led-large-16384-arxiv")

In [2]:
import pickle
from copy import deepcopy

with open("./_FRIDGE/_aug/pubmed_test_aug.pickle", "rb") as f:
    dataset = pickle.load(f)

data = deepcopy(dataset[:20])

del dataset

In [3]:
def generate_global_attention_mask(tokenizer, input_ids):
    mask = torch.zeros_like(input_ids)
    mask[((input_ids == tokenizer.bos_token_id) | (input_ids == tokenizer.eos_token_id)).nonzero(as_tuple=True)] = 1
    return mask

In [4]:
class RewardModel(nn.Module):
    def __init__(self, model="allenai/led-large-16384-arxiv", head_layer_size=32):
        super(RewardModel, self).__init__()
        self.led_encoder = LEDModel.from_pretrained(model).get_encoder()
        self._encoder_output_size = self.led_encoder.layernorm_embedding.weight.shape[0]
        self.head = nn.Sequential(
            nn.Linear(self._encoder_output_size, head_layer_size, bias=False),
            nn.ReLU(),
            nn.Linear(head_layer_size, 1, bias=False)
        )

    def forward(self, input_ids, global_attention_mask):
        hidden_state = self.led_encoder(input_ids, global_attention_mask=global_attention_mask).last_hidden_state
        output = hidden_state.view(hidden_state.size(0), -1, hidden_state.size(-1))[:, -1, :]
        output = self.head(output)
        return output.squeeze()

In [5]:
test = RewardModel()

Some weights of the model checkpoint at allenai/led-large-16384-arxiv were not used when initializing LEDModel: ['lm_head.weight', 'final_logits_bias']
- This IS expected if you are initializing LEDModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LEDModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [1]:
class Criterion():
    def __init__(self):
        self.logsig = nn.LogSigmoid()
    def loss(self, output):
        return -self.logsig(output[0] - output[1])

In [4]:
from torch import optim

optimizer = optim.SGD(test.parameters(), lr=0.001)
criterion = Criterion()

In [8]:
from tqdm import tqdm

for d in tqdm(data):
    art = tokenizer.batch_decode(d["article"].unsqueeze(0))[0]
    abt = tokenizer.batch_decode(d["abstract"].unsqueeze(0))[0]
    nos = tokenizer.batch_decode(d["noised"].unsqueeze(0))[0]

    merged_1 = art+" TL;DR: "+abt
    merged_0 = art+" TL:DR: "+nos
    
    put = tokenizer.batch_encode_plus([merged_1, merged_0], return_tensors="pt", padding=True).input_ids[:, 1:-1]
    att = generate_global_attention_mask(tokenizer, put)

    optimizer.zero_grad()

    res = test(put, att)
    loss = criterion.loss(res)
    loss.backward()
    optimizer.step()

    print(loss)

  5%|▌         | 1/20 [02:31<47:59, 151.53s/it]

tensor(0.6836, grad_fn=<NegBackward0>)


 10%|█         | 2/20 [04:55<44:08, 147.14s/it]

tensor(0.6833, grad_fn=<NegBackward0>)


: 

: 

In [6]:
from torch import tensor

In [19]:
criterion.loss(tensor([48792., 2384.]))

tensor(-0.)