In [1]:
import enum
import os
import peft
import rich
import torch
import transformers
import trl


Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
CUDA SETUP: CUDA runtime path found: /cvmfs/ai.mila.quebec/apps/arch/common/cuda/11.7/lib64/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 8.0
CUDA SETUP: Detected CUDA version 117
CUDA SETUP: Loading binary /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...


  warn(msg)


In [2]:
class ModelTypes(str, enum.Enum):
    CAUSAL_LM = "causal_lm"
    SEQ_2_SEQ_LM = "seq_2_seq_lm"

class ModelTokenizerPair:
    def __init__(
        self, 
        hf_name=None, 
        model_cls=None, 
        default_gen_kwargs=None, 
        device=int(os.getenv("LOCAL_RANK", "0")), 
        init_kwargs=None,
    ):

        if default_gen_kwargs is None:
            default_gen_kwargs = dict(
                max_new_tokens=100,
            )
        
        self._init_kwargs = init_kwargs
        self._peft_initialized = False
        self._trl_initialized = False
        self._default_gen_kwargs = default_gen_kwargs
        self.device = device

        if hf_name is not None:
            self.model = model_cls.from_pretrained(hf_name, **init_kwargs)
            self.tokenizer = transformers.AutoTokenizer.from_pretrained(hf_name)

            if "gpt" in hf_name.lower():
                self.tokenizer.pad_token = self.tokenizer.eos_token
                self._default_gen_kwargs["pad_token_id"] = self.model.config.eos_token_id
                self._model_type = ModelTypes.CAUSAL_LM

            else:
                assert "t5" in hf_name.lower()
                self._model_type = ModelTypes.SEQ_2_SEQ_LM
        
        self.m = self.model
        self.t = self.tokenizer
        

    def to(self, *args, **kwargs):
        self.model.to(*args, **kwargs)
        return self

    def cuda(self, *args, **kwargs):
        self.model.cuda(*args, **kwargs)
        return self

    def cpu(self, *args, **kwargs):
        self.model.cpu(*args, **kwargs)
        return self

    def gen_from_text(self, text, gen_kwargs, to_text=True):
        if gen_kwargs is None:
            gen_kwargs = {}

        tokenized = self.tokenizer(
            text, 
            return_tensors="pt", 
            padding=True, 
            truncation=True,
        ).to(self.device)
        
        output = self.generate(
            **tokenized,
            **gen_kwargs,
        )
        if self._model_type == ModelTypes.CAUSAL_LM:
            output = output[:, tokenized["input_ids"].shape[-1]:]
        
        if to_text:
            return self.tokenizer.batch_decode(output)
        
        return output

    def generate(self, *args, **gen_kwargs):
        if gen_kwargs is None:
            gen_kwargs = self._default_gen_kwargs
        else:
            gen_kwargs = self._default_gen_kwargs | gen_kwargs
        return self.model.generate(*args, **gen_kwargs)

    def text_to_text(self, text, gen_kwargs=None):
        return self.gen_from_text(text, gen_kwargs, to_text=True)

    def text_to_ids(self, text, gen_kwargs=None):
        return self.gen_from_text(text, gen_kwargs, to_text=False)

    def ids_to_text(self, model_inputs, gen_kwargs=None):
        if gen_kwargs is None:
            gen_kwargs = {}
        generated = self.generate(**model_inputs, **gen_kwargs)
        return self.tokenizer.batch_decode(generated)

    def init_peft(self, peft_config):
        self._peft_initialized = True
        rich.print(
            f"[red bold]init_trl: "
            f"{self._peft_initialized = } "
            f"{self._trl_initialized = }"
        )
        self.model = peft.get_peft_model(model=self.model, peft_config=peft_config)

    def init_trl(self):
        self._trl_initialized = True
        rich.print(
            f"[red bold]init_trl: "
            f"{self._peft_initialized = } "
            f"{self._trl_initialized = }"
        )
        if self._model_type == ModelTypes.CAUSAL_LM:
            trl_cls = trl.models.AutoModelForCausalLMWithValueHead

        elif self._model_type == ModelTypes.SEQ_2_SEQ_LM:
            trl_cls = trl.models.AutoModelForSeq2SeqLMWithValueHead

        self.model = trl_cls.from_pretrained(self.model)

    def __call__(self, *args, **kwds) -> torch.Tensor:
        return self.model(*args, **kwds)

    def forward_from_text(self, text: str, decoder_text = None) -> torch.Tensor:
        inputs = self.t(
            text, 
            padding        = True, 
            truncation     = True,
            return_tensors = "pt", 
        ).to(self.device)
        
        if not decoder_text is None:
            assert self._model_type == ModelTypes.SEQ_2_SEQ_LM, self._model_type

            decoder_inputs = self.t(
                decoder_text,
                padding        = True,
                truncation     = True,
                return_tensors = "pt",
            ).to(self.device)

            inputs = dict(
                input_ids              = inputs["input_ids"],
                attention_mask         = inputs["attention_mask"],
                decoder_input_ids      = decoder_inputs["input_ids"],
                decoder_attention_mask = decoder_inputs["attention_mask"],
            )
        else:
            assert self._model_type == ModelTypes.CAUSAL_LM, self._model_type

        for v in inputs.values():
            assert v.device.type == "cuda", v.device
            
        return self.model(**inputs)
    
    def train(self, *args, **kwargs):
        return self.model.train(*args, **kwargs)
    
    def training(self, *args, **kwargs):
        return self.model.training(*args, **kwargs)
    
    def eval(self, *args, **kwargs):
        return  self.model.eval(*args, **kwargs)

