In [4]:
partition = 300

In [5]:
import sys
from train import main
from itertools import product  
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt


In [6]:

n_tree_values = [5, 10, 15, 20]
tree_depth_values = [3, 5, 10]
batch_size_values = [256, 512, 1000]

best_score = 0
best_config = {}

for n_tree, tree_depth, batch_size in product(n_tree_values, tree_depth_values, batch_size_values):
    print(f"\nRunning: n_tree={n_tree}, tree_depth={tree_depth}, batch_size={batch_size}")
    sys.argv = [
        'train.py',
        '-dataset', f'gtd{partition}',
        '-n_class', '30',
        '-gpuid', '0',
        '-n_tree', str(n_tree),
        '-tree_depth', str(tree_depth),
        '-batch_size', str(batch_size),
        '-epochs', '200',
        '-verbose', '0',
        '-jointly_training'
    ]
    
    _, _, _, _, _ = main()

    # Read best score from file (assumes one run per file)
    result_file = f"results/result_gtd{partition}"
    with open(result_file, "r") as f:
        lines = f.readlines()
        for line in lines:
            if "Best Accuracy" in line:
                acc = float(line.split()[2])
                if acc > best_score:
                    best_score = acc
                    best_config = {
                        'n_tree': n_tree,
                        'tree_depth': tree_depth,
                        'batch_size': batch_size
                    }
print("\nBest hyperparameter configuration:")
print(best_config)
print(f"Best accuracy: {best_score}")



Running: n_tree=5, tree_depth=3, batch_size=256
Use gtd300 dataset


Training Epochs:   0%|          | 1/200 [00:01<03:38,  1.10s/it]

Training Epochs:  30%|██▉       | 59/200 [01:10<02:47,  1.19s/it]

Early stopping at epoch 60

Running: n_tree=5, tree_depth=3, batch_size=512
Use gtd300 dataset



Training Epochs:  12%|█▎        | 25/200 [00:17<02:02,  1.42it/s]

Early stopping at epoch 26






Running: n_tree=5, tree_depth=3, batch_size=1000
Use gtd300 dataset


Training Epochs:  68%|██████▊   | 137/200 [00:55<00:25,  2.46it/s]

Early stopping at epoch 138

Running: n_tree=5, tree_depth=5, batch_size=256
Use gtd300 dataset



Training Epochs:  22%|██▏       | 44/200 [00:57<03:23,  1.31s/it]

Early stopping at epoch 45

Running: n_tree=5, tree_depth=5, batch_size=512
Use gtd300 dataset



Training Epochs:  36%|███▌      | 71/200 [00:54<01:38,  1.31it/s]

Early stopping at epoch 72






Running: n_tree=5, tree_depth=5, batch_size=1000
Use gtd300 dataset


Training Epochs:  50%|████▉     | 99/200 [00:48<00:49,  2.04it/s]

Early stopping at epoch 100






Running: n_tree=5, tree_depth=10, batch_size=256
Use gtd300 dataset


Training Epochs:  24%|██▍       | 48/200 [01:43<05:28,  2.16s/it]

Early stopping at epoch 49

Running: n_tree=5, tree_depth=10, batch_size=512
Use gtd300 dataset



Training Epochs:  14%|█▍        | 29/200 [00:33<03:14,  1.14s/it]

Early stopping at epoch 30

Running: n_tree=5, tree_depth=10, batch_size=1000
Use gtd300 dataset



Training Epochs:  22%|██▏       | 43/200 [00:30<01:49,  1.43it/s]

Early stopping at epoch 44

Running: n_tree=10, tree_depth=3, batch_size=256
Use gtd300 dataset



Training Epochs:  20%|██        | 40/200 [01:18<05:12,  1.95s/it]

Early stopping at epoch 41






Running: n_tree=10, tree_depth=3, batch_size=512
Use gtd300 dataset


Training Epochs:  20%|██        | 41/200 [00:43<02:49,  1.06s/it]

Early stopping at epoch 42






Running: n_tree=10, tree_depth=3, batch_size=1000
Use gtd300 dataset


Training Epochs:  68%|██████▊   | 137/200 [01:26<00:39,  1.58it/s]

Early stopping at epoch 138






