## SOH estimation within each individual dataset

### fig3

In [None]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape,get_max

def plot_results(estimates,ture,abs_errors,dataset,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(45/25.4, 45/25.4))
    cmap = plt.cm.colors.LinearSegmentedColormap.from_list('custom_cmap', ['#313695','#4575b4','#74add1','#abd9e9', '#e0f3f8'])
    norm = plt.Normalize(0, 0.1)
    ax.scatter(ture, estimates, c=abs_errors, cmap=cmap, norm=norm, s=0.25)
    # if dataset == 'Dataset 6':
    #     plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), label='Absolute Error')
    lower_bound = min(min(estimates),min(ture))-0.01
    upper_bound = max(max(estimates),max(ture))+0.01
    if upper_bound > 1.02:
        upper_bound = 1.02
    ax.plot([lower_bound, upper_bound],[lower_bound, upper_bound],'--', color='red', linewidth=1)
    ax.set_xlim(lower_bound, upper_bound)
    ax.set_ylim(lower_bound, upper_bound)
    ax.set_xlabel('True SOH')
    ax.set_ylabel('Estimation')
    ax.set_title(dataset,fontsize=8)
    plt.tight_layout()
    plt.savefig(save_path, dpi=600)
    plt.close()

def plot_rmses(df,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 8})
    fig, ax = plt.subplots(figsize=(50/25.4, 85/25.4))
    colors = ['#f46d43','#fdae61','#fee090','#abd9e9','#74add1','#4575b4']
    sns.violinplot(x='rmse', y='Dataset', data=df, density_norm='width', inner=None, orient='h',hue='Dataset', palette=colors[::-1], linewidth=0)
    means = df.groupby('Dataset')['rmse'].mean()
    stds = df.groupby('Dataset')['rmse'].std()
    for i, dataset in enumerate(df['Dataset'].unique()):
        mean = means[dataset]
        std = stds[dataset]
        if i == 0:
            ax.plot([mean, mean], [i - 0.2, i + 0.2],'-', color='red', lw=0.5,label='Mean')
            ax.plot([mean - std, mean + std], [i, i],'--', color='black', lw=0.5)
            ax.plot([mean - std, mean - std], [i - 0.1, i + 0.1],'-', color='black', lw=0.5,label='Mean ± Std')
            ax.plot([mean + std, mean + std], [i - 0.1, i + 0.1], '-',color='black', lw=0.5)
        else:
            ax.plot([mean, mean], [i - 0.2, i + 0.2],'-', color='red', lw=0.5)
            ax.plot([mean - std, mean + std], [i, i],'--', color='black', lw=0.5)
            ax.plot([mean - std, mean - std], [i - 0.1, i + 0.1],'-', color='black', lw=0.5)
            ax.plot([mean + std, mean + std], [i - 0.1, i + 0.1], '-',color='black', lw=0.5)
    ax.set_xlim([0,max(df['rmse'])+1.5])
    ax.set_xlabel(r'RMSE (\%)')
    ax.set_ylabel('Dataset')
    if '(b)' in save_path:
        ax.legend(fontsize=7, loc='lower right')
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()


def intra_cross_results():
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    for di in datasets:
        path = os.path.join('results/soh_individual_dataset_results',  di)
        intra_condition_estimates,intra_condition_ture,cross_condition_estimates,cross_condition_ture = get_results(path,di)
        intra_condition_abserrors = np.abs(intra_condition_estimates - intra_condition_ture)
        cross_condition_abserrors = np.abs(cross_condition_estimates - cross_condition_ture)
        plot_results(intra_condition_estimates,intra_condition_ture,intra_condition_abserrors,di,'figs/fig3/fig3(a)_'+di+'.jpg')
        plot_results(cross_condition_estimates,cross_condition_ture,cross_condition_abserrors,di,'figs/fig3/fig3(c)_'+di+'.jpg')

def intra_cross_rmses():
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    intra_grops = []
    intra_rmses = []
    cross_grops = []
    cross_rmses = []
    for di in datasets[::-1]:
        path = os.path.join('results/soh_individual_dataset_results',  di)
        intra_condition_rmses,cross_condition_rmses = get_rmse(path,di)
        i_gropi = np.array([di[-1]]*len(intra_condition_rmses))
        intra_grops.append(i_gropi)
        intra_rmses.append(intra_condition_rmses)
        c_gropi = np.array([di[-1]]*len(cross_condition_rmses))
        cross_grops.append(c_gropi)
        cross_rmses.append(cross_condition_rmses)
    intra_grops = np.concatenate(intra_grops)
    intra_rmses = np.concatenate(intra_rmses)
    cross_grops = np.concatenate(cross_grops)
    cross_rmses = np.concatenate(cross_rmses)
    intra_rmse_df = pd.DataFrame({'Dataset':intra_grops,'rmse':intra_rmses})
    cross_rmse_df = pd.DataFrame({'Dataset':cross_grops,'rmse':cross_rmses})
    plot_rmses(intra_rmse_df,'figs/fig3/fig3(b).jpg')
    plot_rmses(cross_rmse_df,'figs/fig3/fig3(d).jpg')

def intra_cross_max():
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    for di in datasets:
        path = os.path.join('results/soh_individual_dataset_results',  di)
        intra_condition_maxs,cross_condition_maxs = get_max(path,di)
        print(di,max(intra_condition_maxs),max(cross_condition_maxs))
    

def error_ratio_results():
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    for di in datasets:
        path = os.path.join('results/soh_individual_dataset_results',  di)
        intra_condition_estimates,intra_condition_ture,cross_condition_estimates,cross_condition_ture = get_results(path,di)
        intra_condition_abserrors = np.abs(intra_condition_estimates - intra_condition_ture)
        cross_condition_abserrors = np.abs(cross_condition_estimates - cross_condition_ture)
        # print(max(intra_condition_abserrors),max(cross_condition_abserrors))
        intra_03ratio = (np.sum(intra_condition_abserrors < 0.03) / len(intra_condition_abserrors))*100
        cross_03ratio = (np.sum(cross_condition_abserrors < 0.03) / len(cross_condition_abserrors))*100
        # print(f"Dataset: {di}, Ratio of intra_condition_abserrors < 0.03: {intra_03ratio:.2f}")
        print(f"Dataset: {di}, Ratio of cross_condition_abserrors < 0.03: {cross_03ratio:.2f}")

path_dir = 'figs/fig3'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
intra_cross_results()
intra_cross_rmses()
intra_cross_max()
error_ratio_results()

## SOH estimation across datasets

### fig4a

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape

def plot_cross_di(df,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(29/25.4, 25.4/25.4))
    colors = ['#4575b4','#f46d43']
    sns.violinplot(x='condition', y='rmse', data=df, density_norm='width', inner=None, hue='condition', palette=colors, linewidth=0)
    ax.set_xlabel('')
    ax.set_ylabel('')
    means = df.groupby('condition')['rmse'].mean()
    stds = df.groupby('condition')['rmse'].std()
    for i, condition in enumerate(df['condition'].unique()):
        mean = means[condition]
        std = stds[condition]
        ax.plot([i - 0.2, i + 0.2], [mean, mean], '-', color='red', lw=0.5, label='Mean' if i == 0 else "")
        ax.plot([i, i], [mean - std, mean + std], '--', color='black', lw=0.5, label='Mean ± Std' if i == 0 else "")
        ax.plot([i - 0.1, i + 0.1], [mean - std, mean - std], '-', color='black', lw=0.5)
        ax.plot([i - 0.1, i + 0.1], [mean + std, mean + std], '-', color='black', lw=0.5)
    ax.set_ylim([0,7])
    ax.set_xticks([])
    plt.tight_layout()
    plt.savefig(save_path,dpi=600)
    plt.close()

def cross_dataset_errors():
    target_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    source_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    for target_di in target_datasets:
        for source_di in source_datasets:
            if target_di != source_di:
                path = os.path.join('results/soh_cross_laboratory_datasets_results',  target_di, f'transfer from {source_di}')
                intra_condition_rmses,cross_condition_rmses = get_rmse(path,target_di)
                i_grop = np.array(['intra']*len(intra_condition_rmses))
                c_grop = np.array(['cross']*len(cross_condition_rmses))
                rmses = np.concatenate([intra_condition_rmses,cross_condition_rmses])
                grop = np.concatenate([i_grop,c_grop])
                df = pd.DataFrame({'condition':grop, 'rmse':rmses})
                plot_cross_di(df,f'figs/fig4a/fig4a_{target_di}_from_{source_di}.jpg')
                # print(target_di,source_di)

