In [1]:
from transformers import GPT2Config, GPT2Tokenizer, GPT2Model, AutoTokenizer, AutoModel, AutoConfig, Phi3Config
from data_provider.data_split_recorder import split_recorder
import json
import torch
import pickle
from torch import nn
from Prompts.Mapping_helper import Mapping_helper
import os
import numpy as np
from tqdm import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = '2'

## 统计label prompt的最大长度

In [2]:
def read_cell_data_according_to_prefix(data_root_path, file_name):
    '''
    Read the battery data and eol according to the file_name
    The dataset is indicated by the prefix of the file_name
    '''
    prefix = file_name.split('_')[0]
    
    if prefix == 'MICH':
        with open(f'{data_root_path}/total_MICH_labels.json') as f:
            life_labels = json.load(f)
    elif prefix.startswith('Tongji'):
        with open(f'{data_root_path}/Tongji_labels.json') as f:
            life_labels = json.load(f)
    else:
        with open(f'{data_root_path}/{prefix}_labels.json') as f:
            life_labels = json.load(f)
    if file_name in life_labels:
        eol = life_labels[file_name]
    else:
        eol = None
    return eol

In [6]:
dataset_name = 'HUST'
data_root_path = '/data/trf/python_works/Battery-LLM/dataset'

if dataset_name == 'Tongji':
    train_files = split_recorder.Tongji_train_files
    val_files = split_recorder.Tongji_val_files
    test_files = split_recorder.Tongji_test_files
elif dataset_name == 'HUST':
    train_files = split_recorder.HUST_train_files
    val_files = split_recorder.HUST_val_files
    test_files = split_recorder.HUST_test_files
elif dataset_name == 'MATR':
    train_files = split_recorder.MATR_train_files
    val_files = split_recorder.MATR_val_files
    test_files = split_recorder.MATR_test_files
elif dataset_name == 'SNL':
    train_files = split_recorder.SNL_train_files
    val_files = split_recorder.SNL_val_files
    test_files = split_recorder.SNL_test_files
elif dataset_name == 'MICH':
    train_files = split_recorder.MICH_train_files
    val_files = split_recorder.MICH_val_files
    test_files = split_recorder.MICH_test_files
elif dataset_name == 'MICH_EXP':
    train_files = split_recorder.MICH_EXP_train_files
    val_files = split_recorder.MICH_EXP_val_files
    test_files = split_recorder.MICH_EXP_test_files
elif dataset_name == 'UL_PUR':
    train_files = split_recorder.UL_PUR_train_files
    val_files = split_recorder.UL_PUR_val_files
    test_files = split_recorder.UL_PUR_test_files
elif dataset_name == 'RWTH':
    train_files = split_recorder.RWTH_train_files
    val_files = split_recorder.RWTH_val_files
    test_files = split_recorder.RWTH_test_files
elif dataset_name == 'MIX':
    train_files = split_recorder.MIX_train_files
    val_files = split_recorder.MIX_val_files 
    test_files = split_recorder.MIX_test_files
elif dataset_name == 'HNEI':
    train_files = split_recorder.HNEI_train_files
    val_files = split_recorder.HNEI_val_files
    test_files = split_recorder.HNEI_test_files
elif dataset_name == 'CALCE':
    train_files = split_recorder.CALCE_train_files
    val_files = split_recorder.CALCE_val_files
    test_files = split_recorder.CALCE_test_files
elif dataset_name == 'Stanford':
    train_files = split_recorder.Stanford_train_files
    val_files = split_recorder.Stanford_val_files
    test_files = split_recorder.Stanford_test_files
elif dataset_name == 'ISU_ILCC':
    train_files = split_recorder.ISU_ILCC_train_files
    val_files = split_recorder.ISU_ILCC_val_files
    test_files = split_recorder.ISU_ILCC_test_files
elif dataset_name == 'MIX_small':
    train_files = split_recorder.MIX_small_train_files
    val_files = split_recorder.MIX_small_val_files 
    test_files = split_recorder.MIX_small_test_files
