In [1]:
partition = 300

In [2]:
import sys
from train import main
from itertools import product  
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

n_tree_values = [5, 10, 15, 20]
tree_depth_values = [3, 5, 10]
batch_size_values = [256, 512, 1000]

best_score = 0
best_config = {}

for n_tree, tree_depth, batch_size in product(n_tree_values, tree_depth_values, batch_size_values):
    print(f"\nRunning: n_tree={n_tree}, tree_depth={tree_depth}, batch_size={batch_size}")
    sys.argv = [
        'train.py',
        '-dataset', f'gtd{partition}',
        '-n_class', '30',
        '-gpuid', '0',
        '-n_tree', str(n_tree),
        '-tree_depth', str(tree_depth),
        '-batch_size', str(batch_size),
        '-epochs', '200',
        '-jointly_training'
    ]
    
    _, _, _ = main()

    # Read best score from file (assumes one run per file)
    result_file = f"results/result_gtd{partition}"
    with open(result_file, "r") as f:
        lines = f.readlines()
        for line in lines:
            if "Best Accuracy" in line:
                acc = float(line.split()[2])
                if acc > best_score:
                    best_score = acc
                    best_config = {
                        'n_tree': n_tree,
                        'tree_depth': tree_depth,
                        'batch_size': batch_size
                    }
print("\nBest hyperparameter configuration:")
print(best_config)
print(f"Best accuracy: {best_score}")



Running: n_tree=5, tree_depth=3, batch_size=256
Use gtd300 dataset

Test set: Average loss: 3.3160, Accuracy: 595/2700 (0.2204)


Test set: Average loss: 3.2545, Accuracy: 760/2700 (0.2815)


Test set: Average loss: 3.2069, Accuracy: 820/2700 (0.3037)


Test set: Average loss: 3.1649, Accuracy: 924/2700 (0.3422)


Test set: Average loss: 3.1263, Accuracy: 1013/2700 (0.3752)


Test set: Average loss: 3.0901, Accuracy: 1063/2700 (0.3937)


Test set: Average loss: 3.0552, Accuracy: 1085/2700 (0.4019)


Test set: Average loss: 3.0211, Accuracy: 1090/2700 (0.4037)


Test set: Average loss: 2.9879, Accuracy: 1112/2700 (0.4119)


Test set: Average loss: 2.9561, Accuracy: 1131/2700 (0.4189)


Test set: Average loss: 2.9251, Accuracy: 1184/2700 (0.4385)


Test set: Average loss: 2.8953, Accuracy: 1216/2700 (0.4504)


Test set: Average loss: 2.8667, Accuracy: 1277/2700 (0.4730)


Test set: Average loss: 2.8381, Accuracy: 1325/2700 (0.4907)


Test set: Average loss: 2.8102, Accuracy: 1375/2700 (

In [3]:
sys.argv = [
    'train.py',
    '-dataset', f'gtd{partition}',
    '-n_class', '30',
    '-gpuid', '0',
    '-n_tree', str(best_config['n_tree']),
    '-tree_depth', str(best_config['tree_depth']),
    '-batch_size', str(best_config['batch_size']),
    '-epochs', '200',
    '-jointly_training'
]

preds, targets, labels = main()


Use gtd300 dataset

Test set: Average loss: 3.3219, Accuracy: 915/2700 (0.3389)


Test set: Average loss: 3.2341, Accuracy: 1026/2700 (0.3800)


Test set: Average loss: 3.1579, Accuracy: 1133/2700 (0.4196)


Test set: Average loss: 3.0983, Accuracy: 1253/2700 (0.4641)


Test set: Average loss: 3.0486, Accuracy: 1342/2700 (0.4970)


Test set: Average loss: 3.0050, Accuracy: 1492/2700 (0.5526)


Test set: Average loss: 2.9653, Accuracy: 1561/2700 (0.5781)


Test set: Average loss: 2.9281, Accuracy: 1685/2700 (0.6241)


Test set: Average loss: 2.8930, Accuracy: 1764/2700 (0.6533)


Test set: Average loss: 2.8589, Accuracy: 1819/2700 (0.6737)


Test set: Average loss: 2.8257, Accuracy: 1868/2700 (0.6919)


Test set: Average loss: 2.7932, Accuracy: 1916/2700 (0.7096)


Test set: Average loss: 2.7617, Accuracy: 1939/2700 (0.7181)


Test set: Average loss: 2.7304, Accuracy: 1968/2700 (0.7289)


Test set: Average loss: 2.7002, Accuracy: 1987/2700 (0.7359)


Test set: Average loss: 2.6708, Accu

In [4]:
def plot_confusion_matrix(y_true, y_pred, labels, partition):
    cm = confusion_matrix(y_true, y_pred, labels=range(len(labels)))
    cm_normalized = cm.astype('float') / cm.sum(axis=1, keepdims=True)

    plt.figure(figsize=(18, 16))
    sns.heatmap(cm_normalized,
                annot=True,
                fmt=".2f",
                xticklabels=labels,
                yticklabels=labels,
                cmap="viridis",
                square=True,
                linewidths=0.5,
                cbar_kws={"shrink": 0.8})

    plt.title(f"Normalized Confusion Matrix (Partition gtd{partition})", fontsize=18)
    plt.xlabel("Predicted Label", fontsize=14)
    plt.ylabel("True Label", fontsize=14)
    plt.xticks(rotation=90)
    plt.yticks(rotation=0)
    plt.tight_layout()

    save_path = f"results/confusion_matrix_partition_gtd{partition}.png"
    plt.savefig(save_path, dpi=300)
    plt.close()

    print(f"Saved confusion matrix for partition gtd{partition} to {save_path}")



In [5]:
plot_confusion_matrix(targets, preds, labels, partition)

Saved confusion matrix for partition gtd300 to results/confusion_matrix_partition_gtd300.png
