In [1]:
partition = 300

In [2]:
import sys
from train import main
from itertools import product  
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt


In [3]:

n_tree_values = [5, 10, 15, 20]
tree_depth_values = [3, 5, 10]
batch_size_values = [256, 512, 1000]

best_score = 0
best_config = {}

for n_tree, tree_depth, batch_size in product(n_tree_values, tree_depth_values, batch_size_values):
    print(f"\nRunning: n_tree={n_tree}, tree_depth={tree_depth}, batch_size={batch_size}")
    sys.argv = [
        'train.py',
        '-dataset', f'gtd{partition}',
        '-n_class', '30',
        '-gpuid', '0',
        '-n_tree', str(n_tree),
        '-tree_depth', str(tree_depth),
        '-batch_size', str(batch_size),
        '-epochs', '200',
        '-verbose', '0',
        '-jointly_training'
    ]
    
    _, _, _, _, _ = main()

    # Read best score from file (assumes one run per file)
    result_file = f"results/result_gtd{partition}"
    with open(result_file, "r") as f:
        lines = f.readlines()
        for line in lines:
            if "Best Accuracy" in line:
                acc = float(line.split()[2])
                if acc > best_score:
                    best_score = acc
                    best_config = {
                        'n_tree': n_tree,
                        'tree_depth': tree_depth,
                        'batch_size': batch_size
                    }
print("\nBest hyperparameter configuration:")
print(best_config)
print(f"Best accuracy: {best_score}")



Running: n_tree=5, tree_depth=3, batch_size=256
Use gtd300 dataset


Training Epochs:  32%|███▎      | 65/200 [01:16<02:39,  1.18s/it]

Early stopping at epoch 66

Running: n_tree=5, tree_depth=3, batch_size=512
Use gtd300 dataset



Training Epochs:  48%|████▊     | 97/200 [01:02<01:06,  1.54it/s]

Early stopping at epoch 98

Running: n_tree=5, tree_depth=3, batch_size=1000
Use gtd300 dataset



Training Epochs:  18%|█▊        | 37/200 [00:16<01:12,  2.24it/s]

Early stopping at epoch 38






Running: n_tree=5, tree_depth=5, batch_size=256
Use gtd300 dataset


Training Epochs:  34%|███▎      | 67/200 [01:27<02:53,  1.31s/it]

Early stopping at epoch 68

Running: n_tree=5, tree_depth=5, batch_size=512
Use gtd300 dataset



Training Epochs:  43%|████▎     | 86/200 [01:09<01:31,  1.25it/s]

Early stopping at epoch 87






Running: n_tree=5, tree_depth=5, batch_size=1000
Use gtd300 dataset


Training Epochs:  58%|█████▊    | 117/200 [01:06<00:47,  1.75it/s]

Early stopping at epoch 118






Running: n_tree=5, tree_depth=10, batch_size=256
Use gtd300 dataset


Training Epochs:  46%|████▌     | 91/200 [03:02<03:38,  2.00s/it]

Early stopping at epoch 92

Running: n_tree=5, tree_depth=10, batch_size=512
Use gtd300 dataset



Training Epochs:  26%|██▌       | 51/200 [00:55<02:42,  1.09s/it]

Early stopping at epoch 52

Running: n_tree=5, tree_depth=10, batch_size=1000
Use gtd300 dataset



Training Epochs:  74%|███████▍  | 149/200 [01:49<00:37,  1.36it/s]

Early stopping at epoch 150

Running: n_tree=10, tree_depth=3, batch_size=256
Use gtd300 dataset



Training Epochs:  26%|██▌       | 52/200 [01:52<05:19,  2.16s/it]

Early stopping at epoch 53






Running: n_tree=10, tree_depth=3, batch_size=512
Use gtd300 dataset


Training Epochs:  42%|████▏     | 83/200 [01:42<02:23,  1.23s/it]

Early stopping at epoch 84






Running: n_tree=10, tree_depth=3, batch_size=1000
Use gtd300 dataset