elif dataset_name == 'MIX_large':
    train_files = split_recorder.MIX_large_train_files
    val_files = split_recorder.MIX_large_val_files 
    test_files = split_recorder.MIX_large_test_files
elif dataset_name == 'MIX_c':
    train_files = split_recorder.MIX_train_files_complete
    val_files = split_recorder.MIX_val_files_complete 
    test_files = split_recorder.MIX_test_files_complete  
elif dataset_name == 'MIX_c_woISU':
    train_files = split_recorder.MIX_train_files_complete_woISU
    val_files = split_recorder.MIX_val_files_complete_woISU 
    test_files = split_recorder.MIX_test_files_complete_woISU  
else:
    raise Exception('Not implemented')

# '/data/LLMs/models--Qwen--Qwen2.5-7B-Instruct/snapshots/bb46c15ee4bb56c5b63245ef50fd7637234d6f75'
# /data/LLMs/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659
# /data/LLMs/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b
# /data/LLMs/llama2-hf-7b
LLM_path = '/data/LLMs/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659'
tokenizer = AutoTokenizer.from_pretrained(LLM_path)
label_prompt_lengths = []
total_results = []
for i, file_name in enumerate(train_files):
    eol = read_cell_data_according_to_prefix(data_root_path, file_name)
    if not eol:
        continue
    if i == 0:
        label_prompt = f'0'
    else:
        label_prompt = '1'
    if 'Instruct' in LLM_path:
        messages = [
            {"role": "system", "content": "You are an expert in predicting battery cycle life."},
            {"role": "user", "content": label_prompt}
        ]
        label_prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False
        )
    else:
        label_prompt = '<|begin_of_text|>' + label_prompt + '<|end_of_text|>'
    # print(f"{label_prompt}")
    res = tokenizer(label_prompt, return_tensors="pt")
    input_ids, attention_mask = res['input_ids'][0], res['attention_mask'][0]
    print(len(input_ids), input_ids)
    total_results.append(res)
    length = len(input_ids)
    label_prompt_lengths.append(length)
    if i == 1:
        break

# # print(max(label_prompt_lengths))
# end_of_the_prompt = ' Here is description about the cycling records of the battery. You should combine information about the battery specification and operating condition as well as cycling records to do the prediction.'
# # end_of_the_prompt = ' Here is description about the cycling records of the battery. You should combine information about the battery specification and operating condition as well as cycling records to do the prediction.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'+\
# #                     'The cycle life is'
# print(end_of_the_prompt)
# res = tokenizer(end_of_the_prompt, return_tensors="pt")['input_ids'][0]
# print(len(res), res)

43 tensor([128000, 128000, 128006,   9125, 128007,    271,  38766,   1303,  33025,
          2696,     25,   6790,    220,   2366,     18,    198,  15724,   2696,
            25,    220,   1627,  10263,    220,   2366,     19,    271,   2675,
           527,    459,   6335,    304,  52997,  11863,  11008,   2324,     13,
        128009, 128006,    882, 128007,    271,     15, 128009])
43 tensor([128000, 128000, 128006,   9125, 128007,    271,  38766,   1303,  33025,
          2696,     25,   6790,    220,   2366,     18,    198,  15724,   2696,
            25,    220,   1627,  10263,    220,   2366,     19,    271,   2675,
           527,    459,   6335,    304,  52997,  11863,  11008,   2324,     13,
        128009, 128006,    882, 128007,    271,     16, 128009])


In [6]:
if 'Instruct' in LLM_path:
    # end_of_the_prompt = '<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
    # end_of_the_prompt = '<|start_header_id|>assistant<|end_header_id|>\n\n'
    end_of_the_prompt = 'Predict battery cycle life'
else:
    # Llama
    end_of_the_prompt = '<|end_of_text|>'

res = tokenizer(end_of_the_prompt, return_tensors="pt", truncation=True)
print(len(res['input_ids'][0]),res['input_ids'][0])

5 tensor([128000,  54644,  11863,  11008,   2324])


## 检查label prompt之间的语义相似度