def cd_errors_tabel():
    target_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    source_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    df = {'col_targes':['Dataset 1 intra','Dataset 1 cross','Dataset 2 intra','Dataset 2 cross','Dataset 3 intra','Dataset 3 cross','Dataset 4 intra','Dataset 4 cross','Dataset 5 intra','Dataset 5 cross','Dataset 6 intra','Dataset 6 cross'],
          'Dataset 1':[0]*12,
          'Dataset 2':[0]*12,
          'Dataset 3':[0]*12,
          'Dataset 4':[0]*12,
          'Dataset 5':[0]*12,
          'Dataset 6':[0]*12
          }
    for i in range(6):
        target_di = target_datasets[i]
        loc = i*2
        for source_di in source_datasets:
            if target_di != source_di:
                path = os.path.join('results/soh_cross_laboratory_datasets_results',  target_di, f'transfer from {source_di}')
                intra_condition_rmses,cross_condition_rmses = get_rmse(path,target_di)
                df[source_di][loc] = f'{np.mean(intra_condition_rmses):.3f} ± {np.std(intra_condition_rmses):.3f}'
                df[source_di][loc+1] = f'{np.mean(cross_condition_rmses):.3f} ± {np.std(cross_condition_rmses):.3f}'
    df = pd.DataFrame(df)
    df.to_csv('figs/fig4a/fig4a_tabel.csv',index=False)
path_dir = 'figs/fig4a'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
cross_dataset_errors()
cd_errors_tabel()

### fig4b

In [None]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape,get_mae,get_medae,get_max

def plot_di(df,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(29/25.4, 25.4/25.4))
    colors = ['#4575b4','#f46d43']
    sns.violinplot(x='condition', y='rmse', data=df, density_norm='width', inner=None, hue='condition', palette=colors, linewidth=0)
    ax.set_xlabel('')
    ax.set_ylabel('')
    means = df.groupby('condition')['rmse'].mean()
    stds = df.groupby('condition')['rmse'].std()
    for i, condition in enumerate(df['condition'].unique()):
        mean = means[condition]
        std = stds[condition]
        ax.plot([i - 0.2, i + 0.2], [mean, mean], '-', color='red', lw=0.5, label='Mean' if i == 0 else "")
        ax.plot([i, i], [mean - std, mean + std], '--', color='black', lw=0.5, label='Mean ± Std' if i == 0 else "")
        ax.plot([i - 0.1, i + 0.1], [mean - std, mean - std], '-', color='black', lw=0.5)
        ax.plot([i - 0.1, i + 0.1], [mean + std, mean + std], '-', color='black', lw=0.5)
    ax.set_xticks([])
    ax.set_ylim([0,7])
    # ax.set_ylim(0, None)
    plt.tight_layout()
    plt.savefig(save_path,dpi=600)
    plt.close()

def dataset_errors():
    target_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    for target_di in target_datasets:
            path = os.path.join('results/soh_pretraining_ev_data_results',  target_di)
            intra_condition_rmses,cross_condition_rmses = get_rmse(path,target_di)
            i_grop = np.array(['intra']*len(intra_condition_rmses))
            c_grop = np.array(['cross']*len(cross_condition_rmses))
            rmses = np.concatenate([intra_condition_rmses,cross_condition_rmses])
            grop = np.concatenate([i_grop,c_grop])
            df = pd.DataFrame({'condition':grop, 'rmse':rmses})
            plot_di(df,f'figs/fig4b/fig4b_{target_di}.jpg')

def errors_tabel():
    target_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    df = {'dataset':target_datasets,
            'intra_RMSE':[0]*6,
            'intra_MAPE':[0]*6,
            'intra_MAE':[0]*6,
            'intra_MedAE':[0]*6,
            'intra_MAX':[0]*6,
            'cross_RMSE':[0]*6,
            'cross_MAPE':[0]*6,
            'cross_MAE':[0]*6,
            'cross_MedAE':[0]*6,
            'cross_MAX':[0]*6
            }
    for i in range(6):
        target_di = target_datasets[i]
        path = os.path.join('results/soh_pretraining_ev_data_results',  target_di)
        intra_condition_rmses,cross_condition_rmses = get_rmse(path,target_di)
        intra_condition_mapes,cross_condition_mapes = get_mape(path,target_di)
        intra_condition_maes,cross_condition_maes = get_mae(path,target_di)
        intra_condition_medaes,cross_condition_medaes = get_medae(path,target_di)
        intra_condition_maxs,cross_condition_maxs = get_max(path,target_di)
        df['intra_RMSE'][i] = f'{np.mean(intra_condition_rmses):.3f} ± {np.std(intra_condition_rmses):.3f}'
        df['cross_RMSE'][i] = f'{np.mean(cross_condition_rmses):.3f} ± {np.std(cross_condition_rmses):.3f}'
        df['intra_MAPE'][i] = f'{np.mean(intra_condition_mapes):.3f} ± {np.std(intra_condition_mapes):.3f}'
        df['cross_MAPE'][i] = f'{np.mean(cross_condition_mapes):.3f} ± {np.std(cross_condition_mapes):.3f}'
        df['intra_MAE'][i] = f'{np.mean(intra_condition_maes):.3f} ± {np.std(intra_condition_maes):.3f}'
        df['cross_MAE'][i] = f'{np.mean(cross_condition_maes):.3f} ± {np.std(cross_condition_maes):.3f}'
        df['intra_MedAE'][i] = f'{np.mean(intra_condition_medaes):.3f} ± {np.std(intra_condition_medaes):.3f}'
        df['cross_MedAE'][i] = f'{np.mean(cross_condition_medaes):.3f} ± {np.std(cross_condition_medaes):.3f}'
        df['intra_MAX'][i] = f'{np.mean(intra_condition_maxs):.3f} ± {np.std(intra_condition_maxs):.3f}'
        df['cross_MAX'][i] = f'{np.mean(cross_condition_maxs):.3f} ± {np.std(cross_condition_maxs):.3f}'
        print(target_di,max(intra_condition_maxs),max(cross_condition_maxs))
    df = pd.DataFrame(df)
    df.to_csv('figs/fig4b/fig4b_tabel.csv',index=False)
path_dir = 'figs/fig4b'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
dataset_errors()
errors_tabel()

### fig4c

In [None]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os

def plot_results(estimates,ture,abs_errors,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(83/25.4, 50/25.4))
    cmap = plt.cm.colors.LinearSegmentedColormap.from_list('custom_cmap', ['#313695','#4575b4','#74add1','#abd9e9', '#e0f3f8'])
    norm = plt.Normalize(0, 0.1)
    ax.scatter(ture, estimates, c=abs_errors, cmap=cmap, norm=norm, s=0.25)
    plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label='Absolute Error')
    lower_bound = min(min(estimates),min(ture))-0.01
    upper_bound = max(max(estimates),max(ture))+0.01
    if upper_bound > 1.02:
        upper_bound = 1.02
    ax.plot([lower_bound, upper_bound],[lower_bound, upper_bound],'--', color='red', linewidth=1)
    ax.set_xlim(lower_bound, upper_bound)
    ax.set_ylim(lower_bound, upper_bound)
    ax.set_xlabel('True SOH')
    ax.set_ylabel('Estimation')
    plt.tight_layout()
    plt.savefig(save_path, dpi=600)
    plt.close()

def dataset_errors():
    path = os.path.join('results/soh_estimation_on_field_data/WSL')
    true_soh,est_soh = [],[]
    for i in range(10):
        df = pd.read_csv(path+'/Experiment'+str(i+1)+'/test_results.csv')
        esti = df['field_data_est'].values
        esti = esti[~np.isnan(esti)]
        truei = df['field_data_true'].values
        truei = truei[~np.isnan(truei)]
        true_soh.append(truei)
        est_soh.append(esti)
    true_soh = np.concatenate(true_soh)
    est_soh = np.concatenate(est_soh)
    abs_errors = np.abs(true_soh-est_soh)
    print(max(abs_errors))
    plot_results(est_soh,true_soh,abs_errors,'figs/fig4c/fig4c_field_data_results')

def errors_tabel():
    methods = ['WSL']
    df = {'method':methods,
            'RMSE':[0],
            'MAPE':[0],
            'MAE':[0],
            'MedAE':[0],
            'MAX':[0]
            }
    for i in range(len(methods)):
        rmses,mapes,maes,medaes,maxs = [],[],[],[],[]
        path = os.path.join('results/soh_estimation_on_field_data',methods[i])
        for j in range(10):
            er = pd.read_csv(path+'/Experiment'+str(j+1)+'/eval_metrics.csv')
            rmses.append(er.loc[0,'RMSE']*100)
            mapes.append(er.loc[0,'MAPE']*100)
            maes.append(er.loc[0,'MAE']*100)
            medaes.append(er.loc[0,'MedAE']*100)
            maxs.append(er.loc[0,'MAX']*100)
        df['RMSE'][i] = f'{np.mean(rmses):.3f} ± {np.std(rmses):.3f}'
        df['MAPE'][i] = f'{np.mean(mapes):.3f} ± {np.std(mapes):.3f}'
        df['MAE'][i] = f'{np.mean(maes):.3f} ± {np.std(maes):.3f}'
        df['MedAE'][i] = f'{np.mean(medaes):.3f} ± {np.std(medaes):.3f}'
        df['MAX'][i] = f'{np.mean(maxs):.3f} ± {np.std(maxs):.3f}'
    df = pd.DataFrame(df)
    df.to_csv('figs/fig4c/fig4c_tabel.csv',index=False)
