In [1]:
import torch
import transformers


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
model_name = "Enoch/llama-7b-hf"
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name, device_map=device)
tokenizer.pad_token_id = tokenizer.eos_token_id

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [4]:
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    low_cpu_mem_usage=True, #загружает веса по кусочкам, минимизируя потребление RAM при старте
    offload_state_dict=True,# Если не хватит GPU, то часть модели перенесет на CPU
    load_in_4bit=True, #Квантование до 4 бит с помощью bitsandbytes
    torch_dtype=torch.float32 # нормировка по слоям и активации
)

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
Loading checkpoint shards: 100%|██████████| 33/33 [00:10<00:00,  3.19it/s]


In [5]:
# замораживаем веса модели
for param in model.parameters():
    param.requires_grad = False

In [6]:
"""
позволяет экономить память
почти не сохраняет активации в памяти, а пересчитывает их заново
"""
model.gradient_checkpointing_enable()

In [7]:
model.enable_input_require_grads()# отключаем градиенты для входных данных

In [8]:
prompt = "A quick brown fox"
batch = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False).to(device)


In [9]:
batch

{'input_ids': tensor([[    1,   319,  4996, 17354,  1701, 29916]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]], device='cuda:0')}

In [10]:
# # Параметры генерации
# max_length = 100         # Максимальная длина генерируемой последовательности
# num_beams = 5            # Количество "лучей" (beam search)
# early_stopping = True    # Остановка, если все лучи достигли EOS
# temperature=10
# # Генерация с использованием beam search
# outputs = model.generate(
#     temperature=temperature,
#     input_ids=batch['input_ids'],
#     attention_mask=batch['attention_mask'],
#     max_length=max_length,
#     num_beams=num_beams,
#     early_stopping=early_stopping,
#     eos_token_id=tokenizer.eos_token_id  # Указываем токен конца последовательности
# )




In [11]:
# # Декодирование и вывод результата
# print(tokenizer.decode(outputs[0], skip_special_tokens=False))

In [12]:
# while batch['input_ids'][0].cpu().tolist()[-1] != tokenizer.eos_token_id:
for _ in range(10):
    logits = model(**batch).logits
    last_state_for_first_batch = logits[0, -1]
    greedy_token = last_state_for_first_batch.argmax(-1)
    new_token = greedy_token.reshape(1,1)
    batch['input_ids'] = torch.cat([batch['input_ids'], new_token], dim=-1)
    batch['attention_mask'] = torch.cat([batch['attention_mask'], torch.ones_like(new_token)], dim=-1)


tokenizer.decode(batch['input_ids'][0].cpu().tolist())

'<s>A quick brown fox jumps over the lazy dog.\nA quick'

In [13]:
from torch.nn import functional as F
import torch.nn as nn

In [14]:
the_truth = "A quick brown fox jumps over the lazy dog. Besides, that dog deserved it anyway!"
batch = tokenizer(the_truth, return_tensors='pt', return_token_type_ids=False).to(device)
output = model(**batch)

next_word_logits = output.logits[:,:-1,:]
true_next_tokens = batch['input_ids'][:,1:]
loss = F.cross_entropy(next_word_logits.flatten(0,1), true_next_tokens.flatten(0,1))

loss

tensor(2.8630, device='cuda:0', grad_fn=<NllLossBackward0>)

In [15]:
class WordEmbeddingsWithLearnedPrompts(nn.Module):
    def __init__(self, word_embeddings: nn.Embedding, num_prompts: int):
        super().__init__()
        self.original_word_embeddings = word_embeddings
        self.num_prompts = num_prompts
        self.learneble_prompts = nn.Parameter(
            data=torch.randn(1, self.num_prompts, self.original_word_embeddings.embedding_dim),
            requires_grad=True
        )
        
    def forward(self, input_ids):
        assert input_ids.dtype == torch.int64
        assert input_ids.shape[1] > self.num_prompts 
        assert torch.all(input_ids[:, :self.num_prompts] == tokenizer.pad_token_id).item(), "Не забудьте добавть паддинги в начало последовательности для подстановки обучаемых параметров"
        
        original_embeddings = self.original_word_embeddings(input_ids)# получилось что [1, num_prompts + num_input_tokens, emb_dim]
        embedded_inputs_with_prompts = torch.cat(
            [self.learneble_prompts.expand(input_ids.shape[0], -1, -1),
            original_embeddings[:,self.num_prompts:]],
            dim=1
        )
        
        return embedded_inputs_with_prompts

In [16]:
num_prompts = 16
test_emb_layer = WordEmbeddingsWithLearnedPrompts(model.model.embed_tokens, num_prompts=num_prompts).to(device)
test_input_ids = tokenizer("a cat says on a may", return_tensors='pt')["input_ids"].to(device)

space_for_prompts = torch.full(
    [
        len(test_input_ids), num_prompts
    ],
    fill_value=tokenizer.pad_token_id,
    dtype=torch.int64,
    device=device
)
test_inputs_with_prompts = torch.cat([space_for_prompts, test_input_ids],dim=1)
with torch.amp.autocast('cuda'):
    test_prompt_embeddings = test_emb_layer(test_inputs_with_prompts)