Training Epochs:  88%|████████▊ | 176/200 [02:07<00:17,  1.38it/s]

Early stopping at epoch 177






Running: n_tree=10, tree_depth=5, batch_size=256
Use gtd300 dataset


Training Epochs:  38%|███▊      | 75/200 [02:56<04:54,  2.36s/it]

Early stopping at epoch 76






Running: n_tree=10, tree_depth=5, batch_size=512
Use gtd300 dataset


Training Epochs:  46%|████▌     | 91/200 [02:01<02:25,  1.34s/it]

Early stopping at epoch 92






Running: n_tree=10, tree_depth=5, batch_size=1000
Use gtd300 dataset


Training Epochs:  42%|████▏     | 84/200 [01:04<01:28,  1.30it/s]

Early stopping at epoch 85






Running: n_tree=10, tree_depth=10, batch_size=256
Use gtd300 dataset


Training Epochs:  42%|████▏     | 83/200 [03:49<05:23,  2.77s/it]

Early stopping at epoch 84

Running: n_tree=10, tree_depth=10, batch_size=512
Use gtd300 dataset



Training Epochs:  57%|█████▋    | 114/200 [03:10<02:23,  1.67s/it]

Early stopping at epoch 115

Running: n_tree=10, tree_depth=10, batch_size=1000
Use gtd300 dataset



Training Epochs:  64%|██████▍   | 128/200 [02:13<01:15,  1.04s/it]

Early stopping at epoch 129

Running: n_tree=15, tree_depth=3, batch_size=256
Use gtd300 dataset



Training Epochs:  28%|██▊       | 57/200 [02:06<05:16,  2.21s/it]

Early stopping at epoch 58






Running: n_tree=15, tree_depth=3, batch_size=512
Use gtd300 dataset


Training Epochs:  66%|██████▋   | 133/200 [02:29<01:15,  1.12s/it]

Early stopping at epoch 134






Running: n_tree=15, tree_depth=3, batch_size=1000
Use gtd300 dataset


Training Epochs:  53%|█████▎    | 106/200 [01:15<01:07,  1.40it/s]

Early stopping at epoch 107






Running: n_tree=15, tree_depth=5, batch_size=256
Use gtd300 dataset


Training Epochs:  22%|██▏       | 43/200 [01:44<06:22,  2.44s/it]

Early stopping at epoch 44






Running: n_tree=15, tree_depth=5, batch_size=512
Use gtd300 dataset


Training Epochs:  28%|██▊       | 55/200 [01:17<03:24,  1.41s/it]

Early stopping at epoch 56






Running: n_tree=15, tree_depth=5, batch_size=1000
Use gtd300 dataset


Training Epochs:  44%|████▎     | 87/200 [01:08<01:28,  1.27it/s]

Early stopping at epoch 88






Running: n_tree=15, tree_depth=10, batch_size=256
Use gtd300 dataset


Training Epochs:  39%|███▉      | 78/200 [03:44<05:50,  2.87s/it]

Early stopping at epoch 79

Running: n_tree=15, tree_depth=10, batch_size=512
Use gtd300 dataset



Training Epochs:  36%|███▋      | 73/200 [02:18<04:00,  1.89s/it]

Early stopping at epoch 74






Running: n_tree=15, tree_depth=10, batch_size=1000
Use gtd300 dataset


Training Epochs:  32%|███▎      | 65/200 [01:27<03:00,  1.34s/it]

Early stopping at epoch 66

Running: n_tree=20, tree_depth=3, batch_size=256
Use gtd300 dataset



Training Epochs:  56%|█████▌    | 112/200 [05:17<04:09,  2.83s/it]

Early stopping at epoch 113






Running: n_tree=20, tree_depth=3, batch_size=512
Use gtd300 dataset


Training Epochs:  48%|████▊     | 95/200 [02:42<02:59,  1.71s/it]

Early stopping at epoch 96






Running: n_tree=20, tree_depth=3, batch_size=1000
Use gtd300 dataset


