In [62]:
# Import modules
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, Subset
import optuna
np.random.seed(87)

In [63]:
def set_device():
    """
    Set the device to GPU if available, otherwise use CPU.

    Returns:
        device (torch.device): The device to use for training.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if torch.cuda.is_available():
        print(f'Using device: {device}')
        print(f'GPU: {torch.cuda.get_device_name(0)}')
    else:
        print(f'Using device: {device}')

    return device

In [64]:
def norm_data(name):
    """
    Load the data from the csv file and normalize the data.

    Args:
        name (str): The name of the csv file.

    Returns:
        ndf (pd.DataFrame): The normalized data.
        exp_num_list (list): List of experiment numbers in order.
    """
    # raw data
    df = pd.read_csv(name) 

    # normalized data
    ndf = pd.DataFrame() 

    # the range of min-max normalization for each feature
    range_mm={
        'V': {'min':df['V'].min()*0.8, 'max': df['V'].max()*1.2},
        'E': {'min':df['E'].min()*0.8, 'max': df['E'].max()*1.2},
        'VF': {'min':df['VF'].min()*0.8, 'max': df['VF'].max()*1.2},
        'VA': {'min':df['VA'].min()*0.8, 'max': df['VA'].max()*1.2},
        'VB': {'min':df['VB'].min()*0.8, 'max': df['VB'].max()*1.2},
        'CFLA': {'min':0, 'max': df['CFLA'].max()*1.2},
        'CALA': {'min':0, 'max': df['CALA'].max()*1.2},
        'CBLA': {'min':0, 'max': df['CBLA'].max()*1.2},
        'CFK': {'min':0, 'max': df['CFK'].max()*1.2},
        'CAK': {'min':0, 'max': df['CAK'].max()*1.2},
        'CBK': {'min':0, 'max': df['CBK'].max()*1.2},
        'I': {'min':0, 'max': df['I'].max()*1.2},
    }
    
    # add experiment number and time
    ndf['exp'] = df['exp']; ndf['t'] = df['t'] 

    # min-max normalization
    for col in ['V', 'E', 'VF', 'VA', 'VB', 'CFLA', 'CALA', 'CBLA', 'CFK', 'CAK', 'CBK', 'I']: # min-max normalization
        if col in range_mm:
            ndf[col] = (df[col] - range_mm[col]['min'])/(range_mm[col]['max'] - range_mm[col]['min'])
        else:
            ndf[col] = df[col]

    # Get the unique experiment numbers in order
    exp_num_list = sorted(ndf['exp'].unique())

    return ndf, exp_num_list

In [65]:
def seq_data_const(ndf):
    """
    Set the data sequences.

    Args:
        ndf (pd.DataFrame): The normalized data.

    Returns:
        sequences (list): The sequences of the data.
    """
    sequences = []
    feature_cols = ['V', 'E', 'VF', 'VA', 'VB', 'CFLA', 'CALA', 'CBLA', 'CFK', 'CAK', 'CBK', 'I']
    
    # get the sequences of the data for each experiment
    for exp in ndf['exp'].unique():
        exp_data = ndf[ndf['exp'] == exp].sort_values(by='t')
        sequences.append(exp_data[feature_cols].values)
    
    return sequences

In [66]:
def padded_sequences(sequences):
    """
    Pad the sequences.

    Args:
        sequences (list): The sequences of the data.

    Returns:
        padded_sequences (torch.Tensor): The padded sequences.
    """
    max_seq_len = max([len(seq) for seq in sequences])
    seq_len = [len(seq) for seq in sequences]
    padded_sequences = pad_sequence([torch.tensor(seq) for seq in sequences], batch_first=True, padding_value=-1)
    
    return padded_sequences, seq_len, max_seq_len

In [67]:
def gen_dataset(pad_seq, seq_len):
    """
    Generate the dataset.

    Args:
        pad_seq (torch.Tensor): The padded sequences.
        seq_len (list): The length of the sequences.

    Returns:
        dataset (torch.utils.data.Dataset): The dataset.
    """
    input_tensor = pad_seq.float()
    seq_len_tensor = torch.tensor(seq_len)
    dataset = TensorDataset(input_tensor, seq_len_tensor)
    return dataset

In [68]:
def dataloaders(dataset, exp_num_list, batch_size=4):
    """
    Split the dataset into train/val/test with 8:1:1 ratio
    
    Args:
        dataset: TensorDataset
        exp_num_list: list of experiment numbers
        batch_size: batch size
        random_state: random seed
        
    Returns:
        tuple: (train_loader, val_loader, test_loader)
    """

    
    # required train experiment numbers
    required_train_exps = [1, 3, 5, 6, 11, 15, 17, 19, 20, 40, 41, 42]
    
    # all experiment numbers
    all_exps = exp_num_list
    total_exps = len(all_exps)
    
    # batch_size
    batch_size = math.ceil(len(dataset)/10)

    # 8:1:1 ratio
    train_count = int(total_exps * 0.8)
    val_count = math.ceil(total_exps * 0.1)
    
    # remaining experiments
    remaining_exps = [exp for exp in all_exps if exp not in required_train_exps]
    
    # number of experiments to add to train
    additional_train_needed = train_count - len(required_train_exps)
    
    if additional_train_needed < 0:
        raise ValueError("The number of required train experiments is greater than the total train set. Please adjust required_train_exps.")
    
    # shuffle remaining experiments
    np.random.shuffle(remaining_exps)
    
    # split remaining experiments into train, val, test
    train_exps = required_train_exps + remaining_exps[:additional_train_needed]
    val_exps = remaining_exps[additional_train_needed:additional_train_needed + val_count]
    test_exps = remaining_exps[additional_train_needed + val_count:]
    
    print(f"Actual split:")
    print(f"  Train: {sorted(train_exps)} ({len(train_exps)} experiments)")
    print(f"  Val: {sorted(val_exps)} ({len(val_exps)} experiments)")  
    print(f"  Test: {sorted(test_exps)} ({len(test_exps)} experiments)")
    
    # find indices of each experiment (exp_num_list and dataset have the same order)
    train_indices = []
    val_indices = []
    test_indices = []
    
    for idx, exp in enumerate(all_exps):
        if exp in train_exps:
            train_indices.append(idx)
        elif exp in val_exps:
            val_indices.append(idx)
        elif exp in test_exps:
            test_indices.append(idx)
    
    # split dataset into train, val, test
    train_subset = Subset(dataset, train_indices)
    val_subset = Subset(dataset, val_indices)
    test_subset = Subset(dataset, test_indices)
    
    # create DataLoader
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)
    
    print(f"\nCompleted DataLoader creation:")
    print(f"  Train: {len(train_subset) if train_subset else 0} sequences")
    print(f"  Val: {len(val_subset) if val_subset else 0} sequences")
    print(f"  Test: {len(test_subset) if test_subset else 0} sequences")
    
    return train_loader, val_loader, test_loader

In [69]:
class SequentialStateExtractor(nn.Module):
    """
    The module based on LSTM to extract hidden dynamics from the sequential pattern of BMED.
    The hidden state of each step accumulates the information of all previous steps.

    Args:
        input_nodes (int): The number of input nodes.
        hidden_nodes (int): The number of hidden nodes.
        num_layers (int): The number of layers.
        dropout (float): The dropout rate.
    
    Output:
        hidden_states: [batch_size, seq_len, hidden_nodes] - hidden state of each step
    """
    def __init__(self, input_nodes, hidden_nodes, num_layers, dropout=0.2):
        super().__init__()
        self.hidden_nodes = hidden_nodes
        self.num_layers = num_layers

        # LSTM Layer
        self.lstm = nn.LSTM(input_nodes, hidden_nodes, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)

        self.layer_norm = nn.LayerNorm(hidden_nodes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, seq_len):
        """
        Extract the hidden state of each step from the sequential pattern of BMED.

        Args:
            x: [batch_size, seq_len, input_nodes] - state sequence of BMED system
            seq_len [batch_size] - length of each sequence

        Returns:
            hidden_states: [batch_size, seq_len, hidden_nodes] - hidden state of each step
        """
        # check the input shape
        if x.size(0) != seq_len.size(0):
            raise ValueError(f"Batch size mismatch: input {x.size(0)} vs seq_len {seq_len.size(0)}")
        
        # Move the seq_len to CPU and transfer to integer
        seq_len_cpu = seq_len.detach().cpu().long()

        # check the length of sequence
        if (seq_len_cpu <= 0).any():
            invalid_lengths = seq_len_cpu[seq_len_cpu <= 0]
            raise ValueError(f'Invalid sequence lengths detected: {invalid_lengths.tolist()}. All sequence lengths mut be positive')
        
        # pack the padded sequence
        packed_input = pack_padded_sequence(x, seq_len_cpu, batch_first=True, enforce_sorted=False)
        packed_output, (hidden, cell) = self.lstm(packed_input)

        # re-pad the sequence
        lstm_out, output_lengths = pad_packed_sequence(packed_output, batch_first=True, total_length=x.size(1))

        # Normalization and dropout
        normed_output = self.layer_norm(lstm_out)
        return self.dropout(normed_output)

In [70]:
class PhysicalChangeDecoder(nn.Module):
    """
    The module based on MLP to decode the hidden state to the physical change.

    Args:
        hidden_nodes (int): The number of hidden nodes.
        output_nodes (int): The number of output nodes.
        num_layers (int): The number of layers.
        num_nodes (int): The number of nodes in the hidden layers.
        dropout (float): The dropout rate.
    
    Output:
        physical_changes: [batch_size, seq_len, output_nodes] - [dVA, dVB, dNALA, dNAK, dNBK, nI]
    """
    def __init__(self, hidden_nodes, output_nodes, num_layers=2, num_nodes=None, dropout=0.3):
        super().__init__()

        if num_nodes is None:
            num_nodes = hidden_nodes

        self.layers = nn.ModuleList()

        # input layer: hidden_nodes -> num_nodes
        self.layers.append(nn.Linear(hidden_nodes, num_nodes))
        self.layers.append(nn.LayerNorm(num_nodes))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.Dropout(dropout))

        # hidden layers: num_nodes -> num_nodes
        for i in range(num_layers - 1):
            self.layers.append(nn.Linear(num_nodes, num_nodes))
            self.layers.append(nn.LayerNorm(num_nodes))
            self.layers.append(nn.ReLU())
            self.layers.append(nn.Dropout(dropout))

        # output layer: num_nodes -> output_nodes
        self.layers.append(nn.Linear(num_nodes, output_nodes))

    def forward(self, hidden_states):
        """
        Decode the hidden state to the physical change.

        Args:
            hidden_states: [batch_size, seq_len, hidden_nodes] - hidden state of each step

        Returns:
            physical_changes: [batch_size, seq_len, output_nodes] - [dVA, dVB, dNALA, dNAK, dNBK, nI]
        """
        x = hidden_states
        for layer in self.layers:
            x = layer(x)
        return x

In [71]:
class PhysicsConstraintLayer(nn.Module):
    """
    The module based on MLP to apply the physical constraints to the physical changes.

    Output:
        new_state: [batch_size, seq_len, 12] - new state
    """
    def __init__(self, eps=1e-1):
        super().__init__()
        self.eps = eps # prevent division by zero

    def forward(self, physical_changes, current_state):
        """
        Apply the physical constraints to the physical changes.

        Args:
            physical_changes: [batch_size, seq_len, 7] - physical changes
            current_state: [batch_size, seq_len, 12] - current state

        Returns:
            new_state: [batch_size, seq_len, 12] - new state
        """
        # check the input shape
        if physical_changes.dim() != current_state.dim():
            raise ValueError(f"Dimension mismatch: physical_changes {physical_changes.shape} vs current_state {current_state.shape}")
        
        if current_state.size(-1) != 12:
            raise ValueError(f"Expected 12 state features, got {current_state.size(-1)}")
        
        if physical_changes.size(-1) != 7:
            raise ValueError(f"Expected 7 physical changes, got {physical_changes.size(-1)}")

        
        # extract the current state variables (keep the dimension)
        V = current_state[..., 0:1]     # Voltage (fixed)
        E = current_state[..., 1:2]     # External electrolyte concentration (fixed)
        VF = current_state[..., 2:3]    # Feed volume
        VA = current_state[..., 3:4]    # Acid volume
        VB = current_state[..., 4:5]    # Base volume
        CFLA = current_state[..., 5:6]  # LA concentration in Feed tank
        CALA = current_state[..., 6:7]  # LA concentration in Acid tank
        CBLA = current_state[..., 7:8]  # LA concentration in Base tank
        CFK = current_state[..., 8:9]   # K concentration in Feed tank
        CAK = current_state[..., 9:10]  # K concentration in Acid tank
        CBK = current_state[..., 10:11] # K concentration in Base tank
        I = current_state[..., 11:12]   # Current

        # calculate the mole of ion species
        NFLA = CFLA * VF; NALA = CALA * VA; NBLA = CBLA * VB
        NFK = CFK * VF; NAK = CAK * VA; NBK = CBK * VB

        # calculate the physical changes
        dVA = physical_changes[..., 0:1]    # Acid tank volume change (bidirectional)
        dVB = physical_changes[..., 1:2]    # Base tank volume change (bidirectional)
        dNALA = physical_changes[..., 2:3]  # LA change in Acid tank (unidirectional)
        dNBLA = physical_changes[..., 3:4]  # LA change in Base tank (unidirectional)
        dNAK = physical_changes[..., 4:5]   # K change in Acid tank (unidirectional)
        dNBK = physical_changes[..., 5:6]   # K change in Base tank (unidirectional)
        nI = physical_changes[..., 6:7]     # New current value
        
        # calculate the new volume
        nVF = VF - dVA - dVB  # New Feed tank volume 
        nVA = VA + dVA        # New Acid tank volume
        nVB = VB + dVB        # New Base tank volume

        # limit the ion species changes (unidirectional flow only)
        dNALA = torch.clamp(dNALA, min=0)
        dNBLA = torch.clamp(dNBLA, min=0)
        dNAK = torch.clamp(dNAK, min=0)
        dNBK = torch.clamp(dNBK, min=0)

        # calculate the new mole of ion species
        nNFLA = NFLA - dNALA - dNBLA  # New LA mole in Feed tank
        nNALA = NALA + dNALA         # New LA mole in Acid tank
        nNBLA = NBLA + dNBLA         # New LA mole in Base tank
        nNFK = NFK - dNAK - dNBK     # New K mole in Feed tank
        nNAK = NAK + dNAK            # New K mole in Acid tank
        nNBK = NBK + dNBK            # New K mole in Base tank

        # limit the physical changes
        nVF = torch.clamp(nVF, min=self.eps)
        nVA = torch.clamp(nVA, min=self.eps)
        nVB = torch.clamp(nVB, min=self.eps)
        nNFLA = torch.clamp(nNFLA, min=0)
        nNALA = torch.clamp(nNALA, min=0)
        nNBLA = torch.clamp(nNBLA, min=0)
        nNFK = torch.clamp(nNFK, min=0)
        nNAK = torch.clamp(nNAK, min=0)
        nNBK = torch.clamp(nNBK, min=0)
        nI = torch.clamp(nI, min=0)

        # calculate the new concentration
        nCFLA = nNFLA / nVF  # New LA concentration in Feed tank
        nCALA = nNALA / nVA  # New LA concentration in Acid tank
        nCBLA = nNBLA / nVB  # New LA concentration in Base tank
        nCFK = nNFK / nVF    # New K concentration in Feed tank
        nCAK = nNAK / nVA    # New K concentration in Acid tank
        nCBK = nNBK / nVB    # New K concentration in Base tank

        # assemble the new state
        new_state = torch.cat([
            V, E,  # fixed: voltage, external electrolyte concentration
            nVF, nVA, nVB,  # new volume
            nCFLA, nCALA, nCBLA,  # new LA concentration
            nCFK, nCAK, nCBK,     # new K concentration
            nI  # new current
        ], dim=-1)

        return new_state

In [72]:
class BMEDAutoregressiveModel(nn.Module):
    """
    The autoregressive model to predict the state of BMED system.
    """
    def __init__(self, state_extractor_params, decoder_params):
        super().__init__()
        self.state_extractor = SequentialStateExtractor(**state_extractor_params)
        self.physical_decoder = PhysicalChangeDecoder(**decoder_params)
        self.physics_constraint = PhysicsConstraintLayer()

    def forward(self, current_state, seq_lengths):
        """
        Predict the next step from the all previous steps.

        Args:
            current_state: [batch_size, seq_len, 12] - current state
            seq_lengths: [batch_size] - length of each sequence

        Returns:
            new_state: [batch_size, seq_len, 12] - new state
        """
        # Extract the hidden state of each step using LSTM
        hidden_states = self.state_extractor(current_state, seq_lengths)
        # Decode the hidden state to the physical change
        physical_changes = self.physical_decoder(hidden_states)
        # Calculate the new state using physical constraints
        new_state = self.physics_constraint(physical_changes, current_state)

        return new_state

In [73]:
def masked_mse_loss(pred, target, seq_len):
    """
    Calculate the masked MSE loss for the autoregressive model.

    Args:
        pred: [batch_size, seq_len, 12] - predicted state
        target: [batch_size, seq_len, 12] - target state
        seq_len: [batch_size] - length of each sequence

    Returns:
        avg_loss: average loss excluding the masked parts
    """
    # check the input shape
    if pred.shape != target.shape:
        raise ValueError(f"Shape mismatch: predictions {pred.shape} vs targets {target.shape}")

    if pred.size(0) != seq_len.size(0):
        raise ValueError(f"Batch size mismatch: predictions {pred.size(0)} vs sequence lengths {seq_len.size(0)}")
    
    batch_size, max_len, features = pred.shape

    # Move seq_len to CPU to be compatible with arange.
    seq_len_cpu = seq_len.detach().cpu().long()

    # Validation check on sequence lengths
    if (seq_len_cpu <= 0).any():
        invalid_lengths = seq_len_cpu[seq_len_cpu <= 0]
        raise ValueError(f'Invalid sequence lengths detected: {invalid_lengths.tolist()}. All sequence lengths must be positive.')

    # Check if any sequence length exceeds max_len
    if (seq_len_cpu > max_len).any():
        invalid_lengths = seq_len_cpu[seq_len_cpu > max_len]
        raise ValueError(f'Sequence lengths exceed max_len: {invalid_lengths.tolist()} > {max_len}')

    # Generate mask as long as the sequence length
    mask = torch.arange(max_len, device='cpu')[None, :] < seq_len_cpu[:, None]
    mask = mask.float().to(pred.device)

    # Calculate the MSE of each feature
    loss = F.mse_loss(pred, target, reduction='none')

    # Apply the mask to exclude the masked parts
    masked_loss_sum = (loss * mask.unsqueeze(-1)).sum()
    valid_elements = mask.sum() * features

    if valid_elements == 0:
        raise ValueError('No valid elements found after masking. Check sequence lengths and data.')
    
    masked_loss = masked_loss_sum / valid_elements

    return masked_loss

In [74]:
def free_running_data(input_seq, seq_len):
    """
    Prepare the data for free running .

    Args:
        input_seq: [batch_size, seq_len, 12] - input sequences
        seq_lengths: [batch_size] - length of each sequence

    Returns:
        init: [t0] initial state
        targets: [t1, t2, ..., t_n] next states
        target_seq_len: length of each target sequence
    """
    # initial state
    init = input_seq[:, 0, :]
    # target states
    targets = input_seq[:, 1:, :]
    # length of each target sequence
    if (seq_len - 1 < 1).any():
        invalid_lengths = seq_len[seq_len - 1 < 1]
        raise ValueError(f'The length of target sequence cannot be less than 1. Wrong seq_len: {invalid_lengths.tolist()}')
    target_seq_len = seq_len - 1

    return init, targets, target_seq_len


In [75]:
def free_running_prediction(model, init, targets_shape, device, mode='eval'):
    """
    Free running prediction using only initial state with different modes.
    
    Args:
        model: BMEDAutoregressiveModel
        initial_state: [batch_size, 12] - initial state
        targets_shape: tuple - shape of targets to match (batch_size, seq_len, features)
        device: computation device
        mode: 'eval' (evaluation), 'train' (training), 'simulation' (pure inference)
        
    Returns:
        predictions: [batch_size, targets_seq_len, 12] - predicted sequence
    """
    # Set model mode
    if mode == 'train':
        model.train()
        context_manager = torch.enable_grad()
    elif mode in ['eval', 'simulation']:
        model.eval()
        context_manager = torch.no_grad()
    else:
        raise ValueError(f"Invalid mode: {mode}. Choose from 'train', 'eval', 'simulation'")
    
    batch_size = init.size(0)
    num_steps = targets_shape[1]  # Use the actual targets sequence length
    
    # Initialize predictions with initial state
    pred = [init.unsqueeze(1)]  # [batch_size, 1, 12]
    current_state = init.unsqueeze(1)  # [batch_size, 1, 12]
    
    with context_manager:
        for step in range(num_steps):
            # Predict next state using current sequence
            seq_len = torch.full((batch_size,), current_state.size(1), device=device)
            next_state = model(current_state, seq_len)
            
            # Take the last predicted state
            next_step = next_state[:, -1:, :]  # [batch_size, 1, 12]
            pred.append(next_step)
            
            # Update current state sequence
            current_state = torch.cat([current_state, next_step], dim=1)
    
    # Return all predictions except the initial state
    return torch.cat(pred[1:], dim=1)  # [batch_size, num_steps, 12]

In [76]:
def train_epoch_free_running(model, train_loader, optimizer, device):
    """
    Train the model using free running approach for one epoch.
    
    Args:
        model: BMEDAutoregressiveModel
        train_loader: training data loader
        optimizer: optimizer
        device: computation device
        
    Returns:
        float: average training loss for the epoch
    """
    model.train()
    total_loss = 0.0
    num_batches = 0
    
    for batch_idx, (input_seq, seq_len) in enumerate(train_loader):
        input_seq = input_seq.to(device)
        seq_len = seq_len.to(device)
        
        # Clear gradients
        optimizer.zero_grad()
        
        # Prepare free running data
        init, targets, target_seq_len = free_running_data(input_seq, seq_len)
        
        # Free running prediction in train mode
        pred = free_running_prediction(
            model, init, targets.shape, device, mode='train'
        )
        
        # Calculate masked loss
        loss = masked_mse_loss(pred, targets, target_seq_len)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.000001)
        
        # Update parameters
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    return total_loss / num_batches if num_batches > 0 else 0.0


def validate_epoch_free_running(model, val_loader, device):
    """
    Validate the model using free running approach for one epoch.
    
    Args:
        model: BMEDAutoregressiveModel
        val_loader: validation data loader
        device: computation device
        
    Returns:
        float: average validation loss for the epoch
    """
    model.eval()
    total_loss = 0.0
    num_batches = 0
    
    with torch.no_grad():
        for batch_idx, (input_seq, seq_len) in enumerate(val_loader):
            input_seq = input_seq.to(device)
            seq_len = seq_len.to(device)
            
            # Prepare free running data
            init, targets, target_seq_len = free_running_data(input_seq, seq_len)
            
            # Free running prediction in eval mode
            pred = free_running_prediction(
                model, init, targets.shape, device, mode='eval'
            )
            
            # Calculate masked loss
            loss = masked_mse_loss(pred, targets, target_seq_len)
            
            total_loss += loss.item()
            num_batches += 1
    
    return total_loss / num_batches if num_batches > 0 else 0.0

In [77]:
# def train_free_running_model(model, train_loader, val_loader, optimizer, scheduler, device, 
#                              num_epochs=200, patience=20, min_epochs=10):
def train_free_running_model(model, train_loader, val_loader, optimizer, device, 
                             num_epochs=200, patience=20, min_epochs=10):
    """
    Complete training loop for free running model.
    
    Args:
        model: BMEDAutoregressiveModel
        train_loader: training data loader
        val_loader: validation data loader
        optimizer: optimizer
        scheduler: learning rate scheduler
        device: computation device
        num_epochs: maximum number of epochs
        patience: early stopping patience
        min_epochs: minimum epochs before early stopping
        
    Returns:
        dict: training history and best model state
    """
    best_val_loss = float('inf')
    best_train_loss = float('inf')
    best_total_loss = float('inf')
    patience_counter = 0
    train_history = []
    val_history = []
    best_model_state = None
    
    print(f"Starting Free Running Training...")
    print(f"Device: {device}")
    print(f"Max Epochs: {num_epochs}, Patience: {patience}")
    print("=" * 60)
    
    for epoch in range(num_epochs):
        # Training
        train_loss = train_epoch_free_running(model, train_loader, optimizer, device)
        
        # Validation
        val_loss = validate_epoch_free_running(model, val_loader, device)
        
        # Calculate total loss
        total_loss = train_loss + val_loss
        
        # Learning rate scheduling
        # if scheduler:
        #     scheduler.step(val_loss)
        
        # Record history
        train_history.append(train_loss)
        val_history.append(val_loss)
        
        # Early stopping check based on total loss
        if total_loss < best_total_loss:
            best_total_loss = total_loss
            best_val_loss = val_loss
            best_train_loss = train_loss
            patience_counter = 0
            best_model_state = model.state_dict().copy()
        else:
            patience_counter += 1
        
        # Print progress
        if (epoch + 1) % 10 == 0 or epoch < 10:
            print(f"Epoch {epoch+1:3d}/{num_epochs}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}, Total Loss = {total_loss:.6f} | Best: Train = {best_train_loss:.6f}, Val = {best_val_loss:.6f}, Total = {best_total_loss:.6f}")
        
        # Early stopping
        if epoch >= min_epochs and patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    print("=" * 60)
    print(f"Training completed!")
    print(f"Best train loss: {best_train_loss:.6f}")
    print(f"Best val loss: {best_val_loss:.6f}")
    print(f"Best total loss: {best_total_loss:.6f}")
    
    return {
        'train_history': train_history,
        'val_history': val_history,
        'best_train_loss': best_train_loss,
        'best_val_loss': best_val_loss,
        'best_total_loss': best_total_loss,
        'best_model_state': best_model_state,
        'final_epoch': epoch + 1
    }

In [78]:
# Load data and create dataloaders
print("Loading and preprocessing data...")
ndf, exp_num_list = norm_data('BMED_DATA_AG.csv')
sequences = seq_data_const(ndf)
padded_seq, seq_len, max_seq_len = padded_sequences(sequences)
dataset = gen_dataset(padded_seq, seq_len)

print(f"Dataset created with {len(dataset)} experiments")
print(f"Max sequence length: {max_seq_len}")
print(f"Experiment numbers: {sorted(exp_num_list)}")

# Create train/val/test dataloaders with stratified split
train_loader, val_loader, test_loader = dataloaders(dataset, exp_num_list, batch_size=4)

Loading and preprocessing data...
Dataset created with 39 experiments
Max sequence length: 37
Experiment numbers: [np.int64(0), np.int64(1), np.int64(2), np.int64(3), np.int64(4), np.int64(5), np.int64(6), np.int64(7), np.int64(8), np.int64(9), np.int64(10), np.int64(11), np.int64(12), np.int64(13), np.int64(14), np.int64(15), np.int64(16), np.int64(17), np.int64(18), np.int64(19), np.int64(20), np.int64(21), np.int64(22), np.int64(23), np.int64(24), np.int64(25), np.int64(26), np.int64(27), np.int64(28), np.int64(29), np.int64(30), np.int64(31), np.int64(32), np.int64(33), np.int64(34), np.int64(35), np.int64(36), np.int64(37), np.int64(38)]
Actual split:
  Train: [1, np.int64(2), 3, np.int64(4), 5, 6, np.int64(7), np.int64(8), np.int64(9), 11, np.int64(12), np.int64(13), 15, np.int64(16), 17, np.int64(18), 19, 20, np.int64(21), np.int64(25), np.int64(26), np.int64(27), np.int64(29), np.int64(30), np.int64(32), np.int64(33), np.int64(35), np.int64(37), 40, 41, 42] (31 experiments)
  V

In [79]:
# Initialize model and training setup
device = set_device()

study = optuna.load_study(study_name="bmed_autoregressive_optimization", storage="sqlite:///bmed_optuna_study.db")
best_params = study.best_params

# Model parameters
state_extractor_params = {
    'input_nodes': 12,
    'hidden_nodes': best_params['hidden_size'],
    #'num_layers': best_params['num_layers'],
    'num_layers': 2,
    'dropout': best_params['extractor_dropout']
}

decoder_params = {
    'hidden_nodes': best_params['hidden_size'],
    'output_nodes': 7,  # [dVA, dVB, dNALA, dNBLA, dNAK, dNBK, nI]
    #'num_layers': best_params['decoder_layers'],
    'num_layers': 2,
    'num_nodes': best_params['decoder_nodes'],
    'dropout': best_params['decoder_dropout']
}

# Initialize model
model = BMEDAutoregressiveModel(state_extractor_params, decoder_params)
model = model.to(device)

# Training setup
optimizer = torch.optim.AdamW(model.parameters(), lr=best_params['learning_rate'], weight_decay=best_params['weight_decay'])
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)

print(f"Model initialized with {sum(p.numel() for p in model.parameters())} parameters")
print(f"Model on device: {next(model.parameters()).device}")

Using device: cuda
GPU: NVIDIA GeForce RTX 4080 SUPER
Model initialized with 655015 parameters
Model on device: cuda:0


In [80]:
# Start free running training
print("\n🚀 Starting Free Running Training...")
training_results = train_free_running_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    #scheduler=scheduler,
    device=device,
    num_epochs=10000,
    patience=1000,
    min_epochs=1000
)

# Load best model
if training_results['best_model_state'] is not None:
    model.load_state_dict(training_results['best_model_state'])
    print("✅ Best model loaded!")
else:
    print("⚠️ No best model found, using current state")


🚀 Starting Free Running Training...
Starting Free Running Training...
Device: cuda
Max Epochs: 10000, Patience: 1000


Epoch   1/10000: Train Loss = 90.282280, Val Loss = 1.800746, Total Loss = 92.083026 | Best: Train = 90.282280, Val = 1.800746, Total = 92.083026
Epoch   2/10000: Train Loss = 1.104173, Val Loss = 0.205180, Total Loss = 1.309352 | Best: Train = 1.104173, Val = 0.205180, Total = 1.309352
Epoch   3/10000: Train Loss = 0.523071, Val Loss = 0.182905, Total Loss = 0.705976 | Best: Train = 0.523071, Val = 0.182905, Total = 0.705976
Epoch   4/10000: Train Loss = 0.326244, Val Loss = 0.013367, Total Loss = 0.339611 | Best: Train = 0.326244, Val = 0.013367, Total = 0.339611
Epoch   5/10000: Train Loss = 0.242681, Val Loss = 0.044002, Total Loss = 0.286683 | Best: Train = 0.242681, Val = 0.044002, Total = 0.286683
Epoch   6/10000: Train Loss = 0.206244, Val Loss = 0.021210, Total Loss = 0.227454 | Best: Train = 0.206244, Val = 0.021210, Total = 0.227454
Epoch   7/10000: Train Loss = 0.170558, Val Loss = 0.014499, Total Loss = 0.185058 | Best: Train = 0.170558, Val = 0.014499, Total = 0.185058
Ep

KeyboardInterrupt: 

In [81]:
# best model 저장
import torch

best_model_path = "bmed_batch_best_model.pth"
torch.save(model.state_dict(), best_model_path)
print(f"✅ Best model이 '{best_model_path}' 파일에 저장되었습니다.")

✅ Best model이 'bmed_batch_best_model.pth' 파일에 저장되었습니다.
