## SOH estimation within each individual dataset

### fig3

In [None]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape

def plot_results(estimates,ture,abs_errors,dataset,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(45/25.4, 45/25.4))
    cmap = plt.cm.colors.LinearSegmentedColormap.from_list('custom_cmap', ['#313695','#4575b4','#74add1','#abd9e9', '#e0f3f8'])
    norm = plt.Normalize(0, 0.1)
    ax.scatter(ture, estimates, c=abs_errors, cmap=cmap, norm=norm, s=0.25)
    # if dataset == 'Dataset 6':
    #     plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), label='Absolute Error')
    lower_bound = min(min(estimates),min(ture))-0.01
    upper_bound = max(max(estimates),max(ture))+0.01
    if upper_bound > 1.02:
        upper_bound = 1.02
    ax.plot([lower_bound, upper_bound],[lower_bound, upper_bound],'--', color='red', linewidth=1)
    ax.set_xlim(lower_bound, upper_bound)
    ax.set_ylim(lower_bound, upper_bound)
    ax.set_xlabel('True SOH')
    ax.set_ylabel('Estimation')
    ax.set_title(dataset,fontsize=8)
    plt.tight_layout()
    plt.savefig(save_path, dpi=600)
    plt.close()

def plot_rmses(df,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 8})
    fig, ax = plt.subplots(figsize=(50/25.4, 85/25.4))
    colors = ['#f46d43','#fdae61','#fee090','#abd9e9','#74add1','#4575b4']
    sns.violinplot(x='rmse', y='Dataset', data=df, scale='width', inner=None, orient='h', palette=colors[::-1], linewidth=0)
    means = df.groupby('Dataset')['rmse'].mean()
    stds = df.groupby('Dataset')['rmse'].std()
    for i, dataset in enumerate(df['Dataset'].unique()):
        mean = means[dataset]
        std = stds[dataset]
        if i == 0:
            ax.plot([mean, mean], [i - 0.2, i + 0.2],'-', color='red', lw=0.5,label='Mean')
            ax.plot([mean - std, mean + std], [i, i],'--', color='black', lw=0.5)
            ax.plot([mean - std, mean - std], [i - 0.1, i + 0.1],'-', color='black', lw=0.5,label='Mean ± Std')
            ax.plot([mean + std, mean + std], [i - 0.1, i + 0.1], '-',color='black', lw=0.5)
        else:
            ax.plot([mean, mean], [i - 0.2, i + 0.2],'-', color='red', lw=0.5)
            ax.plot([mean - std, mean + std], [i, i],'--', color='black', lw=0.5)
            ax.plot([mean - std, mean - std], [i - 0.1, i + 0.1],'-', color='black', lw=0.5)
            ax.plot([mean + std, mean + std], [i - 0.1, i + 0.1], '-',color='black', lw=0.5)
    ax.set_xlim([0,max(df['rmse'])+1.5])
    ax.set_xlabel('RMSE (\%)')
    ax.set_ylabel('Dataset')
    if '(b)' in save_path:
        ax.legend(fontsize=7, loc='lower right')
    plt.tight_layout()
    # plt.savefig(save_path)
    # plt.close()


def intra_cross_results():
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    for di in datasets:
        path = os.path.join('results/soh_individual_dataset_results',  di)
        intra_condition_estimates,intra_condition_ture,cross_condition_estimates,cross_condition_ture = get_results(path,di)
        intra_condition_abserrors = np.abs(intra_condition_estimates - intra_condition_ture)
        cross_condition_abserrors = np.abs(cross_condition_estimates - cross_condition_ture)
        plot_results(intra_condition_estimates,intra_condition_ture,intra_condition_abserrors,di,'figs/fig3/fig3(a)_'+di+'.jpg')
        plot_results(cross_condition_estimates,cross_condition_ture,cross_condition_abserrors,di,'figs/fig3/fig3(c)_'+di+'.jpg')

def intra_cross_rmses():
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    intra_grops = []
    intra_rmses = []
    cross_grops = []
    cross_rmses = []
    for di in datasets[::-1]:
        path = os.path.join('results/soh_individual_dataset_results',  di)
        intra_condition_rmses,cross_condition_rmses = get_rmse(path,di)
        i_gropi = np.array([di[-1]]*len(intra_condition_rmses))
        intra_grops.append(i_gropi)
        intra_rmses.append(intra_condition_rmses)
        c_gropi = np.array([di[-1]]*len(cross_condition_rmses))
        cross_grops.append(c_gropi)
        cross_rmses.append(cross_condition_rmses)
    intra_grops = np.concatenate(intra_grops)
    intra_rmses = np.concatenate(intra_rmses)
    cross_grops = np.concatenate(cross_grops)
    cross_rmses = np.concatenate(cross_rmses)
    intra_rmse_df = pd.DataFrame({'Dataset':intra_grops,'rmse':intra_rmses})
    cross_rmse_df = pd.DataFrame({'Dataset':cross_grops,'rmse':cross_rmses})
    plot_rmses(intra_rmse_df,'figs/fig3/fig3(b).jpg')
    plot_rmses(cross_rmse_df,'figs/fig3/fig3(d).jpg')

def error_ratio_results():
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    for di in datasets:
        path = os.path.join('results/soh_individual_dataset_results',  di)
        intra_condition_estimates,intra_condition_ture,cross_condition_estimates,cross_condition_ture = get_results(path,di)
        intra_condition_abserrors = np.abs(intra_condition_estimates - intra_condition_ture)
        cross_condition_abserrors = np.abs(cross_condition_estimates - cross_condition_ture)
        # print(max(intra_condition_abserrors),max(cross_condition_abserrors))
        intra_03ratio = (np.sum(intra_condition_abserrors < 0.03) / len(intra_condition_abserrors))*100
        cross_03ratio = (np.sum(cross_condition_abserrors < 0.03) / len(cross_condition_abserrors))*100
        # print(f"Dataset: {di}, Ratio of intra_condition_abserrors < 0.03: {intra_03ratio:.2f}")
        print(f"Dataset: {di}, Ratio of soh_individual_dataset_abserrors < 0.03: {cross_03ratio:.2f}")

path_dir = 'figs/fig3'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
# intra_cross_results()
# intra_cross_rmses()
# error_ratio_results()

## SOH estimation across datasets

### fig4a

In [3]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape

def plot_cross_di(df,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(29/25.4, 28/25.4))
    colors = ['#4575b4','#f46d43']
    sns.violinplot(x='condition', y='rmse', data=df, scale='width', inner=None, palette=colors, linewidth=0)
    ax.set_xlabel('')
    ax.set_ylabel('')
    means = df.groupby('condition')['rmse'].mean()
    stds = df.groupby('condition')['rmse'].std()
    for i, condition in enumerate(df['condition'].unique()):
        mean = means[condition]
        std = stds[condition]
        ax.plot([i - 0.2, i + 0.2], [mean, mean], '-', color='red', lw=0.5, label='Mean' if i == 0 else "")
        ax.plot([i, i], [mean - std, mean + std], '--', color='black', lw=0.5, label='Mean ± Std' if i == 0 else "")
        ax.plot([i - 0.1, i + 0.1], [mean - std, mean - std], '-', color='black', lw=0.5)
        ax.plot([i - 0.1, i + 0.1], [mean + std, mean + std], '-', color='black', lw=0.5)
    ax.set_ylim([0,7])
    ax.set_xticks([])
    plt.tight_layout()
    plt.savefig(save_path,dpi=600)
    plt.close()

