In [1]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import precision_score, recall_score
from livelossplot import PlotLosses
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
import os
import tempfile
import ray
from ray import tune, air
from ray.tune.schedulers import ASHAScheduler
from functools import partial
from ray.train import Checkpoint
import warnings
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Device is {device}!")

local_dir = os.path.abspath("./ray_tune")

Device is cuda!


In [2]:
df = pd.read_csv('./preprocessed.csv', low_memory=False) 
df

Unnamed: 0,dns_qtype,dns_rcode,dns_query,dst_ip_bytes,src_pkts,label,type,conn_state-OTH,conn_state-REJ,conn_state-RSTO,...,service-ssl,dns_AA--1,dns_AA-F,dns_AA-T,dns_RA--1,dns_RA-F,dns_RA-T,dns_RD--1,dns_RD-F,dns_RD-T
0,0,0,2,3786,6,0,normal,0,0,0,...,0,1,0,0,1,0,0,1,0,0
1,12,0,4765,172,1,0,normal,0,0,0,...,0,0,0,1,0,0,1,0,1,0
2,12,3,6112,800,1,0,normal,0,0,0,...,0,0,1,0,0,1,0,0,1,0
3,43,0,5259,525,1,0,normal,0,0,0,...,0,0,1,0,0,1,0,0,1,0
4,43,0,2790,408,1,0,normal,0,0,0,...,0,0,1,0,0,1,0,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
46559,1,0,12284,158,1,0,normal,0,0,0,...,0,0,0,1,0,1,0,0,1,0
46560,0,0,2,8462,10,1,ddos,0,0,0,...,1,1,0,0,1,0,0,1,0,0
46561,12,0,3631,375,1,0,normal,0,0,0,...,0,0,1,0,0,1,0,0,1,0
46562,28,0,9338,712,1,0,normal,0,0,0,...,0,0,1,0,0,1,0,0,1,0


In [3]:
X = df.drop(['label','type'], axis=1)
X.head()

Unnamed: 0,dns_qtype,dns_rcode,dns_query,dst_ip_bytes,src_pkts,conn_state-OTH,conn_state-REJ,conn_state-RSTO,conn_state-RSTOS0,conn_state-RSTR,...,service-ssl,dns_AA--1,dns_AA-F,dns_AA-T,dns_RA--1,dns_RA-F,dns_RA-T,dns_RD--1,dns_RD-F,dns_RD-T
0,0,0,2,3786,6,0,0,0,0,0,...,0,1,0,0,1,0,0,1,0,0
1,12,0,4765,172,1,0,0,0,0,0,...,0,0,0,1,0,0,1,0,1,0
2,12,3,6112,800,1,0,0,0,0,0,...,0,0,1,0,0,1,0,0,1,0
3,43,0,5259,525,1,0,0,0,0,0,...,0,0,1,0,0,1,0,0,1,0
4,43,0,2790,408,1,0,0,0,0,0,...,0,0,1,0,0,1,0,0,1,0


In [4]:
y = df.filter(items=['label'])
y.head()

Unnamed: 0,label
0,0
1,0
2,0
3,0
4,0


In [5]:
min_max_scaler = MinMaxScaler()
X =  pd.DataFrame(min_max_scaler.fit_transform(X))
X.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,33,34,35,36,37,38,39,40,41,42
0,0.0,0.0,0.000141,4.4e-05,2.4e-05,0.0,0.0,0.0,0.0,0.0,...,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0
1,0.047059,0.0,0.336797,2e-06,4e-06,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0
2,0.047059,0.6,0.432005,9e-06,4e-06,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0
3,0.168627,0.0,0.371713,6e-06,4e-06,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0
4,0.168627,0.0,0.197201,5e-06,4e-06,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0


In [6]:
x = torch.Tensor(X.values) 
gt = torch.Tensor(y.values)

In [7]:
gt.size()

torch.Size([46564, 1])

In [9]:
x.size()

torch.Size([46564, 43])

In [10]:
#torch.manual_seed(42)

