In [2]:
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))


import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

from shared.data_loader import load_dataset_deepcorr
from shared.data_processing import generate_flow_pairs 
from shared.train_test_split import calc_train_test_indexes

import time

In [3]:
#################### Parameters ####################
dataset_path =  "/home/yagnihotri/datasets/deepcorr_original_dataset"
load_only_flows_with_min_300 = True
train_ratio = 0.8
flow_size = 300
negative_samples = 50

#################### Paths ####################
run_folder_path = "./"

In [4]:
# Loading deepcorr dataset
deepcorr_dataset = load_dataset_deepcorr(dataset_path, load_only_flows_with_min_300)

Loading dataset from pickle files...


Loading progress: 100%|█████████████████████████████████| 10/10 [00:13<00:00,  1.38s/it]

Dataset length:  7324





In [5]:
# Split the dataset into training and test sets
train_indexes, test_indexes = calc_train_test_indexes(deepcorr_dataset, train_ratio)

# Preprocess the data and generate the data arrays for training and testing
l2s, labels,l2s_test,labels_test = generate_flow_pairs(deepcorr_dataset, train_indexes, test_indexes, flow_size, run_folder_path, negative_samples)


Splitting dataset into training and testing sets...
Length of true correlating flow pairs for TRAINING:  5859 flow pairs
Length of true correlating flow pairs for TESTING:  1465 flow pairs
Generating flow pairs...


100%|██████████████████████████████████████████████| 5859/5859 [00:41<00:00, 142.33it/s]
100%|██████████████████████████████████████████████| 1465/1465 [00:09<00:00, 150.34it/s]


In [6]:
print("Training set size: ", len(l2s))

Training set size:  298809


In [7]:
l2s_flattened = l2s.reshape(l2s.shape[0], -1)
print(l2s_flattened.shape)

l2s_test_flattened = l2s_test.reshape(l2s_test.shape[0], -1)  # Same for test data
print(l2s_test_flattened.shape)

(298809, 2400)
(74715, 2400)


In [8]:

# Sample a subset of your training data
subset_size = int(0.1 * l2s_flattened.shape[0])
subset_l2s = l2s_flattened[:subset_size]
subset_labels = labels[:subset_size]

# Time training on the subset
start_time = time.time()
clf = DecisionTreeClassifier()
clf.fit(subset_l2s, subset_labels)
training_time_subset = time.time() - start_time

# Time prediction and evaluation on a subset of the test data
subset_size_test = int(0.1 * l2s_test_flattened.shape[0])
subset_l2s_test = l2s_test_flattened[:subset_size_test]
subset_labels_test = labels_test[:subset_size_test]

start_time = time.time()
subset_y_pred = clf.predict(subset_l2s_test)
testing_time_subset = time.time() - start_time
accuracy = accuracy_score(subset_labels_test, subset_y_pred)

# Estimate total training and testing time for the full dataset
total_training_estimate = training_time_subset * 10  # Assuming linear scaling
total_testing_estimate = testing_time_subset * 10  # Assuming linear scaling
total_estimate = total_training_estimate + total_testing_estimate

print(f"Estimated total training time: {total_training_estimate:.2f} seconds")
print(f"Estimated total testing time: {total_testing_estimate:.2f} seconds")
print(f"Estimated total time (training + testing): {total_estimate:.2f} seconds")
print(f"Accuracy on test subset: {accuracy:.2f}")


In [None]:
# Create a Decision Tree Classifier
clf = DecisionTreeClassifier()

# Start timing the training process
start_time = time.time()

# Train the model
clf.fit(l2s_flattened, labels)

# End timing the training process
training_time = time.time() - start_time
print(f'Training completed in {training_time:.2f} seconds')

# Start timing the prediction process
start_time = time.time()

# Make predictions
y_pred = clf.predict(l2s_test_flattened)

# End timing the prediction process
testing_time = time.time() - start_time
print(f'Testing completed in {testing_time:.2f} seconds')

# Evaluate the model
accuracy = accuracy_score(labels_test, y_pred)
print(f'Accuracy: {accuracy:.2f}')