Running: n_tree=10, tree_depth=5, batch_size=256
Use gtd300 dataset


Training Epochs:  16%|█▋        | 33/200 [01:11<06:00,  2.16s/it]

Early stopping at epoch 34






Running: n_tree=10, tree_depth=5, batch_size=512
Use gtd300 dataset


Training Epochs:  74%|███████▍  | 149/200 [02:58<01:01,  1.20s/it]

Early stopping at epoch 150






Running: n_tree=10, tree_depth=5, batch_size=1000
Use gtd300 dataset


Training Epochs:  66%|██████▌   | 131/200 [01:33<00:49,  1.40it/s]

Early stopping at epoch 132






Running: n_tree=10, tree_depth=10, batch_size=256
Use gtd300 dataset


Training Epochs:  40%|████      | 80/200 [03:45<05:37,  2.82s/it]

Early stopping at epoch 81






Running: n_tree=10, tree_depth=10, batch_size=512
Use gtd300 dataset


Training Epochs:  14%|█▍        | 29/200 [00:48<04:43,  1.66s/it]

Early stopping at epoch 30

Running: n_tree=10, tree_depth=10, batch_size=1000
Use gtd300 dataset



Training Epochs:  26%|██▋       | 53/200 [00:53<02:27,  1.00s/it]

Early stopping at epoch 54






Running: n_tree=15, tree_depth=3, batch_size=256
Use gtd300 dataset


Training Epochs:  34%|███▍      | 68/200 [02:43<05:17,  2.40s/it]

Early stopping at epoch 69






Running: n_tree=15, tree_depth=3, batch_size=512
Use gtd300 dataset


Training Epochs:  42%|████▏     | 83/200 [01:48<02:32,  1.30s/it]

Early stopping at epoch 84






Running: n_tree=15, tree_depth=3, batch_size=1000
Use gtd300 dataset


Training Epochs:  58%|█████▊    | 116/200 [01:23<01:00,  1.38it/s]

Early stopping at epoch 117






Running: n_tree=15, tree_depth=5, batch_size=256
Use gtd300 dataset


Training Epochs:  52%|█████▎    | 105/200 [04:04<03:41,  2.33s/it]

Early stopping at epoch 106






Running: n_tree=15, tree_depth=5, batch_size=512
Use gtd300 dataset


Training Epochs:  28%|██▊       | 57/200 [01:26<03:35,  1.51s/it]

Early stopping at epoch 58






Running: n_tree=15, tree_depth=5, batch_size=1000
Use gtd300 dataset


Training Epochs:  76%|███████▌  | 152/200 [01:59<00:37,  1.27it/s]

Early stopping at epoch 153






Running: n_tree=15, tree_depth=10, batch_size=256
Use gtd300 dataset


Training Epochs:  34%|███▎      | 67/200 [03:08<06:14,  2.82s/it]

Early stopping at epoch 68






Running: n_tree=15, tree_depth=10, batch_size=512
Use gtd300 dataset


Training Epochs:  68%|██████▊   | 136/200 [04:05<01:55,  1.80s/it]

Early stopping at epoch 137






Running: n_tree=15, tree_depth=10, batch_size=1000
Use gtd300 dataset


Training Epochs:  58%|█████▊    | 116/200 [02:12<01:36,  1.14s/it]

Early stopping at epoch 117






Running: n_tree=20, tree_depth=3, batch_size=256
Use gtd300 dataset


Training Epochs:  38%|███▊      | 76/200 [03:05<05:02,  2.44s/it]

Early stopping at epoch 77






Running: n_tree=20, tree_depth=3, batch_size=512
Use gtd300 dataset


Training Epochs:  48%|████▊     | 96/200 [02:26<02:38,  1.52s/it]

Early stopping at epoch 97






Running: n_tree=20, tree_depth=3, batch_size=1000
Use gtd300 dataset


Training Epochs:  52%|█████▎    | 105/200 [01:29<01:20,  1.18it/s]

Early stopping at epoch 106






Running: n_tree=20, tree_depth=5, batch_size=256
Use gtd300 dataset


Training Epochs:  38%|███▊      | 75/200 [03:47<06:18,  3.03s/it]

Early stopping at epoch 76