In [7]:
def make_LLM_inputs(eol, tokenizer):
    label_prompt = f'The battery cycle life is {eol}.'
    if 'Instruct' in LLM_path:
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": label_prompt}
        ]
        label_prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
    else:
        label_prompt = '<|begin_of_text|>' + label_prompt + '<|end_of_text|>'
    res = tokenizer(label_prompt, return_tensors="pt")
    end_cut_off = - (len(res['input_ids'][0]) - 1)
    return res['input_ids'][:,end_cut_off:], res['attention_mask'][:,end_cut_off:]

In [8]:
llm = AutoModel.from_pretrained(
            LLM_path,
            # 'huggyllama/llama-7b',
            trust_remote_code=True,
            local_files_only=True,
            load_in_4bit=True
        )
euclidean_dist = nn.PairwiseDistance(p=2)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [9]:

life1 = 1500
life2 = 1400

input_ids, attention_mask = make_LLM_inputs(life1, tokenizer)
embedding1 = llm(input_ids, attention_mask).last_hidden_state
embedding1 = embedding1[:,-1]
D = embedding1.shape[-1]

input_ids, attention_mask = make_LLM_inputs(life2, tokenizer)
embedding2 = llm(input_ids, attention_mask).last_hidden_state
embedding2 = embedding2[:,-1]

eu_dist = euclidean_dist(embedding1, embedding2).detach().float() / np.sqrt(D)
print(eu_dist)



We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


tensor([0.3809])


In [10]:
lives = [1,100,200,300,400,450,500,700,900,1100,1300,1500,1800,2000,2200,2400,2600,2500,2800,3000]
total_distances = []
for index, life1 in tqdm(enumerate(lives)):
    for life2 in lives[index+1:]:
        input_ids, attention_mask = make_LLM_inputs(life1, tokenizer)
        embedding1 = llm(input_ids, attention_mask).last_hidden_state
        embedding1 = embedding1[:,-1]
        D = embedding1.shape[-1]

        input_ids, attention_mask = make_LLM_inputs(life2, tokenizer)
        embedding2 = llm(input_ids, attention_mask).last_hidden_state
        embedding2 = embedding2[:,-1]

        eu_dist = euclidean_dist(embedding1, embedding2).detach().float() / np.sqrt(D)
        total_distances.append(eu_dist)
        print(eu_dist)

print(f'min {min(total_distances)} | max {max(total_distances)} | mean {np.mean(total_distances)} | median {np.median(total_distances)} | std {np.std(total_distances)}')

0it [00:00, ?it/s]

tensor([1.0908])
tensor([1.2715])
tensor([1.1768])
tensor([1.1533])
tensor([1.1523])
tensor([1.2402])
tensor([1.1191])
tensor([1.1787])
tensor([1.2461])
tensor([1.2725])
tensor([1.2988])
tensor([1.2520])
tensor([1.2812])
tensor([1.3096])
tensor([1.2959])
tensor([1.2607])
tensor([1.2979])
tensor([1.2725])


1it [00:05,  5.18s/it]

tensor([1.2646])
tensor([0.5781])
tensor([0.6240])
tensor([0.6392])
tensor([0.8203])
tensor([0.6396])
tensor([0.6260])
tensor([0.5991])
tensor([0.7944])
tensor([0.8511])
tensor([0.8442])
tensor([0.7739])
tensor([0.7090])
tensor([0.9268])
tensor([0.8423])
tensor([0.9092])
tensor([0.8066])
tensor([0.8120])


2it [00:10,  5.02s/it]

tensor([0.7837])
tensor([0.4019])
tensor([0.4587])
tensor([0.6895])
tensor([0.3706])
tensor([0.5024])
tensor([0.4885])
tensor([0.6309])
tensor([0.6489])
tensor([0.6108])
tensor([0.5688])
tensor([0.4214])
tensor([0.7148])
tensor([0.6123])
tensor([0.7275])
tensor([0.5610])
tensor([0.6089])


3it [00:14,  4.85s/it]

