In [None]:
!pip install transformers peft accelerate bitsandbytes -U --no-index --find-links /kaggle/input/lmsys-wheel-files
!pip install -q -U bitsandbytes --no-index --find-links /kaggle/input/lmsys-libraries

In [None]:
!pip install --no-deps --no-index /kaggle/input/hf-libraries/transformers/transformers-4.43.1-py3-none-any.whl

In [None]:
%%writefile gemma_inference.py

import time
from dataclasses import dataclass
from concurrent.futures import ThreadPoolExecutor
import json
import torch
import sklearn
import numpy as np
import pandas as pd
from transformers import  GemmaTokenizerFast, BitsAndBytesConfig, Gemma2ForCausalLM
from transformers.data.data_collator import pad_without_fast_tokenizer_warning
from peft import PeftModel

assert torch.cuda.device_count() == 2

@dataclass
class Config:
    gemma_dir = '/kaggle/input/googlegemma-2-9b-it'
    lora_dir = '/kaggle/input/gemma2-len-2200-all-train/checkpoint-5750'
    max_length = 1900
    batch_size = 4
    device = torch.device("cuda")    
    tta = False  # test time augmentation. <prompt>-<model-b's response>-<model-a's response>
    spread_max_length = False  # whether to apply max_length//3 on each input or max_length on the concatenated input

cfg = Config()

# Load & pre-process Data 


from tqdm import tqdm

test = pd.read_csv('/kaggle/input/lmsys-chatbot-arena/test.csv')

if len(test)<= 10:
    test = pd.read_csv('/kaggle/input/lmsys-chatbot-arena/train.csv', nrows=100)

def process(input_str):
    return json.loads(input_str)
original_length = len(test)
test.loc[:, 'prompt'] = test['prompt'].apply(process)
test.loc[:, 'response_a'] = test['response_a'].apply(process)
test.loc[:, 'response_b'] = test['response_b'].apply(process)

test = test.explode(['prompt','response_a','response_b']).reset_index(drop=True)
test = test.fillna('None')
test['response_a'] = test['response_a'].apply(lambda x: 'None' if len(x) == 0 else x)
test['response_b'] = test['response_b'].apply(lambda x: 'None' if len(x) == 0 else x)

def get_text_length(text):
    '''
    不用空格分隔的文本, text length = len
    不用空格分隔的一般tokenizer后长度类似，所以还可以缩小
    空格分隔的，len(text.split(" "))
    '''
    length1 = len(text)
    length2 = len(text.split(" "))
    #远超过
    if length1 >= length2 * 30 and length1>= 300:
        return length1 * 0.75
    return length2
    
def prompt_3(data, max_length, if_train):
    '''
    超过max length新开一行，label不变
    从后往前拼接
    #Prompt1
    xxxx
    #Response
    ##Model A
    xxxx
    ##Model B
    xxxx
    
    #Prompt2
    #Response
    ##Model A
    xxxx
    ##Model B
    xxxx
    '''

    data['prompt_response'] = "#Prompt\n" + data['prompt'] + "\n\n" + "#Response\n" + "##Model A\n" + data['response_a'] + "\n\n" + "##Model B\n" + data['response_b']
    data = data.iloc[::-1].reset_index(drop = True)#反转
    prompt_response = []
    ids = []
    labels = []
    #只有一种可能会超出max length：
    #单条的prompt和reponse加在一起超出max length
    over_max_length = [] #是否有超出max length的部分
    overflow_prompt = []
    overflow_response_a = [] #超出max length的部分
    overflow_response_b = [] #超出max length的部分
    text_length = 0
    for idx, row in tqdm(data.iterrows(), total=len(data)):
        text = row['prompt_response']
        response_a = row['response_a']
        response_b = row['response_b']
        prompt = row['prompt']
        id = row['id']
        
        if if_train:
            label = row['label']
        
        if id not in ids:
            #第一次出现
            prompt_response.append(text)
            text_length = get_text_length(text)
            ids.append(id)
            if if_train:
                labels.append(label)
            if text_length > max_length:
                over_max_length.append(1)
                overflow_prompt.append(prompt)
                overflow_response_a.append(response_a)
                overflow_response_b.append(response_b)
            else:
                over_max_length.append(0)
                overflow_prompt.append(None)
                overflow_response_a.append(None)
                overflow_response_b.append(None)
        
        else:
            text_length += get_text_length(text)
            if text_length <= max_length:
                #取上一个text出来，合并后替换
                text = text + "\n\n" + prompt_response[-1]
                prompt_response[-1] = text
                over_max_length[-1] = 0
                overflow_prompt[-1] = None
                overflow_response_a[-1] = None
                overflow_response_b[-1] = None
                
            else:
                #另一起一行
                prompt_response.append(text)
                text_length = get_text_length(text)
                ids.append(id)
                
                if if_train:
                    labels.append(label)
                    
                #另起一行但超出场合都
                if text_length > max_length:
                    over_max_length.append(1)
                    overflow_prompt.append(prompt)
                    overflow_response_a.append(response_a)
                    overflow_response_b.append(response_b)
                else:
                    over_max_length.append(0)
                    overflow_prompt.append(None)
                    overflow_response_a.append(None)
                    overflow_response_b.append(None)
                    
                
                    
    if if_train:           
        data = pd.DataFrame({'id': ids, 'prompt_response': prompt_response, "label": labels, 'overflow_prompt': overflow_prompt, 'over_max_length': over_max_length, 'overflow_response_a': overflow_response_a, 'overflow_response_b': overflow_response_b})
        data = data.iloc[::-1].reset_index(drop = True)#反转
    else:
        data = pd.DataFrame({'id': ids, 'prompt_response': prompt_response, 'over_max_length': over_max_length, 'overflow_prompt': overflow_prompt, 'overflow_response_a': overflow_response_a, 'overflow_response_b': overflow_response_b})
        data = data.iloc[::-1].reset_index(drop = True)#反转
    return data

test = prompt_3(test, cfg.max_length * 0.75, False)
test = test.drop_duplicates(subset=['id'], keep='last').reset_index(drop=True)
assert len(test) == original_length

# tokenize

