In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AdamW
from datamodule import WikiTextV2Datamodule
import os
from constants import TARGET_MODEL, DRAFT_MODEL

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = "cuda:4"

target_model = AutoModelForCausalLM.from_pretrained("facebook/opt-1.3b").to(device)
draft_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m").to(device)
target_model.eval()
draft_model.train()

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 768, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-11): 12 x OPTDecoderLayer(
          (self_attn): OPTSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,)

Попробуем обучиться на 1 батче:

In [4]:
import torch
import torch.nn.functional as F

datamodule = WikiTextV2Datamodule(
    min_len=5,  
    max_len=12,
    target_model=target_model,
    device=device,
    batch_size=1, 
)
datamodule.setup(stage="fit")
train_loader = datamodule.train_dataloader()
optimizer = AdamW(
    [p for p in draft_model.parameters() if p.requires_grad],
    lr=0.001,
    weight_decay=0.001
)

epochs = 5
batch = next(iter(train_loader))
for i in range(1000):
    optimizer.zero_grad()
        
    input_ids = batch["input_ids"]
    target_scores = batch["scores"]
    
    draft_outputs = draft_model(input_ids)
    draft_logits = draft_outputs.logits[:, -1, :]
    
    log_draft_probs = F.log_softmax(draft_logits, dim=-1)
    target_probs = F.softmax(target_scores, dim=-1)    
    
    loss = F.kl_div(log_draft_probs, target_probs, reduction='batchmean')
    
    loss.backward()
    optimizer.step()
    
    if (i + 1) % 5 == 0:
        print(loss.item())   

Loaded preprocessed data from cache




0.27268749475479126
0.12750867009162903
0.05827204883098602
0.047447118908166885
0.03230045363306999
0.02384386770427227
0.020718352869153023
0.016534434631466866
0.014508516527712345
0.009837430901825428
0.011079397983849049
0.013059280812740326
0.009622174315154552
0.01495021115988493
0.009480783715844154
0.011719456873834133
0.017221765592694283
0.009780880995094776
0.010229299776256084
0.008598558604717255
0.01047136727720499
0.008495763875544071
0.010267514735460281
0.00801050290465355
0.009800273925065994
0.012494172900915146
0.011017469689249992
0.00961147528141737
0.01980660855770111
0.011592146009206772
0.009667702950537205
0.01057096105068922
0.015266941860318184
0.010866181924939156
0.01014289353042841
0.01787448860704899
0.009766926057636738
0.01346069946885109
0.009608574211597443
0.010621866211295128
0.010807722806930542
0.006432997062802315
0.009688368067145348
0.008376557379961014
0.007480780594050884
0.007161600515246391
0.011216191574931145
0.015565186738967896
0.0062

In [None]:
from finetune_draft_model import DraftModelFinetuner 
import lightning as L

datamodule = WikiTextV2Datamodule(
    min_len=5,  
    max_len=12,
    target_model=target_model,
    device=device,
    batch_size=1, 
)
datamodule.setup(stage="fit")
torch.set_float32_matmul_precision('medium')

trainer = L.Trainer(
    accelerator="gpu", max_epochs=3, limit_train_batches=None, logger=False, devices=[4] # TensorBoardLogger(save_dir=".")
)

finetuner = DraftModelFinetuner()
trainer.fit(model=finetuner, datamodule=datamodule)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Loaded preprocessed data from cache


You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6]


Loaded preprocessed data from cache



  | Name         | Type           | Params | Mode
-------------------------------------------------------
0 | draft_model  | OPTForCausalLM | 125 M  | eval
1 | target_model | OPTForCausalLM | 1.3 B  | eval
-------------------------------------------------------
125 M     Trainable params
1.3 B     Non-trainable params
1.4 B     Total params
5,763.990 Total estimated model params size (MB)
0         Modules in train mode
412       Modules in eval mode
/home/amirelkanov/Fabula/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


In [None]:
torch.save(finetuner.state_dict(), 'model.pt')