# Models

> Fill in a module description here


In [None]:
#| default_exp models

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore.utils import *
from transformers import AutoModelForCausalLM
from torch import nn
from peft import *
import torch

In [None]:
#| export
class CausalLMModel(torch.nn.Module):
    def __init__(self, model_name_or_path):
        super(CausalLMModel, self).__init__()
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
        )

    def forward(self, input_id):
        output = self.model(input_ids=input_id, return_dict=False)
        return output

In [None]:
#| export
def delegate(to, *methods):
    def dec(klass):
        def create_delegator(method):
            def delegator(self, *args, **kwargs):
                obj = getattr(self, to)
                m = getattr(obj, method)
                return m(*args, **kwargs)
            return delegator
        for m in methods:
            setattr(klass, m, create_delegator(m))
        return klass
    return dec

In [None]:
#| export
class CausalLMPEFTModel(torch.nn.Module):
    def __init__(self, cfg):
        super(CausalLMPEFTModel, self).__init__()
        # PEFT configs
        peft_config = LoraConfig(
            lora_alpha=cfg.lora_alpha,
            lora_dropout=cfg.lora_dropout,
            target_modules=cfg.target_modules,
            r=cfg.r,
            bias="none",
            task_type="CAUSAL_LM",
        )
        base_model = AutoModelForCausalLM.from_pretrained(
            cfg.model,
            device_map="cpu",
            use_cache=False,
            torch_dtype=torch.bfloat16
        )
        self.model = get_peft_model(base_model, peft_config)

    def __getattr__(self, name):
        # Bypass __getattr__ when accessing self.model
        self.model = self.__dict__.get("model", None)
        if self.model and hasattr(self.model, name):
            return getattr(self.model, name)
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")


    def forward(self, input_id):
        output = self.model(input_ids=input_id, return_dict=False)
        return output

In [None]:
#| export
def get_model(cfg):
    model = AutoModelForCausalLM.from_pretrained(cfg.model,
                                                 trust_remote_code=True,
                                                 device_map='cpu',
                                                 )
    config = LoraConfig(
                    r=cfg.r,
                    target_modules=cfg.target_modules,
                    lora_alpha=cfg.lora_alpha,
                    lora_dropout=cfg.lora_dropout,
                    bias="none",
                    task_type="CAUSAL_LM",
                )
        
    model = get_peft_model(model, config)
    
    return model
    

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()