path_dir = 'figs/fig4c'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
dataset_errors()
errors_tabel()

### figS6,figS7

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape

def plot_cd_results(estimates,ture,abs_errors,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(29/25.4, 29/25.4))
    cmap = plt.cm.colors.LinearSegmentedColormap.from_list('custom_cmap', ['#313695','#4575b4','#74add1','#abd9e9', '#e0f3f8'])
    norm = plt.Normalize(0, 0.1)
    ax.scatter(ture, estimates, c=abs_errors, cmap=cmap, norm=norm, s=0.25)
    # cbar = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), label='Absolute Error')
    # cbar.set_ticks([0, 0.02, 0.04, 0.06, 0.08, 0.1])
    lower_bound = min(min(estimates),min(ture))-0.01
    upper_bound = max(max(estimates),max(ture))+0.01
    if upper_bound > 1.02:
        upper_bound = 1.02
    ax.plot([lower_bound, upper_bound],[lower_bound, upper_bound],'--', color='red', linewidth=1)
    ax.set_xlim(lower_bound, upper_bound)
    ax.set_ylim(lower_bound, upper_bound)
    ax.set_xlabel('')
    ax.set_ylabel('')
    plt.tight_layout()
    plt.savefig(save_path, dpi=600)
    plt.close()

def cross_dataset_results():
    target_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    source_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    for target_di in target_datasets:
        for source_di in source_datasets:
            if target_di != source_di:
                path = os.path.join('results/soh_cross_laboratory_datasets_results',  target_di, f'transfer from {source_di}')
                intra_condition_estimates,intra_condition_ture,cross_condition_estimates,cross_condition_ture = get_results(path,target_di)
                intra_condition_abserrors = np.abs(intra_condition_estimates - intra_condition_ture)
                cross_condition_abserrors = np.abs(cross_condition_estimates - cross_condition_ture)
                # print(f'{target_di}_from_{source_di}')
                # print(max(intra_condition_abserrors))
                plot_cd_results(intra_condition_estimates,intra_condition_ture,intra_condition_abserrors,f'figs/figS6/figS6_{target_di}_from_{source_di}.jpg')
                # print(max(cross_condition_abserrors))
                plot_cd_results(cross_condition_estimates,cross_condition_ture,cross_condition_abserrors,f'figs/figS7/figS7_{target_di}_from_{source_di}.jpg')
                # print(target_di,source_di)

path_dir = 'figs/figS6'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
path_dir = 'figs/figS7'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
cross_dataset_results()

## Comparison with existing methods

### fig5a

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import scienceplots
import numpy as np
import pandas as pd
import os
from load_results import get_rmse,get_mape,get_mae,get_medae,get_max

def plot_comparison_rmses(df,save_path):
    plt.style.use('ieee')
    plt.style.use('science')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(160/25.4, 45/25.4))
    colors = ['#deebf7','#c6dbef','#9ecae1','#6baed6','#4292c6','#2171b5','#08519c']
    sns.barplot(x='Dataset', y='rmse', hue='method', data=df, palette=colors,
                errorbar='sd',  
                linestyle='-',
                capsize=0.2,
                ax = ax,
                err_kws={'linewidth': 0.3,'color': 'black'})
    ax.set_ylim([0,6])
    ax.set_xlabel('')
    ax.set_ylabel(r'RMSE (\%)')
    if 'intra' in save_path:
        ax.legend(loc='upper left', ncol=3)
    else:
        ax.legend().remove()
    plt.tight_layout()
    # plt.show()
    plt.savefig(save_path,dpi=600)
    plt.close()

def comparison_rmses():
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    methods = ['SVR','RF','GPR','CNN','Benchmark','SSL','WSL']
    intra_grops = []
    intra_meths = []
    intra_rmses = []
    cross_grops = []
    cross_meths = []
    cross_rmses = []
    for di in datasets:
        for i in range(len(methods)):
            methi = methods[i]
            path = path = os.path.join('results/comparison_methods_with_limited_labels_results', methi, di)
            if methi =='WSL':
                path = os.path.join('results/soh_individual_dataset_results', di)
            intra_condition_rmses,cross_condition_rmses = get_rmse(path,di)

            i_gropi = np.array([di]*len(intra_condition_rmses))
            i_methi = np.array([methi]*len(intra_condition_rmses))
            intra_grops.append(i_gropi)
            intra_meths.append(i_methi)
            intra_rmses.append(intra_condition_rmses)

            c_gropi = np.array([di]*len(cross_condition_rmses))
            c_methi = np.array([methi]*len(cross_condition_rmses))
            cross_grops.append(c_gropi)
            cross_meths.append(c_methi)
            cross_rmses.append(cross_condition_rmses)
    intra_grops = np.concatenate(intra_grops)
    intra_meths = np.concatenate(intra_meths)
    intra_rmses = np.concatenate(intra_rmses)

    cross_grops = np.concatenate(cross_grops)
    cross_meths = np.concatenate(cross_meths)
    cross_rmses = np.concatenate(cross_rmses)

    intra_rmse_df = pd.DataFrame({'Dataset':intra_grops, 'method':intra_meths,'rmse':intra_rmses})
    cross_rmse_df = pd.DataFrame({'Dataset':cross_grops, 'method':cross_meths,'rmse':cross_rmses})
    plot_comparison_rmses(intra_rmse_df,'figs/fig5/fig5(a)_intra.jpg')
    plot_comparison_rmses(cross_rmse_df,'figs/fig5/fig5(a)_cross.jpg')


def comparison_tabel():
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    methods = ['SVR','RF','GPR','CNN','Benchmark','SSL','WSL']
    # methods = ['WSL']
    errors = ['RMSE','MAPE','MAE','MedAE','MAX']
    error_df = {'dataset':[],'method':[]}
    for condition in ['intra','cross']:
        for eri in errors:
            error_df[condition+'_'+eri] = []

    for di in datasets:
        for i in range(len(methods)):
            methi = methods[i]
            path = path = os.path.join('results/comparison_methods_with_limited_labels_results', methi, di)
            if methi =='WSL':
                path = os.path.join('results/soh_individual_dataset_results', di)
            intra_condition_rmses,cross_condition_rmses = get_rmse(path,di)
            intra_condition_mapes,cross_condition_mapes = get_mape(path,di)
            intra_condition_maes,cross_condition_maes = get_mae(path,di)
            intra_condition_medaes,cross_condition_medaes = get_medae(path,di)
            intra_condition_maxs,cross_condition_maxs = get_max(path,di)
            error_df['dataset'].append(di)
            error_df['method'].append(methi)
            error_df['intra_RMSE'].append(f'{np.mean(intra_condition_rmses):.3f} ± {np.std(intra_condition_rmses):.3f}')
            error_df['intra_MAPE'].append(f'{np.mean(intra_condition_mapes):.3f} ± {np.std(intra_condition_mapes):.3f}')
            error_df['intra_MAE'].append(f'{np.mean(intra_condition_maes):.3f} ± {np.std(intra_condition_maes):.3f}')
            error_df['intra_MedAE'].append(f'{np.mean(intra_condition_medaes):.3f} ± {np.std(intra_condition_medaes):.3f}')
            error_df['intra_MAX'].append(f'{np.mean(intra_condition_maxs):.3f} ± {np.std(intra_condition_maxs):.3f}')

            error_df['cross_RMSE'].append(f'{np.mean(cross_condition_rmses):.3f} ± {np.std(cross_condition_rmses):.3f}')
            error_df['cross_MAPE'].append(f'{np.mean(cross_condition_mapes):.3f} ± {np.std(cross_condition_mapes):.3f}')
            error_df['cross_MAE'].append(f'{np.mean(cross_condition_maes):.3f} ± {np.std(cross_condition_maes):.3f}')
            error_df['cross_MedAE'].append(f'{np.mean(cross_condition_medaes):.3f} ± {np.std(cross_condition_medaes):.3f}')
            error_df['cross_MAX'].append(f'{np.mean(cross_condition_maxs):.3f} ± {np.std(cross_condition_maxs):.3f}')
    error_df = pd.DataFrame(error_df)
    error_df.to_csv('figs/fig5/fig5(a)_tabel.csv', index=False)

path_dir = 'figs/fig5'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)

comparison_rmses()
comparison_tabel()

### figS8-S13

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import scienceplots
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape

def plot_comparison_results(estimates,ture,abs_errors,dataset,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(45/25.4, 45/25.4))
    cmap = plt.cm.colors.LinearSegmentedColormap.from_list('custom_cmap', ['#313695','#4575b4','#74add1','#abd9e9', '#e0f3f8'])
    norm = plt.Normalize(0, 0.1)
    ax.scatter(ture, estimates, c=abs_errors, cmap=cmap, norm=norm, s=0.25)
    # if dataset == 'Dataset 6':
    #     plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), label='Absolute Error')
    lower_bound = min(min(estimates),min(ture))-0.01
    upper_bound = max(max(estimates),max(ture))+0.01
    if upper_bound > 1.02:
        upper_bound = 1.02
    ax.plot([lower_bound, upper_bound],[lower_bound, upper_bound],'--', color='red', linewidth=1)
    ax.set_xlim(lower_bound, upper_bound)
    ax.set_ylim(lower_bound, upper_bound)
    ax.set_xlabel('True SOH')
    ax.set_ylabel('Estimation')
    ax.set_title(dataset,fontsize=8)
    plt.tight_layout()
    plt.savefig(save_path, dpi=600)
    plt.close()

def comparison_results():
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    methods = ['SVR','RF','GPR','CNN','Benchmark','SSL']
    fig_paths = ['figS8','figS9','figS10','figS11','figS12','figS13']
    for i in range(len(methods)):
        methi = methods[i]
        fig_ph = os.path.join('figs',fig_paths[i]+'_'+methi)
        if not os.path.exists(fig_ph):
            os.makedirs(fig_ph)

        for di in datasets:
            path = path = os.path.join('results/comparison_methods_with_limited_labels_results', methi, di)
            intra_condition_estimates,intra_condition_ture,cross_condition_estimates,cross_condition_ture = get_results(path,di)
            intra_condition_abserrors = np.abs(intra_condition_estimates - intra_condition_ture)
            cross_condition_abserrors = np.abs(cross_condition_estimates - cross_condition_ture)
            # print(max(intra_condition_abserrors))
            plot_comparison_results(intra_condition_estimates,intra_condition_ture,intra_condition_abserrors,di,os.path.join(fig_ph,'intra'+di+'.jpg'))
            # print(max(cross_condition_abserrors))
            plot_comparison_results(cross_condition_estimates,cross_condition_ture,cross_condition_abserrors,di,os.path.join(fig_ph,'cross'+di+'.jpg'))

comparison_results()

### fig5b

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import scienceplots
import numpy as np
import pandas as pd
import os

def plot_comparison_rmses(df,save_path):
    plt.style.use('ieee')
    plt.style.use('science')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(36/25.4, 45/25.4))
    colors = ['#deebf7','#c6dbef','#9ecae1','#6baed6','#4292c6','#2171b5','#08519c']
    sns.barplot(x='Dataset', y='rmse', hue='method', data=df, palette=colors,
                errorbar='sd',  
                linestyle='-',
                capsize=0.2,
                ax = ax,
                err_kws={'linewidth': 0.3,'color': 'black'})
    ax.set_ylim([0,10])
    ax.set_xlabel('')
    ax.set_ylabel(r'RMSE (\%)')
    if 'intra' in save_path:
        ax.legend(loc='upper left', ncol=3)
    else:
        ax.legend().remove()
    plt.tight_layout()
    # plt.show()
    plt.savefig(save_path,dpi=600)
    plt.close()

def comparison_rmses_d8():
    methods = ['SVR','RF','GPR','CNN','Benchmark','SSL','WSL']
    meths = []
    rmses = []
    for i in range(len(methods)):
        methi = methods[i]
        path = path = os.path.join('results/soh_estimation_on_field_data', methi)
        rmsei = []
        for j in range(10):
            metrics = pd.read_csv(path+'/Experiment'+str(j+1)+'/eval_metrics.csv')
            rmsei.append(metrics.loc[0,'MAPE']*100)
        rmses.append(rmsei)
        methi = np.array([methi]*len(rmsei))
        meths.append(methi)
    rmses = np.concatenate(rmses)
    meths = np.concatenate(meths)
    grops = np.array(['Dataset 8']*len(rmses))
    df = pd.DataFrame({'Dataset':grops, 'method':meths,'rmse':rmses})
    plot_comparison_rmses(df,'figs/fig5/fig5(b)_dataset_8.jpg')


def errors_tabel():
    methods = ['SVR','RF','GPR','CNN','Benchmark','SSL','WSL']
    df = {'method':methods,
            'RMSE':[0]*7,
            'MAPE':[0]*7,
            'MAE':[0]*7,
            'MedAE':[0]*7,
            'MAX':[0]*7
            }
    for i in range(len(methods)):
        rmses,mapes,maes,medaes,maxs = [],[],[],[],[]
        path = os.path.join('results/soh_estimation_on_field_data',methods[i])
        for j in range(10):
            er = pd.read_csv(path+'/Experiment'+str(j+1)+'/eval_metrics.csv')
            rmses.append(er.loc[0,'RMSE']*100)
            mapes.append(er.loc[0,'MAPE']*100)
            maes.append(er.loc[0,'MAE']*100)
            medaes.append(er.loc[0,'MedAE']*100)
            maxs.append(er.loc[0,'MAX']*100)
        df['RMSE'][i] = f'{np.mean(rmses):.3f} ± {np.std(rmses):.3f}'
        df['MAPE'][i] = f'{np.mean(mapes):.3f} ± {np.std(mapes):.3f}'
        df['MAE'][i] = f'{np.mean(maes):.3f} ± {np.std(maes):.3f}'
        df['MedAE'][i] = f'{np.mean(medaes):.3f} ± {np.std(medaes):.3f}'
        df['MAX'][i] = f'{np.mean(maxs):.3f} ± {np.std(maxs):.3f}'
    df = pd.DataFrame(df)
    df.to_csv('figs/fig5/fig5(b)_tabel.csv',index=False)

path_dir = 'figs/fig5'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)

comparison_rmses_d8()
errors_tabel()

### figS14

In [None]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os

def plot_results(estimates,ture,abs_errors,meth,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(45/25.4, 45/25.4))
    cmap = plt.cm.colors.LinearSegmentedColormap.from_list('custom_cmap', ['#313695','#4575b4','#74add1','#abd9e9', '#e0f3f8'])
    norm = plt.Normalize(0, 0.1)
    ax.scatter(ture, estimates, c=abs_errors, cmap=cmap, norm=norm, s=0.25)
    # plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label='Absolute Error')
    lower_bound = min(min(estimates),min(ture))-0.01
    upper_bound = max(max(estimates),max(ture))+0.01
    if upper_bound > 1.02:
        upper_bound = 1.02
    ax.plot([lower_bound, upper_bound],[lower_bound, upper_bound],'--', color='red', linewidth=1)
    ax.set_xlim(lower_bound, upper_bound)
    ax.set_ylim(lower_bound, upper_bound)
    ax.set_xlabel('True SOH')
    ax.set_ylabel('Estimation')
    ax.set_title(meth,fontsize=8)
    plt.tight_layout()
    plt.savefig(save_path, dpi=600)
    plt.close()

def comparison_errors_d8():
    methods = ['SVR','RF','GPR','CNN','Benchmark','SSL']
    for k in range(len(methods)):
        methk = methods[k]
        path = os.path.join('results/soh_estimation_on_field_data', methk)
        true_soh,est_soh = [],[]
        for i in range(10):
            df = pd.read_csv(path+'/Experiment'+str(i+1)+'/test_results.csv')
            esti = df['field_data_est'].values
            esti = esti[~np.isnan(esti)]
            truei = df['field_data_true'].values
            truei = truei[~np.isnan(truei)]
            true_soh.append(truei)
            est_soh.append(esti)
        true_soh = np.concatenate(true_soh)
        est_soh = np.concatenate(est_soh)
        abs_errors = np.abs(true_soh-est_soh)
        plot_results(est_soh,true_soh,abs_errors,methk,'figs/figS14/figS14_'+methk+'.jpg')

path_dir = 'figs/figS14'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
comparison_errors_d8()

### fig5c, figS15

In [None]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os

def plot_comparison_single_label(df,num_condition,save_path):
    plt.style.use('ieee')
    plt.style.use('science')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=((10+16*num_condition)/25.4, 45/25.4))
    ax.axvline(x=0.5, color='black', linestyle='--',linewidth=0.25)
    ax.axvspan(-0.5, 0.5, facecolor='lightgray', alpha=0.3)
    ax.axvspan(0.5, len(df['condition'].unique()) - 0.5, facecolor='lightblue', alpha=0.3)
    colors = ['#fee391','#fec44f','#fe9929','#ec7014','#cc4c02','#8c2d04']
    sns.barplot(x='condition', y='rmse', hue='method', data=df, palette=colors,
                errorbar='sd',  
                linestyle='-',
                capsize=0.2,
                ax = ax,
                err_kws={'linewidth': 0.3,'color': 'black'})
    ax.set_xlabel('')
    ax.set_ylabel(r'RMSE (\%)')
    ax.set_ylim(0, None)
    if 'Dataset 3' in save_path:
        ax.legend(loc='upper left', ncol=2)
    elif 'Dataset 6' in save_path:
        ax.legend(loc='upper right', ncol=2)
    else:
        ax.legend().remove()
    plt.tight_layout()
    # plt.show()
    plt.savefig(save_path,dpi=600)
    plt.close()