tensor([0.5591])
tensor([0.3032])
tensor([0.5068])
tensor([0.3035])
tensor([0.3789])
tensor([0.4321])
tensor([0.5483])
tensor([0.5483])
tensor([0.4661])
tensor([0.4751])
tensor([0.4351])
tensor([0.5938])
tensor([0.4966])
tensor([0.5806])
tensor([0.4419])
tensor([0.4836])


4it [00:19,  4.66s/it]

tensor([0.3386])
tensor([0.4766])
tensor([0.3613])
tensor([0.3569])
tensor([0.4382])
tensor([0.5601])
tensor([0.5752])
tensor([0.5400])
tensor([0.5273])
tensor([0.4851])
tensor([0.6270])
tensor([0.5332])
tensor([0.6011])
tensor([0.4939])
tensor([0.5063])


5it [00:23,  4.49s/it]

tensor([0.4480])
tensor([0.5664])
tensor([0.4692])
tensor([0.5361])
tensor([0.5977])
tensor([0.6206])
tensor([0.6284])
tensor([0.5947])
tensor([0.6987])
tensor([0.6514])
tensor([0.6196])
tensor([0.6304])
tensor([0.5972])
tensor([0.5532])


6it [00:27,  4.27s/it]

tensor([0.6001])
tensor([0.4119])
tensor([0.4329])
tensor([0.5405])
tensor([0.5522])
tensor([0.4688])
tensor([0.5176])
tensor([0.3914])
tensor([0.6147])
tensor([0.5254])
tensor([0.6270])
tensor([0.4167])
tensor([0.5273])


7it [00:30,  4.04s/it]

tensor([0.3977])
tensor([0.3894])
tensor([0.5703])
tensor([0.6021])
tensor([0.5913])
tensor([0.5601])
tensor([0.5415])
tensor([0.6660])
tensor([0.5835])
tensor([0.6401])
tensor([0.5391])
tensor([0.5234])


8it [00:34,  3.81s/it]

tensor([0.5068])
tensor([0.5708])
tensor([0.6216])
tensor([0.6045])
tensor([0.5332])
tensor([0.5474])
tensor([0.7012])
tensor([0.5957])
tensor([0.6772])
tensor([0.5664])
tensor([0.5674])


9it [00:37,  3.56s/it]

tensor([0.5356])
tensor([0.4358])
tensor([0.4895])
tensor([0.4934])
tensor([0.5410])
tensor([0.4897])
tensor([0.5117])
tensor([0.5269])
tensor([0.4841])
tensor([0.4944])


10it [00:39,  3.31s/it]

tensor([0.5098])
tensor([0.4248])
tensor([0.5054])
tensor([0.5347])
tensor([0.4990])
tensor([0.4597])
tensor([0.4158])
tensor([0.4390])
tensor([0.4773])


11it [00:42,  3.09s/it]

tensor([0.4736])
tensor([0.4248])
tensor([0.3860])
tensor([0.4482])
tensor([0.3633])
tensor([0.4700])
tensor([0.3027])
tensor([0.4607])


12it [00:44,  2.82s/it]

tensor([0.2866])
tensor([0.4500])
tensor([0.5229])
tensor([0.4343])
tensor([0.5308])
tensor([0.4585])
tensor([0.4480])


13it [00:46,  2.55s/it]

tensor([0.4314])
tensor([0.5449])
tensor([0.4421])
tensor([0.5703])
tensor([0.3645])
tensor([0.4968])


14it [00:48,  2.28s/it]

tensor([0.3337])
tensor([0.4119])
tensor([0.4509])
tensor([0.4387])
tensor([0.4761])


15it [00:49,  2.01s/it]

tensor([0.4783])
tensor([0.4092])
tensor([0.3157])
tensor([0.4001])


16it [00:50,  1.74s/it]

tensor([0.3623])
tensor([0.4207])
tensor([0.4214])


17it [00:51,  1.46s/it]

tensor([0.4673])
tensor([0.4055])


18it [00:52,  1.19s/it]

tensor([0.2815])


20it [00:52,  2.62s/it]

tensor([0.4089])
min tensor([0.2815]) | max tensor([1.3096]) | mean 0.6000012755393982 | median 0.535888671875 | std 0.24207854270935059





