In [6]:
partition = 478

In [7]:
import sys
from train import main
from itertools import product  
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

n_tree_values = [5, 10, 15, 20]
tree_depth_values = [3, 5, 10]
batch_size_values = [256, 512, 1000]

best_score = 0
best_config = {}

for n_tree, tree_depth, batch_size in product(n_tree_values, tree_depth_values, batch_size_values):
    print(f"\nRunning: n_tree={n_tree}, tree_depth={tree_depth}, batch_size={batch_size}")
    sys.argv = [
        'train.py',
        '-dataset', f'gtd{partition}',
        '-n_class', '30',
        '-gpuid', '0',
        '-n_tree', str(n_tree),
        '-tree_depth', str(tree_depth),
        '-batch_size', str(batch_size),
        '-epochs', '200',
        '-jointly_training'
    ]
    
    _, _, _ = main()

    # Read best score from file (assumes one run per file)
    result_file = f"results/result_gtd{partition}"
    with open(result_file, "r") as f:
        lines = f.readlines()
        for line in lines:
            if "Best Accuracy" in line:
                acc = float(line.split()[2])
                if acc > best_score:
                    best_score = acc
                    best_config = {
                        'n_tree': n_tree,
                        'tree_depth': tree_depth,
                        'batch_size': batch_size
                    }
print("\nBest hyperparameter configuration:")
print(best_config)
print(f"Best accuracy: {best_score}")



Running: n_tree=5, tree_depth=3, batch_size=256
Use gtd478 dataset

Test set: Average loss: 3.4056, Accuracy: 252/4320 (0.0583)


Test set: Average loss: 3.4020, Accuracy: 226/4320 (0.0523)


Test set: Average loss: 3.4050, Accuracy: 265/4320 (0.0613)


Test set: Average loss: 3.3937, Accuracy: 254/4320 (0.0588)


Test set: Average loss: 3.4095, Accuracy: 252/4320 (0.0583)


Test set: Average loss: 3.3998, Accuracy: 256/4320 (0.0593)


Test set: Average loss: 3.3945, Accuracy: 232/4320 (0.0537)


Test set: Average loss: 3.3864, Accuracy: 258/4320 (0.0597)


Test set: Average loss: 3.3845, Accuracy: 249/4320 (0.0576)


Test set: Average loss: 3.3764, Accuracy: 250/4320 (0.0579)


Test set: Average loss: 3.3678, Accuracy: 244/4320 (0.0565)


Test set: Average loss: 3.3603, Accuracy: 243/4320 (0.0562)


Test set: Average loss: 3.3536, Accuracy: 251/4320 (0.0581)


Test set: Average loss: 3.3496, Accuracy: 226/4320 (0.0523)


Test set: Average loss: 3.3410, Accuracy: 214/4320 (0.0495)


T

In [8]:
sys.argv = [
    'train.py',
    '-dataset', f'gtd{partition}',
    '-n_class', '30',
    '-gpuid', '0',
    '-n_tree', str(best_config['n_tree']),
    '-tree_depth', str(best_config['tree_depth']),
    '-batch_size', str(best_config['batch_size']),
    '-epochs', '200',
    '-jointly_training'
]

preds, targets, labels = main()


Use gtd478 dataset

Test set: Average loss: 3.3844, Accuracy: 170/4320 (0.0394)


Test set: Average loss: 3.3671, Accuracy: 281/4320 (0.0650)


Test set: Average loss: 3.3545, Accuracy: 301/4320 (0.0697)


Test set: Average loss: 3.3476, Accuracy: 308/4320 (0.0713)


Test set: Average loss: 3.3395, Accuracy: 379/4320 (0.0877)


Test set: Average loss: 3.3294, Accuracy: 456/4320 (0.1056)




Test set: Average loss: 3.3209, Accuracy: 495/4320 (0.1146)


Test set: Average loss: 3.3166, Accuracy: 470/4320 (0.1088)


Test set: Average loss: 3.3042, Accuracy: 498/4320 (0.1153)


Test set: Average loss: 3.2974, Accuracy: 485/4320 (0.1123)


Test set: Average loss: 3.2881, Accuracy: 482/4320 (0.1116)


Test set: Average loss: 3.2796, Accuracy: 487/4320 (0.1127)


Test set: Average loss: 3.2705, Accuracy: 497/4320 (0.1150)


Test set: Average loss: 3.2609, Accuracy: 499/4320 (0.1155)


Test set: Average loss: 3.2564, Accuracy: 499/4320 (0.1155)


Test set: Average loss: 3.2428, Accuracy: 505/4320 (0.1169)


Test set: Average loss: 3.2383, Accuracy: 488/4320 (0.1130)


Test set: Average loss: 3.2337, Accuracy: 505/4320 (0.1169)


Test set: Average loss: 3.2292, Accuracy: 488/4320 (0.1130)


Test set: Average loss: 3.2240, Accuracy: 496/4320 (0.1148)


Test set: Average loss: 3.2121, Accuracy: 485/4320 (0.1123)


Test set: Average loss: 3.2019, Accuracy: 511/4320 (0.1183)


Test se

In [9]:
def plot_confusion_matrix(y_true, y_pred, labels, partition):
    cm = confusion_matrix(y_true, y_pred, labels=range(len(labels)))
    cm_normalized = cm.astype('float') / cm.sum(axis=1, keepdims=True)

    plt.figure(figsize=(18, 16))
    sns.heatmap(cm_normalized,
                annot=True,
                fmt=".2f",
                xticklabels=labels,
                yticklabels=labels,
                cmap="viridis",
                square=True,
                linewidths=0.5,
                cbar_kws={"shrink": 0.8})

    plt.title(f"Normalized Confusion Matrix (Partition gtd{partition})", fontsize=18)
    plt.xlabel("Predicted Label", fontsize=14)
    plt.ylabel("True Label", fontsize=14)
    plt.xticks(rotation=90)
    plt.yticks(rotation=0)
    plt.tight_layout()

    save_path = f"results/confusion_matrix_partition_gtd{partition}.png"
    plt.savefig(save_path, dpi=300)
    plt.close()

    print(f"Saved confusion matrix for partition gtd{partition} to {save_path}")



In [10]:
plot_confusion_matrix(targets, preds, labels, partition)

Saved confusion matrix for partition gtd478 to results/confusion_matrix_partition_gtd478.png