def get_rmse_single(path,test_cells):
    rmses = []
    df = pd.read_csv(path+'/eval_metrics.csv')
    for j in range(len(df)):
        test_cell = df.iloc[j,0]
        if test_cell in test_cells:
            rmses.append(df.iloc[j]['RMSE']*100)
    return rmses

def comparison_single_label():
    test_cells = {
        'Dataset 1': [['25-'+str(i+1) for i in range(2,8)],['45-'+str(i+1) for i in range(2,6)]],
        'Dataset 2': [['1C-'+str(i+1) for i in range(5,10)],['2C-4'],['3C-5','3C-6','3C-7','3C-9','3C-10']],
        'Dataset 3': [['0-CC-1','0-CC-3'],['10-CC-2','10-CC-3'],['25-CC-2','25-CC-3'],['40-CC-2','40-CC-3']],
        'Dataset 4': [['CY25-05_1-#12','CY25-05_1-#18','CY25-05_1-#19'],['CY35-05_1-#1'],['CY45-05_1-#21']+['CY45-05_1-#'+str(i+1) for i in range(23,28)]],
        'Dataset 5': [['2C-5','2C-8'],['3C-'+str(i+1) for i in range(6,15)],['4C-6']],
        'Dataset 6': [['25_1b_100','25_1c_100','25_1d_100'],['25_0.5a_100'],['25_2b_100'],['25_3a_100','25_3c_100','25_3d_100'],
                      ['35_1b_100','35_1c_100','35_1d_100'],['35_2a_100']]
    }
    condition_code = {
        'Dataset 1': [1,2],
        'Dataset 2': [5,6,7],
        'Dataset 3': [8,9,10,11],
        'Dataset 4': [16,17,18],
        'Dataset 5': [19,20,21],
        'Dataset 6': [23,22,24,25,26,27]
    }
    methods = ['SVR','RF','CNN','Benchmark','WSL(enough labels)','WSL(limited labels)']
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    wsl_files = ['Experiment1(25-1)','Experiment1(1C-4)','Experiment1(0-CC-2)','Experiment1(CY25-05_1-#14)','Experiment1(2C-6)','Experiment2(25_1a_100)']
    for i,di in enumerate(datasets):
        grops = []
        meths = []
        rmses = []
        for j,conditionj in enumerate(condition_code[di]):
            cells = test_cells[di][j]
            for meth in methods:
                if meth =='WSL(enough labels)':
                    path = f'results/comparison_methods_with_enough_labels_results/{di}/WSL_el'
                elif meth =='WSL(limited labels)':
                    path = f'results/soh_individual_dataset_results/{di}/{wsl_files[i]}'
                else:
                    path = os.path.join('results/comparison_methods_with_enough_labels_results', di, meth)
                ci_rmses = get_rmse_single(path,cells)
                ci_gropj = np.array([rf'\#{conditionj}']*len(ci_rmses))
                ci_methj = np.array([meth]*len(ci_rmses))
                grops.append(ci_gropj)
                meths.append(ci_methj)
                rmses.append(ci_rmses)
        grops = np.concatenate(grops)
        meths = np.concatenate(meths)
        rmses = np.concatenate(rmses)
        df = pd.DataFrame({'condition':grops, 'method':meths,'rmse':rmses})
        if di in ['Dataset 2','Dataset 3']:
            plot_comparison_single_label(df,len(condition_code[di]),f'figs/fig5/fig5(c)_{di}.jpg')
        else:
            plot_comparison_single_label(df,len(condition_code[di]),f'figs/figS15/{di}.jpg')


def single_label_errors():
    test_cells = {
        'Dataset 1': [['25-'+str(i+1) for i in range(2,8)],['45-'+str(i+1) for i in range(2,6)]],
        'Dataset 2': [['1C-'+str(i+1) for i in range(5,10)],['2C-4'],['3C-5','3C-6','3C-7','3C-9','3C-10']],
        'Dataset 3': [['0-CC-1','0-CC-3'],['10-CC-2','10-CC-3'],['25-CC-2','25-CC-3'],['40-CC-2','40-CC-3']],
        'Dataset 4': [['CY25-05_1-#12','CY25-05_1-#18','CY25-05_1-#19'],['CY35-05_1-#1'],['CY45-05_1-#21']+['CY45-05_1-#'+str(i+1) for i in range(23,28)]],
        'Dataset 5': [['2C-5','2C-8'],['3C-'+str(i+1) for i in range(6,15)],['4C-6']],
        'Dataset 6': [['25_1b_100','25_1c_100','25_1d_100'],['25_0.5a_100'],['25_2b_100'],['25_3a_100','25_3c_100','25_3d_100'],
                      ['35_1b_100','35_1c_100','35_1d_100'],['35_2a_100']]
    }
    condition_code = {
        'Dataset 1': [1,2],
        'Dataset 2': [5,6,7],
        'Dataset 3': [8,9,10,11],
        'Dataset 4': [16,17,18],
        'Dataset 5': [19,20,21],
        'Dataset 6': [23,22,24,25,26,27]
    }
    methods = ['SVR','RF','CNN','Benchmark','WSL(enough labels)','WSL(limited labels)']
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    wsl_files = ['Experiment1(25-1)','Experiment1(1C-4)','Experiment1(0-CC-2)','Experiment1(CY25-05_1-#14)','Experiment1(2C-6)','Experiment2(25_1a_100)']
    for i,di in enumerate(datasets):
        df = {'method':methods}
        for j,conditionj in enumerate(condition_code[di]):
            cells = test_cells[di][j]
            df[f'#{conditionj}'] = ['']*len(methods)
            for meth in methods:
                if meth =='WSL(enough labels)':
                    path = f'results/comparison_methods_with_enough_labels_results/{di}/WSL_el'
                elif meth =='WSL(limited labels)':
                    path = f'results/soh_individual_dataset_results/{di}/{wsl_files[i]}'
                else:
                    path = os.path.join('results/comparison_methods_with_enough_labels_results', di, meth)
                ci_rmses = get_rmse_single(path,cells)
                df[f'#{conditionj}'][methods.index(meth)] = f'{np.mean(ci_rmses):.3f}'
        df = pd.DataFrame(df)
        df.to_csv(f'figs/fig5/fig5(c)_{di}_tabel.csv',index=False)

path_dir = 'figs/figS15'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
comparison_single_label()
single_label_errors()

### fig5d

In [None]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape,get_mae,get_medae,get_max

def plot_comparison_tl(df,save_path):
    plt.style.use('ieee')
    plt.style.use('science')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(32/25.4, 45/25.4))
    colors = ['#4575b4','#ec7014']
    sns.barplot(x='condition', y='rmse', hue='method', data=df, palette=colors,
                errorbar='sd',  
                linestyle='-',
                capsize=0.2,
                ax = ax,
                err_kws={'linewidth': 0.3,'color': 'black'})
    ax.set_xlabel('')
    ax.set_ylabel(r'RMSE (\%)')
    if 'Dataset 1' in save_path:
        ax.legend(loc='upper left', ncol=1)
    else:
        ax.legend().remove()
    if 'Dataset 5' in save_path:
        ax.set_yticks([0,1,2,3])
    else:
        ax.set_yticks([0,1,2])
    plt.tight_layout()
    plt.savefig(save_path,dpi=600)
    plt.close()

def comparison_tl():
    target_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5']
    methods = ['TL','WSL']
    for target_di in target_datasets:
        condition_grop,method_grop,rmses = [],[],[]
        for meth in methods:
            path = os.path.join('results/comparison_methods_TL_results',meth, f'{target_di} transfer from Dataset 6')
            intra_condition_rmses,cross_condition_rmses = get_rmse(path,target_di)
            condition_grop.append(np.array(['intra']*len(intra_condition_rmses)))
            condition_grop.append(np.array(['cross']*len(cross_condition_rmses)))
            method_grop.append(np.array([meth]*(len(intra_condition_rmses)+len(cross_condition_rmses))))
            rmses.append(intra_condition_rmses)
            rmses.append(cross_condition_rmses)
        condition_grop = np.concatenate(condition_grop)
        method_grop = np.concatenate(method_grop)
        rmses = np.concatenate(rmses)
        df = pd.DataFrame({'condition':condition_grop, 'method':method_grop, 'rmse':rmses})
        plot_comparison_tl(df,f'figs/fig5/fig5(d)_{target_di}.jpg')
    
