In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import warnings
warnings.filterwarnings("ignore")
import pandas as pd
import numpy as np
from scipy import stats
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import fbeta_score

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

class ExpertModel(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_units, dropout_rate):
        super(ExpertModel, self).__init__()
        layers = []
        for units in hidden_units:
            layers.append(nn.Linear(input_dim, units))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))
            input_dim = units
        layers.append(nn.Linear(input_dim, output_dim))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class GateModel(nn.Module):
    def __init__(self, input_dim, num_experts, hidden_units, dropout_rate):
        super(GateModel, self).__init__()
        layers = []
        for units in hidden_units:
            layers.append(nn.Linear(input_dim, units))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))
            input_dim = units
        layers.append(nn.Linear(input_dim, num_experts))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return F.softmax(self.model(x), dim=1)


class MixtureOfExperts(pl.LightningModule):
    def __init__(self, input_dim, output_dim, num_experts, expert_hidden_units, gate_hidden_units, num_active_experts, dropout_rate, learning_rate=1e-3):
        super(MixtureOfExperts, self).__init__()
        self.save_hyperparameters()

        self.experts = nn.ModuleList([ExpertModel(input_dim, output_dim, expert_hidden_units, dropout_rate) for _ in range(num_experts)])
        self.gate = GateModel(input_dim, num_experts, gate_hidden_units, dropout_rate)
        self.num_active_experts = num_active_experts
        self.expert_usage_count = torch.zeros(num_experts, dtype=torch.float32)

        self.learning_rate = learning_rate
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
        gate_output = self.gate(x)

        expert_usage_count_adjusted = self.expert_usage_count + 1e-10
        importance_scores = gate_output / expert_usage_count_adjusted

        top_n_expert_indices = torch.argsort(importance_scores, dim=1, descending=True)[:, :self.num_active_experts]
        selected_expert_indices = top_n_expert_indices.view(-1)

        self.expert_usage_count += torch.bincount(selected_expert_indices, minlength=len(self.experts)).float()

        mask = torch.sum(F.one_hot(top_n_expert_indices, num_classes=len(self.experts)), dim=1)
        masked_gate_output = gate_output * mask
        normalized_gate_output = masked_gate_output / (torch.sum(masked_gate_output, dim=1, keepdim=True) + 1e-7)

        masked_expert_outputs = torch.stack([expert_outputs[:, i] * normalized_gate_output[:, i].unsqueeze(1)
                                              for i in range(len(self.experts))], dim=1)
        final_output = torch.sum(masked_expert_outputs, dim=1)

        return final_output

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        preds = torch.argmax(y_hat, dim=1)
        f2_score = fbeta_score(y.cpu().numpy(), preds.cpu().numpy(), beta=2, average='macro')
        self.log('val_f2', f2_score, prog_bar=True, sync_dist=True)
        return {'val_f2': f2_score}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.1, patience=5, verbose=True
        )
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_f2',
                'interval': 'epoch',
                'frequency': 1
            }
        }

    def on_fit_start(self):
        self.expert_usage_count = self.expert_usage_count.to(self.device)

class ExpertUsageLogger(pl.Callback):
    def __init__(self, moe_model):
        super(ExpertUsageLogger, self).__init__()
        self.moe_model = moe_model
        self.expert_usage_history = []

    def on_train_epoch_end(self, trainer, pl_module):
        usage_count = self.moe_model.expert_usage_count.clone().cpu().numpy()
        self.expert_usage_history.append(usage_count)

    def plot_expert_usage(self):
        import matplotlib.pyplot as plt
        usage_history = torch.tensor(self.expert_usage_history)
        plt.figure(figsize=(10, 6))
        for i in range(usage_history.shape[1]):
            plt.plot(usage_history[:, i], label=f'Expert {i}')
        plt.xlabel('Epoch')
        plt.ylabel('Expert Usage Count')
        plt.title('Expert Usage Over Epochs')
        plt.legend(loc='upper left')
        plt.show()


