In [1]:
partition = 478

In [2]:
import sys
from train import main
from itertools import product  
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

n_tree_values = [5, 10, 15, 20]
tree_depth_values = [3, 5, 10]
batch_size_values = [256, 512, 1000]

best_score = 0
best_config = {}

for n_tree, tree_depth, batch_size in product(n_tree_values, tree_depth_values, batch_size_values):
    print(f"\nRunning: n_tree={n_tree}, tree_depth={tree_depth}, batch_size={batch_size}")
    sys.argv = [
        'train.py',
        '-dataset', f'gtd{partition}',
        '-n_class', '30',
        '-gpuid', '0',
        '-n_tree', str(n_tree),
        '-tree_depth', str(tree_depth),
        '-batch_size', str(batch_size),
        '-epochs', '200',
        '-jointly_training'
    ]
    
    _, _, _ = main()

    # Read best score from file (assumes one run per file)
    result_file = f"results/result_gtd{partition}"
    with open(result_file, "r") as f:
        lines = f.readlines()
        for line in lines:
            if "Best Accuracy" in line:
                acc = float(line.split()[2])
                if acc > best_score:
                    best_score = acc
                    best_config = {
                        'n_tree': n_tree,
                        'tree_depth': tree_depth,
                        'batch_size': batch_size
                    }
print("\nBest hyperparameter configuration:")
print(best_config)
print(f"Best accuracy: {best_score}")



Running: n_tree=5, tree_depth=3, batch_size=256
Use gtd478 dataset

Test set: Average loss: 3.2818, Accuracy: 1009/4320 (0.2336)


Test set: Average loss: 3.2047, Accuracy: 1225/4320 (0.2836)


Test set: Average loss: 3.1387, Accuracy: 1302/4320 (0.3014)


Test set: Average loss: 3.0770, Accuracy: 1395/4320 (0.3229)


Test set: Average loss: 3.0187, Accuracy: 1476/4320 (0.3417)


Test set: Average loss: 2.9621, Accuracy: 1748/4320 (0.4046)


Test set: Average loss: 2.9073, Accuracy: 1814/4320 (0.4199)


Test set: Average loss: 2.8539, Accuracy: 1873/4320 (0.4336)


Test set: Average loss: 2.8031, Accuracy: 2098/4320 (0.4856)


Test set: Average loss: 2.7541, Accuracy: 2110/4320 (0.4884)


Test set: Average loss: 2.7072, Accuracy: 2126/4320 (0.4921)


Test set: Average loss: 2.6618, Accuracy: 2150/4320 (0.4977)


Test set: Average loss: 2.6193, Accuracy: 2149/4320 (0.4975)


Test set: Average loss: 2.5785, Accuracy: 2159/4320 (0.4998)


Test set: Average loss: 2.5392, Accuracy: 2165/43

In [3]:
sys.argv = [
    'train.py',
    '-dataset', f'gtd{partition}',
    '-n_class', '30',
    '-gpuid', '0',
    '-n_tree', str(best_config['n_tree']),
    '-tree_depth', str(best_config['tree_depth']),
    '-batch_size', str(best_config['batch_size']),
    '-epochs', '200',
    '-jointly_training'
]

preds, targets, labels = main()


Use gtd478 dataset

Test set: Average loss: 3.1671, Accuracy: 1791/4320 (0.4146)


Test set: Average loss: 3.0192, Accuracy: 2364/4320 (0.5472)


Test set: Average loss: 2.9108, Accuracy: 2612/4320 (0.6046)


Test set: Average loss: 2.8152, Accuracy: 2858/4320 (0.6616)


Test set: Average loss: 2.7254, Accuracy: 3064/4320 (0.7093)


Test set: Average loss: 2.6414, Accuracy: 3135/4320 (0.7257)


Test set: Average loss: 2.5618, Accuracy: 3221/4320 (0.7456)


Test set: Average loss: 2.4854, Accuracy: 3244/4320 (0.7509)


Test set: Average loss: 2.4135, Accuracy: 3325/4320 (0.7697)


Test set: Average loss: 2.3455, Accuracy: 3326/4320 (0.7699)


Test set: Average loss: 2.2805, Accuracy: 3325/4320 (0.7697)


Test set: Average loss: 2.2183, Accuracy: 3352/4320 (0.7759)


Test set: Average loss: 2.1573, Accuracy: 3359/4320 (0.7775)


Test set: Average loss: 2.1002, Accuracy: 3363/4320 (0.7785)


Test set: Average loss: 2.0463, Accuracy: 3383/4320 (0.7831)


Test set: Average loss: 1.9949, Acc

In [4]:
def plot_confusion_matrix(y_true, y_pred, labels, partition):
    cm = confusion_matrix(y_true, y_pred, labels=range(len(labels)))
    cm_normalized = cm.astype('float') / cm.sum(axis=1, keepdims=True)

    plt.figure(figsize=(18, 16))
    sns.heatmap(cm_normalized,
                annot=True,
                fmt=".2f",
                xticklabels=labels,
                yticklabels=labels,
                cmap="viridis",
                square=True,
                linewidths=0.5,
                cbar_kws={"shrink": 0.8})

    plt.title(f"Normalized Confusion Matrix (Partition gtd{partition})", fontsize=18)
    plt.xlabel("Predicted Label", fontsize=14)
    plt.ylabel("True Label", fontsize=14)
    plt.xticks(rotation=90)
    plt.yticks(rotation=0)
    plt.tight_layout()

    save_path = f"results/confusion_matrix_partition_gtd{partition}.png"
    plt.savefig(save_path, dpi=300)
    plt.close()

    print(f"Saved confusion matrix for partition gtd{partition} to {save_path}")



In [5]:
plot_confusion_matrix(targets, preds, labels, partition)

Saved confusion matrix for partition gtd478 to results/confusion_matrix_partition_gtd478.png