def cross_dataset_errors():
    target_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    source_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    for target_di in target_datasets:
        for source_di in source_datasets:
            if target_di != source_di:
                path = os.path.join('results/soh_cross_laboratory_datasets_results',  target_di, f'transfer from {source_di}')
                intra_condition_rmses,cross_condition_rmses = get_rmse(path,target_di)
                i_grop = np.array(['intra']*len(intra_condition_rmses))
                c_grop = np.array(['cross']*len(cross_condition_rmses))
                rmses = np.concatenate([intra_condition_rmses,cross_condition_rmses])
                grop = np.concatenate([i_grop,c_grop])
                df = pd.DataFrame({'condition':grop, 'rmse':rmses})
                plot_cross_di(df,f'figs/fig4a/fig4a_{target_di}_from_{source_di}.jpg')
                # print(target_di,source_di)

def cd_errors_tabel():
    target_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    source_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    df = {'col_targes':['Dataset 1 intra','Dataset 1 cross','Dataset 2 intra','Dataset 2 cross','Dataset 3 intra','Dataset 3 cross','Dataset 4 intra','Dataset 4 cross','Dataset 5 intra','Dataset 5 cross','Dataset 6 intra','Dataset 6 cross'],
          'Dataset 1':[0]*12,
          'Dataset 2':[0]*12,
          'Dataset 3':[0]*12,
          'Dataset 4':[0]*12,
          'Dataset 5':[0]*12,
          'Dataset 6':[0]*12
          }
    for i in range(6):
        target_di = target_datasets[i]
        loc = i*2
        for source_di in source_datasets:
            if target_di != source_di:
                path = os.path.join('results/soh_cross_laboratory_datasets_results',  target_di, f'transfer from {source_di}')
                intra_condition_rmses,cross_condition_rmses = get_rmse(path,target_di)
                df[source_di][loc] = f'{np.mean(intra_condition_rmses):.3f} ± {np.std(intra_condition_rmses):.3f}'
                df[source_di][loc+1] = f'{np.mean(cross_condition_rmses):.3f} ± {np.std(cross_condition_rmses):.3f}'
    df = pd.DataFrame(df)
    df.to_csv('figs/fig4a/fig4a_tabel.csv',index=False)
path_dir = 'figs/fig4a'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
cross_dataset_errors()
cd_errors_tabel()

### fig4b

In [5]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape

def plot_di(df,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(29/25.4, 29/25.4))
    colors = ['#4575b4','#f46d43']
    sns.violinplot(x='condition', y='rmse', data=df, scale='width', inner=None, palette=colors, linewidth=0)
    ax.set_xlabel('')
    ax.set_ylabel('')
    means = df.groupby('condition')['rmse'].mean()
    stds = df.groupby('condition')['rmse'].std()
    for i, condition in enumerate(df['condition'].unique()):
        mean = means[condition]
        std = stds[condition]
        ax.plot([i - 0.2, i + 0.2], [mean, mean], '-', color='red', lw=0.5, label='Mean' if i == 0 else "")
        ax.plot([i, i], [mean - std, mean + std], '--', color='black', lw=0.5, label='Mean ± Std' if i == 0 else "")
        ax.plot([i - 0.1, i + 0.1], [mean - std, mean - std], '-', color='black', lw=0.5)
        ax.plot([i - 0.1, i + 0.1], [mean + std, mean + std], '-', color='black', lw=0.5)
    ax.set_xticks([])
    ax.set_ylim([0,7])
    # ax.set_ylim(0, None)
    plt.tight_layout()
    plt.savefig(save_path,dpi=600)
    plt.close()

def dataset_errors():
    target_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    for target_di in target_datasets:
            path = os.path.join('results/soh_pretraining_ev_data_results',  target_di)
            intra_condition_rmses,cross_condition_rmses = get_rmse(path,target_di)
            i_grop = np.array(['intra']*len(intra_condition_rmses))
            c_grop = np.array(['cross']*len(cross_condition_rmses))
            rmses = np.concatenate([intra_condition_rmses,cross_condition_rmses])
            grop = np.concatenate([i_grop,c_grop])
            df = pd.DataFrame({'condition':grop, 'rmse':rmses})
            plot_di(df,f'figs/fig4b/fig4b_{target_di}.jpg')

def errors_tabel():
    target_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    df = {'dataset':target_datasets,
            'intra_RMSE':[0]*6,
            'intra_MAPE':[0]*6,
            'cross_RMSE':[0]*6,
            'cross_MAPE':[0]*6
            }
    for i in range(6):
        target_di = target_datasets[i]
        path = os.path.join('results/soh_pretraining_ev_data_results',  target_di)
        intra_condition_rmses,cross_condition_rmses = get_rmse(path,target_di)
        intra_condition_mapes,cross_condition_mapes = get_mape(path,target_di)
        df['intra_RMSE'][i] = f'{np.mean(intra_condition_rmses):.3f} ± {np.std(intra_condition_rmses):.3f}'
        df['cross_RMSE'][i] = f'{np.mean(cross_condition_rmses):.3f} ± {np.std(cross_condition_rmses):.3f}'
        df['intra_MAPE'][i] = f'{np.mean(intra_condition_mapes):.3f} ± {np.std(intra_condition_mapes):.3f}'
        df['cross_MAPE'][i] = f'{np.mean(cross_condition_mapes):.3f} ± {np.std(cross_condition_mapes):.3f}'
    df = pd.DataFrame(df)
    df.to_csv('figs/fig4b/fig4b_tabel.csv',index=False)
path_dir = 'figs/fig4b'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
dataset_errors()
errors_tabel()

### figS5,figS6

In [4]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape

def plot_cd_results(estimates,ture,abs_errors,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(29/25.4, 29/25.4))
    cmap = plt.cm.colors.LinearSegmentedColormap.from_list('custom_cmap', ['#313695','#4575b4','#74add1','#abd9e9', '#e0f3f8'])
    norm = plt.Normalize(0, 0.1)
    ax.scatter(ture, estimates, c=abs_errors, cmap=cmap, norm=norm, s=0.25)
    # cbar = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), label='Absolute Error')
    # cbar.set_ticks([0, 0.02, 0.04, 0.06, 0.08, 0.1])
    lower_bound = min(min(estimates),min(ture))-0.01
    upper_bound = max(max(estimates),max(ture))+0.01
    if upper_bound > 1.02:
        upper_bound = 1.02
    ax.plot([lower_bound, upper_bound],[lower_bound, upper_bound],'--', color='red', linewidth=1)
    ax.set_xlim(lower_bound, upper_bound)
    ax.set_ylim(lower_bound, upper_bound)
    ax.set_xlabel('')
    ax.set_ylabel('')
    plt.tight_layout()
    plt.savefig(save_path, dpi=600)
    plt.close()

def cross_dataset_results():
    target_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    source_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    for target_di in target_datasets:
        for source_di in source_datasets:
            if target_di != source_di:
                path = os.path.join('results/soh_cross_laboratory_datasets_results',  target_di, f'transfer from {source_di}')
                intra_condition_estimates,intra_condition_ture,cross_condition_estimates,cross_condition_ture = get_results(path,target_di)
                intra_condition_abserrors = np.abs(intra_condition_estimates - intra_condition_ture)
                cross_condition_abserrors = np.abs(cross_condition_estimates - cross_condition_ture)
                # print(f'{target_di}_from_{source_di}')
                # print(max(intra_condition_abserrors))
                plot_cd_results(intra_condition_estimates,intra_condition_ture,intra_condition_abserrors,f'figs/figS5/figS5_{target_di}_from_{source_di}.jpg')
                # print(max(cross_condition_abserrors))
                plot_cd_results(cross_condition_estimates,cross_condition_ture,cross_condition_abserrors,f'figs/figS6/figS6_{target_di}_from_{source_di}.jpg')
                # print(target_di,source_di)

path_dir = 'figs/figS5'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
path_dir = 'figs/figS6'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
cross_dataset_results()

