In [1]:
partition = 200

In [2]:
import sys
from train import main
from itertools import product  
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

n_tree_values = [5, 10, 15, 20]
tree_depth_values = [3, 5, 10]
batch_size_values = [256, 512, 1000]

best_score = 0
best_config = {}

for n_tree, tree_depth, batch_size in product(n_tree_values, tree_depth_values, batch_size_values):
    print(f"\nRunning: n_tree={n_tree}, tree_depth={tree_depth}, batch_size={batch_size}")
    sys.argv = [
        'train.py',
        '-dataset', f'gtd{partition}',
        '-n_class', '30',
        '-gpuid', '0',
        '-n_tree', str(n_tree),
        '-tree_depth', str(tree_depth),
        '-batch_size', str(batch_size),
        '-epochs', '200',
        '-jointly_training'
    ]
    
    _, _, _ = main()

    # Read best score from file (assumes one run per file)
    result_file = f"results/result_gtd{partition}"
    with open(result_file, "r") as f:
        lines = f.readlines()
        for line in lines:
            if "Best Accuracy" in line:
                acc = float(line.split()[2])
                if acc > best_score:
                    best_score = acc
                    best_config = {
                        'n_tree': n_tree,
                        'tree_depth': tree_depth,
                        'batch_size': batch_size
                    }
print("\nBest hyperparameter configuration:")
print(best_config)
print(f"Best accuracy: {best_score}")



Running: n_tree=5, tree_depth=3, batch_size=256
Use gtd200 dataset

Test set: Average loss: 3.3684, Accuracy: 178/1800 (0.0989)


Test set: Average loss: 3.3465, Accuracy: 189/1800 (0.1050)


Test set: Average loss: 3.3287, Accuracy: 199/1800 (0.1106)


Test set: Average loss: 3.3133, Accuracy: 208/1800 (0.1156)


Test set: Average loss: 3.2993, Accuracy: 229/1800 (0.1272)


Test set: Average loss: 3.2854, Accuracy: 239/1800 (0.1328)


Test set: Average loss: 3.2719, Accuracy: 262/1800 (0.1456)


Test set: Average loss: 3.2587, Accuracy: 296/1800 (0.1644)


Test set: Average loss: 3.2460, Accuracy: 312/1800 (0.1733)


Test set: Average loss: 3.2338, Accuracy: 336/1800 (0.1867)


Test set: Average loss: 3.2217, Accuracy: 342/1800 (0.1900)


Test set: Average loss: 3.2098, Accuracy: 359/1800 (0.1994)


Test set: Average loss: 3.1982, Accuracy: 372/1800 (0.2067)


Test set: Average loss: 3.1868, Accuracy: 382/1800 (0.2122)


Test set: Average loss: 3.1753, Accuracy: 388/1800 (0.2156)


T

In [3]:
sys.argv = [
    'train.py',
    '-dataset', f'gtd{partition}',
    '-n_class', '30',
    '-gpuid', '0',
    '-n_tree', str(best_config['n_tree']),
    '-tree_depth', str(best_config['tree_depth']),
    '-batch_size', str(best_config['batch_size']),
    '-epochs', '200',
    '-jointly_training'
]

preds, targets, labels = main()


Use gtd200 dataset

Test set: Average loss: 3.3463, Accuracy: 346/1800 (0.1922)


Test set: Average loss: 3.2934, Accuracy: 348/1800 (0.1933)


Test set: Average loss: 3.2525, Accuracy: 394/1800 (0.2189)


Test set: Average loss: 3.2200, Accuracy: 446/1800 (0.2478)


Test set: Average loss: 3.1921, Accuracy: 498/1800 (0.2767)


Test set: Average loss: 3.1671, Accuracy: 545/1800 (0.3028)


Test set: Average loss: 3.1443, Accuracy: 576/1800 (0.3200)


Test set: Average loss: 3.1231, Accuracy: 591/1800 (0.3283)


Test set: Average loss: 3.1034, Accuracy: 609/1800 (0.3383)


Test set: Average loss: 3.0848, Accuracy: 600/1800 (0.3333)


Test set: Average loss: 3.0664, Accuracy: 619/1800 (0.3439)


Test set: Average loss: 3.0488, Accuracy: 619/1800 (0.3439)


Test set: Average loss: 3.0304, Accuracy: 632/1800 (0.3511)


Test set: Average loss: 3.0145, Accuracy: 634/1800 (0.3522)


Test set: Average loss: 2.9974, Accuracy: 634/1800 (0.3522)


Test set: Average loss: 2.9827, Accuracy: 637/1800

In [4]:
def plot_confusion_matrix(y_true, y_pred, labels, partition):
    cm = confusion_matrix(y_true, y_pred, labels=range(len(labels)))
    cm_normalized = cm.astype('float') / cm.sum(axis=1, keepdims=True)

    plt.figure(figsize=(18, 16))
    sns.heatmap(cm_normalized,
                annot=True,
                fmt=".2f",
                xticklabels=labels,
                yticklabels=labels,
                cmap="viridis",
                square=True,
                linewidths=0.5,
                cbar_kws={"shrink": 0.8})

    plt.title(f"Normalized Confusion Matrix (Partition gtd{partition})", fontsize=18)
    plt.xlabel("Predicted Label", fontsize=14)
    plt.ylabel("True Label", fontsize=14)
    plt.xticks(rotation=90)
    plt.yticks(rotation=0)
    plt.tight_layout()

    save_path = f"results/confusion_matrix_partition_gtd{partition}.png"
    plt.savefig(save_path, dpi=300)
    plt.close()

    print(f"Saved confusion matrix for partition gtd{partition} to {save_path}")



In [5]:
plot_confusion_matrix(targets, preds, labels, partition)

Saved confusion matrix for partition gtd200 to results/confusion_matrix_partition_gtd200.png