In [3]:
train_data_path = 'CIC_IoMT_2024_WiFi_MQTT_train.parquet'
test_data_path = 'CIC_IoMT_2024_WiFi_MQTT_test.parquet'
usage_ratio=0.2

In [4]:
df_train = pd.read_parquet(train_data_path)
df_test = pd.read_parquet(test_data_path)

# Combine train and test data
df_combined = pd.concat([df_train, df_test])

display(df_train.nunique())
df_train.info()
# Perform stratified sampling
df_sampled, _ = train_test_split(df_combined, train_size=usage_ratio, stratify=df_combined['label'], random_state=42)

# Split back into train and test based on the original indices
df_train: pd.DataFrame = df_sampled[df_sampled.index.isin(df_train.index)]
df_test: pd.DataFrame = df_sampled[df_sampled.index.isin(df_test.index)]
numeric_columns = df_train.select_dtypes(include=[np.number]).columns
df_train[numeric_columns] = df_train[numeric_columns].astype(np.float32)
df_test[numeric_columns] = df_test[numeric_columns].astype(np.float32)

Header_Length       822146
Protocol Type         1441
Duration              2684
Rate               4174510
Srate              4174510
Drate                    1
fin_flag_number        229
syn_flag_number        525
rst_flag_number        501
psh_flag_number        428
ack_flag_number        619
ece_flag_number          8
cwr_flag_number          6
ack_count              563
syn_count              986
fin_count             2605
rst_count            11333
HTTP                    33
HTTPS                  257
DNS                     62
Telnet                   6
SMTP                     7
SSH                     14
IRC                      7
TCP                    253
UDP                    282
DHCP                    33
ARP                     72
ICMP                   227
IGMP                    13
IPv                     72
LLC                     72
Tot sum               6502
Min                   5117
Max                   5287
AVG                   5291
Std                  13274
T

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7160831 entries, 0 to 7160830
Data columns (total 46 columns):
 #   Column           Dtype   
---  ------           -----   
 0   Header_Length    float32 
 1   Protocol Type    float16 
 2   Duration         float16 
 3   Rate             float32 
 4   Srate            float32 
 5   Drate            int8    
 6   fin_flag_number  float16 
 7   syn_flag_number  float16 
 8   rst_flag_number  float16 
 9   psh_flag_number  float16 
 10  ack_flag_number  float16 
 11  ece_flag_number  float16 
 12  cwr_flag_number  float16 
 13  ack_count        float16 
 14  syn_count        float16 
 15  fin_count        float16 
 16  rst_count        float16 
 17  HTTP             float16 
 18  HTTPS            float16 
 19  DNS              float16 
 20  Telnet           float16 
 21  SMTP             float16 
 22  SSH              float16 
 23  IRC              float16 
 24  TCP              float16 
 25  UDP              float16 
 26  DHCP          

In [5]:
df_train.info()
df_train.nunique()

<class 'pandas.core.frame.DataFrame'>
Index: 1755002 entries, 5278661 to 3488986
Data columns (total 46 columns):
 #   Column           Dtype  
---  ------           -----  
 0   Header_Length    float32
 1   Protocol Type    float32
 2   Duration         float32
 3   Rate             float32
 4   Srate            float32
 5   Drate            float32
 6   fin_flag_number  float32
 7   syn_flag_number  float32
 8   rst_flag_number  float32
 9   psh_flag_number  float32
 10  ack_flag_number  float32
 11  ece_flag_number  float32
 12  cwr_flag_number  float32
 13  ack_count        float32
 14  syn_count        float32
 15  fin_count        float32
 16  rst_count        float32
 17  HTTP             float32
 18  HTTPS            float32
 19  DNS              float32
 20  Telnet           float32
 21  SMTP             float32
 22  SSH              float32
 23  IRC              float32
 24  TCP              float32
 25  UDP              float32
 26  DHCP             float32
 27  ARP        

