In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from sklearn.preprocessing import StandardScaler
import pickle
import os
import joblib
from tensorflow.keras.models import load_model
import tensorflow as tf
import matplotlib.pyplot as plt
tf.random.set_seed(42)

In [None]:
import tensorflow as tf
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import load_model
from tensorflow.keras.layers import Dense, Concatenate, Reshape
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Lambda

In [None]:
def individual_to_params(individual):
    criterion, splitter, max_depth, min_samples_split, min_samples_leaf, min_weight_fraction_leaf, max_features, max_leaf_nodes, min_impurity_decrease, ccp_alpha = individual
    
    params = {"criterion": criterion, "splitter": splitter, "max_depth": max_depth, "min_samples_split": min_samples_split, "min_samples_leaf": min_samples_leaf, "min_weight_fraction_leaf": min_weight_fraction_leaf, "max_features": max_features, "max_leaf_nodes": max_leaf_nodes, "min_impurity_decrease": min_impurity_decrease, "ccp_alpha": ccp_alpha}
    
    return params

In [None]:
def load_and_preprocess(filepath):
    df = pd.read_csv(filepath, index_col=[0])
    # df=df[['SrcWin','sHops','dHops','sTtl','dTtl','SynAck','SrcBytes','DstBytes','SAppBytes',\
    #                    'Dur','TotPkts','TotBytes','TotAppByte','Rate','SrcRate','Label']]
    #Le = LabelEncoder()
    #df['Label'] = le.fit_transform(df['Label'])
    df=df[['SrcWin', 'sHops', 'sTtl', 'dTtl', 'SrcBytes', 'DstBytes', 'Dur', 'TotBytes', 'Rate','Label']]
    print(df.shape)
    print("loading data")
    X = df.iloc[:,:-1]
    y = df.iloc[:,-1]
    return X, y,df


In [None]:
data_path='./data/'
single_file = os.path.join(data_path, 'isot_botnet.csv')
X, y,df = load_and_preprocess(single_file)
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)  

train_indices, test_indices, y_train, y_test = train_test_split(np.arange(len(X)), y, test_size=0.3, random_state=42,shuffle=True)
train_df = X.iloc[train_indices]
test_df = X.iloc[test_indices]

In [None]:
# Scale the features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(train_df[['SrcWin', 'sHops', 'sTtl', 'dTtl', 'SrcBytes', 'DstBytes', 'Dur', 'TotBytes', 'Rate']])
X_test_scaled = scaler.transform(test_df[['SrcWin', 'sHops', 'sTtl', 'dTtl', 'SrcBytes', 'DstBytes', 'Dur', 'TotBytes', 'Rate']])

In [None]:
model_path='./optimization/information_feature_selection/'

In [None]:
rf_clf=joblib.load(model_path+'best_rf_exploration_400_indisot_botnet.pkl')
predictions = rf_clf.predict(X_test_scaled)
accuracy = accuracy_score(y_test, predictions)
# Print out its metrics
print("Accuracy: ", accuracy_score(y_test, predictions))
print("Precision: ", precision_score(y_test, predictions))
print("Recall: ", recall_score(y_test, predictions))
print("F1 score: ", f1_score(y_test, predictions))
print("Confusion Matrix: \n", confusion_matrix(y_test, predictions))

In [None]:
# Calculating unscaled max and min values from etf_train_df
SrcWin_max = train_df['SrcWin'].max()
SrcWin_min = train_df['SrcWin'].min()

sHops_max = train_df['sHops'].max()
sHops_min = train_df['sHops'].min()

sTtl_max = train_df['sTtl'].max()
sTtl_min = train_df['sTtl'].min()

dTtl_max = train_df['dTtl'].max()
dTtl_min = train_df['dTtl'].min()

SrcBytes_max = train_df['SrcBytes'].max()
SrcBytes_min = train_df['SrcBytes'].min()

DstBytes_max = train_df['DstBytes'].max()
DstBytes_min = train_df['DstBytes'].min()

Dur_max = train_df['Dur'].max()
Dur_min = train_df['Dur'].min()

TotBytes_max = train_df['TotBytes'].max()
TotBytes_min = train_df['TotBytes'].min()

Rate_max = train_df['Rate'].max()
Rate_min = train_df['Rate'].min()

# Calculating scaled min and max values from etf_X_train_norm
min_values = X_train_scaled.min(axis=0)
max_values = X_train_scaled.max(axis=0)

