# README
This notebook is used to obtain the embeddings of the domain-knowledge prompt for pretraining PBT.

In [1]:
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.cm as cm
import matplotlib as mpl
import pickle
import transformers
from matplotlib.lines import Line2D
import seaborn as sns
import json
import torch
from data_provider.data_split_recorder import split_recorder
from Prompts.Mapping_helper import Mapping_helper
from transformers import GPT2Config, GPT2Tokenizer, GPT2Model, AutoTokenizer, AutoModel, AutoConfig, Phi3Config

In [2]:
def set_ax_linewidth(ax, bw=1.5):
    ax.spines['bottom'].set_linewidth(bw)
    ax.spines['left'].set_linewidth(bw)
    ax.spines['top'].set_linewidth(bw)
    ax.spines['right'].set_linewidth(bw)

def set_ax_font_size(ax, fontsize=10):
    ax.tick_params(axis='y',
                 labelsize=fontsize # y轴字体大小设置
                  ) 
    ax.tick_params(axis='x',
                 labelsize=fontsize # x轴字体大小设置
                  ) 

def set_draft(the_plt, other_ax=''):
    ax = the_plt.gca()
    ax.axes.xaxis.set_ticklabels([])
    ax.axes.yaxis.set_ticklabels([])
    plt.xlabel('')
    plt.ylabel('')
    if other_ax:
        other_ax.axes.xaxis.set_ticklabels([])
        other_ax.axes.yaxis.set_ticklabels([])
        other_ax.set_ylabel('')
        other_ax.set_xlabel('')

def set_draft_fig(fig):
    for ax in fig.axes:
        ax.axes.xaxis.set_ticklabels([])
        ax.axes.yaxis.set_ticklabels([])
        ax.set_ylabel('')
        ax.set_xlabel('')

In [3]:
target_dataset = 'MIX_all' # [MIX_large, MIX_all, MIX_all2024, MIX_all42]
if target_dataset == 'MIX_large':
    cell_names = [i.split('.pkl')[0] for i in split_recorder.MIX_large_train_files]
elif target_dataset == 'MIX_all':
    cell_names = [i.split('.pkl')[0] for i in split_recorder.MIX_all_train_files]
elif target_dataset == 'MIX_all2024':
    cell_names = [i.split('.pkl')[0] for i in split_recorder.MIX_all_2024_train_files]
elif target_dataset == 'MIX_all42':
    cell_names = [i.split('.pkl')[0] for i in split_recorder.MIX_all_42_train_files]
cell_names

['UL-PUR_N10-NA7_18650_NCA_23C_0-100_0.5-0.5C_g',
 'UL-PUR_N15-NA10_18650_NCA_23C_0-100_0.5-0.5C_j',
 'RWTH_016',
 'RWTH_045',
 'RWTH_009',
 'RWTH_039',
 'RWTH_046',
 'RWTH_019',
 'RWTH_037',
 'RWTH_013',
 'RWTH_003',
 'RWTH_044',
 'RWTH_026',
 'RWTH_006',
 'RWTH_031',
 'RWTH_036',
 'RWTH_048',
 'RWTH_033',
 'RWTH_021',
 'RWTH_012',
 'RWTH_034',
 'RWTH_018',
 'RWTH_022',
 'RWTH_030',
 'RWTH_028',
 'RWTH_011',
 'RWTH_040',
 'RWTH_041',
 'RWTH_042',
 'RWTH_025',
 'RWTH_047',
 'RWTH_004',
 'HUST_1-6',
 'HUST_2-2',
 'HUST_1-3',
 'HUST_6-3',
 'HUST_1-2',
 'HUST_3-7',
 'HUST_3-2',
 'HUST_10-6',
 'HUST_3-6',
 'HUST_5-1',
 'HUST_10-5',
 'HUST_6-2',
 'HUST_6-1',
 'HUST_8-1',
 'HUST_10-7',
 'HUST_1-4',
 'HUST_5-4',
 'HUST_1-5',
 'HUST_6-6',
 'HUST_5-6',
 'HUST_6-4',
 'HUST_9-2',
 'HUST_10-4',
 'HUST_5-3',
 'HUST_7-7',
 'HUST_3-1',
 'HUST_4-1',
 'HUST_4-4',
 'HUST_4-6',
 'HUST_8-8',
 'HUST_2-4',
 'HUST_9-8',
 'HUST_9-5',
 'HUST_3-3',
 'HUST_1-7',
 'HUST_4-5',
 'HUST_9-6',
 'HUST_1-1',
 'HUST_4-3'

## Tokenize and embedding

In [4]:
def create_causal_mask(B, seq_len):
    '''
    return:
        casual mask: [B, L, L]. 0 indicates masked.
    '''
    # Create a lower triangular matrix of shape (seq_len, seq_len)
    mask = torch.tril(torch.ones(seq_len, seq_len))  # (L, L)
    mask = mask.unsqueeze(0).expand(B, -1, -1)
    return mask

def last_token_pool(last_hidden_states, attention_mask):
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]