Header_Length       282568
Protocol Type          965
Duration              2115
Rate               1131368
Srate              1131368
Drate                    1
fin_flag_number        125
syn_flag_number        228
rst_flag_number        219
psh_flag_number        188
ack_flag_number        250
ece_flag_number          5
cwr_flag_number          4
ack_count              331
syn_count              578
fin_count             1835
rst_count            10928
HTTP                    25
HTTPS                  142
DNS                     37
Telnet                   3
SMTP                     4
SSH                     10
IRC                      4
TCP                    143
UDP                    152
DHCP                    17
ARP                     64
ICMP                   135
IGMP                    12
IPv                     64
LLC                     64
Tot sum               6057
Min                   4304
Max                   5274
AVG                   5240
Std                  12737
T

In [6]:
df_train['DHCP'].unique()

array([0.        , 0.01000214, 0.60009766, 0.19995117, 0.02000427,
       0.39990234, 0.09997559, 0.41992188, 0.2199707 , 0.30004883,
       0.36010742, 0.5600586 , 0.23999023, 0.4399414 , 0.13000488,
       0.02999878, 0.35009766], dtype=float32)

In [7]:
numerical_columns = [col for col in df_train.columns if col not in ['label', 'Drate']]

target_train = df_train['label']
df_train = df_train.drop(columns=['label', 'Drate'])
target_test = df_test['label']
df_test = df_test.drop(columns=['label', 'Drate'])

In [8]:
print(f"Number of missing values: {df_train.isna().sum().sum()}")

Number of missing values: 0


In [9]:
display(df_train.describe())

Unnamed: 0,Header_Length,Protocol Type,Duration,Rate,Srate,fin_flag_number,syn_flag_number,rst_flag_number,psh_flag_number,ack_flag_number,...,AVG,Std,Tot size,IAT,Number,Magnitue,Radius,Covariance,Variance,Weight
count,1755002.0,1755002.0,1755002.0,1755002.0,1755002.0,1755002.0,1755002.0,1755002.0,1755002.0,1755002.0,...,1755002.0,1755002.0,1755002.0,1755002.0,1755002.0,1755002.0,1755002.0,1755002.0,1755002.0,1755002.0
mean,29883.37,8.046584,64.63663,15732.94,15732.94,0.005101282,0.1571194,0.03962176,0.02219948,0.09567751,...,60.53767,6.038682,60.53738,84678376.0,9.498865,10.43581,8.529323,2367.965,0.09065043,141.4742
std,281461.6,6.304724,7.837749,40000.8,40000.8,0.03395753,0.3367704,0.1395269,0.09655753,0.2522283,...,87.83871,38.02954,87.58988,17817836.0,0.8414046,3.150808,53.76392,19802.74,0.2328664,21.66214
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,42.0,0.0,42.0,0.0,1.0,9.164062,0.0,0.0,0.0,1.0
25%,2.31,1.049805,64.0,6.4289,6.4289,0.0,0.0,0.0,0.0,0.0,...,42.09375,0.0,42.25,84679176.0,9.5,9.171875,0.0,0.0,0.0,141.5
50%,108.0,6.0,64.0,133.1385,133.1385,0.0,0.0,0.0,0.0,0.0,...,50.0,0.0,50.0,84696416.0,9.5,10.0,0.0,0.0,0.0,141.5
75%,19400.38,17.0,64.0,19763.98,19763.98,0.0,0.0,0.0,0.0,0.0,...,54.0,0.0,54.0,84696904.0,9.5,10.39062,0.0,0.0,0.0,141.5
max,9892476.0,17.0,255.0,2097152.0,2097152.0,1.0,1.0,1.0,1.0,1.0,...,1514.0,720.5,1514.0,169470848.0,13.5,55.03125,1019.5,519757.5,1.0,244.625


In [10]:
from sklearn.preprocessing import OrdinalEncoder