def tokenize(tokenizer, data):
    prompts = []
    for i in tqdm(range(len(data))):
        now_data = data.loc[i]
        idx = now_data['id']
        
        over_max_length = now_data['over_max_length']
        templete_part1 = "<start_of_turn>user\nHere are two question-answering dialogues. Compare two model performance on answering question, determine which is better.\n\n"
        templete_part1_input_ids = tokenizer(text=templete_part1, add_special_tokens=True, padding=False)['input_ids']

        templete_part2 = "\n###options\nA. Model A\nB. Model B\nC. Tie\n<end_of_turn>\n"
        templete_part2_input_ids = tokenizer(text=templete_part2, add_special_tokens=True, padding=False)['input_ids'][1:]

        templete_part3 = "<start_of_turn>model\n"
        templete_part3_input_ids = tokenizer(text=templete_part3, add_special_tokens=True, padding=False)['input_ids'][1:]

        templete_part4_input_ids = tokenizer(text="\n\n", add_special_tokens=False, padding=False)['input_ids']

        if over_max_length:
            prompt = "#Prompt\n" + now_data['overflow_prompt']
            r_a = "#Response\n" + "##Model A\n" + now_data['overflow_response_a']
            r_b = "##Model B\n" + now_data['overflow_response_b']

            prompt_ids = tokenizer(text=prompt, add_special_tokens=False, truncation=False, padding=False)['input_ids']
            model_a_input_ids = tokenizer(text=r_a, add_special_tokens=False, truncation=False, padding=False)['input_ids']
            model_b_input_ids = tokenizer(text=r_b, add_special_tokens=False, truncation=False, padding=False)['input_ids']

            if len(prompt_ids) + len(model_a_input_ids) + len(model_b_input_ids) <= cfg.max_length:
                prompt_response_ids = prompt_ids + model_a_input_ids + model_b_input_ids

            else:
                '''
                prompt 和 response 按照 300， 800， 800
                response 优先
                多的再给prompt
                '''
                length = [len(prompt_ids), len(model_a_input_ids), len(model_b_input_ids)]
                print(f"before {len(prompt_ids) + len(model_a_input_ids) + len(model_b_input_ids)}")
                print(f"before {length}")
                prompt_max_length, a_max_length, b_max_length = adjust(length)

                prompt_ids = prompt_ids[:prompt_max_length] + templete_part4_input_ids
                model_a_input_ids = model_a_input_ids[:a_max_length] + templete_part4_input_ids
                model_b_input_ids = model_b_input_ids[:b_max_length] + templete_part4_input_ids

                prompt_response_ids = prompt_ids + model_a_input_ids + model_b_input_ids
                print(f"after {[prompt_max_length, a_max_length, b_max_length]}")
                print(f"after {len(prompt_response_ids)}")

        else:
            prompt_response = now_data['prompt_response']
            prompt_response_ids = tokenizer(text=prompt_response, add_special_tokens=True, truncation=True, max_length=cfg.max_length, padding=False)['input_ids'][1:]    

        input_ids = templete_part1_input_ids + prompt_response_ids + templete_part2_input_ids + templete_part3_input_ids
        input_text = tokenizer.decode(input_ids[1:], skip_special_tokens=False)
        if i == 0:
            print(input_text)
        prompts.append(input_text)
    tokenized = tokenizer(prompts)
    input_ids = tokenized.input_ids
    attention_mask = tokenized.attention_mask
    return input_ids, attention_mask 

def adjust_values(A, B, a_space, b_space, ex_space):
    # 计算A和a_space的差值
    a_diff = a_space - A
    b_diff = b_space - B
    
    # 第一种情况：A小于a_space，B小于b_space
    if A < a_space and B < b_space:
        ex_space += a_diff + b_diff
        return A, B, ex_space

    # 第二种情况：如果A和B都各自大于自己的space
    elif A > a_space and B > b_space:
        total_extra_needed = (A - a_space) + (B - b_space)
        if total_extra_needed > ex_space:
            A = int(a_space + ex_space / 2)
            B = int(b_space + ex_space / 2)
            ex_space = 0
        else:
            a_space = A
            b_space = B
            ex_space -= total_extra_needed
            
        return A, B, ex_space
        
    # 第三种情况：A或者B其中有一个大于a_space, b_space
    elif A >= a_space or B >= b_space:
        # 如果A大于a_space但是B小于b_space
        if A >= a_space and B <= b_space:
            extra_needed = A - a_space
            ex_space += b_space - B
            #够用
            if ex_space >= extra_needed:
                ex_space -= extra_needed
                
            else:
                #不够用
                #b_space = B + available_space
                A = a_space + ex_space
                ex_space = 0

        # 如果B大于b_space但是A小于a_space
        elif B > b_space and A < a_space:
            extra_needed = B - b_space
            ex_space += a_space - A
            
            if ex_space >= extra_needed:
                ex_space -= extra_needed
                
            else:
                B = b_space + ex_space
                ex_space = 0

        return A, B, ex_space
    

def adjust(current_lengths, prompt_length_space=300, response_length_space=800):
    prompt_length = current_lengths[0]
    response_a_length = current_lengths[1]
    response_b_length = current_lengths[2]
    #先看prompt的额度
    ex_space = max(0, prompt_length_space - prompt_length)
    response_a_length, response_b_length, ex_space = adjust_values(response_a_length, response_b_length, response_length_space, response_length_space, ex_space)
    prompt_length = min(prompt_length, prompt_length_space)
    prompt_length += ex_space

    return prompt_length, response_a_length, response_b_length


tokenizer = GemmaTokenizerFast.from_pretrained(cfg.gemma_dir)
# tokenizer.add_eos_token = True
tokenizer.padding_side = "left"

data = pd.DataFrame()
data["id"] = test["id"]
data["input_ids"], data["attention_mask"] = tokenize(tokenizer, test)
data["length"] = data["input_ids"].apply(len)


print(tokenizer.decode(data["input_ids"][0]))

from transformers.cache_utils import Cache, DynamicCache, StaticCache
from typing import List, Optional, Tuple, Union
from transformers.utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    replace_return_docstrings
)
from transformers.modeling_outputs import (
    CausalLMOutputWithPast
)

class CustomGemma2ForCausalLM(Gemma2ForCausalLM):
    _tied_weights_keys = ["lm_head.weight"]
    
    def __init__(self, config):
        super().__init__(config)
        
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        if self.training and self.config._attn_implementation != "eager":
            logger.warning_once(
                "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
                f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
            )
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0][:,-1]
        logits = self.lm_head(hidden_states)
        if self.config.final_logit_softcapping is not None:
            logits = logits / self.config.final_logit_softcapping
            logits = torch.tanh(logits)
            logits = logits * self.config.final_logit_softcapping

        logits = logits.float()
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    
# Load base model on GPU 0
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,  
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)
device_0 = torch.device('cuda:0')
model_0 = CustomGemma2ForCausalLM.from_pretrained(
    cfg.gemma_dir,
    device_map=device_0,
    use_cache=False,
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
)

