In [7]:
import os
import time
import pickle
import math
import json
from tqdm import tqdm
from copy import deepcopy

import random
import numpy as np


from datetime import datetime

import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn

from sklearn.metrics import roc_auc_score,mean_squared_error,mean_absolute_error, r2_score

from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset, Sampler, TensorDataset
from torch.utils.data.sampler import RandomSampler
    

import warnings
warnings.filterwarnings('ignore')

In [2]:
class Trainer():
    def __init__(self, lr, n_epochs, device, patience, lamda, alpha, model_name, 
                 gamma_rul_consistency=0.0, epsilon_drop=1.0): # <-- New parameters
        """
        Args:
            lr (float): Learning rate
            n_epochs (int): The number of training epoch
            device: 'cuda' or 'cpu'
            patience (int): How long to wait after last time validation loss improved.
            lamda (float): The weight of the main RUL loss
            alpha (List[float]): The weights of Capacity (SOH) loss
            model_name (str): The model save path
            gamma_rul_consistency (float): Weight for the RUL sequential consistency loss.
                                           Set to 0 to disable.
            epsilon_drop (float): Minimum expected RUL drop for consistency loss.
        """
        self.lr = lr
        self.n_epochs = n_epochs
        self.device = device
        self.patience = patience
        self.model_name = model_name
        self.lamda = lamda
        self.alpha = alpha
        self.gamma_rul_consistency = gamma_rul_consistency # <-- Store new param
        self.epsilon_drop = epsilon_drop                   # <-- Store new param

    def train(self, train_loader, valid_loader, model, load_model):
        model = model.to(self.device)
        device = self.device
        optimizer = optim.Adam(model.parameters(), lr=self.lr)
        model_name = self.model_name
        lamda = self.lamda
        alpha = self.alpha
        
        loss_fn = nn.MSELoss() # For main RUL and SOH losses
        early_stopping = EarlyStopping(self.patience, verbose=True)
        # loss_fn.to(self.device) # MSELoss doesn't need to be sent to device explicitly

        train_loss_main_rul_metric = [] # To track the main RUL MSE for reporting
        valid_loss_metric = [] # Validation metric (can be main RUL or combined SOH)
        total_combined_loss_log = [] # To log the actual loss optimized
        total_combined_loss_log_reg=[]
        total_combined_loss_log_rul=[]
        total_combined_loss_log_soh=[]

        for epoch in range(self.n_epochs):
            model.train()
            y_true_rul_epoch, y_pred_rul_epoch = [], [] # For main RUL metric
            batch_losses, batch_losses_reg,batch_losses_rul,batch_losses_soh = [],[],[],[] # Store combined loss from each batch

            for step, (x, y) in enumerate(train_loader):
                optimizer.zero_grad()
                
                x = x.to(device)
                y = y.to(device) # y[:,0] is true final RUL, y[:,1:] are 10 true SOHs

                # --- MODIFIED SECTION FOR LOSS CALCULATION ---
                # Ensure your model's forward pass is modified as discussed:
                # It should return: pred_final_rul, pred_soh_sequence, pred_rul_sequence
                pred_final_rul, pred_soh_sequence, pred_rul_sequence = model(x)

                # 1. Main RUL Loss (on the single final RUL prediction)
                # Ensure pred_final_rul is correctly shaped (B, 1) or (B,) for y[:,0] (B,)
                loss_main_rul = lamda * loss_fn(pred_final_rul.squeeze(), y[:, 0])
                
                # 2. SOH Loss (on the sequence of 10 SOH predictions)
                loss_main_soh = torch.tensor(0.0, device=device)
                # y has shape (batch_size, 11) -> 1 RUL + 10 SOH
                # So, y.shape[1] - 1 = 10 (number of SOH values)
                num_soh_points = pred_soh_sequence.shape[1] # Should be 10
                for i in range(num_soh_points):
                    loss_main_soh += loss_fn(pred_soh_sequence[:, i], y[:, i + 1]) * alpha[i]
                
                # 3. RUL Sequential Consistency Loss (New)
                loss_rul_consistency = torch.tensor(0.0, device=device)
                if self.gamma_rul_consistency > 0 and pred_rul_sequence is not None:
                    loss_rul_consistency = self.gamma_rul_consistency * \
                        compute_rul_sequential_consistency_loss(
                            pred_rul_sequence, 
                            self.epsilon_drop, 
                            device
                        )
                
                # Total loss for this batch
                current_batch_total_loss = loss_main_rul + loss_main_soh + loss_rul_consistency
                # --- END OF MODIFIED SECTION ---

                current_batch_total_loss.backward()
                optimizer.step()
                batch_losses.append(current_batch_total_loss.cpu().detach().numpy())
                batch_losses_reg.append(loss_rul_consistency.cpu().detach().numpy())
                batch_losses_rul.append(loss_main_rul.cpu().detach().numpy())
                batch_losses_soh.append(loss_main_soh.cpu().detach().numpy())

                y_pred_rul_epoch.append(pred_final_rul.squeeze(-1)) # For main RUL training metric
                y_true_rul_epoch.append(y[:, 0])

            y_true_rul_epoch_cat = torch.cat(y_true_rul_epoch, axis=0)
            y_pred_rul_epoch_cat = torch.cat(y_pred_rul_epoch, axis=0)

            # For logging: main RUL MSE and average total optimized loss
            epoch_main_rul_mse = mean_squared_error(
                y_true_rul_epoch_cat.cpu().detach().numpy(), 
                y_pred_rul_epoch_cat.cpu().detach().numpy()
            )
            train_loss_main_rul_metric.append(epoch_main_rul_mse)
            
            avg_epoch_total_loss = np.mean(batch_losses)
            avg_epoch_total_loss_reg = np.mean(batch_losses_reg)
            avg_epoch_total_loss_rul = np.mean(batch_losses_rul)
            avg_epoch_total_loss_soh = np.mean(batch_losses_soh)
            
            total_combined_loss_log.append(avg_epoch_total_loss)
            total_combined_loss_log_reg.append(avg_epoch_total_loss_reg)
            total_combined_loss_log_rul.append(avg_epoch_total_loss_rul)
            total_combined_loss_log_soh.append(avg_epoch_total_loss_soh)
            
            # ---- Validation Part ----
            model.eval()
            y_true_val, y_pred_val = [], [] # For validation metric calculation
            # The validation logic seems to calculate MSE based on different parts 
            # of the output depending on lambda and alpha.
            # Early stopping should ideally monitor a consistent validation loss (e.g., sum of SOH MSEs or main RUL MSE).
            # For simplicity, let's assume early stopping monitors the main RUL MSE on validation.
            # You might need to adjust this if your primary validation metric is different.
            
            current_valid_losses_for_early_stopping = []
            with torch.no_grad():
                for step, (x_val, y_val) in enumerate(valid_loader):
                    x_val = x_val.to(device)
                    y_val = y_val.to(device)
                    
                    # Assuming model returns all three parts
                    val_final_rul, val_soh_seq, _ = model(x_val) # rul_sequence not used for this val loss
                    
                    # Example: Using main RUL MSE for validation metric
                    # This needs to align with how early_stopping expects val_loss
                    val_loss_batch = loss_fn(val_final_rul.squeeze(), y_val[:,0])
                    current_valid_losses_for_early_stopping.append(val_loss_batch.cpu().numpy())

            epoch_val_loss = np.mean(current_valid_losses_for_early_stopping)
            valid_loss_metric.append(epoch_val_loss)
            
            if self.n_epochs > 100:
                if (epoch % 100 == 0 and epoch !=0):
                    print(f'Epoch: {epoch}')
                    print(f'-- Train MainRUL MSE: {epoch_main_rul_mse:.4f}', 
                          f'-- Valid MainRUL MSE: {epoch_val_loss:.4f}',
                          f'-- Avg Total Optimized Loss (Train): {avg_epoch_total_loss:.10f}',
                         f'-- || Reg loss (Train): {avg_epoch_total_loss_reg:.10f}',
                         f'-- || rul loss (Train): {avg_epoch_total_loss_rul:.10f}',
                         f'-- || soh loss (Train): {avg_epoch_total_loss_soh:.10f}')

                early_stopping(epoch_val_loss, model, f'{model_name}_best.pt') # Pass the chosen validation metric
                if early_stopping.early_stop:
                    print("Early stopping")
                    break
                
        if load_model:
            print(f"Loading best model from {model_name}_best.pt")
            model.load_state_dict(torch.load(f'{model_name}_best.pt'))
        else:
            #print(f"Saving final model to {model_name}_end.pt")
            torch.save(model.state_dict(), f'{model_name}_end.pt')

        # Return main RUL MSE log, validation metric log, and total optimized loss log
        return model, train_loss_main_rul_metric, valid_loss_metric, total_combined_loss_log,total_combined_loss_log_reg,total_combined_loss_log_rul,total_combined_loss_log_soh

    # def test(...): # Your test method seems okay, ensure it uses the modified model correctly.
    # It will now get 3 outputs from model(x). Use pred_final_rul for y_ and pred_soh_sequence for soh_.
    def test(self, test_loader, model):
        model = model.to(self.device)
        device = self.device

        y_true_rul_list, y_pred_rul_list = [], []
        soh_true_list, soh_pred_list = [], []
        # If you want to inspect the RUL sequence during test:
        # rul_sequence_list = [] 

        model.eval()
        with torch.no_grad():
            for step, (x_test, y_test) in enumerate(test_loader):
                x_test = x_test.to(device)
                y_test = y_test.to(device) # y_test[:,0] is RUL, y_test[:,1:] is SOH

                # Model returns: pred_final_rul, pred_soh_sequence, pred_rul_sequence
                pred_final_rul, pred_soh_sequence, _ = model(x_test) # pred_rul_sequence ignored for standard metrics
                #print(pred_final_rul, pred_final_rul.shape)
                #print(pred_final_rul.squeeze(-1),  pred_final_rul.squeeze(-1).shape)
                y_pred_rul_list.append(pred_final_rul.squeeze(-1))
                y_true_rul_list.append(y_test[:, 0])
                
                soh_pred_list.append(pred_soh_sequence)
                soh_true_list.append(y_test[:, 1:])
                # rul_sequence_list.append(pred_rul_sequence)

        #print(len(y_pred_rul_list),y_pred_rul_list)
        y_true_rul_cat = torch.cat(y_true_rul_list, axis=0)
        y_pred_rul_cat = torch.cat(y_pred_rul_list, axis=0)
        soh_true_cat = torch.cat(soh_true_list, axis=0)
        soh_pred_cat = torch.cat(soh_pred_list, axis=0)
        
        mse_loss_rul = mean_squared_error(
            y_true_rul_cat.cpu().detach().numpy(), 
            y_pred_rul_cat.cpu().detach().numpy()
        )
        # Note: Your original test returned SOH tensors directly.
        # You can calculate SOH MSE similarly if needed.
        return y_true_rul_cat, y_pred_rul_cat, mse_loss_rul, soh_true_cat, soh_pred_cat