### fig5,figS7-S11 (Knowledge learned by the DNN)

In [None]:
import numpy as np
import torch
import os
import warnings
import seaborn as sns
import scienceplots
import umap
from models.CNN_BiLSTM import CNN_BiLSTM
from data_process.data_load import load_test_data
from matplotlib import pyplot as plt
warnings.filterwarnings('ignore')

def extract_features(model, X):
    input_feat,cnn_feat,rec_feat,fnn_feat = [],[],[],[]
    with torch.no_grad():
        for xi in X:
            xi = torch.tensor(xi, dtype=torch.float)
            xi = xi.view(1, 50, 1)
            input_feat.append(xi.reshape(1,-1))

            xi = xi.permute(0,2,1)
            xi = model.cnn1(xi)
            xi = model.cnn2(xi)
            xi = xi.permute(0,2,1)
            cnn_feat.append(xi.reshape(1,-1))

            xi,_ = model.lstm1_1(xi)
            xi = model.lstm1_2(xi)
            xi,_ = model.lstm2_1(xi)
            xi = model.lstm2_2(xi)
            rec_feat.append(xi.reshape(1,-1))

            xi = model.flt(xi)
            xi = model.fc1(xi)
            fnn_feat.append(xi.reshape(1,-1))
    input_feat = torch.cat(input_feat).numpy()
    cnn_feat = torch.cat(cnn_feat).numpy()
    rec_feat = torch.cat(rec_feat).numpy()
    fnn_feat = torch.cat(fnn_feat).numpy()
    return [input_feat, cnn_feat, rec_feat, fnn_feat]

def plot_embeddings(embeddings, labels,op_conditions,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(55/25.4, 55/25.4))
    marks = ['o', '*', 's', '^', 'D', 'x']
    conditions = np.unique(op_conditions)
    for i,domain in enumerate(conditions):
        idx = op_conditions == domain
        scatter = sns.scatterplot(x=embeddings[idx, 0][::10], y=embeddings[idx, 1][::10], 
                                hue=labels[idx][::10], palette='coolwarm',
                                marker=marks[i], 
                                edgecolor='black', s=10, linewidth=0.2)
    # norm = plt.Normalize(labels.min(), labels.max())
    # sm = plt.cm.ScalarMappable(cmap='coolwarm', norm=norm)
    # sm.set_array([])
    # plt.colorbar(sm, label='SOH')  # add colorbar
    # ax.set_xticks([0,10],[])
    # ax.set_yticks([-5,5,15],[])
    ax.set_xlabel('Dimension 1', fontsize=6)
    ax.set_ylabel('Dimension 2', fontsize=6)
    plt.legend().remove()
    plt.tight_layout()
    plt.savefig(save_path, dpi=600)
    plt.close()

datasets = ['Dataset 1', 'Dataset 2', 'Dataset 3','Dataset 4', 'Dataset 5', 'Dataset 6']
condition_code = {
        'Dataset 1': [1,2],
        'Dataset 2': [5,6,7],
        'Dataset 3': [8,9,10,11],
        'Dataset 4': [16,17,18],
        'Dataset 5': [19,20,21],
        'Dataset 6': [22,23,24,25,26,27]
    }
ft_cells = ['Experiment1(25-1)','Experiment1(1C-4)','Experiment1(0-CC-2)','Experiment1(CY25-05_1-#14)','Experiment1(2C-6)','Experiment1(25_0.5b_100)']
test_cells = {
    'Dataset 1':[['25-2'],['45-1','45-2']],
    'Dataset 2':[['1C-5','1C-6'],['2C-4','2C-5'],['3C-4','3C-5']],
    'Dataset 3':[['0-CC-1','0-CC-3'],['10-CC-1','10-CC-2'],['25-CC-1','25-CC-2'],['40-CC-1','40-CC-2']],
    'Dataset 4':[['CY25-05_1-#12','CY25-05_1-#16'],['CY35-05_1-#1','CY35-05_1-#2'],['CY45-05_1-#21','CY45-05_1-#22']],
    'Dataset 5':[['2C-5','2C-7'],['3C-5','3C-6'],['4C-5','4C-6']],
    'Dataset 6':[['25_0.5a_100'],['25_1a_100','25_1b_100'],['25_2a_100','25_2b_100'],['25_3a_100','25_3b_100'],['35_1a_100','35_1b_100'],['35_2a_100','35_2b_100']],
}

fig_name = ['figS7', 'figS8', 'figS9', 'figS10', 'figS11', 'fig5']
for i in range(6):
    path_dir = 'figs/'+fig_name[i]
    if not os.path.exists(path_dir):
        os.makedirs(path_dir)
    di = datasets[i]
    di_conditions = condition_code[di]
    di_test_cells = test_cells[di]
    x,sohs,conditions = [],[],[]
    for j in range(len(di_conditions)):
        for cell in di_test_cells[j]:
            data = load_test_data(di,cell)
            x.append(np.array(data['X']))
            sohs.append(np.array(data['Y']))
            conditions.append(np.array([di_conditions[j]]*len(data['Y'])))
    x = np.concatenate(x, axis=0)
    sohs = np.concatenate(sohs, axis=0)
    conditions = np.concatenate(conditions, axis=0)

    umap_model = umap.UMAP(
        n_neighbors=150,
        min_dist=0.1,
        metric='euclidean',
        random_state=42
    )

    # 预训练完成的DNN的特征提取过程
    # Feature extraction process of the pre-trained DNN
    pre_path = os.path.join('results\soh_pretraining_ev_data_results','pre_trained_model.pth')
    pre_model = CNN_BiLSTM()
    pre_model.load_state_dict(torch.load(pre_path))
    pre_model.eval()
    pre_feas = extract_features(pre_model, x)
    for j in range(len(pre_feas)):
        pre_feaj_umap = umap_model.fit_transform(pre_feas[j])
        plot_embeddings(pre_feaj_umap, sohs, conditions, path_dir + f'/{di}_pre_{j+1}d.jpg')
        # if j==3:
        #     plot_embeddings(pre_feaj_umap, sohs, conditions, f'features/{di}_colorbar.jpg')
        
        
    # 微调DNN的特征提取过程
    # Feature extraction process of the fine-tuned DNN
    ft_path = os.path.join('results\soh_pretraining_ev_data_results',di,ft_cells[i],'from_Dataset 7_ft_model.pth')
    ft_model = CNN_BiLSTM()
    ft_model.load_state_dict(torch.load(ft_path))
    ft_model.eval()
    ft_feas = extract_features(ft_model, x)
    for j in range(len(ft_feas)):
        ft_feaj_umap = umap_model.fit_transform(ft_feas[j])
        plot_embeddings(ft_feaj_umap, sohs, conditions, path_dir + f'/{di}_ft_{j+1}d.jpg')

## comparison with existing methods

### fig6 a

In [5]:
import matplotlib.pyplot as plt
import seaborn as sns
import scienceplots
import numpy as np
import pandas as pd
import os
from load_results import get_rmse,get_mape

def plot_comparison_rmses(df,save_path):
    plt.style.use('ieee')
    plt.style.use('science')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(160/25.4, 45/25.4))
    colors = ['#deebf7','#c6dbef','#9ecae1','#6baed6','#4292c6','#2171b5','#08519c']
    sns.barplot(x='Dataset', y='rmse', hue='method', data=df, palette=colors,
                errorbar='sd',  
                linestyle='-',
                errwidth=0.3,
                capsize=.05,
                errcolor='black',
                ax = ax
                )
    ax.set_ylim([0,6])
    ax.set_xlabel('')
    ax.set_ylabel('RMSE (\%)')
    if 'intra' in save_path:
        ax.legend(loc='upper left', ncol=3)
    else:
        ax.legend().remove()
    plt.tight_layout()
    # plt.show()
    plt.savefig(save_path,dpi=600)
    plt.close()