# Load base model on GPU 1

device_1 = torch.device('cuda:1')
model_1 = CustomGemma2ForCausalLM.from_pretrained(
    cfg.gemma_dir,
    device_map=device_1,
    use_cache=False,
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
)

model_0 = PeftModel.from_pretrained(model_0, cfg.lora_dir)
model_1 = PeftModel.from_pretrained(model_1, cfg.lora_dir)

model_0.eval()
model_1.eval()

@torch.no_grad()
@torch.cuda.amp.autocast()
def inference_test(df, model, device, batch_size=cfg.batch_size):
    a_win, b_win, tie = [], [], []
    
    for start_idx in range(0, len(df), batch_size):
        end_idx = min(start_idx + batch_size, len(df))
        tmp = df.iloc[start_idx:end_idx]
        input_ids = tmp["input_ids"].to_list()
        attention_mask = tmp["attention_mask"].to_list()
        inputs = pad_without_fast_tokenizer_warning(
            tokenizer,
            {"input_ids": input_ids, "attention_mask": attention_mask},
            padding="longest",
            pad_to_multiple_of=None,
            return_tensors="pt",
        )

        outputs = model(**inputs.to(device))
    
    return outputs

A_TOKEN_IDS = tokenizer('A',add_special_tokens=True, truncation=True, max_length=1024)['input_ids'][1:]
B_TOKEN_IDS = tokenizer('B',add_special_tokens=True, truncation=True, max_length=1024)['input_ids'][1:]
C_TOKEN_IDS = tokenizer('C',add_special_tokens=True, truncation=True, max_length=1024)['input_ids'][1:]

@torch.no_grad()
@torch.cuda.amp.autocast()
def inference(df, model, device, batch_size=cfg.batch_size):
    a_win, b_win, tie = [], [], []
    
    for start_idx in range(0, len(df), batch_size):
        end_idx = min(start_idx + batch_size, len(df))
        tmp = df.iloc[start_idx:end_idx]
        input_ids = tmp["input_ids"].to_list()
        attention_mask = tmp["attention_mask"].to_list()
        inputs = pad_without_fast_tokenizer_warning(
            tokenizer,
            {"input_ids": input_ids, "attention_mask": attention_mask},
            padding="longest",
            pad_to_multiple_of=None,
            return_tensors="pt",
        )
        outputs = model(**inputs.to(device))
        proba = ((outputs.logits.cpu()[:,A_TOKEN_IDS + B_TOKEN_IDS + C_TOKEN_IDS]) / 1.03).softmax(-1)
        
        a_win.extend(proba[:, 0].tolist())
        b_win.extend(proba[:, 1].tolist())
        tie.extend(proba[:, 2].tolist())
    
    df["winner_model_a"] = a_win
    df["winner_model_b"] = b_win
    df["winner_tie"] = tie
    
    return df


st = time.time()

# sort by input length to fully leverage dynaminc padding
data = data.sort_values("length", ascending=False)
# the total #tokens in sub_1 and sub_2 should be more or less the same
sub_1 = data.iloc[0::2].copy()
sub_2 = data.iloc[1::2].copy()

with ThreadPoolExecutor(max_workers=2) as executor:
    results = executor.map(inference, (sub_1, sub_2), (model_0, model_1), (device_0, device_1))

result_df = pd.concat(list(results), axis=0)
proba = result_df[["winner_model_a", "winner_model_b", "winner_tie"]].values

print(f"elapsed time: {time.time() - st}")

result_df.loc[:, "winner_model_a"] = proba[:, 0]
result_df.loc[:, "winner_model_b"] = proba[:, 1]
result_df.loc[:, "winner_tie"] = proba[:, 2]
submission_df = result_df[["id", 'winner_model_a', 'winner_model_b', 'winner_tie']]
submission_df = submission_df.sort_values('id')
submission_df.to_csv('submission_gemma.csv', index=False)

In [None]:
%%writefile llama_inference.py

import time
from dataclasses import dataclass
from concurrent.futures import ThreadPoolExecutor
import json
import torch
import sklearn
import numpy as np
import pandas as pd
from transformers import  GemmaTokenizerFast, BitsAndBytesConfig, Gemma2ForCausalLM, AutoTokenizer, AutoModelForCausalLM
from transformers.data.data_collator import pad_without_fast_tokenizer_warning
from peft import PeftModel
from tqdm import tqdm

@dataclass
class Config:
    gemma_dir = '/kaggle/input/llama31instruct/Meta-Llama-3.1-8B-Instruct'
    lora_dir = '/kaggle/input/llama3-1-all-data/checkpoint-5768'
    max_length = 2400
    batch_size = 4
    device = torch.device("cuda")    
    tta = False  # test time augmentation. <prompt>-<model-b's response>-<model-a's response>
    spread_max_length = False  # whether to apply max_length//3 on each input or max_length on the concatenated input

cfg = Config()

