In [4]:
from merge import MergedModel, MergedModelArguments
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from transformers import TrainingArguments
from sae_direction_alignment import AutoencoderMerged

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class CustomEvaluateArguments(TrainingArguments):
    def __init__(self, *args, multiplier=1, d_models=[], lambda_l1=1, lambda_cos=2, **kwargs):
        super().__init__(*args, **kwargs)
        self.multiplier = multiplier
        self.d_models = d_models
        self.lambda_l1 = lambda_l1
        self.lambda_cos = lambda_cos

def load_merged_model(model_path, cfg, models, layer, device):
    merged_model = AutoencoderMerged(models, cfg, layer=layer, device=device)
    merged_model.load_state_dict(torch.load(model_path, map_location=device))
    merged_model.eval()
    return merged_model
  
cfg = CustomEvaluateArguments(
    output_dir='./results',
    multiplier=70,
    d_models=[768, 768],
    lambda_l1=0.5,
    lambda_cos=0.5,
)

model_name1 = "Sharathhebbar24/math_gpt2_sft"
tokenizer1 = AutoTokenizer.from_pretrained(model_name1)
model1 = AutoModelForCausalLM.from_pretrained(model_name1).to(device)

model_name2 = "yoavgur/gpt2-bash-history-baseline"
tokenizer2 = AutoTokenizer.from_pretrained(model_name2)
model2 = AutoModelForCausalLM.from_pretrained(model_name2).to(device)

merged_model = load_merged_model('./results/model_epoch_-1_layer5_loss_88489494.0000.pt', cfg, [model1, model2], layer=-2, device=device)

mergedcfg=MergedModelArguments()

model=MergedModel([model1, model2], mergedcfg, device, [merged_model])

test_input = tokenizer1(["That beauty still may live in thine or thee. I", "This is another test"], return_tensors='pt', padding=True)['input_ids'].to(device)

ans=model.generate(test_input, max_tokens=30)

print(tokenizer1.decode(ans[0]))
print(tokenizer1.decode(ans[1]))