def tl_errors():
    target_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5']
    methods = ['TL','WSL']
    errors = ['RMSE','MAPE','MAE','MedAE','MAX']
    error_df = {'dataset':[],'method':[]}
    for condition in ['intra','cross']:
        for eri in errors:
            error_df[condition+'_'+eri] = []

    for di in target_datasets:
        for i in range(len(methods)):
            methi = methods[i]
            path = os.path.join('results/comparison_methods_TL_results',methi, f'{di} transfer from Dataset 6')
            intra_condition_rmses,cross_condition_rmses = get_rmse(path,di)
            intra_condition_mapes,cross_condition_mapes = get_mape(path,di)
            intra_condition_maes,cross_condition_maes = get_mae(path,di)
            intra_condition_medaes,cross_condition_medaes = get_medae(path,di)
            intra_condition_maxs,cross_condition_maxs = get_max(path,di)
            error_df['dataset'].append(di)
            error_df['method'].append(methi)
            error_df['intra_RMSE'].append(f'{np.mean(intra_condition_rmses):.3f} ± {np.std(intra_condition_rmses):.3f}')
            error_df['intra_MAPE'].append(f'{np.mean(intra_condition_mapes):.3f} ± {np.std(intra_condition_mapes):.3f}')
            error_df['intra_MAE'].append(f'{np.mean(intra_condition_maes):.3f} ± {np.std(intra_condition_maes):.3f}')
            error_df['intra_MedAE'].append(f'{np.mean(intra_condition_medaes):.3f} ± {np.std(intra_condition_medaes):.3f}')
            error_df['intra_MAX'].append(f'{np.mean(intra_condition_maxs):.3f} ± {np.std(intra_condition_maxs):.3f}')

            error_df['cross_RMSE'].append(f'{np.mean(cross_condition_rmses):.3f} ± {np.std(cross_condition_rmses):.3f}')
            error_df['cross_MAPE'].append(f'{np.mean(cross_condition_mapes):.3f} ± {np.std(cross_condition_mapes):.3f}')
            error_df['cross_MAE'].append(f'{np.mean(cross_condition_maes):.3f} ± {np.std(cross_condition_maes):.3f}')
            error_df['cross_MedAE'].append(f'{np.mean(cross_condition_medaes):.3f} ± {np.std(cross_condition_medaes):.3f}')
            error_df['cross_MAX'].append(f'{np.mean(cross_condition_maxs):.3f} ± {np.std(cross_condition_maxs):.3f}')
    error_df = pd.DataFrame(error_df)
    error_df.to_csv('figs/fig5/fig5(d)_tabel.csv', index=False)

comparison_tl()
tl_errors()

## knowledge_learned_by_the_DNN

### fig6g

In [None]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os

def plot_mmd_soh(xs,wsl_y,benchmark_y,y_label,save_path):
    plt.style.use('ieee')
    plt.style.use('science')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(80/25.4, 45/25.4))
    x = np.arange(len(xs))
    x1 = x-0.1
    x2 = x+0.1
    markerline1, stemlines1, baseline1 = ax.stem(x1, wsl_y, linefmt='#4575b4', markerfmt='o', basefmt=' ', label='WSL')
    markerline2, stemlines2, baseline2 = ax.stem(x2, benchmark_y, linefmt='#ec7014', markerfmt='s', basefmt=' ', label='Benchmark')
    
    plt.setp(stemlines1, linewidth=0.5)  # 茎线宽度
    plt.setp(stemlines2, linewidth=0.5)

    plt.setp(markerline1, markersize=1.5)  # 标记点大小
    plt.setp(markerline2, markersize=1.5)

    xv_position = [0.5,2.5,5.5,7.5,9.5]
    for xv in xv_position:
        ax.axvline(x=xv,color='black',linestyle='--',linewidth=0.25)

    ax.set_xlabel('Target condition')
    ax.set_ylabel(y_label)
    ax.set_xticks(x,xs)
    ax.set_ylim(bottom=0)
    ax.set_xlim(-0.5,14.5)
    plt.tick_params(axis='x', which='minor', bottom=False,top=False,length=1)
    # plt.legend(loc='upper right')
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    

def get_mmd_soh():
    mmd_path = 'figs/fig6/mmd_result_to_fc1.csv'
    mmd_df = pd.read_csv(mmd_path)
    xs = [2,6,7,9,10,11,17,18,20,21,23,24,25,26,27]
    xs = [r'\#'+str(xi) for xi in xs]
    plot_mmd_soh(xs,mmd_df['mmd_of_wsl'].values,mmd_df['mmd_of_benchmark'].values,'MMD','figs/fig6/fig6g_mmd.jpg')

    soh_path = 'figs/fig6/soh_errors.csv'
    soh_df = pd.read_csv(soh_path)
    plot_mmd_soh(xs,soh_df['rmse_of_wsl'].values,soh_df['rmse_of_benchmark'].values,r'Mean RMSE (\%)','figs/fig6/fig6h_soh_rmse.jpg')

get_mmd_soh()

## Robustness and sensitivity analysis

### fig7a, figS21 (Impact of pre-training data volume)

In [None]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape

def pre_samples_errors():
    plt.style.use('ieee')
    plt.style.use('science')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})

    target_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    pre_data_rates = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,'all']
    pre_rates = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]
    for i,di in enumerate(target_datasets):
        intra_rmses = []
        cross_rmses = []
        intra_stds = []
        cross_stds = []
        for pre_rate in pre_data_rates:
            path = os.path.join('results/pretraining_samples_effect_results', di, f'pre_rate={str(pre_rate)}')
            if pre_rate == 'all':
                pre_rate = 1
                path = os.path.join('results/soh_individual_dataset_results', di)
            intra_condition_rmses,cross_condition_rmses = get_rmse(path,di)
            intra_rmses.append(np.mean(intra_condition_rmses))
            cross_rmses.append(np.mean(cross_condition_rmses))
            intra_stds.append(np.std(intra_condition_rmses))
            cross_stds.append(np.std(cross_condition_rmses))
        fig, ax = plt.subplots(figsize=(80/25.4, 35/25.4))
        plt.tick_params(axis='x', which='minor', bottom=False,top=False)
        ax.errorbar(pre_rates,intra_rmses,yerr=intra_stds,fmt='-o',color='#4575b4',
                    linewidth=0.5, markersize=1,capsize=1,elinewidth=0.4,capthick=0.4,label='Intra-condition')
        ax.errorbar(pre_rates,cross_rmses,yerr=cross_stds,fmt='-s',color='#f46d43',
                    linewidth=0.5, markersize=1,capsize=1,elinewidth=0.4,capthick=0.4,label='Cross-condition')
        ax.set_xlabel('')
        ax.set_ylim(bottom=0)
        ax.set_xticks(pre_rates)
        ax.set_ylabel(r'RMSE (\%)')

        plt.tight_layout()
        if di in ['Dataset 1']:
            plt.savefig(f'figs/fig7/fig7a_{di}.jpg',dpi=600)
        else:
            ax.set_title(di)
            plt.savefig(f'figs/figS21/figS21_{di}.jpg',dpi=600)
        plt.close()


def print_pre_samples_errors(di):
    pre_data_rates = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,'all']
    for pre_rate in pre_data_rates:
        path = os.path.join('results/pretraining_samples_effect_results', di, f'pre_rate={str(pre_rate)}')
        if pre_rate == 'all':
            pre_rate = 1
            path = os.path.join('results/soh_individual_dataset_results', di)
        intra_condition_rmses,cross_condition_rmses = get_rmse(path,di)
        print(f'{di}_{pre_rate}: {np.mean(intra_condition_rmses):.3f} ± {np.std(intra_condition_rmses):.3f} ; {np.mean(cross_condition_rmses):.3f} ± {np.std(cross_condition_rmses):.3f}')
            

path_dir = 'figs/fig7'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
path_dir = 'figs/figS21'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
pre_samples_errors()
print_pre_samples_errors('Dataset 1')


### fig7b,figS22 (Impact of fine-tuning data volume)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape

def ft_samples_errors():
    plt.style.use('ieee')
    plt.style.use('science')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})

    target_datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    ft_data_nums = [2+i for i in range(10)]+['all']
    ft_nums = [2+i for i in range(10)]+[12]
    for i,di in enumerate(target_datasets):
        intra_rmses = []
        cross_rmses = []
        intra_stds = []
        cross_stds = []
        for ft_num in ft_data_nums:
            path = os.path.join('results/fine_tuning_samples_effect_results', di, f'ft_num={str(ft_num)}')
            if ft_num == 6:
                path = os.path.join('results/soh_individual_dataset_results', di)
            intra_condition_rmses,cross_condition_rmses = get_rmse(path,di)
            intra_rmses.append(np.mean(intra_condition_rmses))
            cross_rmses.append(np.mean(cross_condition_rmses))
            intra_stds.append(np.std(intra_condition_rmses))
            cross_stds.append(np.std(cross_condition_rmses))
        fig, ax = plt.subplots(figsize=(80/25.4, 35/25.4))
        plt.tick_params(axis='x', which='minor', bottom=False,top=False)
        ax.errorbar(ft_nums,intra_rmses,yerr=intra_stds,fmt='--o',color='#4575b4',
                    linewidth=0.5, markersize=1,capsize=1,elinewidth=0.4,capthick=0.4,label='Intra-condition')
        ax.errorbar(ft_nums,cross_rmses,yerr=cross_stds,fmt='--s',color='#f46d43',
                    linewidth=0.5, markersize=1,capsize=1,elinewidth=0.4,capthick=0.4,label='Cross-condition')
        ax.set_xlabel('')
        ax.set_ylim(bottom=0)
        ax.set_xticks(ft_nums,ft_data_nums)
        ax.set_ylabel(r'RMSE (\%)')

        plt.tight_layout()
        if di in ['Dataset 1']:
            plt.savefig(f'figs/fig7/fig7b_{di}.jpg',dpi=600)
        else:
            ax.set_title(di)
            plt.savefig(f'figs/figS22/figS22_{di}.jpg',dpi=600)
        plt.close()
   