Training Epochs:  68%|██████▊   | 137/200 [02:04<00:57,  1.10it/s]

Early stopping at epoch 138






Running: n_tree=20, tree_depth=5, batch_size=256
Use gtd300 dataset


Training Epochs:  36%|███▌      | 71/200 [03:12<05:50,  2.72s/it]

Early stopping at epoch 72






Running: n_tree=20, tree_depth=5, batch_size=512
Use gtd300 dataset


Training Epochs:  72%|███████▏  | 143/200 [03:45<01:29,  1.58s/it]

Early stopping at epoch 144






Running: n_tree=20, tree_depth=5, batch_size=1000
Use gtd300 dataset


Training Epochs:  48%|████▊     | 97/200 [01:32<01:38,  1.05it/s]

Early stopping at epoch 98






Running: n_tree=20, tree_depth=10, batch_size=256
Use gtd300 dataset


Training Epochs:  45%|████▌     | 90/200 [05:34<06:49,  3.72s/it]

Early stopping at epoch 91






Running: n_tree=20, tree_depth=10, batch_size=512
Use gtd300 dataset


Training Epochs:  48%|████▊     | 95/200 [03:29<03:51,  2.21s/it]

Early stopping at epoch 96






Running: n_tree=20, tree_depth=10, batch_size=1000
Use gtd300 dataset


Training Epochs:  52%|█████▎    | 105/200 [02:34<02:19,  1.47s/it]

Early stopping at epoch 106






Best hyperparameter configuration:
{'n_tree': 10, 'tree_depth': 10, 'batch_size': 1000}
Best accuracy: 0.452593


In [4]:
sys.argv = [
    'train.py',
    '-dataset', f'gtd{partition}',
    '-n_class', '30',
    '-gpuid', '0',
    '-n_tree', str(best_config['n_tree']),
    '-tree_depth', str(best_config['tree_depth']),
    '-batch_size', str(best_config['batch_size']),
    '-epochs', '1000',
    '-verbose', '1',
    '-jointly_training'
]

best_model, preds, targets, labels, epoch_logs = main()


Use gtd300 dataset


Training Epochs:   0%|          | 0/1000 [00:00<?, ?it/s]



Training Epochs:   0%|          | 1/1000 [00:00<15:21,  1.08it/s]



Training Epochs:   0%|          | 2/1000 [00:01<15:15,  1.09it/s]



Training Epochs:   0%|          | 3/1000 [00:02<14:42,  1.13it/s]



Training Epochs:   0%|          | 4/1000 [00:03<14:17,  1.16it/s]



Training Epochs:   0%|          | 5/1000 [00:04<15:05,  1.10it/s]



Training Epochs:   1%|          | 6/1000 [00:05<17:24,  1.05s/it]



Training Epochs:   1%|          | 7/1000 [00:06<16:16,  1.02it/s]



Training Epochs:   1%|          | 8/1000 [00:07<15:16,  1.08it/s]



Training Epochs:   1%|          | 9/1000 [00:08<16:06,  1.03it/s]



Training Epochs:   1%|          | 10/1000 [00:09<16:38,  1.01s/it]



Training Epochs:   1%|          | 11/1000 [00:10<16:21,  1.01it/s]



Training Epochs:   1%|          | 12/1000 [00:11<15:55,  1.03it/s]



Training Epochs:   1%|▏         | 13/1000 [00:12<16:27,  1.00s/it]



Training Epochs:   1%|▏         | 14/1000 [00:13<16:02,  1.02it/s]



Training Epochs:   2%|▏         | 15/1000 [00:14<15:18,  1.07it/s]



Training Epochs:   2%|▏         | 16/1000 [00:15<15:32,  1.06it/s]



Training Epochs:   2%|▏         | 17/1000 [00:16<15:08,  1.08it/s]



Training Epochs:   2%|▏         | 18/1000 [00:17<15:49,  1.03it/s]



Training Epochs:   2%|▏         | 19/1000 [00:18<15:31,  1.05it/s]