# Assign them to variables named accordingly
SrcWin_scaled_min, sHops_scaled_min, sTtl_scaled_min, dTtl_scaled_min, SrcBytes_scaled_min, DstBytes_scaled_min, Dur_scaled_min, TotBytes_scaled_min, Rate_scaled_min = min_values

SrcWin_scaled_max, sHops_scaled_max, sTtl_scaled_max, dTtl_scaled_max, SrcBytes_scaled_max, DstBytes_scaled_max, Dur_scaled_max, TotBytes_scaled_max, Rate_scaled_max = max_values

# Create a dictionary to store both unscaled and scaled min and max values for each feature
feature_bounds = {
    'SrcWin': {'min': SrcWin_min, 'max': SrcWin_max, 'scaled_min': SrcWin_scaled_min, 'scaled_max': SrcWin_scaled_max},
    'sHops': {'min': sHops_min, 'max': sHops_max, 'scaled_min': sHops_scaled_min, 'scaled_max': sHops_scaled_max},
    'sTtl': {'min': sTtl_min, 'max': sTtl_max, 'scaled_min': sTtl_scaled_min, 'scaled_max': sTtl_scaled_max},
    'dTtl': {'min': dTtl_min, 'max': dTtl_max, 'scaled_min': dTtl_scaled_min, 'scaled_max': dTtl_scaled_max},
    'SrcBytes': {'min': SrcBytes_min, 'max': SrcBytes_max, 'scaled_min': SrcBytes_scaled_min, 'scaled_max': SrcBytes_scaled_max},
    'DstBytes': {'min': DstBytes_min, 'max': DstBytes_max, 'scaled_min': DstBytes_scaled_min, 'scaled_max': DstBytes_scaled_max},
    'Dur': {'min': Dur_min, 'max': Dur_max, 'scaled_min': Dur_scaled_min, 'scaled_max': Dur_scaled_max},
    'TotBytes': {'min': TotBytes_min, 'max': TotBytes_max, 'scaled_min': TotBytes_scaled_min, 'scaled_max': TotBytes_scaled_max},
    'Rate': {'min': Rate_min, 'max': Rate_max, 'scaled_min': Rate_scaled_min, 'scaled_max': Rate_scaled_max}
}


In [None]:
neural_net=load_model(model_path+'optimized_nn_full_training_500isot_botnet.h5')

In [None]:
predictions = neural_net.predict(X_test_scaled)

In [None]:
# predictions = neural_net.predict(X_test)
predictions = [round(x[0]) for x in predictions] 
# Print out its metrics
print("Accuracy: ", accuracy_score(y_test, predictions))
print("Precision: ", precision_score(y_test, predictions))
print("Recall: ", recall_score(y_test, predictions))
print("F1 score: ", f1_score(y_test, predictions))
print("Confusion Matrix: \n", confusion_matrix(y_test, predictions))

In [None]:
or_miss_rate_rf=np.round((2611/(2611+31315)*100),2)
print(or_miss_rate_rf)

In [None]:
or_miss_rate_nn=np.round((4706/(4706+29220)*100),2)
print(or_miss_rate_nn)

In [None]:
malware_pred_index=np.where((np.array(predictions)==1) & (np.array(y_test)==1))[0]
X_test_malware=X_test_scaled[malware_pred_index]

In [None]:
def build_generator(latent_dim, feature_count,data_min,data_max):
    model = tf.keras.Sequential([
        Dense(128, activation='relu', input_dim=latent_dim),
        Dense(feature_count, activation='tanh'),  # Adjusting only targeted features
    ])
    # Rescale output to match the data range after StandardScaler
    scaling_factor = (data_max - data_min) / 2
    offset = (data_max + data_min) / 2
    model.add(Lambda(lambda x: x * scaling_factor + offset))

    return model

def build_critic(input_shape):
    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=input_shape),
        Dense(128, activation='relu'),
        Dense(1)
    ])
    return model