def comparison_rmses():
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    methods = ['SVR','RF','GPR','CNN','AE','Benchmark','WSL']
    intra_grops = []
    intra_meths = []
    intra_rmses = []
    cross_grops = []
    cross_meths = []
    cross_rmses = []
    for di in datasets:
        for i in range(len(methods)):
            methi = methods[i]
            path = path = os.path.join('results/comparison_methods_with_limited_labels_results', methi, di)
            if methi =='WSL':
                path = os.path.join('results/soh_individual_dataset_results', di)
            intra_condition_rmses,cross_condition_rmses = get_rmse(path,di)

            i_gropi = np.array([di]*len(intra_condition_rmses))
            i_methi = np.array([methi]*len(intra_condition_rmses))
            intra_grops.append(i_gropi)
            intra_meths.append(i_methi)
            intra_rmses.append(intra_condition_rmses)

            c_gropi = np.array([di]*len(cross_condition_rmses))
            c_methi = np.array([methi]*len(cross_condition_rmses))
            cross_grops.append(c_gropi)
            cross_meths.append(c_methi)
            cross_rmses.append(cross_condition_rmses)
    intra_grops = np.concatenate(intra_grops)
    intra_meths = np.concatenate(intra_meths)
    intra_rmses = np.concatenate(intra_rmses)

    cross_grops = np.concatenate(cross_grops)
    cross_meths = np.concatenate(cross_meths)
    cross_rmses = np.concatenate(cross_rmses)

    intra_rmse_df = pd.DataFrame({'Dataset':intra_grops, 'method':intra_meths,'rmse':intra_rmses})
    cross_rmse_df = pd.DataFrame({'Dataset':cross_grops, 'method':cross_meths,'rmse':cross_rmses})
    plot_comparison_rmses(intra_rmse_df,'figs/fig6/fig6(a)_intra.jpg')
    plot_comparison_rmses(cross_rmse_df,'figs/fig6/fig6(a)_cross.jpg')


def comparison_tabel():
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    methods = ['SVR','RF','GPR','CNN','AE','Benchmark','WSL']
    errors = ['RMSE','MAPE']
    error_df = {'dataset':[],'method':[]}
    for condition in ['intra','cross']:
        for eri in errors:
            error_df[condition+'_'+eri] = []

    for di in datasets:
        for i in range(len(methods)):
            methi = methods[i]
            path = path = os.path.join('results/comparison_methods_with_limited_labels_results', methi, di)
            if methi =='WSL':
                path = os.path.join('results/soh_individual_dataset_results', di)
            intra_condition_rmses,cross_condition_rmses = get_rmse(path,di)
            intra_condition_mapes,cross_condition_mapes = get_mape(path,di)
            error_df['dataset'].append(di)
            error_df['method'].append(methi)
            error_df['intra_RMSE'].append(f'{np.mean(intra_condition_rmses):.3f}±{np.std(intra_condition_rmses):.3f}')
            error_df['intra_MAPE'].append(f'{np.mean(intra_condition_mapes):.3f}±{np.std(intra_condition_mapes):.3f}')
            error_df['cross_RMSE'].append(f'{np.mean(cross_condition_rmses):.3f}±{np.std(cross_condition_rmses):.3f}')
            error_df['cross_MAPE'].append(f'{np.mean(cross_condition_mapes):.3f}±{np.std(cross_condition_mapes):.3f}')
    error_df = pd.DataFrame(error_df)
    error_df.to_csv('figs/fig6/fig6(a)_tabel.csv', index=False)

path_dir = 'figs/fig6'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)

comparison_rmses()
comparison_tabel()

### figS12-S17

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import scienceplots
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape

def plot_comparison_results(estimates,ture,abs_errors,dataset,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(45/25.4, 45/25.4))
    cmap = plt.cm.colors.LinearSegmentedColormap.from_list('custom_cmap', ['#313695','#4575b4','#74add1','#abd9e9', '#e0f3f8'])
    norm = plt.Normalize(0, 0.1)
    ax.scatter(ture, estimates, c=abs_errors, cmap=cmap, norm=norm, s=0.25)
    # if dataset == 'Dataset 6':
    #     plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), label='Absolute Error')
    lower_bound = min(min(estimates),min(ture))-0.01
    upper_bound = max(max(estimates),max(ture))+0.01
    if upper_bound > 1.02:
        upper_bound = 1.02
    ax.plot([lower_bound, upper_bound],[lower_bound, upper_bound],'--', color='red', linewidth=1)
    ax.set_xlim(lower_bound, upper_bound)
    ax.set_ylim(lower_bound, upper_bound)
    ax.set_xlabel('True SOH')
    ax.set_ylabel('Estimation')
    ax.set_title(dataset,fontsize=8)
    plt.tight_layout()
    plt.savefig(save_path, dpi=600)
    plt.close()

def comparison_results():
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    methods = ['SVR','RF','GPR','CNN','AE','Benchmark']
    fig_paths = ['figS12','figS13','figS14','figS15','figS16','figS17']
    for i in range(len(methods)):
        methi = methods[i]
        fig_ph = os.path.join('figs',fig_paths[i]+'_'+methi)
        if not os.path.exists(fig_ph):
            os.makedirs(fig_ph)

        for di in datasets:
            path = path = os.path.join('results/comparison_methods_with_limited_labels_results', methi, di)
            intra_condition_estimates,intra_condition_ture,cross_condition_estimates,cross_condition_ture = get_results(path,di)
            intra_condition_abserrors = np.abs(intra_condition_estimates - intra_condition_ture)
            cross_condition_abserrors = np.abs(cross_condition_estimates - cross_condition_ture)
            # print(max(intra_condition_abserrors))
            plot_comparison_results(intra_condition_estimates,intra_condition_ture,intra_condition_abserrors,di,os.path.join(fig_ph,'intra'+di+'.jpg'))
            # print(max(cross_condition_abserrors))
            plot_comparison_results(cross_condition_estimates,cross_condition_ture,cross_condition_abserrors,di,os.path.join(fig_ph,'cross'+di+'.jpg'))

comparison_results()

### fig6b, figS18

In [3]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os

def plot_comparison_single_label(df,num_condition,save_path):
    plt.style.use('ieee')
    plt.style.use('science')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=((10+16*num_condition)/25.4, 45/25.4))
    ax.axvline(x=0.5, color='black', linestyle='--',linewidth=0.25)
    ax.axvspan(-0.5, 0.5, facecolor='lightgray', alpha=0.3)
    ax.axvspan(0.5, len(df['condition'].unique()) - 0.5, facecolor='lightblue', alpha=0.3)
    colors = ['#fff7bc','#fee391','#fec44f','#fe9929','#ec7014','#cc4c02','#8c2d04']
    sns.barplot(x='condition', y='rmse', hue='method', data=df, palette=colors,
                errorbar='sd',  
                linestyle='-',
                errwidth=0.3,
                capsize=.05,
                errcolor='black',
                ax = ax
                )
    ax.set_xlabel('')
    ax.set_ylabel('RMSE (\%)')
    ax.set_ylim(0, None)
    if 'Dataset 3' in save_path:
        ax.legend(loc='upper left', ncol=2)
    elif 'Dataset 6' in save_path:
        ax.legend(loc='upper right', ncol=2)
    else:
        ax.legend().remove()
    plt.tight_layout()
    # plt.show()
    plt.savefig(save_path,dpi=600)
    plt.close()


def get_rmse_single(path,test_cells):
    rmses = []
    df = pd.read_csv(path+'/eval_metrics.csv')
    for j in range(len(df)):
        test_cell = df.iloc[j,0]
        if test_cell in test_cells:
            rmses.append(df.iloc[j]['RMSE']*100)
    return rmses