# Create a TensorDataset from your data and labels
dataset = TensorDataset(x, gt)

# Calculate the number of samples for training and testing datasets
num_samples = len(dataset)
train_size = int(num_samples * 0.7)
test_size = num_samples - train_size

def load_data():
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
    return train_dataset, test_dataset

In [11]:
class Sparsemax(nn.Module):
    def __init__(self, dim=None):
        super(Sparsemax, self).__init__()
        self.dim = -1 if dim is None else dim

    def forward(self, input):
        input = input.transpose(0, self.dim)
        original_size = input.size()
        input = input.reshape(input.size(0), -1)
        input = input.transpose(0, 1)
        dim = 1

        number_of_logits = input.size(dim)
        
        input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input)
        zs = torch.sort(input=input, dim=dim, descending=True)[0]
        range = torch.arange(start=1, end=number_of_logits + 1, device=device,step=1, dtype=input.dtype).view(1, -1)
        range = range.expand_as(zs)

        bound = 1 + range * zs
        cumulative_sum_zs = torch.cumsum(zs, dim)
        is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())
        k = torch.max(is_gt * range, dim, keepdim=True)[0]
        zs_sparse = is_gt * zs
        taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k
        taus = taus.expand_as(input)
        self.output = torch.max(torch.zeros_like(input), input - taus)
        output = self.output
        output = output.transpose(0, 1)
        output = output.reshape(original_size)
        output = output.transpose(0, self.dim)
        return output
    def backward(self, grad_output):
        dim = 1
        nonzeros = torch.ne(self.output, 0)
        sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim)
        self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))
        return self.grad_input

In [12]:
# https://towardsdatascience.com/implementing-tabnet-in-pytorch-fc977c383279
# https://www.kaggle.com/code/samratthapa/tabnet-implementation?scriptVersionId=46472520