def get_data(path='/kaggle/input/lmsys-chatbot-arena/test.csv', reverse=False):
    
    test = pd.read_csv(path)
    
    if len(test)<= 10:
        test = pd.read_csv('/kaggle/input/lmsys-chatbot-arena/train.csv', nrows=100)

        
    if reverse:
        test['response_a'], test['response_b'] = test['response_b'], test['response_a']
    
    def process(input_str):
        return json.loads(input_str)
    
    original_length = len(test)
    test.loc[:, 'prompt'] = test['prompt'].apply(process)
    test.loc[:, 'response_a'] = test['response_a'].apply(process)
    test.loc[:, 'response_b'] = test['response_b'].apply(process)

    test = test.explode(['prompt','response_a','response_b']).reset_index(drop=True)
    test = test.fillna('None')
    test['response_a'] = test['response_a'].apply(lambda x: 'None' if len(x) == 0 else x)
    test['response_b'] = test['response_b'].apply(lambda x: 'None' if len(x) == 0 else x)
 
    def prompt_3(data, max_length, if_train):
        '''
        超过max length新开一行，label不变
        从后往前拼接
        #Prompt1
        xxxx
        #Response
        ##Model A
        xxxx
        ##Model B
        xxxx

        #Prompt2
        #Response
        ##Model A
        xxxx
        ##Model B
        xxxx
        '''

        data['prompt_response'] = "#Prompt\n" + data['prompt'] + "\n\n" + "#Response\n" + "##Model A\n" + data['response_a'] + "\n\n" + "##Model B\n" + data['response_b']
        data = data.iloc[::-1].reset_index(drop=True)#反转
        prompt_response = []
        ids = []
        labels = []
        #只有一种可能会超出max length：
        #单条的prompt和reponse加在一起超出max length
        over_max_length = [] #是否有超出max length的部分
        overflow_prompt = []
        overflow_response_a = [] #超出max length的部分
        overflow_response_b = [] #超出max length的部分
        text_length = 0
        for idx, row in tqdm(data.iterrows(), total=len(data)):
            text = row['prompt_response']
            response_a = row['response_a']
            response_b = row['response_b']
            prompt = row['prompt']
            id = row['id']

            if if_train:
                label = row['label']

            if id not in ids:
                #第一次出现
                prompt_response.append(text)
                text_length = len(text.split(" "))
                ids.append(id)
                if if_train:
                    labels.append(label)
                if text_length > max_length:
                    over_max_length.append(1)
                    overflow_prompt.append(prompt)
                    overflow_response_a.append(response_a)
                    overflow_response_b.append(response_b)
                else:
                    over_max_length.append(0)
                    overflow_prompt.append(None)
                    overflow_response_a.append(None)
                    overflow_response_b.append(None)

            else:
                text_length += len(text.split(" "))
                if text_length <= max_length:
                    #取上一个text出来，合并后替换
                    text = text + "\n\n" + prompt_response[-1]
                    prompt_response[-1] = text
                    over_max_length[-1] = 0
                    overflow_prompt[-1] = None
                    overflow_response_a[-1] = None
                    overflow_response_b[-1] = None

                else:
                    #另一起一行
                    prompt_response.append(text)
                    text_length = len(text.split(" "))
                    ids.append(id)

                    if if_train:
                        labels.append(label)

                    #另起一行但超出场合都
                    if text_length > max_length:
                        over_max_length.append(1)
                        overflow_prompt.append(prompt)
                        overflow_response_a.append(response_a)
                        overflow_response_b.append(response_b)
                    else:
                        over_max_length.append(0)
                        overflow_prompt.append(None)
                        overflow_response_a.append(None)
                        overflow_response_b.append(None)



        if if_train:           
            data = pd.DataFrame({'id': ids, 'prompt_response': prompt_response, "label": labels, 'overflow_prompt': overflow_prompt, 'over_max_length': over_max_length, 'overflow_response_a': overflow_response_a, 'overflow_response_b': overflow_response_b})
            data = data.iloc[::-1].reset_index(drop = True)#反转
        else:
            data = pd.DataFrame({'id': ids, 'prompt_response': prompt_response, 'over_max_length': over_max_length, 'overflow_prompt': overflow_prompt, 'overflow_response_a': overflow_response_a, 'overflow_response_b': overflow_response_b})
            data = data.iloc[::-1].reset_index(drop = True)#反转
        return data
    test = prompt_3(test, cfg.max_length * 0.75, False)
    test = test.drop_duplicates(subset = ['id'], keep ='last').reset_index(drop = True)
    assert len(test) == original_length
    return test


# test = get_data()
test = get_data(reverse = True)

def tokenize(tokenizer, data):
    prompts = []
    for i in tqdm(range(len(data))):
        now_data = data.loc[i]
        idx = now_data['id']
        
        over_max_length = now_data['over_max_length']
        templete_part1 = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHere are two question-answering dialogues. Compare two model performance on answering question, determine which is better.\n\n"
        templete_part1_input_ids = tokenizer(text=templete_part1, add_special_tokens=True, padding=False)['input_ids']

        templete_part2 = "\n###options\nA. Model A\nB. Model B\nC. Tie\n<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
        templete_part2_input_ids = tokenizer(text=templete_part2, add_special_tokens=False, padding=False)['input_ids']

        templete_part4_input_ids = tokenizer(text="\n\n", add_special_tokens=False, padding=False)['input_ids']

        if over_max_length:
            prompt = "#Prompt\n" + now_data['overflow_prompt']
            r_a = "#Response\n" + "##Model A\n" + now_data['overflow_response_a']
            r_b = "##Model B\n" + now_data['overflow_response_b']

            prompt_ids = tokenizer(text=prompt, add_special_tokens=False, truncation=False, padding=False)['input_ids']
            model_a_input_ids = tokenizer(text=r_a, add_special_tokens=False, truncation=False, padding=False)['input_ids']
            model_b_input_ids = tokenizer(text=r_b, add_special_tokens=False, truncation=False, padding=False)['input_ids']

            if len(prompt_ids) + len(model_a_input_ids) + len(model_b_input_ids) <= cfg.max_length:
                prompt_response_ids = prompt_ids + model_a_input_ids + model_b_input_ids

            else:
                '''
                prompt 和 response 按照 300， 800， 800
                response 优先
                多的再给prompt
                '''
                length = [len(prompt_ids), len(model_a_input_ids), len(model_b_input_ids)]
                prompt_max_length, a_max_length, b_max_length = adjust(length)

                prompt_ids = prompt_ids[:prompt_max_length] + templete_part4_input_ids
                model_a_input_ids = model_a_input_ids[:a_max_length] + templete_part4_input_ids
                model_b_input_ids = model_b_input_ids[:b_max_length] + templete_part4_input_ids

                prompt_response_ids = prompt_ids + model_a_input_ids + model_b_input_ids

        else:
            prompt_response = now_data['prompt_response']
            prompt_response_ids = tokenizer(text=prompt_response, add_special_tokens=False, truncation=True, max_length=cfg.max_length, padding=False)['input_ids']   

        input_ids = templete_part1_input_ids + prompt_response_ids + templete_part2_input_ids
        input_text = tokenizer.decode(input_ids[:], skip_special_tokens=False)
        if i == 0:
            print(input_text)
        prompts.append(input_text)
    tokenized = tokenizer(prompts)
    input_ids = tokenized.input_ids
    attention_mask = tokenized.attention_mask
    return input_ids, attention_mask 