def comparison_single_label():
    test_cells = {
        'Dataset 1': [['25-'+str(i+1) for i in range(4,8)],['45-'+str(i+1) for i in range(6)]],
        'Dataset 2': [['1C-8','1C-9','1C-10'],['2C-4','2C-5'],['3C-'+str(i+1) for i in range(3,10)]],
        'Dataset 3': [['0-CC-3'],['10-CC-1','10-CC-2','10-CC-3'],['25-CC-1','25-CC-2','25-CC-3'],['40-CC-1','40-CC-2','40-CC-3']],
        'Dataset 4': [['CY25-05_1-#18','CY25-05_1-#19'],['CY35-05_1-#1','CY35-05_1-#2'],['CY45-05_1-#'+str(i+1) for i in range(20,28)]],
        'Dataset 5': [['2C-8'],['3C-'+str(i+1) for i in range(4,15)],['4C-5','4C-6']],
        'Dataset 6': [['25_1d_100'],['25_0.5a_100','25_0.5b_100'],['25_2a_100','25_2b_100'],['25_3a_100','25_3b_100','25_3c_100','25_3d_100'],
                      ['35_1a_100','35_1b_100','35_1c_100','35_1d_100'],['35_2a_100','35_2b_100']]
    }
    condition_code = {
        'Dataset 1': [1,2],
        'Dataset 2': [5,6,7],
        'Dataset 3': [8,9,10,11],
        'Dataset 4': [16,17,18],
        'Dataset 5': [19,20,21],
        'Dataset 6': [23,22,24,25,26,27]
    }
    methods = ['SVR','RF','GPR','CNN','Benchmark','WSL(enough labels)','WSL(limited labels)']
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    wsl_files = ['Experiment1(25-1)','Experiment2(1C-5)','Experiment1(0-CC-2)','Experiment2(CY25-05_1-#16)','Experiment2(2C-7)','Experiment2(25_1a_100)']
    for i,di in enumerate(datasets):
        grops = []
        meths = []
        rmses = []
        for j,conditionj in enumerate(condition_code[di]):
            cells = test_cells[di][j]
            for meth in methods:
                if meth =='WSL(enough labels)':
                    path = f'results/comparison_methods_with_enough_labels_results/single_condition_label/{di}/WSL_el'
                elif meth =='WSL(limited labels)':
                    path = f'results/soh_individual_dataset_results/{di}/{wsl_files[i]}'
                else:
                    path = os.path.join('results/comparison_methods_with_enough_labels_results/single_condition_label', di, meth)
                ci_rmses = get_rmse_single(path,cells)
                ci_gropj = np.array([rf'\#{conditionj}']*len(ci_rmses))
                ci_methj = np.array([meth]*len(ci_rmses))
                grops.append(ci_gropj)
                meths.append(ci_methj)
                rmses.append(ci_rmses)
        grops = np.concatenate(grops)
        meths = np.concatenate(meths)
        rmses = np.concatenate(rmses)
        df = pd.DataFrame({'condition':grops, 'method':meths,'rmse':rmses})
        if di in ['Dataset 1','Dataset 2','Dataset 3']:
            plot_comparison_single_label(df,len(condition_code[di]),f'figs/fig6/fig6(b)_{di}.jpg')
        else:
            plot_comparison_single_label(df,len(condition_code[di]),f'figs/figS18/{di}.jpg')


def single_label_errors():
    test_cells = {
        'Dataset 1': [['25-'+str(i+1) for i in range(4,8)],['45-'+str(i+1) for i in range(6)]],
        'Dataset 2': [['1C-8','1C-9','1C-10'],['2C-4','2C-5'],['3C-'+str(i+1) for i in range(3,10)]],
        'Dataset 3': [['0-CC-3'],['10-CC-1','10-CC-2','10-CC-3'],['25-CC-1','25-CC-2','25-CC-3'],['40-CC-1','40-CC-2','40-CC-3']],
        'Dataset 4': [['CY25-05_1-#18','CY25-05_1-#19'],['CY35-05_1-#1','CY35-05_1-#2'],['CY45-05_1-#'+str(i+1) for i in range(20,28)]],
        'Dataset 5': [['2C-8'],['3C-'+str(i+1) for i in range(4,15)],['4C-5','4C-6']],
        'Dataset 6': [['25_1d_100'],['25_0.5a_100','25_0.5b_100'],['25_2a_100','25_2b_100'],['25_3a_100','25_3b_100','25_3c_100','25_3d_100'],
                      ['35_1a_100','35_1b_100','35_1c_100','35_1d_100'],['35_2a_100','35_2b_100']]
    }
    condition_code = {
        'Dataset 1': [1,2],
        'Dataset 2': [5,6,7],
        'Dataset 3': [8,9,10,11],
        'Dataset 4': [16,17,18],
        'Dataset 5': [19,20,21],
        'Dataset 6': [23,22,24,25,26,27]
    }
    methods = ['SVR','RF','GPR','CNN','Benchmark','WSL(enough labels)','WSL(limited labels)']
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    wsl_files = ['Experiment1(25-1)','Experiment2(1C-5)','Experiment1(0-CC-2)','Experiment2(CY25-05_1-#16)','Experiment2(2C-7)','Experiment2(25_1a_100)']

    for i,di in enumerate(datasets):
        df = {'method':methods}
        for j,conditionj in enumerate(condition_code[di]):
            cells = test_cells[di][j]
            df[f'#{conditionj}'] = ['']*len(methods)
            for meth in methods:
                if meth =='WSL(enough labels)':
                    path = f'results/comparison_methods_with_enough_labels_results/single_condition_label/{di}/WSL_el'
                elif meth =='WSL(limited labels)':
                    path = f'results/soh_individual_dataset_results/{di}/{wsl_files[i]}'
                else:
                    path = os.path.join('results/comparison_methods_with_enough_labels_results/single_condition_label', di, meth)
                ci_rmses = get_rmse_single(path,cells)
                df[f'#{conditionj}'][methods.index(meth)] = f'{np.mean(ci_rmses):.3f}'
        df = pd.DataFrame(df)
        df.to_csv(f'figs/fig6/fig6(b)_{di}_tabel.csv',index=False)

path_dir = 'figs/figS18'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
comparison_single_label()
single_label_errors()

### fig6c, figS19

In [4]:
import matplotlib.pyplot as plt
import seaborn as sns
import scienceplots
import numpy as np
import pandas as pd
import os

def plot_comparison_full_label(df,num_condition,save_path):
    plt.style.use('ieee')
    plt.style.use('science')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=((10+16*num_condition)/25.4, 45/25.4))
    colors = ['#fff7bc','#fee391','#fec44f','#fe9929','#ec7014','#cc4c02','#8c2d04']
    sns.barplot(x='condition', y='rmse', hue='method', data=df, palette=colors,
                errorbar='sd',  
                linestyle='-',
                errwidth=0.3,
                capsize=.05,
                errcolor='black',
                ax = ax
                )
    ax.set_xlabel('')
    ax.set_ylabel('RMSE (\%)')
    ax.set_ylim(0, None)
    if 'Dataset 3' in save_path:
        ax.legend(loc='upper left', ncol=2)
        ax.set_yticks([0,1,2,3])
    elif 'Dataset 6' in save_path:
        ax.legend(loc='upper right', ncol=2)
    else:
        ax.legend().remove()
    plt.tight_layout()
    plt.savefig(save_path,dpi=600)
    plt.close()
    # plt.show()


def get_rmse_single(path,test_cells):
    rmses = []
    df = pd.read_csv(path+'/eval_metrics.csv')
    for j in range(len(df)):
        test_cell = df.iloc[j,0]
        if test_cell in test_cells:
            rmses.append(df.iloc[j]['RMSE']*100)
    return rmses