def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery:{query}'

In [5]:
# loader the tokenizer and model

# '/data/LLMs/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659'
# '/data/LLMs/models--Qwen--Qwen3-Embedding-8B/snapshots/a3d38e32b9c835d5b3d0d0a3ef3c133bbea92539'
# '/data/LLMs/models--Qwen--Qwen3-Embedding-0.6B/snapshots/744169034862c8eec56628663995004342e4e449'
# 'Qwen/Qwen3-Embedding-0.6B'
LLM_path = '/data/LLMs/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659'
llama_config = AutoConfig.from_pretrained(LLM_path)
# language_model = AutoModel.from_pretrained(
#             LLM_path,
#             # 'huggyllama/llama-7b',
#             trust_remote_code=True,
#             local_files_only=True,
#             config=llama_config,
#             load_in_4bit=True                                                                                                                                                                  
#         )
if 'Qwen3-Embedding-0.6B' in LLM_path:
    language_model = AutoModel.from_pretrained(
                LLM_path
            ).cuda()
else:
    language_model = AutoModel.from_pretrained(
                LLM_path,
                # 'huggyllama/llama-7b',
                trust_remote_code=True,
                local_files_only=True,
                config=llama_config,
                load_in_4bit=True
            )
if 'Llama' in LLM_path:
    tokenizer = AutoTokenizer.from_pretrained(
                    LLM_path,
                    # 'huggyllama/llama-7b',
                    trust_remote_code=True,
                    local_files_only=True, 
                    pad_token='<|endoftext|>'
                )
    tokenizer.padding_side = 'right' # set the padding side
else:
    tokenizer = AutoTokenizer.from_pretrained(LLM_path, padding_side='left')


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [6]:
def read_cell_data_according_to_prefix(file_name, root_path):
    prefix = file_name.split('_')[0]
    if prefix == 'MICH':
        with open(f'{root_path}/Life labels/total_MICH_labels.json') as f:
            life_labels = json.load(f)
    elif prefix.startswith('Tongji'):
        file_name = file_name.replace('--', '-#')
        with open(f'{root_path}/Life labels/Tongji_labels.json') as f:
            life_labels = json.load(f)
    else:
        with open(f'{root_path}/Life labels/{prefix}_labels.json') as f:
            life_labels = json.load(f)
    if file_name in life_labels:
        eol = life_labels[file_name]
    else:
        eol = None

    return eol