class GBN(nn.Module):
    def __init__(self,inp,vbs=128,momentum=0.01):
        super().__init__()
        self.bn = nn.BatchNorm1d(inp,momentum=momentum)
        self.vbs = vbs
    def forward(self,x):
        chunk = torch.chunk(x,max(1,x.size(0)//self.vbs),0)
        res = [self.bn(y) for y in chunk ]
        return torch.cat(res,0)

class GLU(nn.Module):
    def __init__(self,inp_dim,out_dim,fc=None,vbs=128):
        super().__init__()
        if fc:
            self.fc = fc
        else:
            self.fc = nn.Linear(inp_dim,out_dim*2)
        self.bn = GBN(out_dim*2,vbs=vbs) 
        self.od = out_dim
    def forward(self,x):
        x = self.bn(self.fc(x))
        return x[:,:self.od]*torch.sigmoid(x[:,self.od:])
    
class FeatureTransformer(nn.Module):
    def __init__(self,inp_dim,out_dim,shared,n_ind,vbs=128):
        super().__init__()
        first = True
        self.shared = nn.ModuleList()
        if shared:
            self.shared.append(GLU(inp_dim,out_dim,shared[0],vbs=vbs))
            first= False    
            for fc in shared[1:]:
                self.shared.append(GLU(out_dim,out_dim,fc,vbs=vbs))
        else:
            self.shared = None
        self.independ = nn.ModuleList()
        if first:
            self.independ.append(GLU(inp,out_dim,vbs=vbs))
        for x in range(first, n_ind):
            self.independ.append(GLU(out_dim,out_dim,vbs=vbs))
        self.scale = torch.sqrt(torch.tensor([.5],device=device))
    def forward(self,x):
        if self.shared:
            x = self.shared[0](x)
            for glu in self.shared[1:]:
                x = torch.add(x, glu(x))
                x = x*self.scale
                
        for glu in self.independ:
            x = torch.add(x, glu(x))
            x = x*self.scale
        return x

class AttentionTransformer(nn.Module):
    def __init__(self,inp_dim,out_dim,relax,vbs=128):
        super().__init__()
        self.fc = nn.Linear(inp_dim,out_dim)
        self.bn = GBN(out_dim,vbs=vbs)
#         self.smax = Sparsemax()
        self.r = torch.tensor([relax],device=device)
    def forward(self,a,priors):
        a = self.bn(self.fc(a))
        mask = torch.sigmoid(a*priors)
        priors =priors*(self.r-mask)
        return mask

class DecisionStep(nn.Module):
    def __init__(self,inp_dim,n_d,n_a,shared,n_ind,relax,vbs=128):
        super().__init__()
        self.fea_tran = FeatureTransformer(inp_dim,n_d+n_a,shared,n_ind,vbs)
        self.atten_tran = AttentionTransformer(n_a,inp_dim,relax,vbs)
    def forward(self,x,a,priors):
        mask = self.atten_tran(a,priors)
        loss = ((-1)*mask*torch.log(mask+1e-10)).mean()
        x = self.fea_tran(x*mask)
        return x,loss
class TabNet(nn.Module):
    def __init__(self, inp_dim, final_out_dim=1, n_d=64, n_a=64, n_shared=2, n_ind=2, n_steps=3, relax=1.2, vbs=128, lambda_sparse=0.0001):
        super().__init__()
        self.lambda_sparse = lambda_sparse  # Adding sparse regularization weight
        
        if n_shared > 0:
            self.shared = nn.ModuleList([nn.Linear(inp_dim, 2 * (n_d + n_a))])
            for _ in range(1, n_shared):
                self.shared.append(nn.Linear(n_d + n_a, 2 * (n_d + n_a)))
        else:
            self.shared = None
        
        self.first_step = FeatureTransformer(inp_dim, n_d + n_a, self.shared, n_ind, vbs=vbs)
        self.steps = nn.ModuleList()
        for _ in range(n_steps - 1):
            self.steps.append(DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs))
        
        self.fc = nn.Linear(n_d, final_out_dim)
        self.bn = nn.BatchNorm1d(inp_dim)
        self.n_d = n_d
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def forward(self, x):
        x = self.bn(x)
        x_a = self.first_step(x)[:, self.n_d:]
        sparse_loss = torch.tensor(0.).to(self.device)
        out = torch.zeros(x.size(0), self.n_d).to(self.device)
        priors = torch.ones(x.shape).to(self.device)
        for step in self.steps:
            x_te, l = step(x, x_a, priors)
            out += F.relu(x_te[:, :self.n_d])
            x_a = x_te[:, self.n_d:]
            sparse_loss += l
        
        # Final activation function changed to sigmoid for binary classification
        out = torch.sigmoid(self.fc(out))
        # Add the scaled sparse_loss to the main loss outside of this method during optimization
        return out, sparse_loss * self.lambda_sparse
    
class TabNetWithEmbed(nn.Module):
    def __init__(self,inp_dim,final_out_dim,n_d=64,n_a=64,n_shared=2,n_ind=2,n_steps=5,relax=1.2,vbs=128):
        super().__init__()
        self.tabnet = TabNet(inp_dim,final_out_dim,n_d,n_a,n_shared,n_ind,n_steps,relax,vbs)
        self.cat_embed = []
        self.emb1 = nn.Embedding(2,1)
        self.emb3 = nn.Embedding(3,1)
        self.cat_embed.append(self.emb1)
        self.cat_embed.append(self.emb3)
        
    def forward(self,catv,contv):
        catv = catv.to(device)
        contv = contv.to(device)
        embeddings = [embed(catv[:,idx]) for embed,idx in zip(self.cat_embed,range(catv.size(1)))]
        catv = torch.cat(embeddings,1)
        x = torch.cat((catv,contv),1).contiguous()
        x,l = self.tabnet(x)
        return torch.sigmoid(x),l


In [13]:
def train_tabnet(config):
    net = TabNet(43, n_d=config["n_d"], n_a=config["n_a"], n_shared=config["n_shared"], n_ind=config["n_ind"], n_steps=config["n_steps"], relax=config["relax"], vbs=config["vbs"], lambda_sparse=config["lambda_sparse"]).to(device)

    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=config["lr"])

    if "restore" in config:
        checkpoint_path = config["restore"]
        checkpoint = torch.load(checkpoint_path)
        net.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

    trainset, _ = load_data()

    test_abs = int(len(trainset) * 0.7)
    train_subset, val_subset = torch.utils.data.random_split(
        trainset, [test_abs, len(trainset) - test_abs])
    
    # Create DataLoaders
    trainloader = DataLoader(train_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=0)
    valloader = DataLoader(val_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=0)

    for epoch in range(100):  # loop over the dataset multiple times
        running_loss = 0.0
        epoch_steps = 0
        for i, data in enumerate(trainloader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs, sparse_loss = net(inputs)
            loss = criterion(outputs, labels) + 1e-4 * sparse_loss
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            epoch_steps += 1

        val_loss = 0.0
        val_steps = 0
        total_correct = 0
        total_samples = 0
        for i, data in enumerate(valloader, 0):
            with torch.no_grad():
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs, sparse_loss = net(inputs)
                preds = outputs.round()
                total_correct += preds.eq(labels).sum().item()
                total_samples += labels.size(0)
                loss = criterion(outputs, labels) + 1e-4 * sparse_loss
                val_loss += loss.cpu().numpy()
                val_steps += 1

        # Calculate overall validation accuracy
        accuracy = (total_correct / total_samples) * 100
                
        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            checkpoint_path = os.path.join(temp_checkpoint_dir, "checkpoint.pt")
            torch.save({
                "model_state_dict": net.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            }, checkpoint_path)

            # Use ray.train.report to report metrics and checkpoint
            ray.train.report({
                "loss": val_loss,
                "accuracy": accuracy
            }, checkpoint=Checkpoint.from_directory(temp_checkpoint_dir))
    
    print("Finished Training")

In [14]:
def test_best_model(best_result):
    best_trained_model = TabNet(43, n_d=best_result.config["n_d"], n_a=best_result.config["n_a"], n_shared=best_result.config["n_shared"], n_ind=best_result.config["n_ind"], n_steps=best_result.config["n_steps"], relax=best_result.config["relax"], vbs=best_result.config["vbs"], lambda_sparse=best_result.config["lambda_sparse"]).to(device)
    
    print(f'> Number of parameters {len(torch.nn.utils.parameters_to_vector(best_trained_model.parameters()))}')

    checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")

    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path)
    best_trained_model.load_state_dict(checkpoint['model_state_dict'])

    _, testset = load_data()
    
    testloader = DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs, loss = best_trained_model(inputs)
            predicted = outputs.round()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Best trial test set accuracy: {accuracy}%")