In [None]:
#feature_min and feature_max are in original space
def adjust_feature_dependencies(scaler,perturbed_samples, feature_index,feature_name,feature_min,feature_max):
    # Adjust dependencies based on the feature_name
    # Add your specific logic here
    INITIAL_TTL = 255
    perturbed_samples_original = scaler.inverse_transform(perturbed_samples)
    # Example for "Dur":
    perturbed_samples_original[:, feature_index] = np.clip(perturbed_samples_original[:, feature_index], feature_min, feature_max)
    

    # Adjust dependencies for the modified feature
    if feature_name == "Dur":
        original_duration = perturbed_samples_original[:, 6]
        rate_change_factor = original_duration / (perturbed_samples_original[:, 6] + 1e-10)
        perturbed_samples_original[:, 8] *= rate_change_factor
    elif feature_name == "SrcBytes":
        perturbed_samples_original[:, 7] = perturbed_samples_original[:, 4] + perturbed_samples_original[:, 5]
        # Adjusting Duration to keep Rate constant
        original_rate = perturbed_samples_original[:, 8]
        perturbed_samples_original[:, 6] = (perturbed_samples_original[:, 4] + perturbed_samples_original[:, 5]) / original_rate
    elif feature_name == 'DstBytes':
        perturbed_samples_original[:, 7] = perturbed_samples_original[:, 4] + perturbed_samples_original[:, 5]  # TotBytes = SrcBytes + DstBytes"
        # Adjusting Duration to keep Rate constant
        original_rate = perturbed_samples_original[:, 8]
        perturbed_samples_original[:, 6] = (perturbed_samples_original[:, 4] + perturbed_samples_original[:, 5]) / original_rate
    elif feature_name == "TotBytes":
    # TotBytes is dependent on SrcBytes and DstBytes
        perturbed_samples_original[:, 4] = perturbed_samples_original[:, 7] - perturbed_samples_original[:, 5]  # Assuming SrcBytes = TotBytes - DstBytes
        perturbed_samples_original[:, 5] = perturbed_samples_original[:, 7] - perturbed_samples_original[:, 4]  # Assuming DstBytes = TotBytes - SrcBytes
        
        # Adjust Duration to keep Rate constant
        original_rate = perturbed_samples_original[:, 8]
        perturbed_samples_original[:, 6] = perturbed_samples_original[:, 7] / (original_rate + 1e-10)
    elif feature_name == "sHops":
        perturbed_samples_original[:, 2] = INITIAL_TTL - perturbed_samples_original[:, 1]  # sTtl based on sHops
    elif feature_name in ["sTtl", "dTtl"]:
        perturbed_samples_original[:, 1] = INITIAL_TTL - perturbed_samples_original[:, 2]  # sHops based on sTtl
    elif feature_name == "Rate":
        # Adjust Duration based on Rate and TotBytes
        perturbed_samples_original[:, 6] = perturbed_samples_original[:, 7] / (perturbed_samples_original[:, 8] + 1e-10)
    

    
    
    # Rescale to standardized space
    perturbed_samples = scaler.transform(perturbed_samples_original)

    return perturbed_samples




