In [1]:
import os
import argparse
import random

import torch
from src.mola_peft_model_hacked import PeftModel
from transformers import GenerationConfig, LlamaTokenizer, AutoConfig
import sys
from src.mola_modeling_llama_hacked import LlamaForCausalLM_d

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

try:
    if torch.backends.mps.is_available():
        device = "mps"
except:  # noqa: E722
    pass

seed = 10
random.seed(seed)  # random seed
torch.manual_seed(0)

<torch._C.Generator at 0x7f0e38c60c30>

In [2]:
input_text = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nContext: Hint: People who can knit had to learn how to do it.\nQuestion: Is the following trait inherited or acquired?\nSasha is good at knitting hats.\nOptions: (A) acquired (B) inherited\n\n\n### Response:\n"

### MoLA config
base_model = "NousResearch/Llama-2-7b-hf"
mola_weights = "./scienceqa_mola"
number_experts = "2,2,2,2,2,2,2,2,4,4,4,4,4,4,4,4,6,6,6,6,6,6,6,6,8,8,8,8,8,8,8,8"
top_k = "2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2"
lora_target_modules = "q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj"

### Generation Config
temperature = 0.1
top_p = 0.75
top_k_g = 20

In [3]:
lora_target_modules = lora_target_modules.split(",")
lora_target_modules = [str(lr) for lr in lora_target_modules]
number_experts = number_experts.split(",")
number_experts = [int(lr) for lr in number_experts]
top_k = top_k.split(",")
top_k = [int(lr) for lr in top_k]

load_8bit = False

tokenizer = LlamaTokenizer.from_pretrained(base_model, padding_side='left')
config = AutoConfig.from_pretrained(base_model)
config.lora_target_modules = lora_target_modules
if device == "cuda":
    model = LlamaForCausalLM_d.from_pretrained(
        base_model,
        config=config,
        load_in_8bit=load_8bit,
        torch_dtype=torch.float16,
        device_map="auto",
    )
    model = PeftModel.from_pretrained(
        model,
        mola_weights,
        torch_dtype=torch.float16,
        number_experts=number_experts,
        top_k=top_k,
    )
else:
    model = LlamaForCausalLM_d.from_pretrained(
        base_model, config=config, device_map={"": device}, low_cpu_mem_usage=True
    )
    model = PeftModel.from_pretrained(
        model,
        mola_weights,
        device_map={"": device},
    )
obalance = False
model.get_new_parameters(number_experts, top_k, obalance)

print(model.config.pad_token_id, tokenizer.pad_token_id)
print(model.config.bos_token_id, tokenizer.bos_token_id)
print(model.config.eos_token_id, tokenizer.eos_token_id)
# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2

if not load_8bit:
    model.half()  # seems to fix bugs for some users.

model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
    model = torch.compile(model)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

TRAINING MOLA
load----- base_model.model.model.layers.0.self_attn.q_proj.router.router.weight
load----- base_model.model.model.layers.0.self_attn.k_proj.router.router.weight
load----- base_model.model.model.layers.0.self_attn.v_proj.router.router.weight
load----- base_model.model.model.layers.0.self_attn.o_proj.router.router.weight
load----- base_model.model.model.layers.0.mlp.gate_proj.router.router.weight
load----- base_model.model.model.layers.0.mlp.down_proj.router.router.weight
load----- base_model.model.model.layers.0.mlp.up_proj.router.router.weight
load----- base_model.model.model.layers.1.self_attn.q_proj.router.router.weight
load----- base_model.model.model.layers.1.self_attn.k_proj.router.router.weight
load----- base_model.model.model.layers.1.self_attn.v_proj.router.router.weight
load----- base_model.model.model.layers.1.self_attn.o_proj.router.router.weight
load----- base_model.model.model.layers.1.mlp.gate_proj.router.router.weight
load----- base_model.model.model.layers.

In [4]:
max_new_tokens = 128

generation_config = GenerationConfig(
    temperature=temperature,
    top_p=top_p,
    top_k=top_k_g,
)
inputs = tokenizer(input_text, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
with torch.no_grad():
    generation_output = model.generate(
        input_ids=input_ids,
        generation_config=generation_config,
        return_dict_in_generate=True,
        output_scores=True,
        max_new_tokens=max_new_tokens,
    )
s = generation_output.sequences[0]
output = tokenizer.decode(s)
print(output)

<s> Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Context: Hint: People who can knit had to learn how to do it.
Question: Is the following trait inherited or acquired?
Sasha is good at knitting hats.
Options: (A) acquired (B) inherited


### Response:
Answer: The answer is A.</s>