In [None]:
def main(num_samples=10, max_num_epochs=100, gpus_per_trial=1):
    config = {
        "n_d": tune.randint(2, 128 + 1),
        "n_a": tune.randint(2, 128 + 1),
        "n_shared": tune.randint(1, 3 + 1),
        "n_ind": tune.randint(1, 3 + 1),
        "n_steps": tune.randint(1, 3 + 1),
        "relax": tune.uniform(0.1, 1.2),
        "vbs": tune.choice([16, 32, 64, 128]),
        "lambda_sparse": tune.uniform(0, 0.00001),
        "lr": tune.loguniform(1e-5, 1e-1),
        "batch_size": tune.choice([64, 128, 256, 512, 1024]),
    }
    scheduler = ASHAScheduler(
        max_t=max_num_epochs,
        grace_period=1,
        reduction_factor=2)
    
    ray.shutdown()
    ray.init(log_to_driver=False)
    
    tuner = tune.Tuner(
        tune.with_resources(
            tune.with_parameters(train_tabnet),
            resources={"cpu": 64, "gpu": gpus_per_trial}
        ),
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
            scheduler=scheduler,
            num_samples=num_samples,
        ),
        param_space=config,
        run_config=air.RunConfig(
            local_dir=local_dir,
            name="tabnet",
        )
    )
    results = tuner.fit()
    
    best_result = results.get_best_result("loss", "min")

    print("Best trial config: {}".format(best_result.config))
    print("Best trial final validation loss: {}".format(
        best_result.metrics["loss"]))
    print("Best trial final validation accuracy: {}".format(
        best_result.metrics["accuracy"]))

    test_best_model(best_result)
    
