In [7]:
import argparse
import logging
import os
import torch

import numpy as np
from datasets import load_from_disk
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    BitsAndBytesConfig,
    AutoModel,
    AutoConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )


In [13]:
from comet import download_model, load_from_checkpoint

# Import CometKiwi Model
model_path = download_model("Unbabel/wmt22-cometkiwi-da")
model = load_from_checkpoint(model_path)

Fetching 5 files: 100%|██████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 819.58it/s]
INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.8.2 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../.cache/huggingface/hub/models--Unbabel--wmt22-cometkiwi-da/snapshots/b3a8aea5a5fc22db68a554b92b3d96eb6ea75cc9/checkpoints/model.ckpt`
/fs/classhomes/fall2024/cmsc723/c7230018/miniconda3/lib/python3.12/site-packages/pytorch_lightning/core/saving.py:195: Found keys that are not in the model state dict but in the checkpoint: ['encoder.model.embeddings.position_ids']


In [19]:
model = prepare_model_for_kbit_training(model)

In [77]:
print_trainable_parameters(model)

trainable params: 2359296 || all params: 567496731 || trainable%: 0.4157373727673508


In [None]:
def replace_qlora(module, name="<ROOT>"):
    '''
    Replace Linear layers with LoRALinear layers, recursively.
    '''
    for attr_str, _ in module.named_children():
        target_attr = getattr(module, attr_str)
        if type(target_attr) == torch.nn.Linear and "lora" not in attr_str:
            #print('replacing: ', name, attr_str)
            print(target_attr)
            
    for name, immediate_child_module in module.named_children():
        replace_qlora(immediate_child_module, name)

replace_qlora(model)

In [78]:
config = LoraConfig(
        r=16, 
        target_modules = ['query','key','value']
    )

model = get_peft_model(model, config)

In [74]:
model

PeftModel(
  (base_model): LoraModel(
    (model): PeftModel(
      (base_model): LoraModel(
        (model): UnifiedMetric(
          (encoder): XLMREncoder(
            (model): XLMRobertaModel(
              (embeddings): XLMRobertaEmbeddings(
                (word_embeddings): Embedding(250002, 1024, padding_idx=1)
                (position_embeddings): Embedding(514, 1024, padding_idx=1)
                (token_type_embeddings): Embedding(1, 1024)
                (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (encoder): XLMRobertaEncoder(
                (layer): ModuleList(
                  (0-23): 24 x XLMRobertaLayer(
                    (attention): XLMRobertaAttention(
                      (self): XLMRobertaSdpaSelfAttention(
                        (query): lora.Linear(
                          (base_layer): Linear(in_features=1024, out_features=1024, bias=True)
    

In [79]:
data = [
    {
        "src": "The output signal provides constant sync so the display never glitches.",
        "mt": "Das Ausgangssignal bietet eine konstante Synchronisation, so dass die Anzeige nie stört."
    },
    {
        "src": "Kroužek ilustrace je určen všem milovníkům umění ve věku od 10 do 15 let.",
        "mt": "Кільце ілюстрації призначене для всіх любителів мистецтва у віці від 10 до 15 років."
    },
    {
        "src": "Mandela then became South Africa's first black president after his African National Congress party won the 1994 election.",
        "mt": "その後、1994年の選挙でアフリカ国民会議派が勝利し、南アフリカ初の黒人大統領となった。"
    }
]
model_output = model.predict(data, batch_size=8, gpus=1)
print (model_output)

/fs/classhomes/fall2024/cmsc723/c7230018/miniconda3/lib/python3.12/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /fs/classhomes/fall2024/cmsc723/c7230018/miniconda3/ ...
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA RTX A4000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_mat

Prediction({'scores': [0.833168625831604, 0.7671145796775818, 0.8827177882194519], 'system_score': 0.8276669979095459})