def adjust_values(A, B, a_space, b_space, ex_space):
    # 计算A和a_space的差值
    a_diff = a_space - A
    b_diff = b_space - B
    
    # 第一种情况：A小于a_space，B小于b_space
    if A < a_space and B < b_space:
        ex_space += a_diff + b_diff
        return A, B, ex_space

    # 第二种情况：如果A和B都各自大于自己的space
    elif A > a_space and B > b_space:
        total_extra_needed = (A - a_space) + (B - b_space)
        if total_extra_needed > ex_space:
            A = int(a_space + ex_space / 2)
            B = int(b_space + ex_space / 2)
            ex_space = 0
        else:
            a_space = A
            b_space = B
            ex_space -= total_extra_needed
            
        return A, B, ex_space
        
    # 第三种情况：A或者B其中有一个大于a_space, b_space
    elif A >= a_space or B >= b_space:
        # 如果A大于a_space但是B小于b_space
        if A >= a_space and B <= b_space:
            extra_needed = A - a_space
            ex_space += b_space - B
            #够用
            if ex_space >= extra_needed:
                ex_space -= extra_needed
                
            else:
                #不够用
                #b_space = B + available_space
                A = a_space + ex_space
                ex_space = 0

        # 如果B大于b_space但是A小于a_space
        elif B > b_space and A < a_space:
            extra_needed = B - b_space
            ex_space += a_space - A
            
            if ex_space >= extra_needed:
                ex_space -= extra_needed
                
            else:
                B = b_space + ex_space
                ex_space = 0

        return A, B, ex_space
    

def adjust(current_lengths, prompt_length_space=300, response_length_space=800):
    prompt_length = current_lengths[0]
    response_a_length = current_lengths[1]
    response_b_length = current_lengths[2]
    #先看prompt的额度
    ex_space = max(0, prompt_length_space - prompt_length)
    response_a_length, response_b_length, ex_space = adjust_values(response_a_length, response_b_length, response_length_space, response_length_space, ex_space)
    prompt_length = min(prompt_length, prompt_length_space)
    prompt_length += ex_space

    return prompt_length, response_a_length, response_b_length

tokenizer = AutoTokenizer.from_pretrained(cfg.gemma_dir, trust_remote_code=True, truncation_side='left')
tokenizer.pad_token_id = tokenizer.eos_token_id

aug_data = pd.DataFrame()
aug_data["id"] = test["id"]
# swap response_a & response_b
aug_data['input_ids'], aug_data['attention_mask'] = tokenize(tokenizer, test)
aug_data["length"] = aug_data["input_ids"].apply(len)

print(tokenizer.decode(aug_data["input_ids"][0]))

from transformers.cache_utils import Cache, DynamicCache, StaticCache
from typing import List, Optional, Tuple, Union
from transformers.utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    replace_return_docstrings
)
from transformers.modeling_outputs import (
    CausalLMOutputWithPast
)

from transformers import LlamaForCausalLM
class CustomLlamaForCausalLM(LlamaForCausalLM):
    _tied_weights_keys = ["lm_head.weight"]
    
    def __init__(self, config):
        super().__init__(config)
        
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )
        
        hidden_states = outputs[0][:,-1]
#         hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    
    
    
# Load base model on GPU 0
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,  
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)
device_0 = torch.device('cuda:0')
model_0 = CustomLlamaForCausalLM.from_pretrained(
    cfg.gemma_dir,
    device_map=device_0,
    use_cache=False,
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
)
print(tokenizer.pad_token_id)
model_0.config.pad_token_id = tokenizer.pad_token_id
model_0.resize_token_embeddings(len(tokenizer))

# Load base model on GPU 1

device_1 = torch.device('cuda:1')
model_1 = CustomLlamaForCausalLM.from_pretrained(
    cfg.gemma_dir,
    device_map=device_1,
    use_cache=False,
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
)
model_1.config.pad_token_id = tokenizer.pad_token_id
#model_1.resize_token_embeddings(len(tokenizer))


model_0 = PeftModel.from_pretrained(model_0, cfg.lora_dir)
model_1 = PeftModel.from_pretrained(model_1, cfg.lora_dir)

model_0.eval()
model_1.eval()

A_TOKEN_IDS = tokenizer('A',add_special_tokens=False, truncation=True, max_length=1024)['input_ids']
B_TOKEN_IDS = tokenizer('B',add_special_tokens=False, truncation=True, max_length=1024)['input_ids']
C_TOKEN_IDS = tokenizer('C',add_special_tokens=False, truncation=True, max_length=1024)['input_ids']
print(A_TOKEN_IDS, B_TOKEN_IDS, C_TOKEN_IDS)

@torch.no_grad()
@torch.cuda.amp.autocast()
def inference(df, model, device, batch_size=cfg.batch_size):
    a_win, b_win, tie = [], [], []
    
    for start_idx in range(0, len(df), batch_size):
        end_idx = min(start_idx + batch_size, len(df))
        tmp = df.iloc[start_idx:end_idx]
        input_ids = tmp["input_ids"].to_list()
        attention_mask = tmp["attention_mask"].to_list()
        inputs = pad_without_fast_tokenizer_warning(
            tokenizer,
            {"input_ids": input_ids, "attention_mask": attention_mask},
            padding="longest",
            pad_to_multiple_of=None,
            return_tensors="pt",
        )
            
        outputs = model(**inputs)
        # proba = ((outputs.logits.cpu()[:, A_TOKEN_IDS + B_TOKEN_IDS + C_TOKEN_IDS])).softmax(-1)
        proba = ((outputs.logits.cpu()[:, A_TOKEN_IDS + B_TOKEN_IDS + C_TOKEN_IDS]) / 1.03).softmax(-1)
        a_win.extend(proba[:, 0].tolist())
        b_win.extend(proba[:, 1].tolist())
        tie.extend(proba[:, 2].tolist())
    
    df["winner_model_a"] = a_win
    df["winner_model_b"] = b_win
    df["winner_tie"] = tie
    
    return df

st = time.time()

data = aug_data.sort_values("length", ascending=False)  # sort by input length to boost speed
sub_1 = data.iloc[0::2].copy()
sub_2 = data.iloc[1::2].copy()

with ThreadPoolExecutor(max_workers=2) as executor:
    results = executor.map(inference, (sub_1, sub_2), (model_0, model_1), (device_0, device_1))

tta_result_df = pd.concat(list(results), axis=0)
tta_result_df = tta_result_df.sort_values('id')

proba = tta_result_df[["winner_model_a", "winner_model_b", "winner_tie"]].values


tta_result_df.loc[:, "winner_model_a"] = proba[:, 0]
tta_result_df.loc[:, "winner_model_b"] = proba[:, 1]
tta_result_df.loc[:, "winner_tie"] = proba[:, 2]
submission_df = tta_result_df[["id", 'winner_model_a', 'winner_model_b', 'winner_tie']]
submission_df = submission_df.sort_values('id')
submission_df.to_csv('submission_llama_tta.csv', index=False)

print(f"elapsed time: {time.time() - st}")

In [None]:
%%writefile gemma2b_inference.py