def get_features_from_cellNames(cell_names):
    cellName_prompt = {}
    total_labels = []
    for cell_name in cell_names:
        # bg_prompt = (
        #             f"Task description: You are an expert in predicting battery cycle life. " 
        #             f"The cycle life is the number of cycles until the battery's discharge capacity reaches 80% of its nominal capacity. "
        #             f"The discharge capacity is calculated under the described operating condition. "
        #             f"Please directly output the cycle life of the battery based on the provided data. "
        #             )
        if 'CALB' in cell_name:
            bg_prompt = (
                        f"Task description: " 
                        f"The target is the number of cycles until the battery's discharge capacity reaches 90% of its nominal capacity. "
                        f"The discharge capacity is calculated under the described operating condition. "
                        f"Please directly output the target of the battery based on the provided data. "
                        )
        else:
            bg_prompt = (
                        f"Task description: " 
                        f"The target is the number of cycles until the battery's discharge capacity reaches 80% of its nominal capacity. "
                        f"The discharge capacity is calculated under the described operating condition. "
                        f"Please directly output the target of the battery based on the provided data. "
                        )
        helper = Mapping_helper(prompt_type='PROTOCOL', cell_name=cell_name)
        tmp_prompt = bg_prompt + helper.do_mapping()
        eol = read_cell_data_according_to_prefix(cell_name+'.pkl', '/data/trf/python_works/BatteryLife/dataset')
        if eol is None:
            print(cell_name+'.pkl')
            continue
        total_labels.append(eol)
        # Llama-instruct
        if 'Llama' in LLM_path:
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": tmp_prompt}
            ]

            tmp_prompt = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            res = tokenizer(tmp_prompt, return_tensors="pt")
            input_ids, attention_mask = res['input_ids'][:,1:], res['attention_mask'][:,1:]
            llama_enc_out = language_model.get_input_embeddings()(input_ids) # [1, L', d_llm]
            
            cache_position = torch.arange(
                    0, 0 + llama_enc_out.shape[1], device=llama_enc_out.device
                )
            position_ids = cache_position.unsqueeze(0)
            DLP_attention_mask = attention_mask.unsqueeze(1) # [B, 1, L]
            DLP_attention_mask = DLP_attention_mask.expand(-1, DLP_attention_mask.shape[-1], -1) # [B, L, L]
            DLP_attention_mask = DLP_attention_mask.unsqueeze(1) # [B, 1, L, L]
            
            casual_mask = create_causal_mask(1, llama_enc_out.shape[1])
            casual_mask = casual_mask.unsqueeze(1) # [B, 1, L, L]

            DLP_attention_mask = torch.where(casual_mask.to(DLP_attention_mask.device)==1, DLP_attention_mask, torch.zeros_like(DLP_attention_mask))
            DLP_attention_mask = DLP_attention_mask==1 # set True to allow attention to attend to

            hidden_states = language_model(inputs_embeds=llama_enc_out).last_hidden_state
            # hidden_states = llama_enc_out
            # for i, layer in enumerate(language_model.layers):
            #     res = layer(hidden_states=hidden_states, position_ids=position_ids, attention_mask=DLP_attention_mask, cache_position=cache_position)
            #     hidden_states = res[0]

            features = hidden_states[:,-1,:].detach().cpu().numpy().reshape(1, -1)
        elif 'Qwen3' in LLM_path:
            tmp_prompt = [get_detailed_instruct('classification', tmp_prompt)]
            res = tokenizer(
                tmp_prompt,
                padding=True,
                truncation=True,
                max_length=8192,
                return_tensors="pt",
            )
            res.to(language_model.device)
            outputs = language_model(**res)
            embeddings = last_token_pool(outputs.last_hidden_state, res['attention_mask'])
            features = embeddings.detach().cpu().numpy().reshape(1, -1)
        else:
            raise Exception(f'{LLM_path} is not supported here')

        
    
        cellName_prompt[cell_name] = features
    return cellName_prompt, total_labels

cellName_prompt, total_labels = get_features_from_cellNames(cell_names)

# get the features from validation and testing sets
if target_dataset == 'MIX_large':
    val_cell_names = [i.split('.pkl')[0] for i in split_recorder.MIX_large_val_files]
    test_cell_names = [i.split('.pkl')[0] for i in split_recorder.MIX_large_test_files]
elif target_dataset == 'MIX_all':
    val_cell_names = [i.split('.pkl')[0] for i in split_recorder.MIX_all_val_files]
    test_cell_names = [i.split('.pkl')[0] for i in split_recorder.MIX_all_test_files]