main(num_samples=10, max_num_epochs=100, gpus_per_trial=1)

0,1
Current time:,2024-02-22 19:36:19
Running for:,00:05:42.82
Memory:,144.4/1132.4 GiB

Trial name,status,loc,batch_size,lambda_sparse,lr,n_a,n_d,n_ind,n_shared,n_steps,relax,vbs,iter,total time (s),loss,accuracy
train_tabnet_da096_00006,RUNNING,10.234.204.81:172508,512,9.80383e-06,0.0245868,59,23,1,2,2,1.18977,32,32.0,50.3718,2.00027,95.3369
train_tabnet_da096_00007,PENDING,,1024,7.46545e-06,0.0934418,62,43,1,2,1,0.648997,128,,,,
train_tabnet_da096_00008,PENDING,,64,8.74576e-06,4.07895e-05,37,11,2,2,2,0.82598,64,,,,
train_tabnet_da096_00009,PENDING,,128,7.85676e-06,0.00998986,36,17,1,1,3,0.706382,128,,,,
train_tabnet_da096_00000,TERMINATED,10.234.204.81:172508,512,4.43176e-07,0.066866,3,3,3,3,3,0.21678,64,100.0,261.662,2.14865,94.6723
train_tabnet_da096_00001,TERMINATED,10.234.204.81:172508,64,7.36724e-06,0.0278469,60,52,2,3,2,1.10613,16,1.0,9.22821,19.439,93.844
train_tabnet_da096_00002,TERMINATED,10.234.204.81:172508,128,8.62619e-06,0.0427846,13,51,2,1,1,0.755378,32,1.0,1.43128,29.8393,86.9619
train_tabnet_da096_00003,TERMINATED,10.234.204.81:172508,64,9.8738e-06,0.0719789,9,18,2,3,1,0.889047,16,1.0,3.26494,59.5541,86.8596
train_tabnet_da096_00004,TERMINATED,10.234.204.81:172508,1024,7.6654e-06,0.0776692,6,20,2,3,1,0.502163,64,2.0,1.34081,3.90351,86.7369
train_tabnet_da096_00005,TERMINATED,10.234.204.81:172508,512,8.48037e-06,6.23585e-05,56,45,1,2,1,0.558263,128,2.0,1.22889,12.8306,86.9107




In [15]:
y = df.filter(items=['type'])
y['type'].nunique()

10

In [16]:
label_encoder = LabelEncoder()

y_int = label_encoder.fit_transform(y)

  y = column_or_1d(y, warn=True)


In [17]:
x = torch.Tensor(X.values) 
gt = torch.Tensor(y_int).long()

In [18]:
#torch.manual_seed(42)

# Create a TensorDataset from your data and labels
dataset = TensorDataset(x, gt)

# Calculate the number of samples for training and testing datasets
num_samples = len(dataset)
train_size = int(num_samples * 0.7)
test_size = num_samples - train_size

In [19]:
class TabNetWithEmbed(nn.Module):
    def __init__(self,inp_dim,final_out_dim,n_d=64,n_a=64,n_shared=2,n_ind=2,n_steps=5,relax=1.2,vbs=128):
        super().__init__()
        self.tabnet = TabNet(inp_dim,final_out_dim,n_d,n_a,n_shared,n_ind,n_steps,relax,vbs)
        self.cat_embed = []
        self.emb1 = nn.Embedding(2,1)
        self.emb3 = nn.Embedding(3,1)
        self.cat_embed.append(self.emb1)
        self.cat_embed.append(self.emb3)
        
    def forward(self,catv,contv):
        catv = catv.to(device)
        contv = contv.to(device)
        embeddings = [embed(catv[:,idx]) for embed,idx in zip(self.cat_embed,range(catv.size(1)))]
        catv = torch.cat(embeddings,1)
        x = torch.cat((catv,contv),1).contiguous()
        x,l = self.tabnet(x)
        return torch.softmax(x, dim=1),l