def print_ft_samples_errors(di):
    ft_data_nums = [2+i for i in range(9)]+['all']
    for ft_num in ft_data_nums:
        path = os.path.join('results/fine_tuning_samples_effect_results', di, f'ft_num={str(ft_num)}')
        if ft_num == 6:
            path = os.path.join('results/soh_individual_dataset_results', di)
        intra_condition_rmses,cross_condition_rmses = get_rmse(path,di)
        print(f'{di}_{ft_num}: {np.mean(intra_condition_rmses):.3f} ± {np.std(intra_condition_rmses):.3f} ; {np.mean(cross_condition_rmses):.3f} ± {np.std(cross_condition_rmses):.3f}')

path_dir = 'figs/figS22'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
ft_samples_errors()
print_ft_samples_errors('Dataset 1')

### fig7c, figS23 (Impact of stochastic fine-tuning sample selection)

In [None]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape,get_mae,get_medae,get_max

def plot_di(df,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(80/25.4, 35/25.4))
    colors = ['#4575b4','#f46d43']
    df['condition'] = pd.Categorical(df['condition'], categories=['intra', 'cross'], ordered=True)
    sns.violinplot(x='random', y='rmse', data=df, density_norm='width', inner=None, hue='condition', palette=colors, linewidth=0, legend=False, ax=ax)
    ax.set_xlabel('')
    ax.set_ylabel('')
    # 为每个小提琴图添加均值和标准差线，不改变x坐标
    grouped = df.groupby(['random','condition'], observed=False)['rmse']
    means = grouped.mean()
    stds = grouped.std()
    # 获取所有小提琴图的x坐标
    for violin_idx, artist in enumerate(ax.collections):
        if violin_idx >= len(means): 
            break
        mean = means.iloc[violin_idx]
        std = stds.iloc[violin_idx]
        # 获取小提琴图的中心x坐标
        paths = artist.get_paths()
        if not paths:
            continue
        verts = paths[0].vertices
        x_center = np.mean(verts[:, 0])
        # 画均值线
        ax.plot([x_center - 0.15, x_center + 0.15], [mean, mean], '-', color='red', lw=0.5)
        # 画标准差线
        ax.plot([x_center, x_center], [mean - std, mean + std], '--', color='black', lw=0.5)
        ax.plot([x_center - 0.1, x_center + 0.1], [mean - std, mean - std], '-', color='black', lw=0.5)
        ax.plot([x_center - 0.1, x_center + 0.1], [mean + std, mean + std], '-', color='black', lw=0.5)
    ax.set_xticks([i for i in range(len(df['random'].unique()))])
    ax.set_xticklabels(sorted(df['random'].unique()))
    
    # ax.set_xticks([])
    ax.set_xlim([-0.5,9.5])
    ax.set_ylim(0, None)
    ax.set_ylabel(r'RMSE (\%)')
    plt.tight_layout()
    plt.savefig(save_path,dpi=600)
    plt.close()

def random_effect():
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    random_times = 10
    for di in datasets:
        condition_grop,random_grop,rmses = [],[],[]
        for j in range(random_times):
            path = os.path.join('results/random_six_samples_effect',di, f'random_{j+1}')
            intra_condition_rmses,cross_condition_rmses = get_rmse(path,di)
            condition_grop.append(np.array(['intra']*len(intra_condition_rmses)))
            condition_grop.append(np.array(['cross']*len(cross_condition_rmses)))
            random_grop.append(np.array([j+1]*(len(intra_condition_rmses)+len(cross_condition_rmses))))
            rmses.append(intra_condition_rmses)
            rmses.append(cross_condition_rmses)
        condition_grop = np.concatenate(condition_grop)
        random_grop = np.concatenate(random_grop)
        rmses = np.concatenate(rmses)
        df = pd.DataFrame({'condition':condition_grop, 'random':random_grop, 'rmse':rmses})
        if di == 'Dataset 1':
            plot_di(df,f'figs/fig7/fig7c_{di}.jpg')
        else:
            plot_di(df,f'figs/figS23/figS23_{di}.jpg')
    
def print_random_samples_errors(di):
    for j in range(10):
        path = os.path.join('results/random_six_samples_effect',di, f'random_{j+1}')
        intra_condition_rmses,cross_condition_rmses = get_rmse(path,di)
        # print(f'{di}_random{j+1}: {np.mean(intra_condition_rmses):.3f} ± {np.std(intra_condition_rmses):.3f} ; {np.mean(cross_condition_rmses):.3f} ± {np.std(cross_condition_rmses):.3f}')
        intra_condition_maxs,cross_condition_maxs = get_max(path,di)
        print(f'{di}_random{j+1}_max: {max(intra_condition_maxs):.3f}; {max(cross_condition_maxs):.3f}')

path_dir = 'figs/figS23'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
# random_effect()
print_random_samples_errors('Dataset 1')

### fig7d (Sensitivity to input voltage window)

In [None]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape,get_mae,get_medae,get_max

def plot_rmses(df,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(110/25.4, 85/25.4))
    pivot_table = df.pivot(index='y', columns='x', values='rmse')
    cmap = plt.cm.colors.LinearSegmentedColormap.from_list('custom_cmap', ['#313695','#4575b4','#74add1','#abd9e9', '#e0f3f8']+['#fee391','#fec44f','#fe9929','#ec7014','#cc4c02','#8c2d04'])
    heatmap = sns.heatmap(pivot_table, annot=True,fmt='.2f', cmap=cmap, linewidths=.5,ax=ax,annot_kws={"size": 4},cbar=False)
    cbar = plt.colorbar(heatmap.get_children()[0], ax=ax, shrink=1, aspect=30)
    cbar.ax.tick_params(width=0.25)
    cbar.ax.tick_params(which='minor',width=0.25)
    cbar.outline.set_linewidth(0.25)  
    # cbar.ax.set_title('Intensity')
    # 获取当前的边框（spines）并调整它们
    for spine in ['left','bottom']:
        ax.spines[spine].set_visible(True)   # 显示指定的边框
        ax.spines[spine].set_linewidth(0.25)  # 设置边框线条宽度

    # 隐藏顶部和右侧的边框
    for spine in ['top', 'right']:
        ax.spines[spine].set_visible(False)
    
    ax.tick_params(top=False,right=False,width=0.25,length=2)
    ax.tick_params(which='minor', bottom=False, left=False,top=False,right=False)
    ax.set_ylabel(None)
    ax.set_xlabel(None)
    plt.gca().invert_yaxis()
    # plt.subplots_adjust(right=0.95)
    plt.savefig(save_path)
    plt.close()

def window_effect():
    dataset = 'Dataset 3'
    start_vols = np.linspace(3.3,3.95,14)
    start_vols = np.round(start_vols,2)
    end_vols = np.linspace(start_vols[0]+0.1,4.05,14)
    end_vols = np.round(end_vols,2)
    intra_rmses,cross_rmses = [],[]
    x,y = [],[]
    for i in range(len(start_vols)):
        l_vol = start_vols[i]
        r_vol = np.linspace(l_vol+0.1,4.05,14-i)
        for j in range(len(r_vol)):
            r_vol[j] = round(r_vol[j],2)
            path = f'results/voltage_window_effect/V({l_vol}-{r_vol[j]})'
            if l_vol == 3.6 and r_vol[j] == 3.8:
                path = f'results/soh_individual_dataset_results/Dataset 3'
            in_rmse,cr_rmse = get_rmse(path,dataset)
            x.append(l_vol)
            y.append(r_vol[j])
            intra_rmses.append(np.mean(in_rmse))
            cross_rmses.append(np.mean(cr_rmse))
    
    intra_df = pd.DataFrame({'x':x,'y':y,'rmse':intra_rmses})
    cross_df = pd.DataFrame({'x':x,'y':y,'rmse':cross_rmses})
    plot_rmses(intra_df,'figs/fig7/fig7d_dataset3_intra.jpg')
    plot_rmses(cross_df,'figs/fig7/fig7d_dataset3_cross.jpg')