elif target_dataset == 'MIX_all2024':
    val_cell_names = [i.split('.pkl')[0] for i in split_recorder.MIX_all_2024_val_files]
    test_cell_names = [i.split('.pkl')[0] for i in split_recorder.MIX_all_2024_test_files]
elif target_dataset == 'MIX_all42':
    val_cell_names = [i.split('.pkl')[0] for i in split_recorder.MIX_all_42_val_files]
    test_cell_names = [i.split('.pkl')[0] for i in split_recorder.MIX_all_42_test_files]
val_cellName_prompt, val_total_labels = get_features_from_cellNames(val_cell_names)
test_cellName_prompt, test_total_labels = get_features_from_cellNames(test_cell_names)





MICH_18H_pouch_NMC_45C_50-100_0.2-1.5C.pkl
MICH_17C_pouch_NMC_-5C_50-100_0.2-1.5C.pkl
MICH_15H_pouch_NMC_45C_50-100_0.2-0.2C.pkl
MICH_13R_pouch_NMC_25C_50-100_0.2-0.2C.pkl
MICH_14C_pouch_NMC_-5C_50-100_0.2-0.2C.pkl
ZN-coin_403-1_20231209225922_01_4.pkl
ZN-coin_410-1_20231209232559_09_1.pkl
ZN-coin_418-1_20231209234141_11_1.pkl
ZN-coin_402-1_20231209225636_01_1.pkl
MICH_16R_pouch_NMC_25C_50-100_0.2-1.5C.pkl
ZN-coin_428-2_20231212185058_01_4.pkl


Save the results

In [None]:
## Export the domain-knowledge prompt embeddings of the samples
save_path = '/data/trf/python_works/BatteryLife/dataset/'


if 'Llama' in LLM_path:
    name_comment = 'Llama'
    with open(f'{save_path}training_DKP_embed_all_{name_comment}.pkl', 'wb') as f:
        pickle.dump(cellName_prompt, f)
    with open(f'{save_path}validation_DKP_embed_all_{name_comment}.pkl', 'wb') as f:
        pickle.dump(val_cellName_prompt, f)
    with open(f'{save_path}testing_DKP_embed_all_{name_comment}.pkl', 'wb') as f:
        pickle.dump(test_cellName_prompt, f)
elif 'Qwen3-Embedding-4B' in LLM_path:
    name_comment = 'Qwen3_4B'
    with open(f'{save_path}training_DKP_embed_all_{name_comment}.pkl', 'wb') as f:
        pickle.dump(cellName_prompt, f)
    with open(f'{save_path}validation_DKP_embed_all_{name_comment}.pkl', 'wb') as f:
        pickle.dump(val_cellName_prompt, f)
    with open(f'{save_path}testing_DKP_embed_all_{name_comment}.pkl', 'wb') as f:
        pickle.dump(test_cellName_prompt, f)
elif 'Qwen3-Embedding-8B' in LLM_path:
    name_comment = 'Qwen3_8B'
    with open(f'{save_path}training_DKP_embed_all_{name_comment}.pkl', 'wb') as f:
        pickle.dump(cellName_prompt, f)
    with open(f'{save_path}validation_DKP_embed_all_{name_comment}.pkl', 'wb') as f:
        pickle.dump(val_cellName_prompt, f)
    with open(f'{save_path}testing_DKP_embed_all_{name_comment}.pkl', 'wb') as f:
        pickle.dump(test_cellName_prompt, f)
elif 'Qwen3-Embedding-0.6B' in LLM_path:
    name_comment = 'Qwen3_0.6B'
    with open(f'{save_path}training_DKP_embed_all_{name_comment}.pkl', 'wb') as f:
        pickle.dump(cellName_prompt, f)
    with open(f'{save_path}validation_DKP_embed_all_{name_comment}.pkl', 'wb') as f:
        pickle.dump(val_cellName_prompt, f)
    with open(f'{save_path}testing_DKP_embed_all_{name_comment}.pkl', 'wb') as f:
        pickle.dump(test_cellName_prompt, f)

: 