In [4]:
import torch
import torch.nn.functional as F

def compute_rul_sequential_consistency_loss(rul_sequence_pred, epsilon_drop=1.0, device='cuda'):
    """
    Computes the RUL sequential consistency loss.
    Encourages RUL_t - RUL_{t+1} >= epsilon_drop.

    Args:
        rul_sequence_pred (torch.Tensor): Predicted RUL sequence for the window. 
                                          Shape: (batch_size, 10).
        epsilon_drop (float): Minimum expected RUL decrease between consecutive steps.
                              If the 10 cycles are consecutive battery cycles and RUL is 
                              in cycles, 1.0 is a good default. For just strict decrease,
                              a very small positive number can be used.
        device (str): 'cuda' or 'cpu'.

    Returns:
        torch.Tensor: Scalar loss value (mean over batch and sequence).
    """
    if rul_sequence_pred is None or rul_sequence_pred.shape[1] < 2:
        # Not enough RUL predictions in the sequence to compute differences
        return torch.tensor(0.0, device=device, requires_grad=False)

    # Calculate actual drops: RUL_t+1 - RUL_{t} <0 ->RUL_{t+1}< RUL_t because there are less cycles
    #from high values to low values so RUL_{t+1}< RUL_t 
    # RUL_t is rul_sequence_pred[:, :-1] (all but the last)
    # RUL_{t+1} is rul_sequence_pred[:, 1:] (all but the first)
    #reg
    actual_drops = rul_sequence_pred[:, 1:] - rul_sequence_pred[:, :-1]  # Shape: (batch_size, 9)
    #reg_fa (first all)
    #last --first/2-- first
    #actual_drops = rul_sequence_pred[:, -1] - rul_sequence_pred[:, 0]/2 #last>first>first/2 almost the same.
    # Penalty occurs if actual_drops >0
    loss_per_difference = F.relu(actual_drops)  # Shape: (batch_size, 9)

    return loss_per_difference.mean()