encoder = OrdinalEncoder(handle_unknown='use_encoded_value', unknown_value=-1)
target_train_encoded = encoder.fit_transform(target_train.values.reshape(-1, 1))
target_test_encoded = encoder.transform(target_test.values.reshape(-1, 1))

In [11]:
target_train_encoded.shape, target_test_encoded.shape

((1755002, 1), (645668, 1))

In [12]:
z_scores = np.abs(stats.zscore(df_train[numerical_columns].astype(np.float64)))

outlier_mask = np.any(z_scores > 4, axis=1)

# Filter out rows with outliers
df_train = df_train[~outlier_mask]
target_train_encoded = target_train_encoded[~outlier_mask]

print(f"{outlier_mask.sum()} out of {len(outlier_mask)} samples were filtered out as outliers.")
print(f"Number of missing values: {df_train.isna().sum().sum()}")

185096 out of 1755002 samples were filtered out as outliers.
Number of missing values: 0


In [13]:
df_train.sample(15)

Unnamed: 0,Header_Length,Protocol Type,Duration,Rate,Srate,fin_flag_number,syn_flag_number,rst_flag_number,psh_flag_number,ack_flag_number,...,AVG,Std,Tot size,IAT,Number,Magnitue,Radius,Covariance,Variance,Weight
6192364,0.0,1.0,64.0,165.400284,165.400284,0.0,0.0,0.0,0.0,0.0,...,42.0,0.0,42.0,84697032.0,9.5,9.164062,0.0,0.0,0.0,141.5
7117584,54.0,6.0,64.0,68.792351,68.792351,0.0,0.0,0.0,0.0,0.0,...,54.0,0.0,54.0,84696112.0,9.5,10.390625,0.0,0.0,0.0,141.5
621097,100.120003,6.0,64.0,5.615251,5.615251,0.0,1.0,0.0,0.0,0.219971,...,54.96875,0.580566,54.875,84674744.0,9.5,10.484375,0.808105,0.917855,0.379883,141.5
595882,0.0,1.0,64.0,25.294065,25.294065,0.0,0.0,0.0,0.0,0.0,...,42.0,0.0,42.0,84697072.0,9.5,9.164062,0.0,0.0,0.0,141.5
5978704,8825.0,17.0,64.0,142711.765625,142711.765625,0.0,0.0,0.0,0.0,0.0,...,50.0,0.0,50.0,84675256.0,9.5,10.0,0.0,0.0,0.0,141.5
6292793,0.0,1.0,64.0,31.164721,31.164721,0.0,0.0,0.0,0.0,0.0,...,42.0,0.0,42.0,84697048.0,9.5,9.164062,0.0,0.0,0.0,141.5
6920754,21680.5,17.0,64.0,13038.194336,13038.194336,0.0,0.0,0.0,0.0,0.0,...,50.0,0.0,50.0,84696736.0,9.5,10.0,0.0,0.0,0.0,141.5
846606,0.0,1.0,64.0,234.213593,234.213593,0.0,0.0,0.0,0.0,0.0,...,42.0,0.0,42.0,84697024.0,9.5,9.164062,0.0,0.0,0.0,141.5
6415335,453.559998,1.160156,64.0,154.241364,154.241364,0.0,0.0,0.0,0.0,0.0,...,42.5625,1.955078,42.5625,84679320.0,9.5,9.226562,2.769531,29.984301,0.130005,141.5
1711708,0.0,1.0,64.0,24672.376953,24672.376953,0.0,0.0,0.0,0.0,0.0,...,42.0,0.0,42.0,84697064.0,9.5,9.164062,0.0,0.0,0.0,141.5


In [14]:
mean = df_train.mean()
std = df_train.std()
display(mean)
display(std)

