In [2]:
import os
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel


In [3]:
base_model_id = "/home/vamaj/scratch/TraWiC/llms/mistral"
adapter_id = os.path.join("/home/vamaj/scratch/TraWiC/llms/mistral_fim")
device_map = {"": 0}

In [4]:
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    device_map=device_map,
    local_files_only=True,
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:40<00:00, 20.25s/it]


In [5]:
adapter_model = PeftModel.from_pretrained(
    base_model,
    adapter_id,
)

In [6]:
# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    "/home/vamaj/scratch/TraWiC/llms/mistral",
    trust_remote_code=True,
    local_files_only=True,
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"  # Fix weird overflow issue with fp16 training

FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
FIM_SUFFIX = "<fim-suffix>"
FIM_PAD = "<fim-pad>"
EOD = "<|endoftext|>"

num_added_special_tokens = tokenizer.add_special_tokens(
    {
        "additional_special_tokens": [
            FIM_PREFIX,
            FIM_MIDDLE,
            FIM_SUFFIX,
            FIM_PAD,
        ],
    }
)

In [7]:
adapter_model.resize_token_embeddings(len(tokenizer))

Embedding(32004, 4096)

In [None]:
input_text = "<fim-prefix>def<fim-suffix>print(message)<fim-middle>"
encoded_input = tokenizer.encode(input_text, return_tensors="pt")
with torch.no_grad():
    outs = adapter_model.generate(input_ids=encoded_input.to("cuda"))
    if outs.dim() > 1:
        outs = outs[0]

    print(tokenizer.decode(outs, skip_special_tokens=True))