def comparison_full_label():
    test_cells = {
        'Dataset 1': [['25-'+str(i+1) for i in range(2,8)],['45-'+str(i+1) for i in range(2,6)]],
        'Dataset 2': [['1C-'+str(i+1) for i in range(5,10)],['2C-5'],['3C-'+str(i+1) for i in range(5,10)]],
        'Dataset 3': [['0-CC-1','0-CC-3'],['10-CC-2','10-CC-3'],['25-CC-2','25-CC-3'],['40-CC-2','40-CC-3']],
        'Dataset 4': [['CY25-05_1-#18','CY25-05_1-#19'],['CY35-05_1-#2'],['CY45-05_1-#'+str(i+1) for i in range(23,28)]],
        'Dataset 5': [['2C-6','2C-8'],['3C-'+str(i+1) for i in range(6,15)],['4C-6']],
        'Dataset 6': [['25_0.5b_100'],['25_1b_100','25_1c_100','25_1d_100'],['25_2b_100'],['25_3b_100','25_3c_100','25_3d_100'],
                      ['35_1b_100','35_1c_100','35_1d_100'],['35_2b_100']]
    }
    condition_code = {
        'Dataset 1': [1,2],
        'Dataset 2': [5,6,7],
        'Dataset 3': [8,9,10,11],
        'Dataset 4': [16,17,18],
        'Dataset 5': [19,20,21],
        'Dataset 6': [22,23,24,25,26,27]
    }
    methods = ['SVR','RF','GPR','CNN','Benchmark','WSL(enough labels)','WSL(limited labels)']
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    wsl_files = ['Experiment1(25-1)','Experiment2(1C-5)','Experiment1(0-CC-2)','Experiment2(CY25-05_1-#16)','Experiment2(2C-7)','Experiment2(25_1a_100)']
    for i,di in enumerate(datasets):
        grops = []
        meths = []
        rmses = []
        for j,conditionj in enumerate(condition_code[di]):
            cells = test_cells[di][j]
            for meth in methods:
                if meth =='WSL(enough labels)':
                    path = f'results/comparison_methods_with_enough_labels_results/full_condition_label/{di}/WSL_el'
                elif meth =='WSL(limited labels)':
                    path = f'results/soh_individual_dataset_results/{di}/{wsl_files[i]}'
                else:
                    path = os.path.join('results/comparison_methods_with_enough_labels_results/full_condition_label', di, meth)
                ci_rmses = get_rmse_single(path,cells)
                ci_gropj = np.array([rf'\#{conditionj}']*len(ci_rmses))
                ci_methj = np.array([meth]*len(ci_rmses))
                grops.append(ci_gropj)
                meths.append(ci_methj)
                rmses.append(ci_rmses)
        grops = np.concatenate(grops)
        meths = np.concatenate(meths)
        rmses = np.concatenate(rmses)
        df = pd.DataFrame({'condition':grops, 'method':meths,'rmse':rmses})
        if di in ['Dataset 1','Dataset 2','Dataset 3']:
            plot_comparison_full_label(df,len(condition_code[di]),f'figs/fig6/fig6(c)_{di}.jpg')
        else:
            plot_comparison_full_label(df,len(condition_code[di]),f'figs/figS19/{di}.jpg')

def full_label_errors():
    test_cells = {
        'Dataset 1': [['25-'+str(i+1) for i in range(2,8)],['45-'+str(i+1) for i in range(2,6)]],
        'Dataset 2': [['1C-'+str(i+1) for i in range(5,10)],['2C-5'],['3C-'+str(i+1) for i in range(5,10)]],
        'Dataset 3': [['0-CC-1','0-CC-3'],['10-CC-2','10-CC-3'],['25-CC-2','25-CC-3'],['40-CC-2','40-CC-3']],
        'Dataset 4': [['CY25-05_1-#18','CY25-05_1-#19'],['CY35-05_1-#2'],['CY45-05_1-#'+str(i+1) for i in range(23,28)]],
        'Dataset 5': [['2C-6','2C-8'],['3C-'+str(i+1) for i in range(6,15)],['4C-6']],
        'Dataset 6': [['25_0.5b_100'],['25_1b_100','25_1c_100','25_1d_100'],['25_2b_100'],['25_3b_100','25_3c_100','25_3d_100'],
                      ['35_1b_100','35_1c_100','35_1d_100'],['35_2b_100']]
    }
    condition_code = {
        'Dataset 1': [1,2],
        'Dataset 2': [5,6,7],
        'Dataset 3': [8,9,10,11],
        'Dataset 4': [16,17,18],
        'Dataset 5': [19,20,21],
        'Dataset 6': [22,23,24,25,26,27]
    }
    methods = ['SVR','RF','GPR','CNN','Benchmark','WSL(enough labels)','WSL(limited labels)']
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    wsl_files = ['Experiment1(25-1)','Experiment2(1C-5)','Experiment1(0-CC-2)','Experiment2(CY25-05_1-#16)','Experiment2(2C-7)','Experiment2(25_1a_100)']
    for i,di in enumerate(datasets):
        df = {'method':methods}
        for j,conditionj in enumerate(condition_code[di]):
            cells = test_cells[di][j]
            df[f'#{conditionj}'] = ['']*len(methods)
            for meth in methods:
                if meth =='WSL(enough labels)':
                    path = f'results/comparison_methods_with_enough_labels_results/full_condition_label/{di}/WSL_el'
                elif meth =='WSL(limited labels)':
                    path = f'results/soh_individual_dataset_results/{di}/{wsl_files[i]}'
                else:
                    path = os.path.join('results/comparison_methods_with_enough_labels_results/full_condition_label', di, meth)
                ci_rmses = get_rmse_single(path,cells)
                df[f'#{conditionj}'][methods.index(meth)] = f'{np.mean(ci_rmses):.3f}'
        df = pd.DataFrame(df)
        df.to_csv(f'figs/fig6/fig6(c)_{di}_tabel.csv',index=False)

path_dir = 'figs/figS19'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
comparison_full_label()
full_label_errors()

### fig6d

In [5]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape

def plot_comparison_benchmark_tl(df,save_path):
    plt.style.use('ieee')
    plt.style.use('science')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(32/25.4, 45/25.4))
    colors = ['#4575b4','#ec7014']
    sns.barplot(x='condition', y='rmse', hue='method', data=df, palette=colors,
                errorbar='sd',  
                linestyle='-',
                errwidth=0.3,
                capsize=.05,
                errcolor='black',
                ax = ax
                )
    ax.set_xlabel('')
    ax.set_ylabel('RMSE (\%)')
    if 'Dataset 1' in save_path:
        ax.legend(loc='upper left', ncol=1)
    else:
        ax.legend().remove()
    if 'Dataset 5' in save_path:
        ax.set_yticks([0,1,2,3])
    else:
        ax.set_yticks([0,1,2])
    plt.tight_layout()
    plt.savefig(save_path,dpi=600)
    plt.close()

def comparison_benchmark_tl():
    target_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5']
    methods = ['Benchmark_TL','WSL']
    for target_di in target_datasets:
        condition_grop,method_grop,rmses = [],[],[]
        for meth in methods:
            path = os.path.join('results/comparison_methods_benchmark_tl_results',meth, f'{target_di} transfer from Dataset 6')
            intra_condition_rmses,cross_condition_rmses = get_rmse(path,target_di)
            condition_grop.append(np.array(['intra']*len(intra_condition_rmses)))
            condition_grop.append(np.array(['cross']*len(cross_condition_rmses)))
            method_grop.append(np.array([meth]*(len(intra_condition_rmses)+len(cross_condition_rmses))))
            rmses.append(intra_condition_rmses)
            rmses.append(cross_condition_rmses)
        condition_grop = np.concatenate(condition_grop)
        method_grop = np.concatenate(method_grop)
        rmses = np.concatenate(rmses)
        df = pd.DataFrame({'condition':condition_grop, 'method':method_grop, 'rmse':rmses})
        plot_comparison_benchmark_tl(df,f'figs/fig6/fig6(d)_{target_di}.jpg')
    
