In [1]:
partition = 100

In [2]:
import sys
from train import main
from itertools import product  
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

n_tree_values = [5, 10, 15, 20]
tree_depth_values = [3, 5, 10]
batch_size_values = [256, 512, 1000]

best_score = 0
best_config = {}

for n_tree, tree_depth, batch_size in product(n_tree_values, tree_depth_values, batch_size_values):
    print(f"\nRunning: n_tree={n_tree}, tree_depth={tree_depth}, batch_size={batch_size}")
    sys.argv = [
        'train.py',
        '-dataset', f'gtd{partition}',
        '-n_class', '30',
        '-gpuid', '0',
        '-n_tree', str(n_tree),
        '-tree_depth', str(tree_depth),
        '-batch_size', str(batch_size),
        '-epochs', '200',
        '-jointly_training'
    ]
    
    _, _, _ = main()

    # Read best score from file (assumes one run per file)
    result_file = f"results/result_gtd{partition}"
    with open(result_file, "r") as f:
        lines = f.readlines()
        for line in lines:
            if "Best Accuracy" in line:
                acc = float(line.split()[2])
                if acc > best_score:
                    best_score = acc
                    best_config = {
                        'n_tree': n_tree,
                        'tree_depth': tree_depth,
                        'batch_size': batch_size
                    }
print("\nBest hyperparameter configuration:")
print(best_config)
print(f"Best accuracy: {best_score}")



Running: n_tree=5, tree_depth=3, batch_size=256
Use gtd100 dataset

Test set: Average loss: 3.3963, Accuracy: 44/900 (0.0489)


Test set: Average loss: 3.3888, Accuracy: 71/900 (0.0789)


Test set: Average loss: 3.3782, Accuracy: 99/900 (0.1100)


Test set: Average loss: 3.3642, Accuracy: 116/900 (0.1289)


Test set: Average loss: 3.3497, Accuracy: 145/900 (0.1611)


Test set: Average loss: 3.3362, Accuracy: 150/900 (0.1667)


Test set: Average loss: 3.3235, Accuracy: 170/900 (0.1889)


Test set: Average loss: 3.3106, Accuracy: 190/900 (0.2111)


Test set: Average loss: 3.2977, Accuracy: 212/900 (0.2356)


Test set: Average loss: 3.2850, Accuracy: 235/900 (0.2611)


Test set: Average loss: 3.2723, Accuracy: 261/900 (0.2900)


Test set: Average loss: 3.2602, Accuracy: 280/900 (0.3111)


Test set: Average loss: 3.2485, Accuracy: 286/900 (0.3178)


Test set: Average loss: 3.2357, Accuracy: 308/900 (0.3422)


Test set: Average loss: 3.2240, Accuracy: 314/900 (0.3489)


Test set: Average l

In [3]:
sys.argv = [
    'train.py',
    '-dataset', f'gtd{partition}',
    '-n_class', '30',
    '-gpuid', '0',
    '-n_tree', str(best_config['n_tree']),
    '-tree_depth', str(best_config['tree_depth']),
    '-batch_size', str(best_config['batch_size']),
    '-epochs', '200',
    '-jointly_training'
]

preds, targets, labels = main()


Use gtd100 dataset

Test set: Average loss: 3.3996, Accuracy: 82/900 (0.0911)


Test set: Average loss: 3.3962, Accuracy: 137/900 (0.1522)


Test set: Average loss: 3.3883, Accuracy: 165/900 (0.1833)


Test set: Average loss: 3.3728, Accuracy: 162/900 (0.1800)


Test set: Average loss: 3.3486, Accuracy: 126/900 (0.1400)


Test set: Average loss: 3.3229, Accuracy: 122/900 (0.1356)


Test set: Average loss: 3.2984, Accuracy: 142/900 (0.1578)


Test set: Average loss: 3.2750, Accuracy: 165/900 (0.1833)


Test set: Average loss: 3.2511, Accuracy: 184/900 (0.2044)


Test set: Average loss: 3.2276, Accuracy: 190/900 (0.2111)


Test set: Average loss: 3.2052, Accuracy: 209/900 (0.2322)


Test set: Average loss: 3.1857, Accuracy: 224/900 (0.2489)


Test set: Average loss: 3.1674, Accuracy: 246/900 (0.2733)


Test set: Average loss: 3.1507, Accuracy: 261/900 (0.2900)


Test set: Average loss: 3.1349, Accuracy: 277/900 (0.3078)


Test set: Average loss: 3.1198, Accuracy: 281/900 (0.3122)


Test 

In [4]:
def plot_confusion_matrix(y_true, y_pred, labels, partition):
    cm = confusion_matrix(y_true, y_pred, labels=range(len(labels)))
    cm_normalized = cm.astype('float') / cm.sum(axis=1, keepdims=True)

    plt.figure(figsize=(18, 16))
    sns.heatmap(cm_normalized,
                annot=True,
                fmt=".2f",
                xticklabels=labels,
                yticklabels=labels,
                cmap="viridis",
                square=True,
                linewidths=0.5,
                cbar_kws={"shrink": 0.8})

    plt.title(f"Normalized Confusion Matrix (Partition gtd{partition})", fontsize=18)
    plt.xlabel("Predicted Label", fontsize=14)
    plt.ylabel("True Label", fontsize=14)
    plt.xticks(rotation=90)
    plt.yticks(rotation=0)
    plt.tight_layout()

    save_path = f"results/confusion_matrix_partition_gtd{partition}.png"
    plt.savefig(save_path, dpi=300)
    plt.close()

    print(f"Saved confusion matrix for partition gtd{partition} to {save_path}")



In [5]:
plot_confusion_matrix(targets, preds, labels, partition)

Saved confusion matrix for partition gtd100 to results/confusion_matrix_partition_gtd100.png