assert test_prompt_embeddings.shape[:2] == test_inputs_with_prompts.shape
assert test_prompt_embeddings.shape[-1] == model.config.hidden_size
assert torch.allclose(test_prompt_embeddings[:,:num_prompts], test_emb_layer.learneble_prompts.float())
assert torch.allclose(test_prompt_embeddings[:,num_prompts:], model.model.embed_tokens(test_input_ids).float())
print("Looks legit!")

Looks legit!


In [17]:
assert isinstance(model.model.embed_tokens, nn.Embedding), "Вы уже заменили Embedding слой"

In [18]:
model.model.embed_tokens = WordEmbeddingsWithLearnedPrompts(model.model.embed_tokens, num_prompts=num_prompts).to(device)

opt = torch.optim.Adam([model.model.embed_tokens.learneble_prompts], lr=0.01)

In [19]:
the_truth = ["A quick brown fox did not jump over the lazy dog. Besides, that dog deserved it anyway!"]
batch = tokenizer(the_truth, return_tensors='pt', return_token_type_ids=False).to(device)


space_for_prompts = torch.full(
    size=[
        batch['input_ids'].size()[0],
        num_prompts
    ],
    fill_value=tokenizer.eos_token_id,
    dtype=torch.int64,
    device=device
)

input_ids_with_padding = torch.cat(
    [
        space_for_prompts,
        batch['input_ids']
    ],
    dim=1
)

batch['input_ids'] = input_ids_with_padding
batch['attention_mask'] = torch.cat(
    [
        torch.ones_like(space_for_prompts),
        batch['attention_mask']
    ],
    dim=1
)


EPOCHS = 100
model.train()
for epoch in range(EPOCHS):
    output = model(**batch)
    next_word_logits = output.logits[:, num_prompts : -1, :]
    true_next_token = batch['input_ids'][:,num_prompts + 1:]
    loss = F.cross_entropy(next_word_logits.flatten(0,1), true_next_token.flatten(0,1))
    print(f"Epoch: {epoch}, Loss: {loss}")
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    
assert loss.item() <= 0.1
print("Good job!")

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Epoch: 0, Loss: 7.224433422088623


  return fn(*args, **kwargs)


Epoch: 1, Loss: 6.526238918304443
Epoch: 2, Loss: 6.100164890289307
Epoch: 3, Loss: 5.700748920440674
Epoch: 4, Loss: 5.346858501434326
Epoch: 5, Loss: 5.039243698120117
Epoch: 6, Loss: 4.760209560394287
Epoch: 7, Loss: 4.498955726623535
Epoch: 8, Loss: 4.2563276290893555
Epoch: 9, Loss: 4.033092975616455
Epoch: 10, Loss: 3.8236215114593506
Epoch: 11, Loss: 3.620473623275757
Epoch: 12, Loss: 3.4202325344085693
Epoch: 13, Loss: 3.2247185707092285
Epoch: 14, Loss: 3.0371389389038086
Epoch: 15, Loss: 2.85799241065979
Epoch: 16, Loss: 2.685307741165161
Epoch: 17, Loss: 2.5175652503967285
Epoch: 18, Loss: 2.355043888092041
Epoch: 19, Loss: 2.1985714435577393
Epoch: 20, Loss: 2.0475263595581055
Epoch: 21, Loss: 1.900288462638855
Epoch: 22, Loss: 1.757215976715088
Epoch: 23, Loss: 1.6206257343292236
Epoch: 24, Loss: 1.4905178546905518
Epoch: 25, Loss: 1.3647634983062744
Epoch: 26, Loss: 1.2436720132827759
Epoch: 27, Loss: 1.1281546354293823
Epoch: 28, Loss: 1.0155348777770996
Epoch: 29, Loss:

In [20]:
model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): WordEmbeddingsWithLearnedPrompts(
      (original_word_embeddings): Embedding(32000, 4096, padding_idx=0)
    )
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): 

In [25]:
text = "A quick brown fox"
batch = tokenizer(text, return_tensors='pt', return_token_type_ids=False).to(device)
batch['input_ids'].size(), batch

(torch.Size([1, 6]),
 {'input_ids': tensor([[    1,   319,  4996, 17354,  1701, 29916]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]], device='cuda:0')})

In [26]:
space_for_prompts = torch.full(
    size=[
        batch['input_ids'].size()[0],
        num_prompts
    ],
    fill_value=tokenizer.eos_token_id,
    dtype=torch.int64,
    device=device
)

input_ids_with_padding = torch.cat(
    [
        space_for_prompts,
        batch['input_ids']
    ],
    dim=1
)

batch['input_ids'] = input_ids_with_padding
batch['attention_mask'] = torch.cat(
    [
        torch.ones_like(space_for_prompts),
        batch['attention_mask']
    ],
    dim=1
)


In [27]:
batch['input_ids'].size(), batch['attention_mask'].size()

(torch.Size([1, 22]), torch.Size([1, 22]))

In [28]:
for _ in range(30):
    output = model(**batch)
    logits = output.logits
    next_token = output.logits[0][-1].argmax(-1)
    batch["input_ids"] = torch.cat([batch["input_ids"], next_token.reshape(1,1)], dim=1)
    batch["attention_mask"] = torch.cat([batch["attention_mask"], torch.tensor(1, device=device).reshape(1,1)], dim=1)
    

In [29]:
print(tokenizer.decode(batch["input_ids"][0]), end=' ')

</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><s>A quick brown fox did not jump over the lazy dog. Besides, that dog deserved it anyway! Besides that dog deserved it anyway! That dog deserved it 