def benchmark_tl_errors():
    target_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5']
    methods = ['Benchmark_TL','WSL']
    df = {'dataset-method':[],
            'intra_RMSE':[],
            'intra_MAPE':[],
            'cross_RMSE':[],
            'cross_MAPE':[]
            }
    for i,target_di in enumerate(target_datasets):
        for meth in methods:
            df['dataset-method'].append(f'{target_di} {meth}')
            df['intra_RMSE'].append('')
            df['intra_MAPE'].append('')
            df['cross_RMSE'].append('')
            df['cross_MAPE'].append('')
            path = os.path.join('results/comparison_methods_benchmark_tl_results',meth, f'{target_di} transfer from Dataset 6')
            intra_condition_rmses,cross_condition_rmses = get_rmse(path,target_di)
            intra_condition_mapes,cross_condition_mapes = get_mape(path,target_di)
            df['intra_RMSE'][methods.index(meth)+2*i] = f'{np.mean(intra_condition_rmses):.3f}±{np.std(intra_condition_rmses):.3f}'
            df['intra_MAPE'][methods.index(meth)+2*i] = f'{np.mean(intra_condition_mapes):.3f}±{np.std(intra_condition_mapes):.3f}'
            df['cross_RMSE'][methods.index(meth)+2*i] = f'{np.mean(cross_condition_rmses):.3f}±{np.std(cross_condition_rmses):.3f}'
            df['cross_MAPE'][methods.index(meth)+2*i] = f'{np.mean(cross_condition_mapes):.3f}±{np.std(cross_condition_mapes):.3f}'
    df = pd.DataFrame(df)
    df.to_csv('figs/fig6/fig6(d)_tabel.csv',index=False)

comparison_benchmark_tl()
benchmark_tl_errors()

## Effect of the pre-training and fine-tuning data sizes

### fig7a-b, figS20a-b

In [None]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape
from matplotlib.colors import LinearSegmentedColormap

def plot_pre_rate(df,dataset,save_path):
    plt.style.use('ieee')
    plt.style.use('science')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(48/25.4, 38/25.4))
    plt.tick_params(axis='x', which='minor', bottom=False,top=False)
    cs = None
    if '(a)' in save_path:
        cs = ['#deebf7','#c6dbef','#9ecae1','#6baed6','#4292c6','#2171b5','#08519c'][::-1]
    else:
        cs = ['#fff7bc','#fee391','#fec44f','#fe9929','#ec7014','#cc4c02','#8c2d04'][::-1]
    cs = LinearSegmentedColormap.from_list('cs', cs, N=10)
    sns.barplot(x='pre_rate', y='rmse', data=df, palette=cs(np.linspace(0, 1, 10)),
                errorbar='sd',
                linestyle='-',
                errwidth=0.3,
                capsize=0.3,
                errcolor='black',
                ax = ax
                )
    ax.set_title(dataset)
    ax.set_xlim([-0.8,9.8])
    ax.set_xlabel('')
    ax.set_ylabel('RMSE (\%)')
    plt.tight_layout()
    plt.savefig(save_path,dpi=600)
    plt.close()

def pre_samples_errors():
    target_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    pre_data_rates = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,'all']
    for di in target_datasets:
        intra_rmses = []
        intra_grops = []
        cross_rmses = []
        cross_grops = []
        for pre_rate in pre_data_rates:
            path = os.path.join('results/pretraining_samples_effect_results', di, f'pre_rate={str(pre_rate)}')
            if pre_rate == 'all':
                pre_rate = 1
                path = os.path.join('results/soh_individual_dataset_results', di)
            intra_condition_rmses,cross_condition_rmses = get_rmse(path,di)
            i_grop = np.array([str(pre_rate)]*len(intra_condition_rmses))
            c_grop = np.array([str(pre_rate)]*len(cross_condition_rmses))
            intra_rmses.append(intra_condition_rmses)
            intra_grops.append(i_grop)
            cross_rmses.append(cross_condition_rmses)
            cross_grops.append(c_grop)
        intra_rmses = np.concatenate(intra_rmses)
        intra_grops = np.concatenate(intra_grops)
        cross_rmses = np.concatenate(cross_rmses)
        cross_grops = np.concatenate(cross_grops)
        intra_df = pd.DataFrame({'pre_rate':intra_grops, 'rmse':intra_rmses})
        cross_df = pd.DataFrame({'pre_rate':cross_grops, 'rmse':cross_rmses})
        if di in ['Dataset 1','Dataset 2']:
            plot_pre_rate(intra_df,di,f'figs/fig7/fig7(a)_{di}.jpg')
            plot_pre_rate(cross_df,di,f'figs/fig7/fig7(b)_{di}.jpg')
        else:
            plot_pre_rate(intra_df,di,f'figs/figS20/figS20(a)_{di}.jpg')
            plot_pre_rate(cross_df,di,f'figs/figS20/figS20(b)_{di}.jpg')

def print_pre_samples_errors(di):
    pre_data_rates = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,'all']
    for pre_rate in pre_data_rates:
        path = os.path.join('results/pretraining_samples_effect_results', di, f'pre_rate={str(pre_rate)}')
        if pre_rate == 'all':
            pre_rate = 1
            path = os.path.join('results/soh_individual_dataset_results', di)
        intra_condition_rmses,cross_condition_rmses = get_rmse(path,di)
        print(f'{di}_{pre_rate}: {np.mean(intra_condition_rmses):.3f} ± {np.std(intra_condition_rmses):.3f} ; {np.mean(cross_condition_rmses):.3f} ± {np.std(cross_condition_rmses):.3f}')
            

path_dir = 'figs/fig7'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
path_dir = 'figs/figS20'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
pre_samples_errors()
# print_pre_samples_errors('Dataset 1')


### fig7c-d, figS20c-d

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape
from matplotlib.colors import LinearSegmentedColormap

def plot_ft_num(df,dataset,save_path):
    plt.style.use('ieee')
    plt.style.use('science')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(48/25.4, 38/25.4))
    plt.tick_params(axis='x', which='minor', bottom=False,top=False)
    cs = None
    if '(c)' in save_path:
        cs = ['#deebf7','#c6dbef','#9ecae1','#6baed6','#4292c6','#2171b5','#08519c'][::-1]
    else:
        cs = ['#fff7bc','#fee391','#fec44f','#fe9929','#ec7014','#cc4c02','#8c2d04'][::-1]
    cs = LinearSegmentedColormap.from_list('cs', cs, N=10)
    sns.barplot(x='ft_num', y='rmse', data=df, palette=cs(np.linspace(0, 1, 10)),
                errorbar='sd',
                linestyle='-',
                errwidth=0.3,
                capsize=0.3,  
                errcolor='black',  
                ax = ax
                )
    ax.set_title(dataset)
    ax.set_xlim([-0.8,9.8])
    ax.set_xlabel('')
    ax.set_ylabel('RMSE (\%)')
    plt.tight_layout()
    plt.savefig(save_path,dpi=600)
    plt.close()

