In [1]:
# Usual imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM

torch.__version__

'2.0.1+cu117'

In [2]:
model = AutoModelForCausalLM.from_pretrained('.', torch_dtype=torch.bfloat16).cuda()

In [3]:
tokenizer = AutoTokenizer.from_pretrained('.')

Loading tokenizer from the cache


In [4]:
base_model = model
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 3200, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=3200, out_features=3200, bias=False)
          (k_proj): Linear(in_features=3200, out_features=3200, bias=False)
          (v_proj): Linear(in_features=3200, out_features=3200, bias=False)
          (o_proj): Linear(in_features=3200, out_features=3200, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3200, out_features=8640, bias=False)
          (down_proj): Linear(in_features=8640, out_features=3200, bias=False)
          (up_proj): Linear(in_features=3200, out_features=8640, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm(

In [5]:
# disable grad
for p in model.parameters():
    p.requires_grad_(False)

In [6]:
class BiasInjector(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model
        self.segment_bias = nn.Parameter(torch.randn(base_model.config.hidden_size) * 0.01)
        self.command_token = nn.Parameter(torch.randn(1, 1, base_model.config.hidden_size) * 0.05)
    
    def forward(self, input_ids, attention_mask=None):
        tokens = self.base_model.model.embed_tokens(input_ids)        
        tokens = torch.cat((self.command_token, tokens), -2)
        tokens = tokens + self.segment_bias
        y = self.base_model(inputs_embeds=tokens, attention_mask=attention_mask)
        return y

biased = BiasInjector(base_model).cuda().to(model.config.torch_dtype)

In [7]:
@torch.no_grad()
def gen_up_to(txt, n, allow_nl=1):
    model = biased
    x = tokenizer(txt, return_tensors="pt").input_ids.cuda()
    nl = 13
    res = []
    for i in range(n):
        y = model(x).logits[0][-1].argmax()
        res.append(y)
        if y == nl:
            allow_nl -= 1
            if allow_nl <= 0:
                break
        if y == tokenizer.eos_token_id:
            break
        x = F.pad(x, (0, 1), value=y)

    return tokenizer.decode(res)

gen_up_to("Q: Solve 2+x=5 for x.\nA: ", 10)



'2+x=5\n'

In [8]:
import torch.optim
loss_fn = nn.CrossEntropyLoss()
optim_fn = torch.optim.Adam(biased.parameters())

def train_step(q, a):
    model = biased
    total_loss = 0.0
    n = 0
    expected = tokenizer(a, add_special_tokens=False, return_tensors="pt")
    expected = expected.to("cuda").input_ids[0]
    prompt = tokenizer(q, return_tensors="pt").to("cuda").input_ids
    
    while len(expected) > 0:        
        optim_fn.zero_grad()
        ypred = model(prompt).logits[0][-1]    
        loss = loss_fn(ypred, expected[0])        
        loss.backward()
        optim_fn.step()

        total_loss += loss.item()
        n += 1 
        with torch.no_grad():
            prompt = F.pad(prompt, (0, 1), value=expected[0])        

        expected = expected[1:]
        

    return total_loss / max(n, 1)



In [13]:
from tqdm.auto import tqdm
def step():
    l = 0
    train = [
        ("Solve 2+x=5 for x.", "3"),
        ("Capital of France", "Paris"),
        ("The most famous pokemon owned by Ash", "Pikachu"),
        ("Portable computer", "Laptop"),
        ("If I have a red ball, then the color of my ball is", "red"),
        ("Second color of the rainbow", "orange"),
        ("Spain can be described as", "Spain (Spanish: España, [esˈpaɲa] (listen)), or the Kingdom of Spain (Reino de España),[f] is a country primarily located in Southwestern Europe, with parts of territory in the Atlantic Ocean and across the Mediterranean Sea.[11][g] The largest part of Spain is situated on the Iberian Peninsula; its territory also includes the Canary Islands in the Atlantic Ocean, the Balearic Islands in the Mediterranean Sea, and the autonomous cities of Ceuta and Melilla in Africa. The country's mainland is bordered to the north by France, Andorra and the Bay of Biscay; to the east and south by the Mediterranean Sea and Gibraltar; and to the west by Portugal and the Atlantic Ocean. It is the largest country in Southern Europe and the second-largest and fourth-most populous in the European Union. Spain's capital and largest city is Madrid; other major urban areas include Barcelona, Valencia, Zaragoza, Seville, Málaga, Murcia, Palma de Mallorca, Las Palmas de Gran Canaria, and Bilbao."),
        ("Crown can be described as", "A crown is a traditional form of head adornment, or hat, worn by monarchs as a symbol of their power and dignity. A crown is often, by extension, a symbol of the monarch's government or items endorsed by it. The word itself is used, particularly in Commonwealth countries, as an abstract name for the monarchy itself, as distinct from the individual who inhabits it (that is, The Crown). A specific type of crown (or coronet for lower ranks of peerage) is employed in heraldry under strict rules. Indeed, some monarchies never had a physical crown, just a heraldic representation, as in the constitutional kingdom of Belgium."),
        ("Define destiny", "Destiny, sometimes also called fate (from Latin fatum 'decree, prediction, destiny, fate'), is a predetermined course of events.[1][2] It may be conceived as a predetermined future, whether in general or of an individual.")        
    ]
    for i in tqdm(torch.randperm(len(train))):
        q, a = train[i]
        q = f"Q: {q}\nA: "
        a = f'{{"response": "{a}"}}'
        l += train_step(q,a)
    return l

for e in tqdm(range(1)):
    print(step())



  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

4.038041547355515


In [14]:
from textwrap import wrap 
def pgen(q,n=20): print(">>>"+"\n".join(wrap(gen_up_to(f'Q: {q}\nA: ', n))))
pgen("I have 3 apples. I ate 1 apple. How many apples do I have left?")
pgen("Director of the movie Terminator 2")
pgen("Main male actor of Forest Gump")
pgen("Which fairytale had 7 dwarves?")
pgen("Which president of United States was assasninated?")
pgen("Name the most recent president of United States was assasninated.")
pgen("Define cat", n=100)
pgen("What is the best smartphone?", n=100)

>>>{"response": "0"}
>>>{"response": "James Cameron"}
>>>{"response": "Tom Hanks"}
>>>{"response": "Snow White"}
>>>{"response": "Abraham Lincoln"}
>>>{"response": "Lyndon Baines Johnson"}
>>>{"response": "A cat is a small, furry animal with a long tail and a
bushy head. They are very cute and cuddly, and are often kept as
pets."}
>>>{"response": "Samsung Galaxy S20."}


In [18]:
biased.segment_bias

Parameter containing:
tensor([-0.0125, -0.0045, -0.0062,  ...,  0.0028,  0.0267, -0.0164],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)

In [20]:

for z in biased.parameters():
    if z.requires_grad:
        print(z.numel())


3200
3200