Header_Length      8.671483e+03
Protocol Type      8.133665e+00
Duration           6.429772e+01
Rate               1.382555e+04
Srate              1.382555e+04
fin_flag_number    2.179058e-04
syn_flag_number    1.580998e-01
rst_flag_number    2.131322e-02
psh_flag_number    1.313877e-03
ack_flag_number    3.916546e-02
ece_flag_number    0.000000e+00
cwr_flag_number    0.000000e+00
ack_count          1.003523e-03
syn_count          2.447479e-01
fin_count          2.861828e-02
rst_count          6.779028e-01
HTTP               8.658214e-06
HTTPS              8.020909e-04
DNS                2.159189e-05
Telnet             0.000000e+00
SMTP               0.000000e+00
SSH                2.421044e-07
IRC                0.000000e+00
TCP                3.660790e-01
UDP                3.314583e-01
DHCP               0.000000e+00
ARP                4.128990e-05
ICMP               3.024321e-01
IGMP               0.000000e+00
IPv                9.999593e-01
LLC                9.999593e-01
Tot sum 

Header_Length       14451.265625
Protocol Type           6.540149
Duration                2.056136
Rate                25310.832031
Srate               25310.832031
fin_flag_number         0.004668
syn_flag_number         0.344913
rst_flag_number         0.084044
psh_flag_number         0.016479
ack_flag_number         0.145640
ece_flag_number         0.000000
cwr_flag_number         0.000000
ack_count               0.022014
syn_count               0.546714
fin_count               0.109391
rst_count              11.508127
HTTP                    0.000392
HTTPS                   0.008465
DNS                     0.000537
Telnet                  0.000000
SMTP                    0.000000
SSH                     0.000049
IRC                     0.000000
TCP                     0.480650
UDP                     0.468974
DHCP                    0.000000
ARP                     0.000879
ICMP                    0.457853
IGMP                    0.000000
IPv                     0.000871
LLC       

In [15]:
mean = mean + 1e-5
std = std + 1e-5
df_train = ((df_train - mean) / std).dropna(axis=1)
df_test = ((df_test - mean) / std).dropna(axis=1)

print(f"Number of missing values: {df_train.isna().sum().sum()}")
print(f"Number of missing values: {df_test.isna().sum().sum()}")

Number of missing values: 0
Number of missing values: 0


In [16]:
corr_matrix = np.corrcoef(df_train, rowvar=False)
upper_triangle_indices = np.triu_indices_from(corr_matrix, k=1)
correlated_pairs = [(i, j) for i, j in zip(*upper_triangle_indices) if np.abs(corr_matrix[i, j]) >= 0.8]
cols_train, cols_test = df_train.columns, df_test.columns
correlated_features = set(j for _, j in correlated_pairs)
df_train = np.delete(df_train, list(correlated_features), axis=1)
df_test = np.delete(df_test, list(correlated_features), axis=1)
df_train = pd.DataFrame(df_train, columns=cols_train.drop(cols_train[list(correlated_features)]))
df_test = pd.DataFrame(df_test, columns=cols_test.drop(cols_test[list(correlated_features)]))

In [17]:
cols_train

Index(['Header_Length', 'Protocol Type', 'Duration', 'Rate', 'Srate',
       'fin_flag_number', 'syn_flag_number', 'rst_flag_number',
       'psh_flag_number', 'ack_flag_number', 'ece_flag_number',
       'cwr_flag_number', 'ack_count', 'syn_count', 'fin_count', 'rst_count',
       'HTTP', 'HTTPS', 'DNS', 'Telnet', 'SMTP', 'SSH', 'IRC', 'TCP', 'UDP',
       'DHCP', 'ARP', 'ICMP', 'IGMP', 'IPv', 'LLC', 'Tot sum', 'Min', 'Max',
       'AVG', 'Std', 'Tot size', 'IAT', 'Number', 'Magnitue', 'Radius',
       'Covariance', 'Variance', 'Weight'],
      dtype='object')

In [18]:
df_train