Training Epochs:   2%|▏         | 20/1000 [00:19<15:56,  1.02it/s]



Training Epochs:   2%|▏         | 21/1000 [00:20<15:32,  1.05it/s]



Training Epochs:   2%|▏         | 22/1000 [00:21<15:33,  1.05it/s]



Training Epochs:   2%|▏         | 23/1000 [00:21<14:48,  1.10it/s]



Training Epochs:   2%|▏         | 24/1000 [00:22<14:19,  1.14it/s]



Training Epochs:   2%|▎         | 25/1000 [00:23<14:22,  1.13it/s]



Training Epochs:   3%|▎         | 26/1000 [00:24<15:14,  1.06it/s]



Training Epochs:   3%|▎         | 27/1000 [00:25<15:14,  1.06it/s]



Training Epochs:   3%|▎         | 28/1000 [00:26<16:04,  1.01it/s]



Training Epochs:   3%|▎         | 29/1000 [00:27<14:53,  1.09it/s]



Training Epochs:   3%|▎         | 30/1000 [00:28<14:24,  1.12it/s]



Training Epochs:   3%|▎         | 31/1000 [00:28<13:33,  1.19it/s]



Training Epochs:   3%|▎         | 32/1000 [00:29<13:40,  1.18it/s]



Training Epochs:   3%|▎         | 33/1000 [00:30<13:08,  1.23it/s]



Training Epochs:   3%|▎         | 34/1000 [00:31<13:30,  1.19it/s]



Training Epochs:   4%|▎         | 35/1000 [00:32<13:44,  1.17it/s]



Training Epochs:   4%|▎         | 36/1000 [00:33<13:38,  1.18it/s]



Training Epochs:   4%|▎         | 37/1000 [00:33<13:04,  1.23it/s]



Training Epochs:   4%|▍         | 38/1000 [00:34<13:32,  1.18it/s]



Training Epochs:   4%|▍         | 39/1000 [00:35<13:50,  1.16it/s]



Training Epochs:   4%|▍         | 40/1000 [00:36<12:38,  1.27it/s]



Training Epochs:   4%|▍         | 41/1000 [00:37<12:47,  1.25it/s]



Training Epochs:   4%|▍         | 42/1000 [00:38<13:34,  1.18it/s]



Training Epochs:   4%|▍         | 43/1000 [00:38<12:43,  1.25it/s]



Training Epochs:   4%|▍         | 44/1000 [00:39<11:59,  1.33it/s]



Training Epochs:   4%|▍         | 45/1000 [00:40<12:50,  1.24it/s]



Training Epochs:   5%|▍         | 46/1000 [00:41<13:20,  1.19it/s]



Training Epochs:   5%|▍         | 47/1000 [00:41<12:15,  1.30it/s]



Training Epochs:   5%|▍         | 48/1000 [00:42<12:18,  1.29it/s]



Training Epochs:   5%|▍         | 49/1000 [00:43<12:34,  1.26it/s]



Training Epochs:   5%|▌         | 50/1000 [00:44<11:55,  1.33it/s]



Training Epochs:   5%|▌         | 51/1000 [00:45<12:32,  1.26it/s]



Training Epochs:   5%|▌         | 52/1000 [00:45<12:15,  1.29it/s]



Training Epochs:   5%|▌         | 53/1000 [00:46<11:50,  1.33it/s]



Training Epochs:   5%|▌         | 54/1000 [00:47<12:51,  1.23it/s]



Training Epochs:   6%|▌         | 55/1000 [00:48<12:50,  1.23it/s]



Training Epochs:   6%|▌         | 56/1000 [00:49<12:37,  1.25it/s]



Training Epochs:   6%|▌         | 57/1000 [00:49<11:55,  1.32it/s]



Training Epochs:   6%|▌         | 58/1000 [00:50<12:53,  1.22it/s]



Training Epochs:   6%|▌         | 59/1000 [00:51<13:45,  1.14it/s]