import time
from dataclasses import dataclass
from concurrent.futures import ThreadPoolExecutor
import json
import torch
import sklearn
import numpy as np
import pandas as pd
from transformers import  GemmaTokenizerFast, BitsAndBytesConfig, Gemma2ForCausalLM
from transformers.data.data_collator import pad_without_fast_tokenizer_warning
from peft import PeftModel

assert torch.cuda.device_count() == 2

@dataclass
class Config:
    gemma_dir = '/kaggle/input/gemma-2/transformers/gemma-2-2b-it/1'
    lora_dir = '/kaggle/input/exp58-gemma2b-9143/checkpoint-5400'
    max_length = 2200
    batch_size = 4
    device = torch.device("cuda")    
    tta = True  # test time augmentation. <prompt>-<model-b's response>-<model-a's response>
    spread_max_length = False  # whether to apply max_length//3 on each input or max_length on the concatenated input

cfg = Config()

# Load & pre-process Data 


from tqdm import tqdm

test = pd.read_csv('/kaggle/input/lmsys-chatbot-arena/test.csv')

if len(test)<= 10:
    test = pd.read_csv('/kaggle/input/lmsys-chatbot-arena/train.csv', nrows=100)

    
if cfg.tta:
    test['response_a'], test['response_b'] = test['response_b'], test['response_a']
    
def process(input_str):
    return json.loads(input_str)
original_length = len(test)
test.loc[:, 'prompt'] = test['prompt'].apply(process)
test.loc[:, 'response_a'] = test['response_a'].apply(process)
test.loc[:, 'response_b'] = test['response_b'].apply(process)

test = test.explode(['prompt','response_a','response_b']).reset_index(drop=True)
test = test.fillna('None')
test['response_a'] = test['response_a'].apply(lambda x: 'None' if len(x) == 0 else x)
test['response_b'] = test['response_b'].apply(lambda x: 'None' if len(x) == 0 else x)

def get_text_length(text):
    '''
    不用空格分隔的文本, text length = len
    不用空格分隔的一般tokenizer后长度类似，所以还可以缩小
    空格分隔的，len(text.split(" "))
    '''
    length1 = len(text)
    length2 = len(text.split(" "))
    #远超过
    if length1 >= length2 * 30 and length1>= 300:
        return length1 * 0.75
    return length2
    
def prompt_3(data, max_length, if_train):
    '''
    超过max length新开一行，label不变
    从后往前拼接
    #Prompt1
    xxxx
    #Response
    ##Model A
    xxxx
    ##Model B
    xxxx
    
    #Prompt2
    #Response
    ##Model A
    xxxx
    ##Model B
    xxxx
    '''

    data['prompt_response'] = "#Prompt\n" + data['prompt'] + "\n\n" + "#Response\n" + "##Model A\n" + data['response_a'] + "\n\n" + "##Model B\n" + data['response_b']
    data = data.iloc[::-1].reset_index(drop = True)#反转
    prompt_response = []
    ids = []
    labels = []
    #只有一种可能会超出max length：
    #单条的prompt和reponse加在一起超出max length
    over_max_length = [] #是否有超出max length的部分
    overflow_prompt = []
    overflow_response_a = [] #超出max length的部分
    overflow_response_b = [] #超出max length的部分
    text_length = 0
    for idx, row in tqdm(data.iterrows(), total=len(data)):
        text = row['prompt_response']
        response_a = row['response_a']
        response_b = row['response_b']
        prompt = row['prompt']
        id = row['id']
        
        if if_train:
            label = row['label']
        
        if id not in ids:
            #第一次出现
            prompt_response.append(text)
            text_length = get_text_length(text)
            ids.append(id)
            if if_train:
                labels.append(label)
            if text_length > max_length:
                over_max_length.append(1)
                overflow_prompt.append(prompt)
                overflow_response_a.append(response_a)
                overflow_response_b.append(response_b)
            else:
                over_max_length.append(0)
                overflow_prompt.append(None)
                overflow_response_a.append(None)
                overflow_response_b.append(None)
        
        else:
            text_length += get_text_length(text)
            if text_length <= max_length:
                #取上一个text出来，合并后替换
                text = text + "\n\n" + prompt_response[-1]
                prompt_response[-1] = text
                over_max_length[-1] = 0
                overflow_prompt[-1] = None
                overflow_response_a[-1] = None
                overflow_response_b[-1] = None
                
            else:
                #另一起一行
                prompt_response.append(text)
                text_length = get_text_length(text)
                ids.append(id)
                
                if if_train:
                    labels.append(label)
                    
                #另起一行但超出场合都
                if text_length > max_length:
                    over_max_length.append(1)
                    overflow_prompt.append(prompt)
                    overflow_response_a.append(response_a)
                    overflow_response_b.append(response_b)
                else:
                    over_max_length.append(0)
                    overflow_prompt.append(None)
                    overflow_response_a.append(None)
                    overflow_response_b.append(None)
                    
                
                    
    if if_train:           
        data = pd.DataFrame({'id': ids, 'prompt_response': prompt_response, "label": labels, 'overflow_prompt': overflow_prompt, 'over_max_length': over_max_length, 'overflow_response_a': overflow_response_a, 'overflow_response_b': overflow_response_b})
        data = data.iloc[::-1].reset_index(drop=True)#反转
    else:
        data = pd.DataFrame({'id': ids, 'prompt_response': prompt_response, 'over_max_length': over_max_length, 'overflow_prompt': overflow_prompt, 'overflow_response_a': overflow_response_a, 'overflow_response_b': overflow_response_b})
        data = data.iloc[::-1].reset_index(drop=True)#反转
    return data

test = prompt_3(test, cfg.max_length * 0.75, False)
test = test.drop_duplicates(subset=['id'], keep='last').reset_index(drop=True)
assert len(test) == original_length

# tokenize