In [20]:
def train_tabnet(config):
    net = TabNet(43, 10, n_d=config["n_d"], n_a=config["n_a"], n_shared=config["n_shared"], n_ind=config["n_ind"], n_steps=config["n_steps"], relax=config["relax"], vbs=config["vbs"], lambda_sparse=config["lambda_sparse"]).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=config["lr"])

    if "restore" in config:
        checkpoint_path = config["restore"]
        checkpoint = torch.load(checkpoint_path)
        net.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

    trainset, _ = load_data()

    test_abs = int(len(trainset) * 0.7)
    train_subset, val_subset = torch.utils.data.random_split(
        trainset, [test_abs, len(trainset) - test_abs])
    
    # Create DataLoaders
    trainloader = DataLoader(train_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=0)
    valloader = DataLoader(val_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=0)

    for epoch in range(100):  # loop over the dataset multiple times
        running_loss = 0.0
        epoch_steps = 0
        for i, data in enumerate(trainloader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs, sparse_loss = net(inputs)
            loss = criterion(outputs, labels) + 1e-4 * sparse_loss
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            epoch_steps += 1

        val_loss = 0.0
        val_steps = 0
        total_correct = 0
        total_samples = 0
        for i, data in enumerate(valloader, 0):
            with torch.no_grad():
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs, sparse_loss = net(inputs)
                _, argmax = torch.max(outputs, dim=1)
                total_correct += argmax.eq(labels).sum().item()
                total_samples += labels.size(0)
                loss = criterion(outputs, labels) + 1e-4 * sparse_loss
                val_loss += loss.cpu().numpy()
                val_steps += 1

        # Calculate overall validation accuracy
        accuracy = (total_correct / total_samples) * 100
                
        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            checkpoint_path = os.path.join(temp_checkpoint_dir, "checkpoint.pt")
            torch.save({
                "model_state_dict": net.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            }, checkpoint_path)

            # Use ray.train.report to report metrics and checkpoint
            ray.train.report({
                "loss": val_loss,
                "accuracy": accuracy
            }, checkpoint=Checkpoint.from_directory(temp_checkpoint_dir))
    
    print("Finished Training")

In [21]:
def test_best_model(best_result):
    best_trained_model = TabNet(43, 10, n_d=best_result.config["n_d"], n_a=best_result.config["n_a"], n_shared=best_result.config["n_shared"], n_ind=best_result.config["n_ind"], n_steps=best_result.config["n_steps"], relax=best_result.config["relax"], vbs=best_result.config["vbs"], lambda_sparse=best_result.config["lambda_sparse"]).to(device)
    
    print(f'> Number of parameters {len(torch.nn.utils.parameters_to_vector(best_trained_model.parameters()))}')

    checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")

    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path)
    best_trained_model.load_state_dict(checkpoint['model_state_dict'])

    _, testset = load_data()
    
    testloader = DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs, sparse_loss = best_trained_model(inputs)
            _, predicted = torch.max(outputs, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Best trial test set accuracy: {accuracy}%")

In [22]:
def main(num_samples=10, max_num_epochs=100, gpus_per_trial=1):
    config = {
        "n_d": tune.randint(2, 128 + 1),
        "n_a": tune.randint(2, 128 + 1),
        "n_shared": tune.randint(1, 3 + 1),
        "n_ind": tune.randint(1, 3 + 1),
        "n_steps": tune.randint(1, 3 + 1),
        "relax": tune.uniform(0.1, 1.2),
        "vbs": tune.choice([16, 32, 64, 128]),
        "lambda_sparse": tune.uniform(0, 0.00001),
        "lr": tune.loguniform(1e-5, 1e-1),
        "batch_size": tune.choice([64, 128, 256, 512, 1024]),
    }
    scheduler = ASHAScheduler(
        max_t=max_num_epochs,
        grace_period=1,
        reduction_factor=2)
    
    ray.shutdown()
    ray.init(log_to_driver=False)
    
    tuner = tune.Tuner(
        tune.with_resources(
            tune.with_parameters(train_tabnet),
            resources={"cpu": 64, "gpu": gpus_per_trial}
        ),
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
            scheduler=scheduler,
            num_samples=num_samples,
        ),
        param_space=config,
        run_config=air.RunConfig(
            local_dir=local_dir,
            name="tabnet",
        )
    )
    results = tuner.fit()
    
    best_result = results.get_best_result("loss", "min")

    print("Best trial config: {}".format(best_result.config))
    print("Best trial final validation loss: {}".format(
        best_result.metrics["loss"]))
    print("Best trial final validation accuracy: {}".format(
        best_result.metrics["accuracy"]))

    test_best_model(best_result)
    