Training Epochs:   6%|▌         | 60/1000 [00:52<13:56,  1.12it/s]



Training Epochs:   6%|▌         | 61/1000 [00:53<13:36,  1.15it/s]



Training Epochs:   6%|▌         | 62/1000 [00:54<15:09,  1.03it/s]



Training Epochs:   6%|▋         | 63/1000 [00:55<14:42,  1.06it/s]



Training Epochs:   6%|▋         | 64/1000 [00:56<13:58,  1.12it/s]



Training Epochs:   6%|▋         | 65/1000 [00:57<14:11,  1.10it/s]



Training Epochs:   7%|▋         | 66/1000 [00:58<13:43,  1.13it/s]



Training Epochs:   7%|▋         | 67/1000 [00:58<12:30,  1.24it/s]



Training Epochs:   7%|▋         | 68/1000 [00:59<12:29,  1.24it/s]



Training Epochs:   7%|▋         | 69/1000 [01:00<12:50,  1.21it/s]



Training Epochs:   7%|▋         | 70/1000 [01:01<13:21,  1.16it/s]



Training Epochs:   7%|▋         | 71/1000 [01:02<13:17,  1.17it/s]



Training Epochs:   7%|▋         | 72/1000 [01:03<14:02,  1.10it/s]



Training Epochs:   7%|▋         | 73/1000 [01:03<12:53,  1.20it/s]



Training Epochs:   7%|▋         | 74/1000 [01:04<12:58,  1.19it/s]



Training Epochs:   8%|▊         | 75/1000 [01:05<13:14,  1.16it/s]



Training Epochs:   8%|▊         | 76/1000 [01:06<13:15,  1.16it/s]



Training Epochs:   8%|▊         | 77/1000 [01:07<12:21,  1.24it/s]



Training Epochs:   8%|▊         | 78/1000 [01:08<13:11,  1.17it/s]



Training Epochs:   8%|▊         | 79/1000 [01:09<13:11,  1.16it/s]



Training Epochs:   8%|▊         | 80/1000 [01:09<12:03,  1.27it/s]



Training Epochs:   8%|▊         | 81/1000 [01:10<12:12,  1.26it/s]



Training Epochs:   8%|▊         | 82/1000 [01:11<12:08,  1.26it/s]



Training Epochs:   8%|▊         | 83/1000 [01:12<12:20,  1.24it/s]



Training Epochs:   8%|▊         | 84/1000 [01:12<12:08,  1.26it/s]



Training Epochs:   8%|▊         | 85/1000 [01:13<11:45,  1.30it/s]



Training Epochs:   9%|▊         | 86/1000 [01:14<11:38,  1.31it/s]



Training Epochs:   9%|▊         | 87/1000 [01:15<12:22,  1.23it/s]



Training Epochs:   9%|▉         | 88/1000 [01:16<13:16,  1.14it/s]



Training Epochs:   9%|▉         | 89/1000 [01:17<12:42,  1.20it/s]



Training Epochs:   9%|▉         | 90/1000 [01:18<13:20,  1.14it/s]



Training Epochs:   9%|▉         | 91/1000 [01:18<13:39,  1.11it/s]



Training Epochs:   9%|▉         | 92/1000 [01:19<12:29,  1.21it/s]



Training Epochs:   9%|▉         | 93/1000 [01:20<12:18,  1.23it/s]



Training Epochs:   9%|▉         | 94/1000 [01:21<13:16,  1.14it/s]



Training Epochs:  10%|▉         | 95/1000 [01:22<12:40,  1.19it/s]



Training Epochs:  10%|▉         | 96/1000 [01:22<12:01,  1.25it/s]



Training Epochs:  10%|▉         | 97/1000 [01:23<12:00,  1.25it/s]



Training Epochs:  10%|▉         | 98/1000 [01:24<11:58,  1.26it/s]



Training Epochs:  10%|▉         | 99/1000 [01:25<12:39,  1.19it/s]



Training Epochs:  10%|█         | 100/1000 [01:26<12:48,  1.17it/s]