Unnamed: 0,Header_Length,Protocol Type,Duration,Rate,fin_flag_number,syn_flag_number,rst_flag_number,psh_flag_number,ack_flag_number,ece_flag_number,...,IRC,TCP,DHCP,ARP,ICMP,IGMP,Tot sum,IAT,Number,Variance
0,0.504455,1.355674,-0.144800,-0.032098,-0.048715,-0.458391,-0.253684,-0.080287,-0.268969,-1.0,...,-1.0,-0.761638,-1.0,-0.0577,-0.660552,-1.0,-0.005337,0.008236,0.018562,-0.340259
1,-0.592524,-0.326242,-0.144800,-0.545966,-0.048715,1.541884,3.435111,-0.080287,2.272168,-1.0,...,-1.0,1.318834,-1.0,-0.0577,-0.660552,-1.0,0.487779,-0.028553,0.018562,-0.041834
2,0.850342,1.355674,-0.144800,0.249141,-0.048715,-0.458391,-0.253684,-0.080287,-0.268969,-1.0,...,-1.0,-0.761638,-1.0,-0.0577,-0.660552,-1.0,-0.005337,0.008249,0.018562,-0.340259
3,1.986055,1.355674,-0.144800,-0.095234,-0.048715,-0.458391,-0.253684,-0.080287,-0.268969,-1.0,...,-1.0,-0.761638,-1.0,-0.0577,-0.660552,-1.0,-0.005337,0.008276,0.018562,-0.340259
4,0.947427,1.135878,8.214285,0.517582,-0.048715,-0.458391,-0.253684,-0.080287,-0.268969,-1.0,...,-1.0,-0.761638,-1.0,-0.0577,-0.463928,-1.0,0.218041,0.007956,0.018562,2.246288
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1569901,-0.592577,-0.326242,-0.144800,-0.546201,-0.048715,-0.458391,-0.253684,-0.080287,-0.268969,-1.0,...,-1.0,1.318834,-1.0,-0.0577,-0.660552,-1.0,0.348695,-0.034793,0.018562,-0.340259
1569902,-0.600050,-1.090750,-0.144800,-0.546073,-0.048715,-0.458391,-0.253684,-0.080287,-0.268969,-1.0,...,-1.0,-0.761638,-1.0,-0.0577,1.523509,-1.0,-0.713401,0.015115,0.018562,-0.340259
1569903,2.093589,1.355674,-0.144800,0.136104,-0.048715,-0.458391,-0.253684,-0.080287,-0.268969,-1.0,...,-1.0,-0.761638,-1.0,-0.0577,-0.660552,-1.0,-0.005337,0.007903,0.018562,-0.340259
1569904,-0.600050,-1.090750,0.797497,-0.546073,-0.048715,-0.458391,-0.253684,-0.080287,-0.268969,-1.0,...,-1.0,-0.761638,-1.0,-0.0577,1.523509,-1.0,-0.713401,0.008435,0.018562,-0.340259


In [19]:
n_components = 15
name_cols = [f'PC{i}' for i in range(1, n_components + 1)]
pca = PCA(n_components=n_components)
pca.fit(df_train)
reduced_train = pca.transform(df_train)
reduced_test = pca.transform(df_test)
reduced_train = pd.DataFrame(reduced_train, columns=name_cols)
reduced_test = pd.DataFrame(reduced_test, columns=name_cols)
display(reduced_train)
display(reduced_test)