def ft_samples_errors():
    target_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    ft_data_nums = [2+i for i in range(9)]+['all']
    for di in target_datasets:
        intra_rmses = []
        intra_grops = []
        cross_rmses = []
        cross_grops = []
        for ft_num in ft_data_nums:
            path = os.path.join('results/fine_tuning_samples_effect_results', di, f'ft_num={str(ft_num)}')
            if ft_num == 6:
                path = os.path.join('results/soh_individual_dataset_results', di)
            intra_condition_rmses,cross_condition_rmses = get_rmse(path,di)
            i_grop = np.array([str(ft_num)]*len(intra_condition_rmses))
            c_grop = np.array([str(ft_num)]*len(cross_condition_rmses))
            intra_rmses.append(intra_condition_rmses)
            intra_grops.append(i_grop)
            cross_rmses.append(cross_condition_rmses)
            cross_grops.append(c_grop)
        intra_rmses = np.concatenate(intra_rmses)
        intra_grops = np.concatenate(intra_grops)
        cross_rmses = np.concatenate(cross_rmses)
        cross_grops = np.concatenate(cross_grops)
        intra_df = pd.DataFrame({'ft_num':intra_grops, 'rmse':intra_rmses})
        cross_df = pd.DataFrame({'ft_num':cross_grops, 'rmse':cross_rmses})
        if di in ['Dataset 1','Dataset 2']:
            plot_ft_num(intra_df,di,f'figs/fig7/fig7(c)_{di}.jpg')
            plot_ft_num(cross_df,di,f'figs/fig7/fig7(d)_{di}.jpg')
        else:
            plot_ft_num(intra_df,di,f'figs/figS20/figS20(c)_{di}.jpg')
            plot_ft_num(cross_df,di,f'figs/figS20/figS20(d)_{di}.jpg')

def print_ft_samples_errors(di):
    ft_data_nums = [2+i for i in range(9)]+['all']
    for ft_num in ft_data_nums:
        path = os.path.join('results/fine_tuning_samples_effect_results', di, f'ft_num={str(ft_num)}')
        if ft_num == 6:
            path = os.path.join('results/soh_individual_dataset_results', di)
        intra_condition_rmses,cross_condition_rmses = get_rmse(path,di)
        print(f'{di}_{ft_num}: {np.mean(intra_condition_rmses):.3f} ± {np.std(intra_condition_rmses):.3f} ; {np.mean(cross_condition_rmses):.3f} ± {np.std(cross_condition_rmses):.3f}')

ft_samples_errors()
# print_ft_samples_errors('Dataset 1')

# weak label generation

### fig8 a

In [5]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os

def plot_rs():
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    save_path = 'figs/fig8'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    cmap = plt.cm.colors.LinearSegmentedColormap.from_list('custom_cmap',['#deebf7','#c6dbef','#9ecae1','#6baed6','#4292c6','#2171b5','#08519c'])
    norm = plt.Normalize(0.8, 1)
    for i in range(5):
        data = np.load(f"data/dataset{i+1}_rs.npz") # start_vols, rs (correlation coefficient)
        rs = data['rs'].reshape(1,-1)
        rs = np.abs(rs)
        start_vols = data['start_vols']
        length = 67.5
        if i == 0:
            length = 58.5
        fig, ax = plt.subplots(figsize=(length/25.4, 11.5/25.4))
        ax.tick_params(top=False, bottom=False)
        ax.tick_params(which='minor', bottom=False, top=False)
        max_val_idx = np.unravel_index(np.argmax(rs, axis=None), rs.shape)
        sns.heatmap(rs, cmap=cmap, cbar=False, ax=ax, norm=norm, linewidths=0.2, linecolor='black', 
                    annot=True, fmt=".3f", annot_kws={"size": 4, "color": 'white'})
        for text in ax.texts:
            value = float(text.get_text())
            if value < 0.8:
                text.set_color('black')
        for y in range(rs.shape[0]):
            for x in range(rs.shape[1]):
                if (y, x) == max_val_idx:
                    ax.add_patch(plt.Rectangle((x, y), 1, 1, fill=False, edgecolor='red', lw=0.5))
        ax.set_yticks([])
        ax.set_xticks([i+0.5 for i in range(len(start_vols))],start_vols)
        plt.tight_layout()
        plt.savefig(f"{save_path}/fig8(a)_dataset{i+1}.jpg", dpi=600)
        plt.close()

plot_rs()

### fig8 b-g

In [6]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os

def plot_weak_label(sohs, weak_labels,cd_code,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    colors = ['#f46d43','#fdae61','#fee090','#abd9e9','#74add1','#4575b4'][::-1]
    fig, ax = plt.subplots(figsize=(40/25.4, 30/25.4))
    cds = np.unique(cd_code)
    for i,cd in enumerate(cds):
        idx = cd_code == cd
        ax.scatter(sohs[idx], weak_labels[idx],color = colors[i], s=0.25,label=cd)
    # ax.set_ylim(max())
    # ax.set_xlabel('SOH')
    # ax.set_ylabel('Weak label')
    ax.legend(fontsize=5)
    plt.tight_layout()
    plt.savefig(save_path, dpi=600)
    plt.close()

def weal_label():
    condition_df = {
        'Dataset 1':['25','45'],
        'Dataset 2':['1C','2C','3C'],
        'Dataset 3':['0-CC','10-C','25-C','40-C'],
        'Dataset 4':['CY25','CY35','CY45'],
        'Dataset 5':['2C','3C','4C'],
    }
    condition_code = {
        'Dataset 1':['25-3-4','45-3-4'],
        'Dataset 2':['25-1-1','25-1-2','25-1-3'],
        'Dataset 3':['0-1/3-1/3','10-1/3-1/3','25-1/3-1/3','40-1/3-1/3'],
        'Dataset 4':['25-0.5-1','35-0.5-1','45-0.5-1'],
        'Dataset 5':['25-1-2','25-1-3','25-1-4'],
    }
    figs = ['(b)','(c)','(d)','(e)','(f)']
    for i in range(5):
        di = f'Dataset {i+1}'
        files = os.listdir(f'data/{di}')
        sohs,weal_labels,cd_code = [],[],[]
        for j,condition in enumerate(condition_df[di]):
            for fi in files:
                if condition == fi[:len(condition)]:
                    data = np.load(f'data/{di}/{fi}')
                    capacity = data['capacity']
                    weak_labeli = data['weak_label']
                    sohs.append(capacity/capacity[0])
                    weal_labels.append(weak_labeli/weak_labeli[0])
                    cd_code.append([condition_code[di][j]]*len(capacity))
        sohs = np.concatenate(sohs)
        weal_labels = np.concatenate(weal_labels)
        cd_code = np.concatenate(cd_code)
        plot_weak_label(sohs,weal_labels,cd_code,f'figs/fig8/fig8{figs[i]} {di}.jpg')
        # break
            
    condition_d6 = ['25_0','25_1','25_2','25_3','35_1','35_2']
    condition_code  = ['25-0.5-0.5','25-0.5-1','25-0.5-2','25-0.5-3','35-0.5-1','35-0.5-2']
    files = os.listdir(f'data/Dataset 6')
    sohs,weal_labels,cd_code = [],[],[]
    for j,condition in enumerate(condition_d6):
        for fi in files:
            if condition == fi[:len(condition)] and '100' in fi:
                data = np.load(f'data/Dataset 6/{fi}')
                capacity = data['capacity']
                weak_labeli = data['weak_label']
                sohs.append(capacity/capacity[0])
                weal_labels.append(weak_labeli/weak_labeli[0])
                cd_code.append([condition_code[j]]*len(capacity))
    sohs = np.concatenate(sohs)
    weal_labels = np.concatenate(weal_labels)
    cd_code = np.concatenate(cd_code)
    plot_weak_label(sohs,weal_labels,cd_code,f'figs/fig8/fig8(f) Dataset 6.jpg')

save_path = 'figs/fig8'
if not os.path.exists(save_path):
    os.makedirs(save_path)
weal_label()