In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from classifier import GPT2ForSequenceClassification
from transformers import GPT2LMHeadModel, GPT2Tokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sentiment_model = GPT2ForSequenceClassification.from_pretrained("distilgpt2").to(device)
gpt2_model = GPT2LMHeadModel.from_pretrained("distilgpt2").to(device)
gpt2_model_ref = GPT2LMHeadModel.from_pretrained("distilgpt2").to(device)
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")

sentiment_model.config.pad_token_id = tokenizer.eos_token_id
gpt2_model.config.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

Some weights of the model checkpoint at distilgpt2 were not used when initializing GPT2ForSequenceClassification: ['lm_head.weight']
- This IS expected if you are initializing GPT2ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at distilgpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [40]:
class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs):
        return torch.nn.functional.one_hot(inputs.argmax(dim=2), num_classes = inputs.shape[2]).float()

    @staticmethod
    def backward(ctx, grad_output):
        return torch.nn.functional.hardtanh(grad_output)

    
class LogitsToEmbeds(nn.Module):
    def __init__(self, embedding_weight):
        super(LogitsToEmbeds, self).__init__()
        
        self.fake_embedding = nn.Linear(embedding_weight.weight.size(0), embedding_weight.weight.size(1), bias = False)
        self.fake_embedding.weight = torch.nn.Parameter(embedding_weight.weight.t())
        self.softmax = nn.Softmax(dim = 2)
        
    def forward(self, x):
        x = self.softmax(x)
        x = STEFunction.apply(x)
        x = self.fake_embedding(x)
        return x

class PreferenceLearner(nn.Module):
    def __init__(self, base, reference, reward):
        super(PreferenceLearner, self).__init__()
        self.model = base
        self.ref = reference
        self.reward = reward
        self.toEmbeds = LogitsToEmbeds(sentiment_model.get_input_embeddings())
        
    def logprobs_from_logits(self, logits, labels):
        logp = F.log_softmax(logits, dim=2)
        return torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
        
    def forward(self, input_ids, attention_mask, length):
        
        ref_logits = self.ref(input_ids=input_ids, 
                              attention_mask=attention_mask)['logits']
        logits = self.model(input_ids=input_ids, 
                            attention_mask=attention_mask)['logits']
        
        ref_logprobs = self.logprobs_from_logits(ref_logits, input_ids)
        logprobs = self.logprobs_from_logits(logits, input_ids)
        
        kl = logprobs - ref_logprobs
        
        embeds = self.toEmbeds(logits)
        
        score = self.reward(inputs_embeds = embeds, 
                            sequence_lengths = length-1)['logits']
        
        return score, kl.mean(1)
        
    
model = PreferenceLearner(gpt2_model, 
                          gpt2_model_ref, 
                          sentiment_model)

In [44]:
inputs = tokenizer(['hello w d', 'hello f f f'], return_length=True, padding=True, return_tensors='pt')


(tensor([[-1.6189],
         [-2.5994]], grad_fn=<IndexBackward>),
 tensor([0., 0.], grad_fn=<MeanBackward1>))