Running: n_tree=20, tree_depth=5, batch_size=512
Use gtd300 dataset


Training Epochs:  61%|██████    | 122/200 [03:23<02:09,  1.66s/it]

Early stopping at epoch 123






Running: n_tree=20, tree_depth=5, batch_size=1000
Use gtd300 dataset


Training Epochs:  67%|██████▋   | 134/200 [02:11<01:04,  1.02it/s]

Early stopping at epoch 135






Running: n_tree=20, tree_depth=10, batch_size=256
Use gtd300 dataset


Training Epochs:  44%|████▍     | 88/200 [05:12<06:38,  3.56s/it]

Early stopping at epoch 89






Running: n_tree=20, tree_depth=10, batch_size=512
Use gtd300 dataset


Training Epochs:  27%|██▋       | 54/200 [02:00<05:24,  2.23s/it]

Early stopping at epoch 55






Running: n_tree=20, tree_depth=10, batch_size=1000
Use gtd300 dataset


Training Epochs:  48%|████▊     | 96/200 [02:04<02:15,  1.30s/it]

Early stopping at epoch 97






Best hyperparameter configuration:
{'n_tree': 20, 'tree_depth': 10, 'batch_size': 256}
Best accuracy: 0.863704


In [7]:
sys.argv = [
    'train.py',
    '-dataset', f'gtd{partition}',
    '-n_class', '30',
    '-gpuid', '0',
    '-n_tree', str(best_config['n_tree']),
    '-tree_depth', str(best_config['tree_depth']),
    '-batch_size', str(best_config['batch_size']),
    '-epochs', '1000',
    '-verbose', '1',
    '-jointly_training'
]

best_model, preds, targets, labels, epoch_logs = main()


Use gtd300 dataset


Training Epochs:   0%|          | 0/1000 [00:00<?, ?it/s]



Training Epochs:   0%|          | 1/1000 [00:04<1:07:22,  4.05s/it]



Training Epochs:   0%|          | 2/1000 [00:07<1:04:31,  3.88s/it]



Training Epochs:   0%|          | 3/1000 [00:10<58:50,  3.54s/it]  



Training Epochs:   0%|          | 4/1000 [00:14<59:08,  3.56s/it]



Training Epochs:   0%|          | 5/1000 [00:18<58:47,  3.55s/it]



Training Epochs:   1%|          | 6/1000 [00:21<58:50,  3.55s/it]



Training Epochs:   1%|          | 7/1000 [00:24<57:30,  3.47s/it]



Training Epochs:   1%|          | 8/1000 [00:28<57:44,  3.49s/it]



Training Epochs:   1%|          | 9/1000 [00:32<58:27,  3.54s/it]



Training Epochs:   1%|          | 10/1000 [00:35<58:18,  3.53s/it]



Training Epochs:   1%|          | 11/1000 [00:39<58:30,  3.55s/it]



Training Epochs:   1%|          | 12/1000 [00:42<59:02,  3.59s/it]



Training Epochs:   1%|▏         | 13/1000 [00:46<58:49,  3.58s/it]



Training Epochs:   1%|▏         | 14/1000 [00:49<58:37,  3.57s/it]



Training Epochs:   2%|▏         | 15/1000 [00:53<58:49,  3.58s/it]



Training Epochs:   2%|▏         | 16/1000 [00:57<59:56,  3.66s/it]



Training Epochs:   2%|▏         | 17/1000 [01:01<1:00:44,  3.71s/it]



Training Epochs:   2%|▏         | 18/1000 [01:04<56:50,  3.47s/it]  



Training Epochs:   2%|▏         | 19/1000 [01:07<58:08,  3.56s/it]



Training Epochs:   2%|▏         | 20/1000 [01:11<57:56,  3.55s/it]



Training Epochs:   2%|▏         | 21/1000 [01:15<59:16,  3.63s/it]



Training Epochs:   2%|▏         | 22/1000 [01:18<57:30,  3.53s/it]



Training Epochs:   2%|▏         | 23/1000 [01:21<56:47,  3.49s/it]



Training Epochs:   2%|▏         | 24/1000 [01:25<55:49,  3.43s/it]



Training Epochs:   2%|▎         | 25/1000 [01:28<55:00,  3.39s/it]