main(num_samples=20, max_num_epochs=100, gpus_per_trial=1)

0,1
Current time:,2024-03-19 20:03:56
Running for:,00:16:44.12
Memory:,95.3/1132.4 GiB

Trial name,status,loc,batch_size,lambda_sparse,lr,n_a,n_d,n_ind,n_shared,n_steps,relax,vbs,iter,total time (s),loss,accuracy
train_tabnet_7ac47_00000,TERMINATED,10.234.204.81:3971489,512,7.18292e-06,0.014565,47,27,2,3,3,0.107751,128,100,142.629,30.3911,91.6556
train_tabnet_7ac47_00001,TERMINATED,10.234.204.81:3971489,64,7.30056e-06,0.000599778,102,76,1,1,3,0.660844,16,1,4.76008,235.14,91.7578
train_tabnet_7ac47_00002,TERMINATED,10.234.204.81:3971489,1024,3.20049e-06,0.00251295,37,111,1,2,3,0.851933,32,100,163.207,15.1579,92.443
train_tabnet_7ac47_00003,TERMINATED,10.234.204.81:3971489,64,3.46313e-06,0.000230229,9,40,2,3,1,0.63953,128,1,1.32998,350.917,1.59526
train_tabnet_7ac47_00004,TERMINATED,10.234.204.81:3971489,1024,1.17652e-06,2.8585e-05,57,119,3,2,3,1.03872,64,100,163.289,15.3126,92.0851
train_tabnet_7ac47_00005,TERMINATED,10.234.204.81:3971489,64,9.80669e-06,0.0455735,25,76,2,1,3,0.809862,16,1,5.85169,236.595,86.6039
train_tabnet_7ac47_00006,TERMINATED,10.234.204.81:3971489,1024,4.76241e-06,6.10794e-05,98,22,2,3,2,1.06731,64,2,2.35386,22.1703,12.0667
train_tabnet_7ac47_00007,TERMINATED,10.234.204.81:3971489,64,7.09905e-06,0.00671051,29,58,1,2,2,0.628315,128,1,3.08177,235.394,89.682
train_tabnet_7ac47_00008,TERMINATED,10.234.204.81:3971489,64,8.05662e-06,0.000233435,87,96,2,1,1,1.08061,16,1,1.35065,349.948,6.58554
train_tabnet_7ac47_00009,TERMINATED,10.234.204.81:3971489,64,2.13959e-06,0.000792404,63,84,2,1,1,0.26863,32,1,1.1567,336.925,86.6858


2024-03-19 20:03:56,488	INFO tune.py:1154 -- Total run time: 1004.34 seconds (1004.10 seconds for the tuning loop).


Best trial config: {'n_d': 111, 'n_a': 37, 'n_shared': 2, 'n_ind': 1, 'n_steps': 3, 'relax': 0.8519332861691487, 'vbs': 32, 'lambda_sparse': 3.2004898522373206e-06, 'lr': 0.0025129532017995728, 'batch_size': 1024}
Best trial final validation loss: 15.157926559448242
Best trial final validation accuracy: 92.44299008078536
> Number of parameters 199414
Best trial test set accuracy: 91.45311381531855%