def tokenize(
    tokenizer, data
):
    prompts = []
    for i in tqdm(range(len(data))):
        now_data = data.loc[i]
        idx = now_data['id']
        
        over_max_length = now_data['over_max_length']
        templete_part1 = "<start_of_turn>user\nHere are two question-answering dialogues. Compare two model performance on answering question, determine which is better.\n\n"
        templete_part1_input_ids = tokenizer(text=templete_part1, add_special_tokens=True, padding=False)['input_ids']

        templete_part2 = "\n###options\nA. Model A\nB. Model B\nC. Tie\n<end_of_turn>\n"
        templete_part2_input_ids = tokenizer(text=templete_part2, add_special_tokens=True, padding=False)['input_ids'][1:]
 
        templete_part3 = "<start_of_turn>model\n"
        templete_part3_input_ids = tokenizer(text=templete_part3, add_special_tokens=True, padding=False)['input_ids'][1:]

        templete_part4_input_ids = tokenizer(text="\n\n", add_special_tokens=False, padding=False)['input_ids']

        if over_max_length:
            prompt = "#Prompt\n" + now_data['overflow_prompt']
            r_a = "#Response\n" + "##Model A\n" + now_data['overflow_response_a']
            r_b = "##Model B\n" + now_data['overflow_response_b']

            prompt_ids = tokenizer(text=prompt, add_special_tokens=False, truncation=False, padding=False)['input_ids']
            model_a_input_ids = tokenizer(text=r_a, add_special_tokens=False, truncation=False, padding=False)['input_ids']
            model_b_input_ids = tokenizer(text=r_b, add_special_tokens=False, truncation=False, padding=False)['input_ids']

            if len(prompt_ids) + len(model_a_input_ids) + len(model_b_input_ids) <= cfg.max_length:
                prompt_response_ids = prompt_ids + model_a_input_ids + model_b_input_ids

            else:
                '''
                prompt 和 response 按照 300， 800， 800
                response 优先
                多的再给prompt
                '''
                length = [len(prompt_ids), len(model_a_input_ids), len(model_b_input_ids)]
                print(f"before {len(prompt_ids) + len(model_a_input_ids) + len(model_b_input_ids)}")
                print(f"before {length}")
                prompt_max_length, a_max_length, b_max_length = adjust(length)

                prompt_ids = prompt_ids[:prompt_max_length] + templete_part4_input_ids
                model_a_input_ids = model_a_input_ids[:a_max_length] + templete_part4_input_ids
                model_b_input_ids = model_b_input_ids[:b_max_length] + templete_part4_input_ids

                prompt_response_ids = prompt_ids + model_a_input_ids + model_b_input_ids
                print(f"after {[prompt_max_length, a_max_length, b_max_length]}")
                print(f"after {len(prompt_response_ids)}")

        else:
            prompt_response = now_data['prompt_response']
            prompt_response_ids = tokenizer(text=prompt_response, add_special_tokens=True, truncation=True, max_length=cfg.max_length, padding=False)['input_ids'][1:]    

        input_ids = templete_part1_input_ids + prompt_response_ids + templete_part2_input_ids + templete_part3_input_ids
        input_text = tokenizer.decode(input_ids[1:], skip_special_tokens=False)
        if i == 0:
            print(input_text)
        prompts.append(input_text)
    tokenized = tokenizer(prompts)
    input_ids = tokenized.input_ids
    attention_mask = tokenized.attention_mask
    return input_ids, attention_mask 

def adjust_values(A, B, a_space, b_space, ex_space):
    # 计算A和a_space的差值
    a_diff = a_space - A
    b_diff = b_space - B
    
    # 第一种情况：A小于a_space，B小于b_space
    if A < a_space and B < b_space:
        ex_space += a_diff + b_diff
        return A, B, ex_space

    # 第二种情况：如果A和B都各自大于自己的space
    elif A > a_space and B > b_space:
        total_extra_needed = (A - a_space) + (B - b_space)
        if total_extra_needed > ex_space:
            A = int(a_space + ex_space / 2)
            B = int(b_space + ex_space / 2)
            ex_space = 0
        else:
            a_space = A
            b_space = B
            ex_space -= total_extra_needed
            
        return A, B, ex_space
        
    # 第三种情况：A或者B其中有一个大于a_space, b_space
    elif A >= a_space or B >= b_space:
        # 如果A大于a_space但是B小于b_space
        if A >= a_space and B <= b_space:
            extra_needed = A - a_space
            ex_space += b_space - B
            #够用
            if ex_space >= extra_needed:
                ex_space -= extra_needed
                
            else:
                #不够用
                #b_space = B + available_space
                A = a_space + ex_space
                ex_space = 0

        # 如果B大于b_space但是A小于a_space
        elif B > b_space and A < a_space:
            extra_needed = B - b_space
            ex_space += a_space - A
            
            if ex_space >= extra_needed:
                ex_space -= extra_needed
                
            else:
                B = b_space + ex_space
                ex_space = 0

        return A, B, ex_space
    

def adjust(current_lengths, prompt_length_space=300, response_length_space=950):
    prompt_length = current_lengths[0]
    response_a_length = current_lengths[1]
    response_b_length = current_lengths[2]
    #先看prompt的额度
    ex_space = max(0, prompt_length_space - prompt_length)
    response_a_length, response_b_length, ex_space = adjust_values(response_a_length, response_b_length, response_length_space, response_length_space, ex_space)
    prompt_length = min(prompt_length, prompt_length_space)
    prompt_length += ex_space

    return prompt_length, response_a_length, response_b_length


tokenizer = GemmaTokenizerFast.from_pretrained(cfg.gemma_dir)
# tokenizer.add_eos_token = True
tokenizer.padding_side = "left"

data = pd.DataFrame()
data["id"] = test["id"]
data["input_ids"], data["attention_mask"] = tokenize(tokenizer, test)
data["length"] = data["input_ids"].apply(len)


print(tokenizer.decode(data["input_ids"][0]))

from transformers.cache_utils import Cache, DynamicCache, StaticCache
from typing import List, Optional, Tuple, Union
from transformers.utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    replace_return_docstrings
)
from transformers.modeling_outputs import (
    CausalLMOutputWithPast
)

class CustomGemma2ForCausalLM(Gemma2ForCausalLM):
    _tied_weights_keys = ["lm_head.weight"]
    
    def __init__(self, config):
        super().__init__(config)
        
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        if self.training and self.config._attn_implementation != "eager":
            logger.warning_once(
                "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
                f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
            )
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0][:,-1]
        logits = self.lm_head(hidden_states)
        if self.config.final_logit_softcapping is not None:
            logits = logits / self.config.final_logit_softcapping
            logits = torch.tanh(logits)
            logits = logits * self.config.final_logit_softcapping

        logits = logits.float()
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    
# Load base model on GPU 0
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,  
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)
device_0 = torch.device('cuda:0')
model_0 = CustomGemma2ForCausalLM.from_pretrained(
    cfg.gemma_dir,
    device_map=device_0,
    use_cache=False,
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
)

# Load base model on GPU 1

device_1 = torch.device('cuda:1')
model_1 = CustomGemma2ForCausalLM.from_pretrained(
    cfg.gemma_dir,
    device_map=device_1,
    use_cache=False,
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
)

model_0 = PeftModel.from_pretrained(model_0, cfg.lora_dir)
model_1 = PeftModel.from_pretrained(model_1, cfg.lora_dir)