Training Epochs:   3%|▎         | 26/1000 [01:31<54:44,  3.37s/it]



Training Epochs:   3%|▎         | 27/1000 [01:35<56:15,  3.47s/it]



Training Epochs:   3%|▎         | 28/1000 [01:38<53:57,  3.33s/it]



Training Epochs:   3%|▎         | 29/1000 [01:42<54:18,  3.36s/it]



Training Epochs:   3%|▎         | 30/1000 [01:45<53:49,  3.33s/it]



Training Epochs:   3%|▎         | 31/1000 [01:48<53:17,  3.30s/it]



Training Epochs:   3%|▎         | 32/1000 [01:51<53:12,  3.30s/it]



Training Epochs:   3%|▎         | 33/1000 [01:55<55:05,  3.42s/it]



Training Epochs:   3%|▎         | 34/1000 [01:58<54:02,  3.36s/it]



Training Epochs:   4%|▎         | 35/1000 [02:02<54:11,  3.37s/it]



Training Epochs:   4%|▎         | 36/1000 [02:05<56:08,  3.49s/it]



Training Epochs:   4%|▎         | 37/1000 [02:09<55:20,  3.45s/it]



Training Epochs:   4%|▍         | 38/1000 [02:12<55:25,  3.46s/it]



Training Epochs:   4%|▍         | 39/1000 [02:16<54:47,  3.42s/it]



Training Epochs:   4%|▍         | 40/1000 [02:19<55:17,  3.46s/it]



Training Epochs:   4%|▍         | 41/1000 [02:23<55:31,  3.47s/it]



Training Epochs:   4%|▍         | 42/1000 [02:26<55:15,  3.46s/it]



Training Epochs:   4%|▍         | 43/1000 [02:30<55:34,  3.48s/it]



Training Epochs:   4%|▍         | 44/1000 [02:33<55:14,  3.47s/it]



Training Epochs:   4%|▍         | 45/1000 [02:36<53:51,  3.38s/it]



Training Epochs:   5%|▍         | 46/1000 [02:40<55:44,  3.51s/it]



Training Epochs:   5%|▍         | 47/1000 [02:44<56:58,  3.59s/it]



Training Epochs:   5%|▍         | 48/1000 [02:48<58:52,  3.71s/it]



Training Epochs:   5%|▍         | 49/1000 [02:51<58:36,  3.70s/it]



Training Epochs:   5%|▌         | 50/1000 [02:55<59:38,  3.77s/it]



Training Epochs:   5%|▌         | 51/1000 [02:59<58:24,  3.69s/it]



Training Epochs:   5%|▌         | 52/1000 [03:02<57:24,  3.63s/it]



Training Epochs:   5%|▌         | 53/1000 [03:06<58:33,  3.71s/it]



Training Epochs:   5%|▌         | 54/1000 [03:10<57:40,  3.66s/it]



Training Epochs:   6%|▌         | 55/1000 [03:13<57:05,  3.62s/it]



Training Epochs:   6%|▌         | 56/1000 [03:17<55:08,  3.50s/it]



Training Epochs:   6%|▌         | 57/1000 [03:20<54:29,  3.47s/it]



Training Epochs:   6%|▌         | 58/1000 [03:23<54:31,  3.47s/it]



Training Epochs:   6%|▌         | 59/1000 [03:27<54:28,  3.47s/it]



Training Epochs:   6%|▌         | 60/1000 [03:30<54:00,  3.45s/it]



Training Epochs:   6%|▌         | 61/1000 [03:34<54:18,  3.47s/it]



Training Epochs:   6%|▌         | 62/1000 [03:37<53:26,  3.42s/it]



Training Epochs:   6%|▋         | 63/1000 [03:40<53:07,  3.40s/it]



Training Epochs:   6%|▋         | 64/1000 [03:44<52:46,  3.38s/it]



Training Epochs:   6%|▋         | 65/1000 [03:47<53:52,  3.46s/it]



Training Epochs:   7%|▋         | 66/1000 [03:51<53:10,  3.42s/it]



Training Epochs:   7%|▋         | 67/1000 [03:54<53:00,  3.41s/it]



Training Epochs:   7%|▋         | 68/1000 [03:58<53:15,  3.43s/it]