## 统计domain knowledge prompt的最大长度

In [11]:
def generate_basic_prompt(cell_name):
    '''
    Generate the basic prompt that describes battery specifications and working conditions
    '''
    bg_prompt = (
                f"Task description: You are an expert in predicting battery cycle life. " 
                f"The cycle life is the number of cycles until the battery's discharge capacity reaches 80% of its nominal capacity. "
                f"The discharge capacity is calculated under the described operating condition. "
                f"Please directly output the cycle life of the battery based on the provided data. "
                )

    
    helper = Mapping_helper(prompt_type='PROTOCOL', cell_name=cell_name)
    prompt = helper.do_mapping()
    prompt = bg_prompt + prompt
    # prompt = bg_prompt
    return prompt

In [12]:
dg_prompt_lengths = []
for i in range(2):

    if i == 0:
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": 'I love you'}
        ]
        prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False
        )
    else:
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": 'No'}
        ]
        prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False
        )

    res = tokenizer(prompt, return_tensors="pt")
    input_ids, attention_mask = res['input_ids'][0][1:], res['attention_mask'][0][1:]
    length = len(input_ids)
    dg_prompt_lengths.append(length)
    print(input_ids[:36])
print(max(dg_prompt_lengths))

tensor([128000, 128006,   9125, 128007,    271,  38766,   1303,  33025,   2696,
            25,   6790,    220,   2366,     18,    198,  15724,   2696,     25,
           220,   1627,  10263,    220,   2366,     19,    271,   2675,    527,
           264,  11190,  18328,     13, 128009, 128006,    882, 128007,    271])
tensor([128000, 128006,   9125, 128007,    271,  38766,   1303,  33025,   2696,
            25,   6790,    220,   2366,     18,    198,  15724,   2696,     25,
           220,   1627,  10263,    220,   2366,     19,    271,   2675,    527,
           264,  11190,  18328,     13, 128009, 128006,    882, 128007,    271])
40


In [13]:
dg_prompt_lengths = []
for file_name in train_files:
    eol = read_cell_data_according_to_prefix(data_root_path, file_name)
    if not eol:
        continue
    cell_name = file_name.split('.pkl')[0]
    prompt = generate_basic_prompt(cell_name)
    if 'Instruct' in LLM_path:
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ]
        prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False
        )
    else:
        prompt = '<|begin_of_text|>' + prompt + '<|end_of_text|>'
    res = tokenizer(prompt, return_tensors="pt")
    input_ids, attention_mask = res['input_ids'][0], res['attention_mask'][0]
    length = len(input_ids)
    dg_prompt_lengths.append(length)
    print(input_ids)
print(max(dg_prompt_lengths))

tensor([128000, 128000, 128006,   9125, 128007,    271,  38766,   1303,  33025,
          2696,     25,   6790,    220,   2366,     18,    198,  15724,   2696,
            25,    220,   1627,  10263,    220,   2366,     19,    271,   2675,
           527,    264,  11190,  18328,     13, 128009, 128006,    882, 128007,
           271,   6396,   4096,     25,   1472,    527,    459,   6335,    304,
         52997,  11863,  11008,   2324,     13,    578,  11008,   2324,    374,
           279,   1396,    315,  25492,   3156,    279,  11863,    596,  32643,
          8824,  25501,    220,   1490,      4,    315,   1202,  47855,   8824,
            13,    578,  32643,   8824,    374,  16997,   1234,    279,   7633,
         10565,   3044,     13,   5321,   6089,   2612,    279,  11008,   2324,
           315,    279,  11863,   3196,    389,    279,   3984,    828,     13,
         34712,  26185,     25,    578,    828,   4131,    505,    264,  57907,
         80846,  11863,    304,    264, 

## play

In [None]:
end_of_the_prompt = '<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
res = tokenizer(end_of_the_prompt, return_tensors="pt")
end_input_ids, end_attn_mask = res['input_ids'][0], res['attention_mask'][0]
len(end_input_ids)

5

: 