# 1-Data preparation

In [5]:
import torch
import torch.nn.functional as F
import torch.optim as optim

from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset, Sampler, TensorDataset
from torch.utils.data.sampler import RandomSampler

class BidirectionalLSTM(nn.Module):
    def __init__(self, nIn, nHidden, nOut, dropout):
        super(BidirectionalLSTM, self).__init__()
        """
        Args:
            nIn (int): The number of input unit
            nHidden (int): The number of hidden unit
            nOut (int): The number of output unit
        """
        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=False, batch_first=True)
        self.embedding = nn.Linear(nHidden, nOut)
        if dropout:
            self.dropout = nn.Dropout(p=0.5)
        else:
            self.dropout = dropout

    def forward(self, input):
        recurrent, _ = self.rnn(input)
        b, T, h = recurrent.size()
        t_rec = recurrent.contiguous().view(b * T, h)
        
        if self.dropout:
            t_rec = self.dropout(t_rec)
        output = self.embedding(t_rec)
        output = output.contiguous().view(b, T, -1)

        return output

class CRNN(nn.Module):
    def __init__(self, ni, nc, no, nh, n_rnn=2, leakyRelu=False,sigmoid = False, mode="a_all"):
        """
        Args:
            ni (int): The number of input unit
            nc (int): The number of original channel
            no (int): The number of output unit
            nh (int): The number of hidden unit

        """
        super(CRNN, self).__init__()
        self.mode=mode

        ks = [3, 3, 3,    3, 3,   3, 3]
        ps = [0, 0, 0,    0, 0,   0, 0]
        ss = [2, 2, 2,    2, 2,   2, 1]
        nm = [8, 16, 64,  64, 64, 64, 64]

        cnn = nn.Sequential()

        def convRelu(i, cnn, batchNormalization=False):
            nIn = nc if i == 0 else nm[i - 1]
            if i == 3: nIn = 64
            nOut = nm[i]
            cnn.add_module('conv{0}'.format(i),
                           nn.Conv2d(nIn, nOut, (ks[i],1), (ss[i],1), (ps[i],0)))
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
            if leakyRelu:
                cnn.add_module('relu{0}'.format(i),
                               nn.LeakyReLU(0.2, inplace=True))
            else:
                cnn.add_module('relu{0}'.format(i), nn.ReLU(True))

        convRelu(0,cnn)
        cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d((2,1), (2,1)))
        convRelu(1,cnn)
        cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d((2,1), (2,1)))
        convRelu(2,cnn)
        cnn.add_module('pooling{0}'.format(2),
                       nn.MaxPool2d((2, 1), (2, 1), (0, 0)))
        self.sigmoid = sigmoid
        self.cnn = cnn
        self.rnn = nn.Sequential(
            BidirectionalLSTM(64, nh, nh, False),
            BidirectionalLSTM(nh, nh, no, False),)
        #option 2 traditional:
        if self.mode=="rul_last":
            self.rul = nn.Linear(64, 1) 
            self.soh = nn.Linear(64, 1)
        if self.mode in ["a_all","rul_all"]:
            self.rul = nn.Linear(10, 1) 
            self.soh = nn.Linear(64, 1)
            
        self.rul_sequential_head = nn.Linear(64, 1) # Assuming 'no' is the feature size per time step from RNN
                #last a last loss 
        if self.mode=="soh_a_l_last":
            self.soh = nn.Linear(64, 1)
            self.rul = nn.Linear(10, 1) 




    def forward(self, input):
        """
        Input shape: [b, c, h, w]
        Output shape: 
            rul [b, 1]
            soh [b, 10]
        """
        conv = self.cnn(input)
        b, c, h, w = conv.size()
        conv = conv.squeeze(2)
        conv = conv.permute(0, 2, 1)
        output = self.rnn(conv)
        #print(output.shape) #256,10,64
        
        # In CRNN forward method, after: output = self.rnn(conv)
        # output shape: (batch_size, 10, num_rnn_features) e.g., (batch_size, 10, 64)
        pred_rul_sequence = self.rul_sequential_head(output).squeeze(-1) # Shape: (batch_size, 10)
        soh_pred_sequence = self.soh(output).squeeze(-1) # Assuming self.soh is Linear(num_rnn_features, 1)
        

        if self.mode in ["rul_last"]:
            soh = output[:,-1,:].squeeze() 
            pred_final_rul=self.rul(soh)
        if self.mode in ["soh_a_l_last"]:
            soh = output[:,-1,:].squeeze() 
            soh = self.soh(soh)  # torch size [256, 1]
            soh = soh.view(output.shape[0], 1)  # Asegura que sea [256, 1]
            soh_pred_sequence = soh.repeat(1, output.shape[1])  # Repite en la segunda dimensión
            
        if self.mode not in ["rul_last"]:
            pred_final_rul = self.rul(soh_pred_sequence) 
            

        return pred_final_rul, soh_pred_sequence, pred_rul_sequence # New return signature        

In [8]:

device = 'cuda' 
#model
pretrain_model_path=f'model/wx_inner/Big_model_RUL_laststate_reg_0.1_best.pt' 

trainer = Trainer(lr = "", n_epochs = "",device = device, patience = "",
                  lamda = "", alpha = "", model_name="")

tmp_fea, tmp_lbl = np.random.rand(13, 10, 100, 4),np.random.rand(13, 10)
test_fea_ = tmp_fea[:].copy()
test_lbl_ = tmp_lbl[:].copy()
test_fea_ = test_fea_.transpose(0,3,2,1)
testset = TensorDataset(torch.Tensor(test_fea_), torch.Tensor(test_lbl_))
test_loader = DataLoader(testset,batch_size=len(testset), shuffle=False, drop_last = False)
model = CRNN(100,4,64,64, mode="rul_last") #this defines the forward pass, it depends on the model
model = model.to(device)
model.load_state_dict(torch.load(pretrain_model_path)) #,map_location=torch.device('cpu')))
y_true, y_pred, _, soh_true, soh_pred = trainer.test(test_loader, model)

rul_pred=(y_pred.cpu().detach().numpy())
soh_pred= soh_pred.cpu().detach().numpy()[:,-1] #last SoH