In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AdamW
from datamodule import WikiTextV2Datamodule
import os
import torch
import torch.nn.functional as F

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = "cuda:4"

target_model = AutoModelForCausalLM.from_pretrained("facebook/opt-1.3b").to(device)
draft_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m").to(device)
target_model.eval()
draft_model.train()

In [None]:
datamodule = WikiTextV2Datamodule(
    min_len=5,  
    max_len=12,
    target_model=target_model,
    device=device,
    batch_size=1, 
)
datamodule.setup(stage="fit")
train_loader = datamodule.train_dataloader()
optimizer = AdamW(
    [p for p in draft_model.parameters() if p.requires_grad],
    lr=0.001,
    weight_decay=0.001
)

epochs = 5
for epoch in range(5):
    running_loss = 0.0
    for batch in train_loader:
        optimizer.zero_grad()
        
        input_ids = batch["input_ids"]
        target_scores = batch["scores"]
        
        draft_outputs = draft_model(input_ids)
        draft_logits = draft_outputs.logits[:, -1, :]
        
        log_draft_probs = F.log_softmax(draft_logits, dim=-1)
        target_probs = F.softmax(target_scores, dim=-1)    
        
        loss = F.kl_div(log_draft_probs, target_probs, reduction='batchmean')
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")    

Loaded preprocessed data from cache




KeyboardInterrupt: 

In [None]:
import torch
import lightning as L
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.optim import AdamW
from torch.nn import functional as F
import argparse
from constants import TARGET_MODEL, DRAFT_MODEL, EPS
from helpers import get_distribution


class DraftModelFinetuner(L.LightningModule):
    def __init__(
        self,
        draft_model_name=DRAFT_MODEL,
        target_model_name=TARGET_MODEL,
        learning_rate=5e-5,
        weight_decay=0.01,
    ):
        super().__init__()
        
        self.save_hyperparameters()
    
        self.draft_model = AutoModelForCausalLM.from_pretrained(
            draft_model_name, 
            torch_dtype=torch.float16
        )
        
        self.target_model = AutoModelForCausalLM.from_pretrained(
            target_model_name, 
            torch_dtype=torch.float16
        )
        
        for param in self.target_model.parameters():
            param.requires_grad = False
        
        self.tokenizer = AutoTokenizer.from_pretrained(target_model_name)

        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        target_scores = batch["scores"]
        
        draft_outputs = draft_model(input_ids)
        draft_logits = draft_outputs.logits[:, -1, :]
        
        log_draft_probs = F.log_softmax(draft_logits, dim=-1)
        target_probs = F.softmax(target_scores, dim=-1)    
        
        loss = F.kl_div(log_draft_probs, target_probs, reduction='batchmean')
        self.log("loss", loss, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = AdamW(
            [p for p in self.draft_model.parameters() if p.requires_grad],
            lr=self.learning_rate,
            weight_decay=self.weight_decay
        )
        
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=10,
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler
        }

In [18]:
torch.cuda.set_device(4)
torch.cuda.current_device()

4

In [21]:
datamodule = WikiTextV2Datamodule(
    min_len=5,  
    max_len=12,
    target_model=target_model,
    device=device,
    batch_size=1, 
)
datamodule.setup(stage="fit")
torch.set_float32_matmul_precision('medium')

trainer = L.Trainer(
    accelerator="gpu", max_epochs=3, limit_train_batches=None, logger=False, devices=[4] # TensorBoardLogger(save_dir=".")
)

finetuner = DraftModelFinetuner()
trainer.fit(model=finetuner, datamodule=datamodule)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Loaded preprocessed data from cache


You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6]


Loaded preprocessed data from cache



  | Name         | Type           | Params | Mode
-------------------------------------------------------
0 | draft_model  | OPTForCausalLM | 125 M  | eval
1 | target_model | OPTForCausalLM | 1.3 B  | eval
-------------------------------------------------------
125 M     Trainable params
1.3 B     Non-trainable params
1.4 B     Total params
5,763.990 Total estimated model params size (MB)
0         Modules in train mode
412       Modules in eval mode
/home/amirelkanov/Fabula/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


In [None]:
torch.save(finetuner.state_dict(), 'model.pt')