In [3]:
DTYPE = torch.bfloat16

if DTYPE in (torch.float16, torch.bfloat16):
    init_kwargs = dict(
        torch_dtype=DTYPE,
    )
elif DTYPE is None:
    init_kwargs = {}
else: 
    raise ValueError(f"Invalid DTYPE: {DTYPE}")


t5  = ModelTokenizerPair(
    hf_name="google/flan-t5-small", 
    model_cls=transformers.AutoModelForSeq2SeqLM, 
    init_kwargs=init_kwargs,
).cuda()

gpt = ModelTokenizerPair(
    hf_name="edbeeching/gpt-neo-125M-imdb-lora-adapter-merged", 
    model_cls=transformers.AutoModelForCausalLM,
    init_kwargs=init_kwargs,
).cuda()

In [4]:
shared_message = "What is the color of the sky?"

shared_peft_config = dict(
    lora_dropout=0.05,
    lora_alpha=32,
    r=16,
    bias="none",
)

############################################################
############################################################

causal_lm = peft.LoraConfig(
    task_type=peft.TaskType.CAUSAL_LM,
    **shared_peft_config,
)

seq2seq = peft.LoraConfig(
    task_type=peft.TaskType.SEQ_2_SEQ_LM,
    **shared_peft_config,
)


rich.print("[bold green]Nothing Applied")
print(gpt.text_to_text(shared_message))
print(t5 .text_to_text(shared_message))

t5 .init_peft(seq2seq)
gpt.init_peft(causal_lm)


t5 .to(t5.device)
gpt.to(t5.device)

if DTYPE in (torch.float16, torch.bfloat16):
    t5 .to(DTYPE)
    gpt.to(DTYPE)

rich.print("[bold green]Peft Applied")
print(gpt.text_to_text(shared_message))
print(t5 .text_to_text(shared_message))

t5 .init_trl()
t5.model.v_head.to(t5.device)
gpt.init_trl()
gpt.model.v_head.to(gpt.device)

if DTYPE in (torch.float16, torch.bfloat16):
    t5 .to(DTYPE)
    gpt.to(DTYPE)

rich.print("[bold green]Peft & Trl Applied")
gpt_generated = gpt.text_to_text(shared_message)
print(f"{gpt_generated = }")
t5_generated = t5.text_to_text(shared_message)
print(f"{t5_generated  = }")

rich.print("[bold green]Forward Pass")
gpt.train()
gpt.forward_from_text(shared_message)
t5 .train()
t5 .forward_from_text(shared_message, t5.text_to_text(shared_message))

[" I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know."]
['<pad> blue</s>']


[" I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know."]
['<pad> blue</s>']


gpt_generated = [" I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know. I don't know."]
t5_generated  = ['<pad> blue</s>']


(tensor([[[-43.0000,  -3.6406,  -9.0625,  ..., -43.0000, -43.0000, -43.0000],
          [-32.0000,   5.6562,  -0.4766,  ..., -31.7500, -32.0000, -31.5000],
          [-60.2500,  -6.2188,  -9.5625,  ..., -60.2500, -60.2500, -60.2500],
          ...,
          [-52.0000,   0.8984,  -4.6875,  ..., -52.0000, -52.2500, -51.7500],
          [-56.5000,  -0.6602, -10.0000,  ..., -56.5000, -56.2500, -56.2500],
          [-50.2500,   1.5234,  -8.8750,  ..., -50.0000, -50.2500, -49.7500]]],
        device='cuda:0', grad_fn=<ToCopyBackward0>),
 None,
 tensor([[ 0.3145,  0.2656, -0.4961,  0.0092,  0.0928,  0.1562, -0.0118]],
        device='cuda:0', dtype=torch.bfloat16, grad_fn=<SqueezeBackward1>))