In [None]:
class WGAN(tf.keras.Model):
    def __init__(self,real_data,substitute_detector,rf,critic, generator, latent_dim,checkpoints):
        super(WGAN, self).__init__()
        self.critic = critic
        self.generator = generator
        self.latent_dim = latent_dim
        self.substitute_detector=substitute_detector
        self.rf=rf
        self.checkpoints=checkpoints
        self.checkpoint_data= {checkpoint: {'adversarial_samples': [], 
                                             'nn_miss_rates': 0, 
                                             'rf_miss_rates': 0, 
                                             'l2_distances': 0, 
                                             'unsuccessful_indices': [],
                                             'successful_indices': []} for checkpoint in checkpoints}
        self.full_dataset=real_data

    def compile(self, c_optimizer, g_optimizer, c_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.c_optimizer = c_optimizer
        self.g_optimizer = g_optimizer
        self.c_loss_fn = c_loss_fn
        self.g_loss_fn = g_loss_fn

    
    

    def train_step(self,real_data,batch_size,feature_mask,feature_name,feature_index,scaler,scaled_min,scaled_max,ori_min,ori_max,current_iteration):
        current_batch_size = tf.shape(real_data)[0]
        # Random noise sample
        random_latent_vectors = tf.random.normal(shape=(current_batch_size, self.latent_dim))

        # Generate fake data using the generator
        fake_data = self.generator(random_latent_vectors)
        
        real_data = tf.cast(real_data, tf.float32)

        feature_mask = tf.cast(feature_mask, tf.float32)  # Ensure feature_mask is a float tensor
        feature_mask_expanded = tf.expand_dims(feature_mask, 0)  # Expand dimensions to [1, num_features]
        feature_mask_expanded = tf.tile(feature_mask_expanded, [current_batch_size, 1])  # Tile to match batch size
        # print(feature_mask_expanded)

        # print(real_data.shape)
        # print(fake_data.shape)
        # print('here is the feature mask')
        # print(feature_mask_expanded.shape)


        modified_data = real_data * (1 - feature_mask_expanded) + fake_data * feature_mask_expanded
        # print('here is the modified data')
        # print(modified_data)

        
        # Ensure both real_data and fake_data are of the same type, typically float32
        modified_data = tf.cast(modified_data, tf.float32)
        feature_mask_single = tf.one_hot(indices=feature_index, depth=tf.shape(modified_data)[1])

        # Clip the entire modified_data
        clipped_data = tf.clip_by_value(modified_data, scaled_min, scaled_max)

        # Apply the mask
        modified_data = modified_data * (1 - feature_mask_single) + clipped_data * feature_mask_single

        

        modified_data=adjust_feature_dependencies(scaler,modified_data,feature_index,feature_name,ori_min,ori_max)
        # print('here is the modified data')
        # print(modified_data[:,1])
        # print('-------')
        # print('here is the real data')
        # print(real_data[:,1])

        # Combine real and fake data
        combined_data = tf.concat([real_data, modified_data], axis=0)
        combined_labels = tf.concat([tf.ones((batch_size, 1)), -tf.ones((batch_size, 1))], axis=0)
      

        # Train the critic
        with tf.GradientTape() as tape:
            predictions = self.critic(combined_data)
            c_loss = self.c_loss_fn(combined_labels, predictions)
        grads = tape.gradient(c_loss, self.critic.trainable_weights)
        self.c_optimizer.apply_gradients(zip(grads, self.critic.trainable_weights))

        # Train the generator
        random_latent_vectors = tf.random.normal(shape=(current_batch_size, self.latent_dim))

        with tf.GradientTape() as tape:
            fake_data = self.generator(random_latent_vectors)


            feature_mask = tf.cast(feature_mask, tf.float32)  # Ensure feature_mask is a float tensor
            feature_mask_expanded = tf.expand_dims(feature_mask, 0)  # Expand dimensions to [1, num_features]
            feature_mask_expanded = tf.tile(feature_mask_expanded, [current_batch_size, 1])  # Tile to match batch size
            # print(feature_mask_expanded)

            # print(real_data.shape)
            # print(fake_data.shape)
            # print(feature_mask_expanded.shape)


            modified_data = real_data * (1 - feature_mask_expanded) + fake_data * feature_mask_expanded

            
            # Ensure both real_data and fake_data are of the same type, typically float32
            modified_data = tf.cast(modified_data, tf.float32)
            feature_mask_single = tf.one_hot(indices=feature_index, depth=tf.shape(modified_data)[1])

            # Clip the entire modified_data
            clipped_data = tf.clip_by_value(modified_data, scaled_min, scaled_max)

            # Apply the mask
            modified_data = modified_data * (1 - feature_mask_single) + clipped_data * feature_mask_single

            

            modified_data=adjust_feature_dependencies(scaler,modified_data,feature_index,feature_name,ori_min,ori_max)


            fake_pred = self.critic(fake_data)
            substitute_pred = self.substitute_detector(modified_data)
            
            if current_iteration in self.checkpoints:
                successful_idx=np.where(substitute_pred < 0.5)[0]
                
                self.checkpoint_data[current_iteration]['adversarial_samples'].extend(modified_data[successful_idx])
                
            g_loss = self.g_loss_fn(fake_pred,substitute_pred)  # Pass only the fake predictions to the generator loss function
            # print(g_loss)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
        

    
        # if current_iteration in self.checkpoints:
            # nn_miss_rate, rf_miss_rate, l2_distance, unsuccessful_indices = self.evaluate_checkpoint(modified_data, real_data,self.substitute_detector,self.rf)
            
            # self.checkpoint_data[current_iteration]['adversarial_samples'] = modified_data
            # self.checkpoint_data[current_iteration]['nn_miss_rates'] = nn_miss_rate
            # self.checkpoint_data[current_iteration]['rf_miss_rates'] = rf_miss_rate
            # self.checkpoint_data[current_iteration]['l2_distances'] = l2_distance
            # self.checkpoint_data[current_iteration]['unsuccessful_indices'] = unsuccessful_indices


        return {"c_loss": c_loss, "g_loss": g_loss}

    def post_batch_evaluation(self, current_iteration):
        """Evaluate aggregated data after all batches are processed for a checkpoint."""
        if current_iteration in self.checkpoints:
            aggregated_samples = np.array(self.checkpoint_data[current_iteration]['adversarial_samples'])
            print(f'Evaluation iteration{current_iteration}')
            nn_miss_rate, rf_miss_rate, l2_distance,sucessful_index, unsuccessful_indices = self.evaluate_checkpoint(aggregated_samples,self.full_dataset, self.substitute_detector, self.rf)
            self.checkpoint_data[current_iteration]['successful_indices']=sucessful_index
            self.checkpoint_data[current_iteration]['nn_miss_rates'] = nn_miss_rate
            self.checkpoint_data[current_iteration]['rf_miss_rates'] = rf_miss_rate
            self.checkpoint_data[current_iteration]['l2_distances'] = l2_distance
            self.checkpoint_data[current_iteration]['unsuccessful_indices'] = unsuccessful_indices
    
    def evaluate_checkpoint(self, successful_adversarial_samples, real_data, neural_net, rf_model):
        
        if len(successful_adversarial_samples)>0:
            substitute_pred = self.substitute_detector(successful_adversarial_samples)
            successful_idx=np.where(substitute_pred < 0.5)[0]
                    
            # Calculate NN misclassification rate
            nn_misclassification_rate = np.round(((len(successful_adversarial_samples) +4706) /(4706+29220)*100), 2)
            print(nn_misclassification_rate)

            # Calculate L2 distance for successful adversarial samples
            real_samples_for_successful = real_data[successful_idx]
            l2_distance = np.linalg.norm(real_samples_for_successful - successful_adversarial_samples, axis=1).mean()
            print(f'L2 dist{l2_distance}')

            # Determine unsuccessful indices
            unsuccessful_indices = np.setdiff1d(np.arange(len(successful_adversarial_samples)), successful_idx)
            print(f'NN Misclassication Rate{nn_misclassification_rate}')
            
        else:
            nn_misclassification_rate = np.round(((0 +4706) /(4706+29220)*100), 2)
            print(nn_misclassification_rate)
            real_samples_for_successful = 0
            l2_distance = 0
            print(f'L2 dist{l2_distance}')
            # Determine unsuccessful indices
            unsuccessful_indices = 0
            successful_idx=None
        

        # Calculate RF misclassification rate
        if len(successful_adversarial_samples)>0:
            rf_predictions = rf_model.predict(successful_adversarial_samples)
            rf_malware_count = np.count_nonzero(rf_predictions == 0)
            rf_misclassification_rate = np.round(((rf_malware_count+2611) /(2611+31315)*100),2)
        else:
            rf_malware_count = 0
            rf_misclassification_rate = np.round(((rf_malware_count+2611) /(2611+31315)*100),2)
        print(f'RF Misclassificaiotn Rate{rf_misclassification_rate}')

        
        return nn_misclassification_rate, rf_misclassification_rate, l2_distance,successful_idx, unsuccessful_indices
    def save_checkpoint_data(self,feature):
        # print(self.checkpoint_data)

        file_path = f"D:\\Network-Revisit\\output_gan_isot\\{feature}_data.pkl"

        with open(file_path, "wb") as file:
            pickle.dump(self.checkpoint_data, file)

        checkpoints = sorted(self.checkpoint_data.keys())
        nn_miss_rates = [self.checkpoint_data[checkpoint]['nn_miss_rates'] for checkpoint in checkpoints]
        rf_miss_rates = [self.checkpoint_data[checkpoint]['rf_miss_rates'] for checkpoint in checkpoints]
        l2_distances = [self.checkpoint_data[checkpoint]['l2_distances'] for checkpoint in checkpoints]


        # Plot for Misclassification Rates
        plt.figure(figsize=(10, 7))
        plt.plot(checkpoints, nn_miss_rates, '-o', label='NN Miss Rate')
        plt.plot(checkpoints, rf_miss_rates, '-s', color='red', label='RF Miss Rate')
        plt.xlabel('iterations')
        plt.ylabel('Rate')
        plt.title('NN-MR(%) & RF-MR(%) vs. iteration')
        plt.legend()
        plt.grid(True, which="both", ls="--", c='0.7')
        plt.tight_layout()
        plt.savefig(f"D:\\Network-Revisit\\output_gan_isot\\{feature}_NN_RF_MR_vs_iteration.png", dpi=100)
        plt.show()

        # Plot for L2 Distances
        plt.figure(figsize=(10, 7))
        ax1 = plt.gca()
        ax1.plot(checkpoints, l2_distances, '-o', color='tab:red', label='L2 Distance')
        ax1.set_xlabel('Iterations')
        ax1.set_ylabel('L2 Distance', color='tab:red')
        ax1.tick_params(axis='y', labelcolor='tab:red')
        ax1.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.7)

        # Optional: if you want to plot Misclassification Rate on the secondary axis
        ax2 = ax1.twinx()
        ax2.plot(checkpoints, nn_miss_rates, '-s', color='tab:blue', label='NN Miss Rate')
        ax2.set_ylabel('Misclassification Rate', color='tab:blue')
        ax2.tick_params(axis='y', labelcolor='tab:blue')

        ax1.legend(loc='upper left')
        ax2.legend(loc='upper right')
        plt.title('Trade-off: Perturbation Magnitude vs Misclassification Rate')
        plt.tight_layout()
        plt.savefig(f"D:\\Network-Revisit\\output_gan_isot\\{feature}_L2_vs_Misclassification_Rate.png", dpi=100)
        plt.show()



        return self.checkpoint_data