Training Epochs:  10%|█         | 101/1000 [01:27<12:43,  1.18it/s]



Training Epochs:  10%|█         | 102/1000 [01:27<11:43,  1.28it/s]



Training Epochs:  10%|█         | 103/1000 [01:28<11:28,  1.30it/s]



Training Epochs:  10%|█         | 104/1000 [01:29<12:05,  1.24it/s]



Training Epochs:  10%|█         | 105/1000 [01:30<12:34,  1.19it/s]



Training Epochs:  11%|█         | 106/1000 [01:31<12:23,  1.20it/s]



Training Epochs:  11%|█         | 107/1000 [01:31<11:36,  1.28it/s]



Training Epochs:  11%|█         | 108/1000 [01:32<11:47,  1.26it/s]



Training Epochs:  11%|█         | 109/1000 [01:33<11:43,  1.27it/s]



Training Epochs:  11%|█         | 110/1000 [01:34<11:35,  1.28it/s]



Training Epochs:  11%|█         | 111/1000 [01:34<10:47,  1.37it/s]



Training Epochs:  11%|█         | 112/1000 [01:35<11:00,  1.35it/s]



Training Epochs:  11%|█▏        | 113/1000 [01:36<11:58,  1.24it/s]



Training Epochs:  11%|█▏        | 114/1000 [01:37<11:11,  1.32it/s]



Training Epochs:  12%|█▏        | 115/1000 [01:37<10:03,  1.47it/s]



Training Epochs:  12%|█▏        | 116/1000 [01:38<10:23,  1.42it/s]



Training Epochs:  12%|█▏        | 117/1000 [01:39<11:58,  1.23it/s]



Training Epochs:  12%|█▏        | 117/1000 [01:40<12:36,  1.17it/s]

Early stopping at epoch 118





In [5]:
from sklearn.metrics import classification_report

print(classification_report(targets, preds))

              precision    recall  f1-score   support

           0       0.25      0.51      0.33        90
           1       0.35      0.87      0.49        90
           2       0.34      0.49      0.40        90
           3       0.32      0.32      0.32        90
           4       0.26      0.12      0.17        90
           5       0.46      0.74      0.57        90
           6       0.37      0.17      0.23        90
           7       0.47      0.58      0.52        90
           8       0.53      0.83      0.65        90
           9       0.38      0.48      0.43        90
          10       0.40      0.74      0.52        90
          11       0.54      0.78      0.64        90
          12       0.44      0.20      0.27        90
          13       0.52      0.81      0.63        90
          14       0.32      0.27      0.29        90
          15       0.30      0.08      0.12        90
          16       0.51      0.43      0.47        90
          17       0.62    

In [6]:
def plot_confusion_matrix(y_true, y_pred, labels, partition):
    cm = confusion_matrix(y_true, y_pred, labels=range(len(labels)))
    cm_normalized = cm.astype('float') / cm.sum(axis=1, keepdims=True)

    plt.figure(figsize=(18, 16))
    sns.heatmap(cm_normalized,
                annot=True,
                fmt=".2f",
                xticklabels=labels,
                yticklabels=labels,
                cmap="viridis",
                square=True,
                linewidths=0.5,
                cbar_kws={"shrink": 0.8})

    plt.title(f"Normalized Confusion Matrix (Partition gtd{partition})", fontsize=18)
    plt.xlabel("Predicted Label", fontsize=14)
    plt.ylabel("True Label", fontsize=14)
    plt.xticks(rotation=90)
    plt.yticks(rotation=0)
    plt.tight_layout()

    save_path = f"results/confusion_matrix_partition_gtd{partition}.png"
    plt.savefig(save_path, dpi=300)
    plt.close()

    print(f"Saved confusion matrix for partition gtd{partition} to {save_path}")



In [7]:
plot_confusion_matrix(targets, preds, labels, partition)

Saved confusion matrix for partition gtd300 to results/confusion_matrix_partition_gtd300.png