Training Epochs:   7%|▋         | 69/1000 [04:01<51:34,  3.32s/it]



Training Epochs:   7%|▋         | 70/1000 [04:04<49:42,  3.21s/it]



Training Epochs:   7%|▋         | 71/1000 [04:07<48:11,  3.11s/it]



Training Epochs:   7%|▋         | 72/1000 [04:10<48:09,  3.11s/it]



Training Epochs:   7%|▋         | 73/1000 [04:13<47:43,  3.09s/it]



Training Epochs:   7%|▋         | 74/1000 [04:16<48:17,  3.13s/it]



Training Epochs:   8%|▊         | 75/1000 [04:19<49:04,  3.18s/it]



Training Epochs:   8%|▊         | 76/1000 [04:23<49:57,  3.24s/it]



Training Epochs:   8%|▊         | 77/1000 [04:26<51:06,  3.32s/it]



Training Epochs:   8%|▊         | 78/1000 [04:30<52:54,  3.44s/it]



Training Epochs:   8%|▊         | 79/1000 [04:33<53:42,  3.50s/it]



Training Epochs:   8%|▊         | 80/1000 [04:37<52:49,  3.45s/it]



Training Epochs:   8%|▊         | 81/1000 [04:40<51:42,  3.38s/it]



Training Epochs:   8%|▊         | 82/1000 [04:43<51:06,  3.34s/it]



Training Epochs:   8%|▊         | 83/1000 [04:47<51:27,  3.37s/it]



Training Epochs:   8%|▊         | 84/1000 [04:50<52:55,  3.47s/it]



Training Epochs:   8%|▊         | 85/1000 [04:54<53:36,  3.52s/it]



Training Epochs:   9%|▊         | 86/1000 [04:58<54:31,  3.58s/it]



Training Epochs:   9%|▊         | 87/1000 [05:02<57:35,  3.78s/it]



Training Epochs:   9%|▉         | 88/1000 [05:05<55:23,  3.64s/it]



Training Epochs:   9%|▉         | 89/1000 [05:09<54:55,  3.62s/it]



Training Epochs:   9%|▉         | 90/1000 [05:13<56:01,  3.69s/it]



Training Epochs:   9%|▉         | 91/1000 [05:16<54:16,  3.58s/it]



Training Epochs:   9%|▉         | 92/1000 [05:20<56:22,  3.73s/it]



Training Epochs:   9%|▉         | 93/1000 [05:23<54:30,  3.61s/it]



Training Epochs:   9%|▉         | 94/1000 [05:27<51:57,  3.44s/it]



Training Epochs:  10%|▉         | 95/1000 [05:30<51:23,  3.41s/it]



Training Epochs:  10%|▉         | 96/1000 [05:33<51:36,  3.43s/it]



Training Epochs:  10%|▉         | 97/1000 [05:37<52:26,  3.48s/it]



Training Epochs:  10%|▉         | 98/1000 [05:40<52:06,  3.47s/it]



Training Epochs:  10%|▉         | 99/1000 [05:44<51:32,  3.43s/it]



Training Epochs:  10%|█         | 100/1000 [05:47<49:30,  3.30s/it]



Training Epochs:  10%|█         | 101/1000 [05:50<48:25,  3.23s/it]



Training Epochs:  10%|█         | 102/1000 [05:53<48:32,  3.24s/it]



Training Epochs:  10%|█         | 103/1000 [05:56<47:46,  3.20s/it]



Training Epochs:  10%|█         | 104/1000 [06:00<49:19,  3.30s/it]



Training Epochs:  10%|█         | 105/1000 [06:03<50:17,  3.37s/it]



Training Epochs:  11%|█         | 106/1000 [06:07<51:12,  3.44s/it]



Training Epochs:  11%|█         | 107/1000 [06:10<48:59,  3.29s/it]



Training Epochs:  11%|█         | 108/1000 [06:13<48:03,  3.23s/it]



Training Epochs:  11%|█         | 109/1000 [06:16<48:09,  3.24s/it]



Training Epochs:  11%|█         | 110/1000 [06:19<46:20,  3.12s/it]



Training Epochs:  11%|█         | 111/1000 [06:22<47:55,  3.23s/it]



