In [27]:
import numpy as np
import pandas as pd
import os
import shutil
import logging 
from sklearn.model_selection import ShuffleSplit

from data_process import get_dataset_splits
from utils.evaluation_utils import load_data_from_file, write_results_to_file

In [28]:
import tensorflow as tf
print(tf.__version__)

2.13.1


In [29]:
from tensorflow.keras import *
#打印时间分割线
@tf.function
def printbar():
    ts = tf.timestamp()
    today_ts = ts%(24*60*60)

    hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24)
    minite = tf.cast((today_ts%3600)//60,tf.int32)
    second = tf.cast(tf.floor(today_ts%60),tf.int32)
    
    def timeformat(m):
        if tf.strings.length(tf.strings.format("{}",m))==1:
            return(tf.strings.format("0{}",m))
        else:
            return(tf.strings.format("{}",m))
    
    timestring = tf.strings.join([timeformat(hour),timeformat(minite),
                timeformat(second)],separator = ":")
    tf.print("=========="*8,end = "")
    tf.print(timestring)

将所有代码转化为单线程，而非调用函数

# Data processing

In [30]:
dataset = load_data_from_file("../fullfeature_fillmean_1000.txt")
# 数据类型转换
for key in dataset.keys():
    if key!='sequence_length':
        dataset[key] = dataset[key].astype(np.float32)
    print(key)
    print(dataset[key].shape)
    print(dataset[key].dtype)

previous_covariates
(3000, 160, 25)
float32
previous_treatments
(3000, 160, 3)
float32
covariates
(3000, 161, 25)
float32
treatments
(3000, 161, 3)
float32
sequence_length
(3000,)
int64
outcomes
(3000, 161, 1)
float32


In [31]:
shuffle_split = ShuffleSplit(n_splits=1, test_size=0.1, random_state=10)
train_index, test_index = next(shuffle_split.split(dataset['covariates'][:, :, 0]))
shuffle_split = ShuffleSplit(n_splits=1, test_size=0.11, random_state=10)
train_index, val_index = next(shuffle_split.split(dataset['covariates'][train_index, :, 0]))
dataset_map = get_dataset_splits(dataset, train_index, val_index, test_index, use_predicted_confounders=False)

# RMSN port

train_rmsn → rnn_fit → train (RNNModel-->training:get_training_graph-->validation:get_prediction_graph) → model_rnn

## train_rmsn

In [32]:
# def train_rmsn(dataset_map, model_name, b_use_predicted_confounders):###########################
# model_name = 'tf2_try'
# MODEL_ROOT = os.path.join('results', model_name)
MODEL_ROOT = 'results/rmsn_result_test_use_confounders_False'

# if not os.path.exists(MODEL_ROOT):
#     os.mkdir(MODEL_ROOT)
#     print("Directory ", MODEL_ROOT, " Created ")
# else:
#     # Need to delete previously saved model.
#     shutil.rmtree(MODEL_ROOT)
#     os.mkdir(MODEL_ROOT)
#     print("Directory ", MODEL_ROOT, " Created ")

# rnn_fit参数设置
networks_to_train='propensity_networks'
# networks_to_train='encoder'
b_use_predicted_confounders=False

    # rnn_fit(dataset_map=dataset_map, networks_to_train='propensity_networks', MODEL_ROOT=MODEL_ROOT,
    #         b_use_predicted_confounders=b_use_predicted_confounders)

#     propensity_generation(dataset_map=dataset_map, MODEL_ROOT=MODEL_ROOT,
#                           b_use_predicted_confounders=b_use_predicted_confounders)

#     rnn_fit(networks_to_train='encoder', dataset_map=dataset_map, MODEL_ROOT=MODEL_ROOT,
#             b_use_predicted_confounders=b_use_predicted_confounders)

#     rmsn_mse = rnn_test(dataset_map=dataset_map, MODEL_ROOT=MODEL_ROOT,
#                         b_use_predicted_confounders=b_use_predicted_confounders)

#     rmse = np.sqrt(np.mean(rmsn_mse)) * 100
    # return rmse


## rnn_fit

### 1.1 基础参数设置

In [33]:
specifications = {
     'rnn_propensity_weighted': (0.1, 4, 100, 64, 0.01, 1.0),
     'treatment_rnn_action_inputs_only': (0.1, 3, 100, 128, 0.01, 2.0),
     'treatment_rnn': (0.1, 4, 100, 64, 0.01, 1.0),}
# #####################################################################################
# def rnn_fit(dataset_map, networks_to_train, MODEL_ROOT, b_use_predicted_confounders,
#             b_use_oracle_confounders=False, b_remove_x1=False):
    
# Get the correct networks to train
if networks_to_train == "propensity_networks":
    logging.info("Training propensity networks")
    # net_names = ['treatment_rnn_action_inputs_only']
    net_names = ['treatment_rnn']

elif networks_to_train == "encoder":
    logging.info("Training R-MSN encoder")
    net_names = ["rnn_propensity_weighted"]

elif networks_to_train == "user_defined":
    logging.info("Training user defined network")
    raise NotImplementedError("Specify network to use!")

else:
    raise ValueError("Unrecognised network type")

    logging.info("Running hyperparameter optimisation")

# Experiment name
expt_name = "treatment_effects"

# Possible networks to use along with their activation functions
# change hidden layer of rnn_propensity_weighted to tanh
activation_map = {'rnn_propensity_weighted': ("tanh", 'linear'),
                  'rnn_propensity_weighted_logistic': ("elu", 'linear'),
                  'rnn_model': ("elu", 'linear'),
                  'treatment_rnn': ("tanh", 'sigmoid'),
                  'treatment_rnn_action_inputs_only': ("tanh", 'sigmoid')
                  }

    
# Start Running hyperparam opt
opt_params = {}
# for net_name in net_names:
net_name = net_names[0]
# Re-run hyperparameter optimisation if parameters are not specified, otherwise train with defined params（如果需要超参优化则跑3次，否则跑一次就行）
max_hyperparam_runs = 3 if net_name not in specifications else 1

# Pull datasets
b_predict_actions = "treatment_rnn" in net_name
use_truncated_bptt = net_name != "rnn_model_bptt" # whether to train with truncated backpropagation through time
b_propensity_weight = "rnn_propensity_weighted" in net_name
b_use_actions_only = "rnn_action_inputs_only" in net_name

### 1.2 gpu设置

In [34]:
# Setup tensorflow
# 检测 GPU 设备
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        # set TensorFlow to use all GPU
        tf.config.set_visible_devices(gpus, 'GPU')
        for gpu in gpus:
            # set GPU memery growth
            tf.config.experimental.set_memory_growth(gpu, True)
        print("Using GPU with memory growth")
    except RuntimeError as e:
        # Changing device settings after the program is running may cause errors
        print(e)
else:
    # if no GPU，using CPU
    print("No GPU found, using CPU")

Using GPU with memory growth


In [9]:
gpus

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'),
 PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')]

### 1.3 模型参数配置

In [35]:
# Start hyperparamter optimisation
hyperparam_count = 0
# 删掉超参数优化（随机选取超参数）部分
spec = specifications[net_name]
logging.info("Using specifications for {}: {}".format(net_name, spec))
dropout_rate = spec[0]
memory_multiplier = spec[1]
num_epochs = spec[2]
minibatch_size = spec[3]
learning_rate = spec[4]
max_norm = spec[5]
hidden_activation, output_activation = activation_map[net_name]

model_folder = os.path.join(MODEL_ROOT, net_name)
            
# hyperparam_opt = train(net_name, expt_name,
#                       training_processed, validation_processed, test_processed,
#                       dropout_rate, memory_multiplier, num_epochs,
#                       minibatch_size, learning_rate, max_norm,
#                       use_truncated_bptt,
#                       num_features, num_outputs, model_folder,
#                       hidden_activation, output_activation,
#                       config,
#                       "hyperparam opt: {} of {}".format(hyperparam_count,
#                                                         max_hyperparam_runs))

#     hyperparam_count = len(hyperparam_opt.columns)
#     if hyperparam_count >= max_hyperparam_runs:
#         opt_params[net_name] = hyperparam_opt.T
#         break

# logging.info("Done")
# logging.info(hyperparam_opt.T)

# # Flag optimal params
# logging.info(opt_params)

#### 数据处理函数

In [36]:
def get_processed_data(raw_sim_data,
                       b_predict_actions,
                       b_use_actions_only,
                       b_use_predicted_confounders,
                       b_use_oracle_confounders,
                       b_remove_x1,
                       keep_first_point=False):
    """
    Create formatted data to train both propensity networks and seq2seq architecture

    :param raw_sim_data: Data from simulation
    :param scaling_params: means/standard deviations to normalise the data to
    :param b_predict_actions: flag to package data for propensity network to forecast actions
    :param b_use_actions_only:  flag to package data with only action inputs and not covariates
    :param b_predict_censoring: flag to package data to predict censoring locations
    :return: processed data to train specific network
    """
    horizon = 1
    offset = 1

    # Continuous values

    # Binary application
    treatments = raw_sim_data['treatments']
    covariates = raw_sim_data['covariates']
    dataset_outputs = raw_sim_data['outcomes']
    sequence_lengths = raw_sim_data['sequence_length']
    
    if b_use_predicted_confounders:
        predicted_confounders = raw_sim_data['predicted_confounders']

    if b_use_oracle_confounders:
        predicted_confounders = raw_sim_data['confounders']

    num_treatments = treatments.shape[-1]

    # Parcelling INPUTS
    if b_predict_actions:
        if b_use_actions_only:
            inputs = treatments
            inputs = inputs[:, :-offset, :]

            actions = inputs.copy()

        else:
            # Uses current covariate, to remove confounding effects between action and current value
            if (b_use_predicted_confounders):
                print ("Using predicted confounders")
                inputs = np.concatenate([covariates[:, 1:, ], predicted_confounders[:, 1:, ], treatments[:, :-1, ]],
                                        axis=2)
            else:
                inputs = np.concatenate([covariates[:, 1:,], treatments[:, :-1, ]], axis=2)

            actions = inputs[:, :, -num_treatments:].copy()


    else:
        if (b_use_predicted_confounders):
            inputs = np.concatenate([covariates, predicted_confounders, treatments], axis=2)
        else:
            inputs = np.concatenate([covariates, treatments], axis=2)
        
        if not keep_first_point:
            inputs = inputs[:, 1:, :]

        actions = inputs[:, :, -num_treatments:].copy()


    # Parcelling OUTPUTS
    if b_predict_actions:
        outputs = treatments
        outputs = outputs[:, 1:, :]

    else:
        if keep_first_point:
            outputs = dataset_outputs
        else:
            outputs = dataset_outputs[:, 1:, :]


    # Set array alignment
    sequence_lengths = np.array([i - 1 for i in sequence_lengths]) # everything shortens by 1

    # Remove any trajectories that are too short
    inputs = inputs[sequence_lengths > 0, :, :]
    outputs = outputs[sequence_lengths > 0, :, :]
    sequence_lengths = sequence_lengths[sequence_lengths > 0]
    actions = actions[sequence_lengths > 0, :, :]

    # Add active entires
    active_entries = np.zeros(outputs.shape, dtype=np.float32)

    for i in range(sequence_lengths.shape[0]):
        sequence_length = int(sequence_lengths[i])

        if not b_predict_actions:
            for k in range(horizon):
                #include the censoring point too, but ignore future shifts that don't exist
                active_entries[i, :sequence_length-k, k] = 1
        else:
            active_entries[i, :sequence_length, :] = 1

    return {'outputs': outputs,  # already scaled
            'scaled_inputs': inputs,
            'scaled_outputs': outputs,
            'actions': actions,
            'sequence_lengths': sequence_lengths,
            'active_entries': active_entries
            }


def convert_to_tf_dataset(dataset_map, minibatch_size):
    key_map = {'inputs': dataset_map['scaled_inputs'],
               'outputs': dataset_map['scaled_outputs'],
               'active_entries': dataset_map['active_entries'],
               'sequence_lengths': dataset_map['sequence_lengths']}

    if 'propensity_weights' in dataset_map:
        key_map['propensity_weights'] = dataset_map['propensity_weights']

    if 'initial_states' in dataset_map:
        key_map['initial_states'] = dataset_map['initial_states']

    #from_tensor_slices:切片; shuffle:随机打乱; batch:批次组合; prefetch:提前准备（预取）数据
    tf_dataset = tf.data.Dataset.from_tensor_slices(key_map)\
                .shuffle(buffer_size = 1000).batch(minibatch_size) \
                .prefetch(tf.data.experimental.AUTOTUNE)

    return tf_dataset

### 1.4 数据准备

In [37]:
training_data = dataset_map['training_data']
validation_data = dataset_map['validation_data']
test_data = dataset_map['test_data']

 # Extract only relevant trajs and shift data
b_use_oracle_confounders = False; b_remove_x1 = False
training_processed = get_processed_data(training_data, b_predict_actions,
                                             b_use_actions_only, b_use_predicted_confounders,
                                             b_use_oracle_confounders, b_remove_x1)
validation_processed = get_processed_data(validation_data, b_predict_actions,
                                               b_use_actions_only, b_use_predicted_confounders,
                                               b_use_oracle_confounders, b_remove_x1)
test_processed = get_processed_data(test_data, b_predict_actions,
                                         b_use_actions_only, b_use_predicted_confounders,
                                         b_use_oracle_confounders, b_remove_x1)

num_features = training_processed['scaled_inputs'].shape[-1]
num_outputs = training_processed['scaled_outputs'].shape[-1]



In [38]:
# Load propensity weights if they exist
if b_propensity_weight:
    if net_name == 'rnn_propensity_weighted_den_only':
        # use un-stabilised IPTWs generated by propensity networks
        propensity_weights = np.load(os.path.join(MODEL_ROOT, "propensity_scores_den_only.npy"))
    elif net_name == "rnn_propensity_weighted_logistic":
        # Use logistic regression weights
        propensity_weights = np.load(os.path.join(MODEL_ROOT, "propensity_scores.npy"))
        tmp = np.load(os.path.join(MODEL_ROOT, "propensity_scores_logistic.npy"))
        propensity_weights = tmp[:propensity_weights.shape[0], :, :]
    else:
        # use stabilised IPTWs generated by propensity networks
        propensity_weights = np.load(os.path.join(MODEL_ROOT, "propensity_scores.npy"))

    logging.info("Net name = {}. Mean-adjusting!".format(net_name))

    propensity_weights /= propensity_weights.mean()

    training_processed['propensity_weights'] = np.array(propensity_weights, dtype='float32')
    

In [39]:
# transform to tensorflow format
tf_data_train = convert_to_tf_dataset(training_processed, minibatch_size)
tf_data_valid = convert_to_tf_dataset(validation_processed, minibatch_size)
tf_data_test = convert_to_tf_dataset(test_processed, minibatch_size)

In [19]:
tf_data_train

<_PrefetchDataset element_spec={'inputs': TensorSpec(shape=(None, 160, 28), dtype=tf.float32, name=None), 'outputs': TensorSpec(shape=(None, 160, 3), dtype=tf.float32, name=None), 'active_entries': TensorSpec(shape=(None, 160, 3), dtype=tf.float32, name=None), 'sequence_lengths': TensorSpec(shape=(None,), dtype=tf.int64, name=None)}>

## core routine

### def_train修改
1. 去除session会话
2. 各步骤拆分与重构

#### 2.1 参数指定

In [None]:
# def train(net_name,
#           expt_name,
#           training_dataset, validation_dataset, test_dataset,
#           dropout_rate,
#           memory_multiplier,
#           num_epochs,
#           minibatch_size,
#           learning_rate,
#           max_norm,
#           use_truncated_bptt,
#           num_features,
#           num_outputs,
#           model_folder,
#           hidden_activation,
#           output_activation,
#           tf_config,
#           additonal_info="",
#           b_use_state_initialisation=False,
#           b_use_seq2seq_feedback=False,
#           b_use_seq2seq_training_mode=False,
#           adapter_multiplier=0,
#           b_use_memory_adapter=False,
#           verbose=True):

In [40]:
min_epochs = 1
hidden_layer_size = int(memory_multiplier * num_features)

b_use_state_initialisation = False
if b_use_state_initialisation:
    full_state_size = int(training_dataset['initial_states'].shape[-1])
    adapter_size = adapter_multiplier * full_state_size
else:
    adapter_size = 0
    
model_parameters = {'net_name': net_name,
                    'experiment_name': expt_name,
                    'training_dataset': tf_data_train,
                    'validation_dataset': tf_data_valid,
                    'test_dataset': tf_data_test,
                    'dropout_rate': dropout_rate,
                    'input_size': num_features,
                    'output_size': num_outputs,
                    'hidden_layer_size': hidden_layer_size,
                    'num_epochs': 10, # for test
                    'minibatch_size': minibatch_size,
                    'learning_rate': learning_rate,
                    'max_norm': max_norm,
                    'model_folder': model_folder,
                    'hidden_activation': hidden_activation,
                    'output_activation': output_activation,
                    'backprop_length': 60,  # backprop over 60 timesteps for truncated backpropagation through time
                    'softmax_size': 0, #not used in this paper, but allows for categorical actions
                    'performance_metric': 'xentropy' if output_activation == 'sigmoid' else 'mse'}

#### 2.2 定义模型

In [41]:
def create_model(params):
    
    # Data params
    #training_data = None if 'training_dataset' not in params else params['training_dataset']
    #validation_data = None if 'validation_dataset' not in params else params['validation_dataset']
    #test_data = None if 'test_dataset' not in params else params['test_dataset']
    input_size = params['input_size']
    output_size = params['output_size']

    # Network params
    net_name = params['net_name']
    softmax_size = params['softmax_size']
    dropout_rate = params['dropout_rate']
    hidden_layer_size = params['hidden_layer_size']
    memory_activation_type = params['hidden_activation']
    output_activation_type = params['output_activation']
    #initial_states = None
    # input layer
    inputs = layers.Input(shape=(None,input_size), dtype=tf.float32)
    # define initial states 
    initial_h =layers.Input(shape=(hidden_layer_size,), dtype=tf.float32, name='initial_h')
    initial_c =layers.Input(shape=(hidden_layer_size,), dtype=tf.float32, name='initial_c')

    # LSTM layer
    lstm, state_h, state_c = layers.LSTM(hidden_layer_size, activation=memory_activation_type, 
                       return_sequences=True, return_state=True, dropout=dropout_rate)(inputs, initial_state=[initial_h, initial_c])

    # flattened_lstm = layers.Flatten()(lstm)

    # Seq2Seq(if need)
    use_seq2seq_feedback = False
    if use_seq2seq_feedback:
        logits = lstm
    else:
        # linear output layer
        logits = layers.Dense(output_size)(lstm)

    # Softmax
    if softmax_size != 0:
        logits_reshaped = layers.Reshape((-1, output_size))(logits)
        core_outputs, softmax_outputs = tf.split(logits_reshaped, [output_size - softmax_size, softmax_size], axis=-1)
        core_activated = layers.Activation(output_activation_type)(core_outputs)
        softmax_activated = layers.Softmax(axis=-1)(softmax_outputs)
        outputs = layers.Concatenate(axis=-1)([core_activated, softmax_activated])
    else:
        outputs = layers.Activation(output_activation_type)(logits)

    # construct model
    model = models.Model(inputs=[inputs, initial_h, initial_c], outputs=[outputs, state_h, state_c], name=net_name)
    return model

In [42]:
# 创建模型
tf.keras.backend.clear_session()
model = create_model(model_parameters)

model.summary()

Model: "treatment_rnn"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, None, 28)]           0         []                            
                                                                                                  
 initial_h (InputLayer)      [(None, 112)]                0         []                            
                                                                                                  
 initial_c (InputLayer)      [(None, 112)]                0         []                            
                                                                                                  
 lstm (LSTM)                 [(None, None, 112),          63168     ['input_1[0][0]',             
                              (None, 112),                           'initial_h[0][0]'

In [45]:
model_parameters['input_size']

28

In [None]:
    """
    Common training routine to all RNN models_without_confounders - seq2seq + standard
    """

    min_epochs = 1
    # Setup default hidden layer size
    hidden_layer_size = int(memory_multiplier * num_features)

    if b_use_state_initialisation:

        full_state_size = int(training_dataset['initial_states'].shape[-1])

        adapter_size = adapter_multiplier * full_state_size

    else:
        adapter_size = 0

        # Training simulation
    model_parameters = {'net_name': net_name,
                        'experiment_name': expt_name,
                        'training_dataset': tf_data_train,
                        'validation_dataset': tf_data_valid,
                        'test_dataset': tf_data_test,
                        'dropout_rate': dropout_rate,
                        'input_size': num_features,
                        'output_size': num_outputs,
                        'hidden_layer_size': hidden_layer_size,
                        'num_epochs': num_epochs,
                        'minibatch_size': minibatch_size,
                        'learning_rate': learning_rate,
                        'max_norm': max_norm,
                        'model_folder': model_folder,
                        'hidden_activation': hidden_activation,
                        'output_activation': output_activation,
                        'backprop_length': 60,  # backprop over 60 timesteps for truncated backpropagation through time
                        'softmax_size': 0, #not used in this paper, but allows for categorical actions
                        'performance_metric': 'xentropy' if output_activation == 'sigmoid' else 'mse',
                        'use_seq2seq_feedback': b_use_seq2seq_feedback,
                        'use_seq2seq_training_mode': b_use_seq2seq_training_mode,
                        'use_memory_adapter': b_use_memory_adapter,
                        'memory_adapter_size': adapter_size}

    # Get the right model
    model = RnnModel(model_parameters)
    serialisation_name = model.serialisation_name

    if helpers.hyperparameter_result_exists(model_folder, net_name, serialisation_name):
        logging.warning("Combination found: skipping {}".format(serialisation_name))
        return helpers.load_hyperparameter_results(model_folder, net_name)

    training_handles = model.get_training_graph(use_truncated_bptt=use_truncated_bptt,
                                                b_use_state_initialisation=b_use_state_initialisation)
    validation_handles = model.get_prediction_graph(use_validation_set=True, with_dropout=False,
                                                    b_use_state_initialisation=b_use_state_initialisation)

    # Start optimising
    num_minibatches = int(np.ceil(training_dataset['scaled_inputs'].shape[0] / model_parameters['minibatch_size']))

    i = 1
    epoch_count = 1
    step_count = 1
    min_loss = np.inf
    with sess.as_default():

        sess.run(tf.global_variables_initializer())

        optimisation_summary = pd.Series([])

        while True:
            #for step_count in tqdm(range(num_minibatches), desc=f"Epoch {epoch_count}"):
            try:
                # loss, _ = sess.run([training_handles['loss'],
                #                     training_handles['optimiser']])
                loss, _, numerator = sess.run([training_handles['loss'],
                                    training_handles['optimiser'],
                                    training_handles['numerator']])

                # rango added - tensorflow debugger 2023.10.20###################
                # if sess.should_stop():
                #     break  # NaN or Inf occurred.
                # ###############################################################

                # Flog output
                if (verbose == True):
                    logging.info("Epoch {} | iteration = {} of {}, loss = {} | loss_numerator = {} | net = {} | info = {}".format(
                        epoch_count,
                        step_count,
                        num_minibatches,
                        loss,
                        numerator,
                        model.net_name,
                        additonal_info))

                if step_count == num_minibatches:

                    # Reinit datasets
                    sess.run(validation_handles['initializer'])

                    means = []
                    UBs = []
                    LBs = []
                    while True:
                        try:
                            mean, upper_bound, lower_bound = sess.run([validation_handles['mean'],
                                                                       validation_handles['upper_bound'],
                                                                       validation_handles['lower_bound']])

                            means.append(mean)
                            UBs.append(upper_bound)
                            LBs.append(lower_bound)
                        except tf.errors.OutOfRangeError:
                            break

                    means = np.concatenate(means, axis=0)

                    """
                    means = np.concatenate(means, axis=0)*training_dataset['output_stds'] \
                            + training_dataset['output_means']
                    UBs = np.concatenate(UBs, axis=0)*training_dataset['output_stds'] \
                          + training_dataset['output_means']
                    LBs = np.concatenate(LBs, axis=0)*training_dataset['output_stds'] \
                          + training_dataset['output_means']
                    """


                    active_entries = validation_dataset['active_entries']
                    output = validation_dataset['outputs']

                    if model_parameters['performance_metric'] == "mse":
                        validation_loss = np.sum((means - output)**2 * active_entries) / np.sum(active_entries)
                        #logging.info("Epoch {} Detection| Means= {} | Output = {}".format(epoch_count, means, output))

                    elif model_parameters['performance_metric'] == "xentropy":
                        _, _,features_size = output.shape
                        partition_idx = features_size

                        # Do binary first
                        validation_loss = np.sum((output[:, :, :partition_idx] * -np.log(means[:, :, :partition_idx] + 1e-8)
                                                 + (1 - output[:, :, :partition_idx]) * -np.log(1 - means[:, :, :partition_idx] + 1e-8))
                                                 * active_entries[:, :, :partition_idx]) \
                                          / np.sum(active_entries[:, :, :partition_idx])

                    optimisation_summary[epoch_count] = validation_loss

                    # Compute validation loss
                    if (verbose == True):
                        logging.info("Epoch {} Summary| Validation loss = {} | net = {} | info = {}".format(
                            epoch_count,
                            validation_loss,
                            model.net_name,
                            additonal_info))

                    if np.isnan(validation_loss):
                        logging.warning("NAN Loss found, terminating routine")
                        break

                    # Save model and loss trajectories
                    if validation_loss < min_loss and epoch_count > min_epochs:
                        cp_name = serialisation_name + "_optimal"
                        helpers.save_network(sess, model_folder, cp_name, optimisation_summary)
                        min_loss = validation_loss

                    # Update
                    epoch_count += 1
                    step_count = 0

                step_count += 1
                i += 1

            except tf.errors.OutOfRangeError:
                break

        # Save final
        cp_name = serialisation_name + "_final"
        helpers.save_network(sess, model_folder, cp_name, optimisation_summary)
        helpers.add_hyperparameter_results(optimisation_summary, model_folder, net_name, serialisation_name)

        hyperparam_df = helpers.load_hyperparameter_results(model_folder, net_name)

        logging.info("Terminated at iteration {}".format(i))

    return hyperparam_df

### 2.3 训练模型

#### 自定义损失函数

In [21]:
class CustomLoss(losses.Loss):
    def __init__(self, params, name="custom_loss"):
        super().__init__(name=name)
        self.performance_metric = params['performance_metric']
        # self.weights = params['weights']
        # self.active_entries = params['active_entries']

    def train_call(self, y_true, y_pred, active_entries, weights):
        if self.performance_metric == "mse":
            loss = tf.reduce_sum(tf.square(y_true - y_pred) * active_entries * weights) \
                   / tf.reduce_sum(active_entries)
        elif self.performance_metric == "xentropy":
            loss = tf.reduce_sum((y_true * -tf.math.log(y_pred + 1e-8) +
                                  (1 - y_true) * -tf.math.log(1 - y_pred + 1e-8))
                                  * active_entries * weights) / tf.reduce_sum(active_entries)
        else:
            raise ValueError("Unknown performance metric {}".format(self.performance_metric))
        return loss
    
    def valid_call(self, y_true, y_pred, active_entries):
        if self.performance_metric == "mse":
            loss = tf.reduce_sum(tf.square(y_true - y_pred) * active_entries ) \
                   / tf.reduce_sum(active_entries)
        elif self.performance_metric == "xentropy":
            loss = tf.reduce_sum((y_true * -tf.math.log(y_pred + 1e-8) +
                                  (1 - y_true) * -tf.math.log(1 - y_pred + 1e-8))
                                  * active_entries) / tf.reduce_sum(active_entries)
        else:
            raise ValueError("Unknown performance metric {}".format(self.performance_metric))
        return loss
    
    def get_config(self):
        config = super().get_config()
        config.update({"performance_metric": self.performance_metric})
        return config

#### 训练步骤和验证步骤

In [32]:
optimizer = optimizers.Adam(learning_rate=model_parameters['learning_rate'])
loss_func = CustomLoss(model_parameters)

train_loss = metrics.Mean(name='train_loss')
train_metric = metrics.MeanSquaredError(name='train_mse')

valid_loss = metrics.Mean(name='valid_loss')
valid_metric = metrics.MeanSquaredError(name='valid_mse')

@tf.function
def train_step(model, inputs, outputs, active_entries, weights):
    if weights is None:
        weights = tf.constant(1.0)
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        loss = loss_func.train_call(outputs, predictions, active_entries, weights)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss.update_state(loss)
    train_metric.update_state(outputs, predictions)

@tf.function
def valid_step(model, inputs, outputs, active_entries):
    predictions = model(inputs, training=False)
    loss = loss_func.valid_call(outputs, predictions, active_entries)
    valid_loss.update_state(loss)
    valid_metric.update_state(outputs, predictions)

#### 训练模型函数

In [33]:
def train_model(model, ds_train, ds_valid, epochs):
    # optimisation_summary = pd.Series([])
    for epoch in tf.range(1, epochs+1):
        
        for data in ds_train:
            weights = data['propensity_weights'] if 'propensity_weights' in data else None
            train_step(model, data['inputs'], data['outputs'], data['active_entries'], weights)

        for data in ds_valid:
            weights = data['propensity_weights'] if 'propensity_weights' in data else None
            valid_step(model, data['inputs'], data['outputs'], data['active_entries'])
        # optimisation_summary[epoch] = valid_loss

        # 同样的日志和状态重置操作
        logs = 'Epoch={},Loss:{},Accuracy:{},Valid Loss:{},Valid Accuracy:{}'
        
        if epoch%1 ==0:
            printbar()
            tf.print(tf.strings.format(logs,
            (epoch,train_loss.result(),train_metric.result(),valid_loss.result(),valid_metric.result())))
            tf.print("")
            
        train_loss.reset_states()
        valid_loss.reset_states()
        train_metric.reset_states()
        valid_metric.reset_states()

In [38]:
class TrainModule(tf.Module):
    def __init__(self, params, name=None):
        super(TrainModule, self).__init__(name=name)
        with self.name_scope:  #相当于with tf.name_scope("demo_module")
            self.epochs = params['num_epochs']
            self.ds_train = params['training_dataset']
            self.ds_valid = params['validation_dataset']
            self.ds_test = params['test_dataset']
            self.optimizer = optimizers.Adam(learning_rate=params['learning_rate'])
            self.loss_func = CustomLoss(params)

            self.train_loss = metrics.Mean(name='train_loss')
            self.train_metric = metrics.MeanSquaredError(name='train_mse')

            self.valid_loss = metrics.Mean(name='valid_loss')
            self.valid_metric = metrics.MeanSquaredError(name='valid_mse')
    
    @tf.function
    def train_step(self, model, inputs, outputs, active_entries, weights):
        if weights is None:
            weights = tf.constant(1.0)
        with tf.GradientTape() as tape:
            predictions = model(inputs, training=True)
            loss = self.loss_func.train_call(outputs, predictions, active_entries, weights)
        gradients = tape.gradient(loss, model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        self.train_loss.update_state(loss)
        self.train_metric.update_state(outputs, predictions)

    @tf.function
    def valid_step(self, model, inputs, outputs, active_entries):
        predictions = model(inputs, training=False)
        loss = self.loss_func.valid_call(outputs, predictions, active_entries)
        self.valid_loss.update_state(loss)
        self.valid_metric.update_state(outputs, predictions)
    
    def train_model(self, model):
        # optimisation_summary = pd.Series([])
        for epoch in tf.range(1, self.epochs+1):

            for data in self.ds_train:
                weights = data['propensity_weights'] if 'propensity_weights' in data else None
                self.train_step(model, data['inputs'], data['outputs'], data['active_entries'], weights)

            for data in self.ds_valid:
                weights = data['propensity_weights'] if 'propensity_weights' in data else None
                self.valid_step(model, data['inputs'], data['outputs'], data['active_entries'])
            # optimisation_summary[epoch] = valid_loss

            # 同样的日志和状态重置操作
            logs = 'Epoch={},Loss:{},Accuracy:{},Valid Loss:{},Valid Accuracy:{}'

            if epoch%1 ==0:
                printbar()
                tf.print(tf.strings.format(logs,
                (epoch,self.train_loss.result(),self.train_metric.result(),self.valid_loss.result(),self.valid_metric.result())))
                tf.print("")

            self.train_loss.reset_states()
            self.valid_loss.reset_states()
            self.train_metric.reset_states()
            self.valid_metric.reset_states()

In [39]:
Train = TrainModule(model_parameters)
Train.train_model(model)

Epoch=1,Loss:0.499527872,Accuracy:0.550320148,Valid Loss:0.516167939,Valid Accuracy:0.517898262

Epoch=2,Loss:0.475318611,Accuracy:0.527793229,Valid Loss:0.506939232,Valid Accuracy:0.508185446

Epoch=3,Loss:0.465685338,Accuracy:0.526379406,Valid Loss:0.510615706,Valid Accuracy:0.511348367

Epoch=4,Loss:0.45916155,Accuracy:0.529349089,Valid Loss:0.516534448,Valid Accuracy:0.510599077

Epoch=5,Loss:0.454828203,Accuracy:0.528462529,Valid Loss:0.516691566,Valid Accuracy:0.520721555

Epoch=6,Loss:0.482320428,Accuracy:0.559840798,Valid Loss:0.525589466,Valid Accuracy:0.525901377

Epoch=7,Loss:0.468005776,Accuracy:0.529510081,Valid Loss:0.542681396,Valid Accuracy:0.534254

Epoch=8,Loss:0.457480669,Accuracy:0.533512414,Valid Loss:0.51867938,Valid Accuracy:0.512872338

Epoch=9,Loss:0.450703472,Accuracy:0.534296751,Valid Loss:0.513790131,Valid Accuracy:0.515044034

Epoch=10,Loss:0.443599194,Accuracy:0.533094049,Valid Loss:0.527673125,Valid Accuracy:0.528042316



In [32]:
net_names = ['model1','model2']
for name in net_names:
    tf.keras.backend.clear_session()
    model = create_model(model_parameters, name)
    # train_model(model, tf_data_train, tf_data_valid, 10)
    Train = TrainModule(model_parameters)
    Train.train_model(model)

Epoch=1,Loss:-1.04218924,Accuracy:0.315503478,Valid Loss:-1.6271013,Valid Accuracy:0.256319433

Epoch=2,Loss:-1.8213017,Accuracy:0.25058189,Valid Loss:-1.92948687,Valid Accuracy:0.243127644

Epoch=3,Loss:-1.99545646,Accuracy:0.245146066,Valid Loss:-2.00789499,Valid Accuracy:0.237991899

Epoch=4,Loss:-2.0445447,Accuracy:0.242646903,Valid Loss:-2.07682967,Valid Accuracy:0.236074388

Epoch=5,Loss:-2.07054353,Accuracy:0.241571859,Valid Loss:-2.07036352,Valid Accuracy:0.23433511

Epoch=6,Loss:-2.08658314,Accuracy:0.24051398,Valid Loss:-2.02451658,Valid Accuracy:0.234643

Epoch=7,Loss:-2.08822966,Accuracy:0.239788935,Valid Loss:-2.04530382,Valid Accuracy:0.233269155

Epoch=8,Loss:-2.10269213,Accuracy:0.239260048,Valid Loss:-2.09449148,Valid Accuracy:0.231957987

Epoch=9,Loss:-2.10704803,Accuracy:0.238550946,Valid Loss:-2.07969904,Valid Accuracy:0.231935427

Epoch=10,Loss:-2.12130713,Accuracy:0.238038853,Valid Loss:-2.11661029,Valid Accuracy:0.231317282

Epoch=1,Loss:-1.21397805,Accuracy:0.29

### 2.4 保存和加载模型

In [44]:
path = 'results/tf2_try/savedmodel_try'
model.save(path, save_format = 'tf')
print('export saved model.')





INFO:tensorflow:Assets written to: results/tf2_try/savedmodel_try/assets


INFO:tensorflow:Assets written to: results/tf2_try/savedmodel_try/assets


export saved model.


In [45]:
model_loaded = models.load_model(path, compile=False)

### 2.5 使用模型

In [38]:
# 自定义模型评估函数
def evaluate_model(model, ds_test):
    total_loss = 0
    total_metric = 0
    num_batches = 0

    # 遍历数据集中的每个批次
    for data in ds_test:
        valid_step(model, data['inputs'], data['outputs'], data['active_entries'])
        total_loss += valid_loss.result().numpy()
        total_metric += valid_metric.result().numpy()
        num_batches += 1

        # 重置状态
        valid_loss.reset_states()
        valid_metric.reset_states()

    # 计算整个数据集的平均损失和指标
    avg_loss = total_loss / num_batches
    avg_metric = total_metric / num_batches
    return avg_loss, avg_metric

In [46]:
loss, accuracy = evaluate_model(model, tf_data_test)
print(f"Test Loss: {loss}, Test Accuracy: {accuracy}")
input_data_for_prediction = tf_data_test.map(lambda x: x['inputs'])
predictions = model_loaded.predict(input_data_for_prediction)

Test Loss: -2.254194164276123, Test Accuracy: 0.22184113562107086


In [47]:
predictions.shape

(300, 160, 3)

## Propensitry Generation
def propensity_generation(dataset_map, MODEL_ROOT, b_use_predicted_confounders, b_use_all_data=False,
                          b_use_oracle_confounders=False, b_remove_x1=False):

In [15]:
import rmsn.libs.model_process as model_process

In [16]:
# 参数设置
action_inputs_only = model_process.load_optimal_parameters(net_name='treatment_rnn_action_inputs_only', MODEL_ROOT=MODEL_ROOT)
action_w_trajectory_inputs = model_process.load_optimal_parameters(net_name='treatment_rnn', MODEL_ROOT=MODEL_ROOT)

# Generate propensity weights for validation data as well - used for MSM which is calibrated on train + valid data
b_with_validation = False
# Generate non-stabilised IPTWs (default false)
b_denominator_only = False
b_use_predicted_confounders = True
b_use_all_data=False
b_use_oracle_confounders=False
b_remove_x1=False

# Config + activation functions
activation_map = {'rnn_propensity_weighted': ("elu", 'linear'),
                  'rnn_model': ("elu", 'linear'),
                  'rnn_model_bptt': ("elu", 'linear'),
                  'treatment_rnn': ("tanh", 'sigmoid'),
                  'treatment_rnn_action_inputs_only': ("tanh", 'sigmoid'),
                  'treatment_rnn_softmax': ("tanh", 'sigmoid'),
                  'treatment_rnn_action_inputs_only_softmax': ("tanh", 'sigmoid'),
                  }

configs = {'action_num': action_inputs_only,
           'action_den': action_w_trajectory_inputs}

results/rmsn_result_test_use_confounders_True/treatment_rnn_action_inputs_only/treatment_rnn_action_inputs_only.csv
results/rmsn_result_test_use_confounders_True/treatment_rnn/treatment_rnn.csv


In [17]:
if b_use_all_data:
    training_data = dataset_map
    validation_data = dataset_map
    test_data = None
else:
    training_data = dataset_map['training_data']
    validation_data = dataset_map['validation_data']
    test_data = dataset_map['test_data']
    
if b_with_validation:
    for k in training_data:
        training_data[k] = np.concatenate([training_data[k], validation_data[k]])

In [25]:
# Functions
def get_predictions(config):

    net_name = config[0]
    serialisation_name = config[-1]

    hidden_activation, output_activation = activation_map[net_name]

    # Pull datasets
    b_predict_actions = "treatment_rnn" in net_name
    b_use_actions_only = "rnn_action_inputs_only" in net_name

    # Extract only relevant trajs and shift data
    training_processed = get_processed_data(training_data, b_predict_actions, b_use_actions_only,
                                                 b_use_predicted_confounders, b_use_oracle_confounders, b_remove_x1)
    validation_processed = get_processed_data(validation_data, b_predict_actions,
                                                   b_use_actions_only,
                                                   b_use_predicted_confounders, b_use_oracle_confounders, b_remove_x1)
    # rango added 23.10.24
    # if b_with_test:
    #     test_processed = core.get_processed_data(test_data, b_predict_actions, b_use_actions_only,
    #                                              b_use_predicted_confounders, b_use_oracle_confounders, b_remove_x1)

    num_features = training_processed['scaled_inputs'].shape[-1]  # 4 if not b_use_actions_only else 3
    num_outputs = training_processed['scaled_outputs'].shape[-1]  # 1 if not b_predict_actions else 3  # 5


    # Unpack remaining variables
    dropout_rate = config[1]
    memory_multiplier = config[2] / num_features
    num_epochs = config[3]
    minibatch_size = config[4]
    learning_rate = config[5]
    max_norm = config[6]
    tf_data_train = convert_to_tf_dataset(training_processed, minibatch_size)
    tf_data_valid = convert_to_tf_dataset(validation_processed, minibatch_size)

    model_folder = os.path.join(MODEL_ROOT, net_name)
    model = model_process.load_model(model_folder, serialisation_name)

    # predictition
    outputs = training_processed['scaled_outputs']
    results = model_predict(model, tf_data_train)
    predictions = results['mean_pred']
    #means, outputs, _, _ = test(training_processed, validation_processed, training_processed, tf_config,
    #                            net_name, expt_name, dropout_rate, num_features, num_outputs,
    #                            memory_multiplier, num_epochs, minibatch_size, learning_rate, max_norm,
    #                            hidden_activation, output_activation, model_folder)

    return predictions, outputs

def get_weights(probs, targets):
    w = probs*targets + (1-probs) * (1-targets)
    return w.prod(axis=2)


def get_weights_from_config(config):
    net_name = config[0]

    probs, targets = get_predictions(config)

    return get_weights(probs, targets)

def get_probabilities_from_config(config):
    net_name = config[0]

    probs, targets = get_predictions(config)

    return probs

In [38]:
from tqdm import tqdm
def model_predict(model, dataset, pred_times=100):
    # Initialize lists to store final statistics for all chunks
    all_means = []
    all_upper_bounds = []
    all_lower_bounds = []
    logs = 'Predicting ' + model.name

    for data_chunk in tqdm(dataset, desc=logs):
        chunk_predictions = []

        # Predict the current chunk multiple times
        for _ in range(pred_times):
            prediction = model.predict(data_chunk['inputs'], verbose=0)
            chunk_predictions.append(prediction)

        # Convert list of predictions to a numpy array for statistical computation
        chunk_predictions = np.array(chunk_predictions)

        # Calculate mean, upper bound, and lower bound for the current chunk
        mean_estimate = np.mean(chunk_predictions, axis=0)
        upper_bound = np.percentile(chunk_predictions, 95, axis=0)
        lower_bound = np.percentile(chunk_predictions, 5, axis=0)

        # Append the statistics of the current chunk to their respective lists
        all_means.append(mean_estimate)
        all_upper_bounds.append(upper_bound)
        all_lower_bounds.append(lower_bound)

    # Optional: Convert lists to numpy arrays if further processing is needed
    all_means = np.concatenate(all_means, axis=0) if all_means else np.array([])
    all_upper_bounds = np.concatenate(all_upper_bounds, axis=0) if all_upper_bounds else np.array([])
    all_lower_bounds = np.concatenate(all_lower_bounds, axis=0) if all_lower_bounds else np.array([])

    # At this point, you can either return the raw statistics for each chunk,
    # or aggregate them in some way depending on your application's needs.
    # The following returns the list of statistics for all chunks.
    return {
        'mean_pred': all_means,
        'upper_bound': all_upper_bounds,
        'lower_bound': all_lower_bounds
    }

In [39]:
# Action with trajs
weights = {k: get_weights_from_config(configs[k]) for k in configs}

den = weights['action_den']
num = weights['action_num']

propensity_weights = 1.0/den if b_denominator_only else num/den

Predicting treatment_rnn_action_inputs_only:  12%|█▏        | 18/151 [02:32<18:46,  8.47s/it]


KeyboardInterrupt: 

## Mirrored Strategy Test

### 0. 初始化策略

In [15]:
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: %d' % strategy.num_replicas_in_sync)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')


Number of devices: 2


### 1. 设置输入流水线

In [18]:
# transform to tensorflow format
global_batch_size = minibatch_size * strategy.num_replicas_in_sync
tf_data_train = convert_to_tf_dataset(training_processed, global_batch_size)
dist_tf_data_train = strategy.experimental_distribute_dataset(tf_data_train)

2024-03-05 15:03:47.622023: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:786] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_4"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
input: "Placeholder/_2"
input: "Placeholder/_3"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_FLOAT
      type: DT_FLOAT
      type: DT_INT64
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 2403
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\025TensorSliceDataset:22"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 160
        }
        dim {
          size: 3
        }
      }
      shape {
        dim {
          size: 160
        }
        dim {
          size: 28
       

In [26]:
def inspect_dataset(batch):
    # 这里的内容根据你的数据集结构调整
    # 例如，打印出批次的形状或一些关键数据
    for key, value in batch.items():
        # 打印出每个键对应的值的形状
        print(f"{key}: {value.shape}")

for dist_batch in dist_tf_data_train:
    strategy.run(inspect_dataset, args=(dist_batch,))





inputs: (64, 160, 28)
outputs: (64, 160, 3)
active_entries: (64, 160, 3)
sequence_lengths: (64,)
inputs: (64, 160, 28)
outputs: (64, 160, 3)
active_entries: (64, 160, 3)
sequence_lengths: (64,)




inputs: (64, 160, 28)
outputs: (64, 160, 3)
active_entries: (64, 160, 3)
sequence_lengths: (64,)
inputs: (64, 160, 28)
outputs: (64, 160, 3)
active_entries: (64, 160, 3)
sequence_lengths: (64,)




inputs: (64, 160, 28)
outputs: (64, 160, 3)
active_entries: (64, 160, 3)
sequence_lengths: (64,)
inputs: (64, 160, 28)
outputs: (64, 160, 3)
active_entries: (64, 160, 3)
sequence_lengths: (64,)




inputs: (64, 160, 28)
outputs: (64, 160, 3)
active_entries: (64, 160, 3)
sequence_lengths: (64,)
inputs: (64, 160, 28)
outputs: (64, 160, 3)
active_entries: (64, 160, 3)
sequence_lengths: (64,)




inputs: (64, 160, 28)
outputs: (64, 160, 3)
active_entries: (64, 160, 3)
sequence_lengths: (64,)
inputs: (64, 160, 28)
outputs: (64, 160, 3)
active_entries: (64, 160, 3)
sequence_lengths: (64,)
inputs: (64, 160, 28)
outputs: (64, 160, 3)
active_entries: (64, 160, 3)
sequence_lengths: (64,)
inputs: (64, 160, 28)
outputs: (64, 160, 3)
active_entries: (64, 160, 3)
sequence_lengths: (64,)
inputs: (64, 160, 28)
outputs: (64, 160, 3)
active_entries: (64, 160, 3)
sequence_lengths: (64,)
inputs: (64, 160, 28)
outputs: (64, 160, 3)
active_entries: (64, 160, 3)
sequence_lengths: (64,)
inputs: (64, 160, 28)
outputs: (64, 160, 3)
active_entries: (64, 160, 3)
sequence_lengths: (64,)
inputs: (64, 160, 28)
outputs: (64, 160, 3)
active_entries: (64, 160, 3)
sequence_lengths: (64,)
inputs: (64, 160, 28)
outputs: (64, 160, 3)
active_entries: (64, 160, 3)
sequence_lengths: (64,)
inputs: (64, 160, 28)
outputs: (64, 160, 3)
active_entries: (64, 160, 3)
sequence_lengths: (64,)
inputs: (64, 160, 28)
outputs:

In [20]:
net_name

'treatment_rnn'

### 2. 定义损失函数

In [21]:
with strategy.scope():
    # test loss function ###################################
    mse_loss_object = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
    
    def compute_mse_loss(labels, predictions):
        # 计算每个样本的MSE损失
        per_example_loss = mse_loss_object(labels, predictions)
        # 计算所有样本的平均MSE损失，并根据全局批量大小进行调整
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=global_batch_size)
    
    # custom loss function ##################################
    class CustomLoss(losses.Loss):
        def __init__(self, performance_metric, num_gpus, global_batch_size, name="custom_loss"):
            super().__init__(name=name) #reduction=losses.Reduction.NONE
            self.performance_metric = performance_metric
            self.num_gpus = num_gpus
            self.global_batch_size = global_batch_size
            # self.weights = params['weights']
            # self.active_entries = params['active_entries']

        def train_call(self, y_true, y_pred, active_entries, weights):
            if self.performance_metric == "mse":
                loss = tf.reduce_sum(tf.square(y_true - y_pred) * active_entries * weights) \
                       / tf.reduce_sum(active_entries)
                # per_example_loss = (tf.square(y_true - y_pred) * active_entries * weights) \
                #                     / tf.reduce_sum(active_entries)
            elif self.performance_metric == "xentropy":
                loss = tf.reduce_sum((y_true * -tf.math.log(y_pred + 1e-8) +
                                       (1 - y_true) * -tf.math.log(1 - y_pred + 1e-8))
                                       * active_entries * weights) / tf.reduce_sum(active_entries)
                # per_example_loss = ((y_true * -tf.math.log(y_pred + 1e-8) + \
                #                    (1 - y_true) * -tf.math.log(1 - y_pred + 1e-8)) * active_entries * weights) / tf.reduce_sum(active_entries)

            else:
                raise ValueError("Unknown performance metric {}".format(self.performance_metric))

            # 将总和除以gpu数，获得全局平均损失
            return loss * (1./self.num_gpus)
            # return tf.nn.compute_average_loss(per_example_loss, global_batch_size=self.global_batch_size)

        def valid_call(self, y_true, y_pred):
            if self.performance_metric == "mse":
               #loss = tf.reduce_sum(tf.square(y_true - y_pred) * active_entries ) \
               #        / tf.reduce_sum(active_entries)
                loss = tf.square(y_true - y_pred)

            elif self.performance_metric == "xentropy":
                loss = (y_true * -tf.math.log(y_pred + 1e-8) +
                       (1 - y_true) * -tf.math.log(1 - y_pred + 1e-8))

            else:
                raise ValueError("Unknown performance metric {}".format(self.performance_metric))

            return loss

        def get_config(self):
            config = super().get_config()
            config.update({"performance_metric": self.performance_metric, "global_batch_size": self.global_batch_size})
            return config
    
    # loss_func = CustomLoss(model_parameters['performance_metric'], strategy.num_replicas_in_sync, global_batch_size)

### 3. 定义衡量指标以跟踪损失和准确性

In [22]:
with strategy.scope():
    train_metric = metrics.MeanSquaredError(name='train_mse')
    valid_loss = metrics.Mean(name='valid_loss')
    valid_metric = metrics.MeanSquaredError(name='valid_mse')

### 4. 训练循环

In [25]:
# A model, an optimizer, and a checkpoint must be created under `strategy.scope`.
with strategy.scope():
    model = create_model(model_parameters)
    model.summary()
    optimizer = tf.keras.optimizers.Adam()

Model: "treatment_rnn"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, None, 28)]           0         []                            
                                                                                                  
 initial_h (InputLayer)      [(None, 112)]                0         []                            
                                                                                                  
 initial_c (InputLayer)      [(None, 112)]                0         []                            
                                                                                                  
 lstm (LSTM)                 [(None, None, 112),          63168     ['input_1[0][0]',             
                              (None, 112),                           'initial_h[0][0]'

In [24]:
def train_step(data): #, chunk_sizes
    inputs = data['inputs']
    outputs = data['outputs']
    active_entries = data['active_entries']
    weights = data['propensity_weights'] if 'propensity_weights' in data else tf.constant(1.0)

    with tf.GradientTape() as tape:
       
        batch_size = tf.shape(inputs)[0]
        initial_state = tf.zeros([batch_size, hidden_layer_size], dtype=tf.float32)
        predictions,_,_ = model([inputs,initial_state, initial_state], training=True)
        # Compute loss
        # loss = loss_func.train_call(outputs, predictions, active_entries, weights)
        loss = compute_mse_loss(outputs, predictions)

    gradients = tape.gradient(loss, model.trainable_variables)
    # Clip gradients
    gradients, _ = tf.clip_by_global_norm(gradients, clip_norm = max_norm)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    #self.train_loss.update_state(loss)
    train_metric.update_state(outputs, predictions)

    return loss

In [None]:
@tf.function
def distributed_train_step(data): #, chunk_sizes
    per_replica_losses = strategy.run(train_step, args=(data,)) #, chunk_sizes
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

for epoch in range(num_epochs):
    # TRAIN LOOP
    total_loss = 0.0
    num_batches = 0
    for x in tf_data_train:
        total_loss += distributed_train_step(x)
        num_batches += 1
        train_loss = total_loss / num_batches

    template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
    print(template.format(epoch + 1, train_loss,
                         train_metric.result() * 100, valid_loss.result(),
                         valid_metric.result() * 100))

    valid_loss.reset_states()
    train_metric.reset_states()
    valid_metric.reset_states()

INFO:tensorflow:Collective all_reduce tensors: 5 all_reduces, num_devices = 2, group_size = 2, implementation = CommunicationImplementation.NCCL, num_packs = 1


INFO:tensorflow:Collective all_reduce tensors: 5 all_reduces, num_devices = 2, group_size = 2, implementation = CommunicationImplementation.NCCL, num_packs = 1


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Collective all_reduce tensors: 5 all_reduces, num_devices = 2, group_size = 2, implementation = CommunicationImplementation.NCCL, num_packs = 1


INFO:tensorflow:Collective all_reduce tensors: 5 all_reduces, num_devices = 2, group_size = 2, implementation = CommunicationImplementation.NCCL, num_packs = 1


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
2024-03-05 14:28:27.584663: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:432] Loaded cuDNN version 8600
2024-03-05 14:28:27.610959: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:606] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2024-03-05 14:28:27.621997: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:432] Loaded cuDNN version 8600