Unnamed: 0,PC1,PC2,PC3,PC4,PC5,PC6,PC7,PC8,PC9,PC10,PC11,PC12,PC13,PC14,PC15
0,-0.723982,1.405642,-0.509183,-0.043489,-0.453891,0.017190,-0.045733,0.017456,0.019915,0.054392,-0.470635,-0.216196,-0.012250,-0.084373,-0.001162
1,3.298144,-1.425085,-1.520087,-0.024649,-0.111459,0.825453,-0.316069,-0.604433,-0.368913,-1.495695,-0.013148,-0.637633,-0.093273,0.901276,0.023981
2,-0.785003,1.657681,-0.506929,-0.050599,-0.518908,0.079913,-0.053267,-0.029689,-0.000764,-0.039560,-0.289287,-0.179887,-0.011793,-0.104393,-0.001320
3,-0.802038,2.211745,-0.546086,-0.053113,-0.505308,0.063077,-0.056795,-0.051315,-0.048846,-0.215653,-0.828292,-0.296404,-0.022592,-0.271130,-0.000658
4,0.544230,2.852851,1.881443,0.427168,2.744118,5.206410,-0.559283,0.089575,0.541004,1.320186,-1.332039,4.079731,0.281829,1.283006,-0.026053
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1569901,0.343384,-0.512552,-0.687967,0.027071,0.303698,-0.851152,0.029134,0.315244,0.179364,0.798662,-0.150879,0.100868,0.025759,0.678842,0.036834
1569902,-1.247781,-1.478646,1.264406,0.045499,0.215390,-0.159233,-0.059426,-0.042416,-0.043150,-0.242519,-0.386797,-0.291047,-0.020665,-0.172150,0.014284
1569903,-0.841816,2.321149,-0.540101,-0.057954,-0.551259,0.107866,-0.061544,-0.081089,-0.057524,-0.259889,-0.643144,-0.258358,-0.021054,-0.265889,-0.000574
1569904,-1.212118,-1.348980,1.454567,0.096419,0.600753,0.317209,-0.123615,-0.013282,0.024876,-0.074642,-0.526200,0.211604,0.017073,0.029616,0.014714


Unnamed: 0,PC1,PC2,PC3,PC4,PC5,PC6,PC7,PC8,PC9,PC10,PC11,PC12,PC13,PC14,PC15
0,-0.785003,1.657681,-0.506929,-0.050599,-0.518908,0.079913,-0.053267,-0.029689,-0.000764,-0.039560,-0.289287,-0.179887,-0.011793,-0.104393,-0.001320
1,-0.802038,2.211745,-0.546086,-0.053113,-0.505308,0.063077,-0.056795,-0.051315,-0.048846,-0.215653,-0.828292,-0.296404,-0.022592,-0.271130,-0.000658
2,-0.726011,1.198838,-0.491408,-0.043707,-0.470147,0.034556,-0.045349,0.019654,0.038050,0.117939,-0.202292,-0.158644,-0.007374,-0.013347,-0.001263
3,-0.772734,1.457680,-0.495099,-0.049000,-0.515639,0.077905,-0.051312,-0.017597,0.016456,0.025595,-0.143721,-0.148087,-0.008512,-0.050619,-0.001514
4,23.547853,12.327210,23.732811,279.533905,-44.643639,-22.102262,-0.986734,0.446772,-3.180127,-13.028477,-4.967493,14.304883,0.846526,7.171854,83.742699
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
645663,-0.760433,1.548975,-0.507242,-0.047967,-0.493170,0.055173,-0.050211,-0.010580,0.008195,0.000508,-0.353491,-0.192613,-0.011813,-0.093923,-0.001027
645664,3.475626,-1.446949,-1.398737,-0.003074,-0.173237,1.566787,-0.289812,-0.684234,-0.367597,-1.465663,-0.155180,-1.214700,-0.137583,2.014173,0.004905
645665,-1.247781,-1.478646,1.264406,0.045499,0.215390,-0.159233,-0.059426,-0.042416,-0.043150,-0.242519,-0.386797,-0.291047,-0.020665,-0.172150,0.014284
645666,-0.841816,2.321149,-0.540101,-0.057954,-0.551259,0.107866,-0.061544,-0.081089,-0.057524,-0.259889,-0.643144,-0.258358,-0.021054,-0.265889,-0.000574


In [20]:
X, y = reduced_train, target_train_encoded

In [21]:
X_test, y_test = reduced_test, target_test_encoded

In [22]:
import optuna
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torch.utils.data import DataLoader, TensorDataset