Training Epochs:  11%|█         | 112/1000 [06:26<47:48,  3.23s/it]



Training Epochs:  11%|█▏        | 113/1000 [06:29<47:18,  3.20s/it]



Training Epochs:  11%|█▏        | 114/1000 [06:32<47:09,  3.19s/it]



Training Epochs:  12%|█▏        | 115/1000 [06:35<45:56,  3.11s/it]



Training Epochs:  12%|█▏        | 116/1000 [06:38<45:36,  3.10s/it]



Training Epochs:  12%|█▏        | 117/1000 [06:41<46:46,  3.18s/it]



Training Epochs:  12%|█▏        | 118/1000 [06:45<46:55,  3.19s/it]



Training Epochs:  12%|█▏        | 119/1000 [06:48<48:58,  3.34s/it]



Training Epochs:  12%|█▏        | 120/1000 [06:52<50:36,  3.45s/it]



Training Epochs:  12%|█▏        | 121/1000 [06:56<52:12,  3.56s/it]



Training Epochs:  12%|█▏        | 122/1000 [06:59<52:46,  3.61s/it]



Training Epochs:  12%|█▏        | 123/1000 [07:03<51:29,  3.52s/it]



Training Epochs:  12%|█▏        | 124/1000 [07:07<52:57,  3.63s/it]



Training Epochs:  12%|█▎        | 125/1000 [07:10<52:17,  3.59s/it]



Training Epochs:  13%|█▎        | 126/1000 [07:14<53:10,  3.65s/it]



Training Epochs:  13%|█▎        | 127/1000 [07:18<53:03,  3.65s/it]



Training Epochs:  13%|█▎        | 128/1000 [07:21<52:31,  3.61s/it]



Training Epochs:  13%|█▎        | 129/1000 [07:25<52:39,  3.63s/it]



Training Epochs:  13%|█▎        | 130/1000 [07:29<54:13,  3.74s/it]



Training Epochs:  13%|█▎        | 130/1000 [07:33<50:33,  3.49s/it]

Early stopping at epoch 131





In [8]:
from sklearn.metrics import classification_report

print(classification_report(targets, preds))

              precision    recall  f1-score   support

           0       0.89      1.00      0.94        90
           1       0.98      1.00      0.99        90
           2       0.65      0.90      0.76        90
           3       0.75      0.88      0.81        90
           4       0.99      0.97      0.98        90
           5       1.00      1.00      1.00        90
           6       0.94      0.80      0.86        90
           7       0.91      0.94      0.93        90
           8       1.00      1.00      1.00        90
           9       1.00      1.00      1.00        90
          10       0.62      0.98      0.76        90
          11       0.82      0.94      0.88        90
          12       0.84      0.73      0.78        90
          13       0.99      1.00      0.99        90
          14       0.79      0.41      0.54        90
          15       0.82      0.96      0.88        90
          16       0.93      1.00      0.96        90
          17       1.00    

In [9]:
def plot_confusion_matrix(y_true, y_pred, labels, partition):
    cm = confusion_matrix(y_true, y_pred, labels=range(len(labels)))
    cm_normalized = cm.astype('float') / cm.sum(axis=1, keepdims=True)

    plt.figure(figsize=(18, 16))
    sns.heatmap(cm_normalized,
                annot=True,
                fmt=".2f",
                xticklabels=labels,
                yticklabels=labels,
                cmap="viridis",
                square=True,
                linewidths=0.5,
                cbar_kws={"shrink": 0.8})

    plt.title(f"Normalized Confusion Matrix (Partition gtd{partition})", fontsize=18)
    plt.xlabel("Predicted Label", fontsize=14)
    plt.ylabel("True Label", fontsize=14)
    plt.xticks(rotation=90)
    plt.yticks(rotation=0)
    plt.tight_layout()

    save_path = f"results/confusion_matrix_partition_gtd{partition}.png"
    plt.savefig(save_path, dpi=300)
    plt.close()

    print(f"Saved confusion matrix for partition gtd{partition} to {save_path}")



In [10]:
plot_confusion_matrix(targets, preds, labels, partition)

Saved confusion matrix for partition gtd300 to results/confusion_matrix_partition_gtd300.png