model_0.eval()
model_1.eval()

@torch.no_grad()
@torch.cuda.amp.autocast()
def inference_test(df, model, device, batch_size=cfg.batch_size):
    a_win, b_win, tie = [], [], []
    
    for start_idx in range(0, len(df), batch_size):
        end_idx = min(start_idx + batch_size, len(df))
        tmp = df.iloc[start_idx:end_idx]
        input_ids = tmp["input_ids"].to_list()
        attention_mask = tmp["attention_mask"].to_list()
        inputs = pad_without_fast_tokenizer_warning(
            tokenizer,
            {"input_ids": input_ids, "attention_mask": attention_mask},
            padding="longest",
            pad_to_multiple_of=None,
            return_tensors="pt",
        )

        outputs = model(**inputs.to(device))
    
    return outputs

A_TOKEN_IDS = tokenizer('A',add_special_tokens=True, truncation=True, max_length=1024)['input_ids'][1:]
B_TOKEN_IDS = tokenizer('B',add_special_tokens=True, truncation=True, max_length=1024)['input_ids'][1:]
C_TOKEN_IDS = tokenizer('C',add_special_tokens=True, truncation=True, max_length=1024)['input_ids'][1:]

@torch.no_grad()
@torch.cuda.amp.autocast()
def inference(df, model, device, batch_size=cfg.batch_size):
    a_win, b_win, tie = [], [], []
    
    for start_idx in range(0, len(df), batch_size):
        end_idx = min(start_idx + batch_size, len(df))
        tmp = df.iloc[start_idx:end_idx]
        input_ids = tmp["input_ids"].to_list()
        attention_mask = tmp["attention_mask"].to_list()
        inputs = pad_without_fast_tokenizer_warning(
            tokenizer,
            {"input_ids": input_ids, "attention_mask": attention_mask},
            padding="longest",
            pad_to_multiple_of=None,
            return_tensors="pt",
        )
        outputs = model(**inputs.to(device))
        proba = ((outputs.logits.cpu()[:,A_TOKEN_IDS + B_TOKEN_IDS + C_TOKEN_IDS]) / 1.03).softmax(-1)
        
        a_win.extend(proba[:, 0].tolist())
        b_win.extend(proba[:, 1].tolist())
        tie.extend(proba[:, 2].tolist())
    
    df["winner_model_a"] = a_win
    df["winner_model_b"] = b_win
    df["winner_tie"] = tie
    
    return df


st = time.time()

# sort by input length to fully leverage dynaminc padding
data = data.sort_values("length", ascending=False)

data = data[:int(len(data) * 0.5)]
# the total #tokens in sub_1 and sub_2 should be more or less the same
sub_1 = data.iloc[0::2].copy()
sub_2 = data.iloc[1::2].copy()

with ThreadPoolExecutor(max_workers=2) as executor:
    results = executor.map(inference, (sub_1, sub_2), (model_0, model_1), (device_0, device_1))

result_df = pd.concat(list(results), axis=0)
proba = result_df[["winner_model_a", "winner_model_b", "winner_tie"]].values

print(f"elapsed time: {time.time() - st}")

if cfg.tta:
    result_df.loc[:, "winner_model_a"] = proba[:, 1]
    result_df.loc[:, "winner_model_b"] = proba[:, 0]
    result_df.loc[:, "winner_tie"] = proba[:, 2]
else:
    result_df.loc[:, "winner_model_a"] = proba[:, 0]
    result_df.loc[:, "winner_model_b"] = proba[:, 1]
    result_df.loc[:, "winner_tie"] = proba[:, 2]
submission_df = result_df[["id", 'winner_model_a', 'winner_model_b', 'winner_tie']]
submission_df = submission_df.sort_values('id')
submission_df.to_csv('submission_gemma2b.csv', index=False)

In [None]:
!python llama_inference.py

In [None]:
!python gemma_inference.py

In [None]:
!python gemma2b_inference.py

In [None]:
print(1)

In [None]:
import pandas as pd
llama_result = pd.read_csv('/kaggle/working/submission_llama_tta.csv')
gemma_result = pd.read_csv('/kaggle/working/submission_gemma.csv')
gemma2b_result = pd.read_csv('/kaggle/working/submission_gemma2b.csv')

llama_result['winner_model_a'], llama_result['winner_model_b'] = llama_result['winner_model_b'], llama_result['winner_model_a']
tta_ids = list(gemma2b_result.id.values)

In [None]:
llama_result[llama_result.id.isin(tta_ids)].reset_index(drop=True)

In [None]:
gemma2b_result

In [None]:
gemma_result

In [None]:
gemma2b = gemma2b_result[["winner_model_a", "winner_model_b", "winner_tie"]].values
llama_a = llama_result[llama_result.id.isin(tta_ids)].reset_index(drop=True)[["winner_model_a", "winner_model_b", "winner_tie"]].values
llama_b = llama_result[~llama_result.id.isin(tta_ids)].reset_index(drop=True)[["winner_model_a", "winner_model_b", "winner_tie"]].values


gemma_a = gemma_result[gemma_result.id.isin(tta_ids)].reset_index(drop=True)[["winner_model_a", "winner_model_b", "winner_tie"]].values
gemma_b = gemma_result[~gemma_result.id.isin(tta_ids)].reset_index(drop=True)[["winner_model_a", "winner_model_b", "winner_tie"]].values

In [None]:
print(gemma_a.shape, llama_a.shape, gemma2b.shape)
proba_a = gemma_a * 0.55 + llama_a * 0.4 + gemma2b * 0.05
proba_b = gemma_b * 0.55 + llama_b * 0.45
print(proba_a.shape, proba_b.shape)

In [None]:
%%time
import copy
final_a = copy.deepcopy(gemma2b_result)
final_b = copy.deepcopy(gemma_result[~gemma_result.id.isin(tta_ids)])

final_a.loc[:, "winner_model_a"] = proba_a[:, 0]
final_a.loc[:, "winner_model_b"] = proba_a[:, 1]
final_a.loc[:, "winner_tie"] = proba_a[:, 2]


final_b.loc[:, "winner_model_a"] = proba_b[:, 0]
final_b.loc[:, "winner_model_b"] = proba_b[:, 1]
final_b.loc[:, "winner_tie"] = proba_b[:, 2]

submission_df = pd.concat([final_a, final_b]).sort_values("id").reset_index(drop=True)[["id", 'winner_model_a', 'winner_model_b', 'winner_tie']]
submission_df.to_csv('submission.csv', index=False)
display(submission_df)