In [None]:
def critic_loss(real_pred, fake_pred):
    return tf.reduce_mean(fake_pred) - tf.reduce_mean(real_pred)

def generator_loss(fake_pred, substitute_detector_pred):
    # Standard GAN generator loss
    gan_loss = -tf.reduce_mean(fake_pred)

    # Loss to encourage misclassification by the substitute detector
    # Directly minimize the probability of the malware class
    substitute_loss = tf.reduce_mean(substitute_detector_pred)

    return gan_loss + substitute_loss


In [None]:
batch_size = 2000  # Adjust as needed
epochs = 2001  # Number of epochs

# batch_size = 2000  # Adjust as needed
# epochs = 10  # Number of epochs


# feature_list=['TotBytes', 'Rate']
feature_list=['SrcWin', 'sHops', 'sTtl', 'dTtl', 'SrcBytes', 'DstBytes', 'Dur', 'TotBytes', 'Rate']
# epochs=10
# Assuming X_test_malware is preprocessed and ready for training
for feature_name in feature_list:

    checkpoints = [5,100,750,1000,2000]
    # checkpoints=[1,5,8]
    latent_dim = 100  # Adjust as needed
    feature_count = 9  # Assuming 10 features
    min_value = np.min(X_test_malware)
    max_value = np.max(X_test_malware)
    # Instantiate models
    generator = build_generator(latent_dim, feature_count,min_value,max_value)
    critic = build_critic((feature_count,))

    # Instantiate WGAN model
    wgan = WGAN(X_test_malware,neural_net,rf_clf,critic=critic, generator=generator, latent_dim=latent_dim,checkpoints=checkpoints)

    # Compile WGAN
    wgan.compile(
        c_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
        g_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
        c_loss_fn=critic_loss,
        g_loss_fn=generator_loss
    )
    dataset = tf.data.Dataset.from_tensor_slices(X_test_malware).batch(batch_size)
    # dataset_rate= tf.data.Dataset.from_tensor_slices(rate_data).batch(batch_size)
    # Zipping the two datasets together
    # zipped_dataset = tf.data.Dataset.zip((dataset, dataset_rate))

    feature_index = ['SrcWin', 'sHops', 'sTtl', 'dTtl', 'SrcBytes', 'DstBytes', 'Dur', 'TotBytes', 'Rate'].index(feature_name)
    feature_mask = tf.constant([1 if i == feature_index else 0 for i in range(len(['SrcWin', 'sHops', 'sTtl', 'dTtl', 'SrcBytes', 'DstBytes', 'Dur', 'TotBytes', 'Rate']))], dtype=tf.float32)
    current_iteration=0
    # Training loop
    for epoch in range(epochs):
        if (epoch+1)%200==0:
            print(f"Epoch {epoch+1}/{epochs}")
        for batch_data  in dataset:
            # print('here is the batch data')
            # print(batch_data)
            wgan.train_step(batch_data,batch_size,feature_mask,feature_name,feature_index,scaler,feature_bounds[feature_name]['scaled_min'], \
                            feature_bounds[feature_name]['scaled_max'],feature_bounds[feature_name]['min'],feature_bounds[feature_name]['max'],current_iteration)
        if current_iteration in checkpoints:
            wgan.post_batch_evaluation(current_iteration)
        current_iteration+=1
    wgan.save_checkpoint_data(feature_name)