def objective(trial, X_train, y_train, X_val, y_val, input_dim, output_dim):
    gate_hidden_units_options = {
        "16": [16], 
        "32": [32], 
        "64": [64], 
        "32_16": [32, 16]
    }
    
    chosen_gate_hidden_units_str = trial.suggest_categorical('gate_hidden_units', list(gate_hidden_units_options.keys()))
    chosen_gate_hidden_units = gate_hidden_units_options[chosen_gate_hidden_units_str]

    dropout_rate = trial.suggest_float('dropout_rate', 0.0, 0.5)
    
    # Instantiate the model
    model = MixtureOfExperts(
        input_dim=input_dim,
        output_dim=output_dim,
        num_experts=output_dim,  # Number of experts equals the number of classes
        expert_hidden_units=[32, 64, 32],
        gate_hidden_units=chosen_gate_hidden_units,
        num_active_experts=3,
        dropout_rate=dropout_rate
    )
    
    # Initialize the expert usage logger
    expert_usage_logger = ExpertUsageLogger(model)

    # Initialize the PyTorch Lightning trainer
    logger = TensorBoardLogger("logs", name="MoE_experimental")
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    checkpoint_callback = ModelCheckpoint(monitor='val_f2', mode='max') 

    trainer = pl.Trainer(
        max_epochs=300,
        logger=logger,
        callbacks=[lr_monitor, checkpoint_callback, expert_usage_logger],
        accelerator='gpu',
    )
    
    # Create PyTorch DataLoaders
    train_loader = DataLoader(TensorDataset(torch.tensor(X_train.values, device='cuda'), torch.tensor(y_train, device='cuda')), batch_size=8192, shuffle=True)
    val_loader = DataLoader(TensorDataset(torch.tensor(X_val.values, device='cuda'), torch.tensor(y_val, device='cuda')), batch_size=8192)
    
    # Train the model
    trainer.fit(model, train_loader, val_loader)
    
    # Evaluate the model on the validation set
    val_f2 = trainer.callback_metrics["val_f2"].item()

    return val_f2

# Run the Optuna optimization
def tune_model(X_train, y_train, X_val, y_val, input_dim, output_dim, n_trials=20):
    gate_hidden_units_options = {
        "16": [16], 
        "32": [32], 
        "64": [64], 
        "32_16": [32, 16]
    }
    study = optuna.create_study(direction="maximize")
    
    study.optimize(lambda trial: objective(trial, X_train, y_train, X_val, y_val, input_dim, output_dim), 
                   n_trials=n_trials)
    
    print(f"Best Hyperparameters: {study.best_params}")
    
    # Optionally, retrain the model with the best hyperparameters and return it
    best_params = study.best_params
    best_gate_hidden_units = gate_hidden_units_options[best_params['gate_hidden_units']]
    
    best_model = MixtureOfExperts(
        input_dim=input_dim,
        output_dim=output_dim,
        num_experts=output_dim, 
        expert_hidden_units=[32, 64, 32],
        gate_hidden_units=best_gate_hidden_units,
        num_active_experts=3,
        dropout_rate=best_params['dropout_rate']
    )
    
    # Initialize the expert usage logger
    expert_usage_logger = ExpertUsageLogger(best_model)

    # Train the model with the best hyperparameters
    trainer = pl.Trainer(
        max_epochs=50,
        callbacks=[expert_usage_logger],
        accelerator='gpu'
    )
    
    train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=4096, shuffle=True)
    val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=4096)
    
    trainer.fit(best_model, train_loader, val_loader)
    
    expert_usage_logger.plot_expert_usage()
    
    return best_model, study.best_params

In [23]:
tune_model(X, y, X_test, y_test, X.shape[1], 1)

[I 2024-09-03 14:36:39,738] A new study created in memory with name: no-name-cea7c07c-b864-4065-8ea8-075c74559952
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2024-09-03 14:36:40.377971: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-03 14:36:40.398217: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one ha

Epoch 0:  81%|████████  | 155/192 [00:12<00:02, 12.70it/s, v_num=8]        