In [1]:
# use galacitca 125m model: facebook/galactica-125m from huggingface
# add valuehead to the model so that it can be used as value model in ppo
# shall be callable like this: lm_logits, val = model(input_ids=model_input_ids, attention_mask=model_attention_mask)

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

class ValueHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, 1)

    def forward(self, hidden_states):
        return self.dense(hidden_states).squeeze(-1)
    
class GalacticaValueModel(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
        self.value_head = ValueHead(self.model.config)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        lm_logits = outputs.logits
        val = self.value_head(outputs.hidden_states[-1])
        return lm_logits, val
    
    def save(self, path):
        torch.save(self.model.state_dict(), path)

    def load(self, path):
        self.model.load_state_dict(torch.load(path))

    def save_pretrained(self, path):
        self.model.save_pretrained(path)

    def from_pretrained(self, path):
        self.model = AutoModelForCausalLM.from_pretrained(path)
        self.value_head = ValueHead(self.model.config)


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model = GalacticaValueModel("facebook/galactica-125m")
tokenizer = AutoTokenizer.from_pretrained("facebook/galactica-125m")

# test
input_text = "The quick brown fox jumps over the lazy dog"
model_input = tokenizer(input_text, return_tensors="pt")
lm_logits, val = model(input_ids=model_input.input_ids, attention_mask=model_input.attention_mask)
print(lm_logits.shape, val.shape)
print(lm_logits, val)

torch.Size([1, 9, 50000]) torch.Size([1, 9])
tensor([[[ -1.7453,  -2.2069,  -0.3622,  ...,  -1.1785,  -0.7134,  -1.5329],
         [ -8.2540,  -9.6292,   1.5110,  ...,  -4.4460,  -6.1091,  -6.9986],
         [ -8.2361,  -8.9946,  -0.1690,  ...,  -3.1719,  -5.1670,  -6.3849],
         ...,
         [ -7.1367, -10.2960,  -1.6381,  ...,  -2.0335,  -6.2150,  -6.5535],
         [ -7.7017, -11.1864,  -1.3639,  ...,  -1.4243,  -6.3104,  -7.0003],
         [ -7.8812, -11.0366,  -0.6970,  ...,  -0.9882,  -6.2138,  -8.3551]]],
       grad_fn=<UnsafeViewBackward0>) tensor([[ 0.4318,  0.6611,  0.7635,  1.3717, -0.0065,  0.2449,  1.0526,  1.0531,
          0.7288]], grad_fn=<SqueezeBackward1>)


In [38]:
# generate some text
input_text = "Once upon a time"
model_input = tokenizer(input_text, return_tensors="pt")

# generate text
output = model.model.generate(input_ids=model_input.input_ids, attention_mask=model_input.attention_mask, temperature=.5, do_sample=True, max_length=100, num_return_sequences=1)
print(tokenizer.decode(output[0], skip_special_tokens=True))


Once upon a time of intense stress, the body responds with a series of physiological changes. The main change in the body is the activation of the sympathetic nervous system. The sympathetic nervous system is activated by the release of catecholamines and neurotransmitters from the adrenal gland. The main effect of the sympathetic nervous system is the activation of the heart. The heart is a highly active organ, and its heart rate is controlled by the sympathetic nervous system. The heart is an active organ, and its heart rate is controlled


In [4]:
model.save_pretrained("galactica-125m-value-head")
tokenizer.save_pretrained("galactica-125m-value-head")

# upload the model to huggingface
# name jeggers/galactica-125m-value-head

model.model.push_to_hub("jeggers/galactica-125m-value-head")


pytorch_model.bin:  20%|█▉        | 98.8M/500M [01:14<04:46, 1.40MB/s]  

KeyboardInterrupt: 

pytorch_model.bin:  20%|█▉        | 99.8M/500M [01:30<04:46, 1.40MB/s]