def window_effect_errors():
    dataset = 'Dataset 3'
    start_vols = np.linspace(3.3,3.95,14)
    start_vols = np.round(start_vols,2)
    end_vols = np.linspace(start_vols[0]+0.1,4.05,14)
    end_vols = np.round(end_vols,2)
    intra_rmses,cross_rmses = [],[]
    for i in range(len(start_vols)):
        l_vol = start_vols[i]
        r_vol = np.linspace(l_vol+0.1,4.05,14-i)
        for j in range(len(r_vol)):
            r_vol[j] = round(r_vol[j],2)
            path = f'results/voltage_window_effect/V({l_vol}-{r_vol[j]})'
            if l_vol == 3.6 and r_vol[j] == 3.8:
                path = f'results/soh_individual_dataset_results/Dataset 3'
            in_rmse,cr_rmse = get_rmse(path,dataset)
            intra_rmses.append(np.mean(in_rmse))
            cross_rmses.append(np.mean(cr_rmse))
    intra_rmses = np.array(intra_rmses)
    cross_rmses = np.array(cross_rmses)
    print(len(intra_rmses))
    print(len(intra_rmses[intra_rmses<=2.5]))
    print(len(cross_rmses[cross_rmses<=2.5]))
    

# window_effect()
window_effect_errors()

### fig7e, figS24 (Impact of fine-tuning strategy)

In [None]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os
from load_results import get_results,get_rmse,get_mape,get_mae,get_medae,get_max

def plot_ft(df,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    fig, ax = plt.subplots(figsize=(60/25.4, 40/25.4))
    colors = ['#4575b4','#f46d43']
    strategy_order = ['Freeze None', 'Freeze Conv only', 'Freeze FC only', 'Unfreeze Recurrent', 'Unfreeze last fc']
    df['strategy'] = pd.Categorical(df['strategy'], categories=strategy_order, ordered=True)
    df['condition'] = pd.Categorical(df['condition'], categories=['intra', 'cross'], ordered=True)
    sns.violinplot(x='strategy', y='rmse', data=df, density_norm='width', inner=None, hue='condition', palette=colors, linewidth=0, legend=False, ax=ax, order=strategy_order)
    ax.set_xlabel('')
    ax.set_ylabel('')
    # 为每个小提琴图添加均值和标准差线，不改变x坐标
    grouped = df.groupby(['strategy','condition'], observed=False)['rmse']
    means = grouped.mean()
    stds = grouped.std()
    # 获取所有小提琴图的x坐标
    for violin_idx, artist in enumerate(ax.collections):
        if violin_idx >= len(means): 
            break
        mean = means.iloc[violin_idx]
        std = stds.iloc[violin_idx]
        # 获取小提琴图的中心x坐标
        paths = artist.get_paths()
        if not paths:
            continue
        verts = paths[0].vertices
        x_center = np.mean(verts[:, 0])
        # 画均值线
        ax.plot([x_center - 0.15, x_center + 0.15], [mean, mean], '-', color='red', lw=0.5)
        # 画标准差线
        ax.plot([x_center, x_center], [mean - std, mean + std], '--', color='black', lw=0.5)
        ax.plot([x_center - 0.1, x_center + 0.1], [mean - std, mean - std], '-', color='black', lw=0.5)
        ax.plot([x_center - 0.1, x_center + 0.1], [mean + std, mean + std], '-', color='black', lw=0.5)
    
    ax.set_xticks([i for i in range(len(strategy_order))])
    ax.set_xticklabels(strategy_order, rotation=40)

    ax.set_xlim([-0.5,4.5])
    ax.set_ylim(0, None)
    ax.set_ylabel(r'RMSE (\%)')
    plt.tight_layout()
    plt.savefig(save_path,dpi=600)
    plt.close()

def ft_effect():
    datasets = ['Dataset 1','Dataset 2','Dataset 3','Dataset 4','Dataset 5','Dataset 6']
    strategys = ['Freeze None','Freeze Conv only','Freeze FC only','Unfreeze Recurrent','Unfreeze last fc']
    for di in datasets:
        condition_grop,random_grop,rmses = [],[],[]
        for sti in strategys:
            path = os.path.join('results/ft_module_effect',di,sti)
            if sti == 'Unfreeze Recurrent':
                path = os.path.join(f'results/soh_individual_dataset_results/{di}')
            intra_condition_rmses,cross_condition_rmses = get_rmse(path,di)
            condition_grop.append(np.array(['intra']*len(intra_condition_rmses)))
            condition_grop.append(np.array(['cross']*len(cross_condition_rmses)))
            random_grop.append(np.array([sti]*(len(intra_condition_rmses)+len(cross_condition_rmses))))
            rmses.append(intra_condition_rmses)
            rmses.append(cross_condition_rmses)
        condition_grop = np.concatenate(condition_grop)
        random_grop = np.concatenate(random_grop)
        rmses = np.concatenate(rmses)
        df = pd.DataFrame({'condition':condition_grop, 'strategy':random_grop, 'rmse':rmses})
        if di == 'Dataset 1':
            plot_ft(df,f'figs/fig7/fig7e_{di}.jpg')
        else:
            plot_ft(df,f'figs/figS24/figS24_{di}.jpg')

path_dir = 'figs/figS24'
if not os.path.exists(path_dir):
    os.makedirs(path_dir)
ft_effect()

# weak label generation

### fig8

In [None]:
import matplotlib.pyplot as plt
import scienceplots
import seaborn as sns
import numpy as np
import pandas as pd
import os

def plot_weak_label(sohs, weak_labels,cd_code,dataset,save_path):
    plt.style.use('science')
    plt.style.use('ieee')
    plt.rcParams.update({'font.family': 'Times New Roman', 'font.size': 6})
    colors = None
    fig, ax = plt.subplots(figsize=(55/25.4, 50/25.4))
    cds = np.unique(cd_code)
    if len(cds) == 2:
        colors = ['#4575b4','#f46d43']
    elif len(cds) == 3:
        colors = ['#4575b4','#fdae61','#f46d43']
    elif len(cds) == 4:
        colors = ['#4575b4','#74add1','#fdae61','#f46d43']
    elif len(cds) == 5:
        colors = ['#4575b4','#74add1','#fee090','#fdae61','#f46d43']
    elif len(cds) == 6:
        colors = ['#4575b4','#74add1','#abd9e9','#fee090','#fdae61','#f46d43']
    for i,cd in enumerate(cds):
        idx = cd_code == cd
        ax.scatter(sohs[idx], weak_labels[idx],color = colors[i], s=0.25,label=r'\#'+str(cd))
    ax.set_xlabel('SOH')
    ax.set_ylabel('Weak label')
    ax.set_title(dataset)
    ax.legend()
    plt.tight_layout()
    plt.savefig(save_path, dpi=600)
    plt.close()

def weal_label():
    datasets = ['Dataset 1', 'Dataset 2', 'Dataset 3','Dataset 4', 'Dataset 5', 'Dataset 6']
    condition_code = {
            'Dataset 1': [1,2],
            'Dataset 2': [5,6,7],
            'Dataset 3': [8,9,10,11],
            'Dataset 4': [16,17,18],
            'Dataset 5': [19,20,21],
            'Dataset 6': [22,23,24,25,26,27]
        }
    ft_cells = {
        'Dataset 1':[['25-1','25-2'],['45-1','45-2']],
        'Dataset 2':[['1C-4','1C-5'],['2C-5'],['3C-4','3C-8']],
        'Dataset 3':[['0-CC-2'],['10-CC-1'],['25-CC-2'],['40-CC-1']],
        'Dataset 4':[['CY25-05_1-#14','CY25-05_1-#16'],['CY35-05_1-#2'],['CY45-05_1-#22','CY45-05_1-#23']],
        'Dataset 5':[['2C-6','2C-7'],['3C-5','3C-6'],['4C-5']],
        'Dataset 6':[['25_0.5b_100'],['25_1a_100'],['25_2a_100'],['25_3b_100'],['35_1a_100'],['35_2b_100']],
    }
    
    for di in datasets:
        sohs,weal_labels,cd_code = [],[],[]
        for j,condition in enumerate(condition_code[di]):
            for cell in ft_cells[di][j]:
                data = np.load(f'data/{di}/{cell}.npz')
                capacity = data['capacity']
                weak_labeli = data['weak_label']
                sohs.append(capacity/capacity[0])
                weal_labels.append(weak_labeli/weak_labeli[0])
                cd_code.append([condition]*len(capacity))
        sohs = np.concatenate(sohs)
        weal_labels = np.concatenate(weal_labels)
        cd_code = np.concatenate(cd_code)
        plot_weak_label(sohs,weal_labels,cd_code,di,f'figs/fig8/fig8_{di}.jpg')

save_path = 'figs/fig8'
if not os.path.exists(save_path):
    os.makedirs(save_path)
weal_label()