In [1]:
import torch
from tqdm import tqdm
import torch.optim as optim
import os
from sklearn.metrics import top_k_accuracy_score
from torch.utils.tensorboard import SummaryWriter
from benchmark.metrics import MetricsCollection
import torch.nn as nn
from evaluate import LogitsEvaluator, EmbeddingEvaluator
import copy
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader
from dataset import CustomBatchSamplerPillID, PillImages
from benchmark.pillid_datasets import SingleImgPillID, BalancedBatchSamplerPillID, SiamesePillID
import utils
from benchmark.models.multihead_model import MultiheadModel
from benchmark.models.embedding_model import EmbeddingModel
from benchmark.models.losses import MultiheadLoss
from benchmark.metric_utils import HardNegativePairSelector, RandomNegativeTripletSelector
from train import Trainer
import pandas as pd


In [2]:
all_imgs_df, fold_indicies = utils.load_data()
ref_df = all_imgs_df[all_imgs_df.is_ref].reset_index(drop=True)
# unique_classes = all_imgs_df['label'].unique()
unique_classes = ref_df["label"].unique()
all_imgs_df = all_imgs_df[all_imgs_df["label"].isin(unique_classes)].reset_index(drop=True) 
n_classes = len(unique_classes)
label_encoder = LabelEncoder()
label_encoder.fit(unique_classes)
partitions = utils.split_data(all_imgs_df, fold_indicies)
datasets = utils.get_datasets(partitions, ref_df, 'label', False, label_encoder=label_encoder)
dataloaders = {}
for k,v in datasets.items():
    dataloaders[k] = DataLoader(v, batch_sampler=CustomBatchSamplerPillID(v.df, 32, labelcol='label', min_classes=5, min_per_class=3, keep_remainders=True, batch_size_mode='max', debug=False))
eval_dataset = PillImages(pd.concat([partitions['val'], ref_df]), "eval", labelcol="label", label_encoder=label_encoder)
dataloaders["eval"] = DataLoader(eval_dataset, batch_size=32, shuffle=False)

In [None]:
# dataloaders={}
# train_df = pd.concat([partitions["train"], ref_df])
# val_df = pd.concat([partitions["val"], ref_df])
# labelcol="label"
# # train_dataset = SingleImgPillID(train_df, label_encoder, train=True, labelcol=labelcol)
# # val_dataset = SingleImgPillID(val_df, label_encoder, train=False, labelcol=labelcol)
# train_dataset = PillImages(train_df, "train", labelcol=labelcol, label_encoder=label_encoder)
# val_dataset = PillImages(val_df, "val", labelcol=labelcol, label_encoder=label_encoder)
# dataloaders["train"] = DataLoader(train_dataset, batch_sampler=BalancedBatchSamplerPillID(train_df, batch_size=32, labelcol=labelcol))
# dataloaders["val"] = DataLoader(val_dataset, batch_sampler=BalancedBatchSamplerPillID(val_df, batch_size=32, labelcol=labelcol))
# dataloaders["eval"] = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [None]:
# generated with copilot
def clear_directory(directory):
    """
    Recursively deletes all files and subdirectories in the specified directory using os.walk.
    """
    for root, dirs, files in os.walk(directory, topdown=False):
        for name in files:
            os.remove(os.path.join(root, name))
        for name in dirs:
            os.rmdir(os.path.join(root, name))


In [None]:
# with batch sampler from dataset.py
torch.mps.empty_cache()
log_file_path = "./benchmark_training_logs"
writer = SummaryWriter(log_file_path)
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
print(device)
appearance_network = 'resnet50'
pooling = 'GAvP'
dropout = 0.0
embedding_dim = 2048
ce_w = 1.0
arcface_w = 0.1
contrastive_w = 1.0 
triplet_w = 1.0
focal_w = 0.0
loss_weights = {'ce': ce_w, 'arcface': arcface_w, 'contrastive': contrastive_w, 'triplet': triplet_w, 'focal': focal_w}
focal_gamma = 0.0
metric_margin = 1.0
train_with_side_labels = True
use_ref_labels = True
clip_grads = True
path = "./"
criterion = MultiheadLoss(len(label_encoder.classes_),
            metric_margin, HardNegativePairSelector(),
            metric_margin, RandomNegativeTripletSelector(metric_margin),
            use_cosine=False,
            weights=loss_weights,
            focal_gamma=focal_gamma,
            use_side_labels=train_with_side_labels)
E_model = EmbeddingModel(network=appearance_network, pooling=pooling, dropout_p=dropout, cont_dims=embedding_dim, pretrained=True)
model = MultiheadModel(E_model, n_classes, train_with_side_labels=train_with_side_labels).to(device)
opt = optim.Adam(model.parameters(), lr=1e-4)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.1, patience=5)
trainer = Trainer(device=device, model=model, dataloaders=dataloaders, clip_gradients=clip_grads, optimizer=opt, lr_scheduler=lr_scheduler, criterion=criterion, writer=writer, eval_update_type="logit", metric_type="euclidean", simulate_pairs=False, shift_labels=False, path=path)

mps
treat front/back as different classes (first half: front, second half: back), n_classes=9804


In [5]:
clear_directory(log_file_path)
clear_directory("./checkpoints")
trainer.train(num_epochs=10, checkpoint=3, earlystop_patience=5)

Epoch 0
Running train loop...


112it [00:35,  3.19it/s]                         


Running val loop...


38it [00:06,  5.59it/s]                        


Loading eval data...


100%|██████████| 330/330 [00:33<00:00,  9.75it/s]


Starting logit eval...
Starting emb eval...
logit_acc_1=0.0009479571523367143
logit_acc_5=0.0033178500331785005
logit_micro_ap=0.0002330153089054328
logit_mrr=0.003795272459241446
emb_acc_1=0.0912751677852349
emb_map_1=0.0912751677852349
emb_acc_5=0.22281879194630871
emb_map_5=0.13662192393736017
emb_micro_ap=0.012400867865038535
emb_mrr=0.1616293514560103
Saving model to path ./checkpoints/epoch_0.pth
Best checkpoint: 0, Best value: 0.0002330153089054328
Best Checkpoint index: 0
Epoch 1
Running train loop...


114it [00:35,  3.20it/s]                         


Running val loop...


38it [00:07,  5.41it/s]                        


Epoch 2
Running train loop...


114it [00:36,  3.15it/s]                         


Running val loop...


38it [00:06,  5.57it/s]                        


Epoch 3
Running train loop...


113it [00:36,  3.13it/s]                         


Running val loop...


38it [00:07,  5.40it/s]                        


Loading eval data...


100%|██████████| 330/330 [00:35<00:00,  9.36it/s]


Starting logit eval...
Starting emb eval...
logit_acc_1=0.0019907100199071004
logit_acc_5=0.007583657218693715
logit_micro_ap=0.00027589559021883026
logit_mrr=0.006601959764174629
emb_acc_1=0.1691275167785235
emb_map_1=0.1691275167785235
emb_acc_5=0.3919463087248322
emb_map_5=0.2513646532438479
emb_micro_ap=0.036682290230576536
emb_mrr=0.28164239895440685
Saving model to path ./checkpoints/epoch_3.pth
Best checkpoint: 1, Best value: 0.00027589559021883026
Best Checkpoint index: 1
Epoch 4
Running train loop...


114it [00:36,  3.16it/s]                         


Running val loop...


39it [00:07,  5.36it/s]                        


Epoch 5
Running train loop...


114it [00:35,  3.21it/s]                         


Running val loop...


38it [00:06,  5.58it/s]                        


Epoch 6
Running train loop...


114it [00:35,  3.20it/s]                         


Running val loop...


38it [00:06,  5.48it/s]                        


Loading eval data...


100%|██████████| 330/330 [00:34<00:00,  9.68it/s]


Starting logit eval...
Starting emb eval...
logit_acc_1=0.003697032894113186
logit_acc_5=0.013745378708882358
logit_micro_ap=0.00040234713152583997
logit_mrr=0.011014833083077585
emb_acc_1=0.2738255033557047
emb_map_1=0.2738255033557047
emb_acc_5=0.5288590604026846
emb_map_5=0.3615212527964206
emb_micro_ap=0.04912300947990809
emb_mrr=0.38768190827714244
Saving model to path ./checkpoints/epoch_6.pth
Best checkpoint: 2, Best value: 0.00040234713152583997
Best Checkpoint index: 2
Epoch 7
Running train loop...


114it [00:36,  3.15it/s]                         


Running val loop...


38it [00:07,  5.32it/s]                        


Epoch 8
Running train loop...


113it [00:35,  3.17it/s]                         


Running val loop...


38it [00:07,  5.34it/s]                        


Epoch 9
Running train loop...


113it [00:35,  3.18it/s]                         


Running val loop...


38it [00:07,  5.34it/s]                        


Loading eval data...


100%|██████████| 330/330 [00:36<00:00,  9.14it/s]


Starting logit eval...
Starting emb eval...
logit_acc_1=0.006825291496824344
logit_acc_5=0.02559484311309129
logit_micro_ap=0.0009902656984576029
logit_mrr=0.017259668950086793
emb_acc_1=0.33557046979865773
emb_map_1=0.33557046979865773
emb_acc_5=0.6040268456375839
emb_map_5=0.4340268456375839
emb_micro_ap=0.08791018876763936
emb_mrr=0.46110607511858703
Saving model to path ./checkpoints/epoch_9.pth
Best checkpoint: 3, Best value: 0.0009902656984576029
Best Checkpoint index: 3
Loading eval data...


100%|██████████| 330/330 [00:35<00:00,  9.29it/s]


Starting logit eval...
Starting emb eval...
logit_acc_1=0.006825291496824344
logit_acc_5=0.02559484311309129
logit_micro_ap=0.0009902656984576029
logit_mrr=0.017259668950086793
emb_acc_1=0.33557046979865773
emb_map_1=0.33557046979865773
emb_acc_5=0.6040268456375839
emb_map_5=0.4340268456375839
emb_micro_ap=0.08791018876763936
emb_mrr=0.46110607511858703
Saving model to path ./checkpoints/epoch_8.pth


'./checkpoints/epoch_9.pth'

In [None]:
# with batch sampler from benchmark.pillid_dataset.py (https://github.com/usuyama/ePillID-benchmark)
dataloaders={}
train_df = pd.concat([partitions["train"], ref_df])
val_df = pd.concat([partitions["val"], ref_df])
labelcol="label"
train_dataset = PillImages(train_df, "train", labelcol=labelcol, label_encoder=label_encoder)
val_dataset = PillImages(val_df, "val", labelcol=labelcol, label_encoder=label_encoder)
dataloaders["train"] = DataLoader(train_dataset, batch_sampler=BalancedBatchSamplerPillID(train_df, batch_size=32, labelcol=labelcol))
dataloaders["val"] = DataLoader(val_dataset, batch_sampler=BalancedBatchSamplerPillID(val_df, batch_size=32, labelcol=labelcol))
dataloaders["eval"] = DataLoader(val_dataset, batch_size=32, shuffle=False)
torch.mps.empty_cache()
log_file_path = "./benchmark_training_logs2"
writer2 = SummaryWriter(log_file_path)
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
print(device)
appearance_network = 'resnet50'
pooling = 'GAvP'
dropout = 0.0
embedding_dim = 2048
ce_w = 1.0
arcface_w = 0.1
contrastive_w = 1.0 
triplet_w = 1.0
focal_w = 0.0
loss_weights = {'ce': ce_w, 'arcface': arcface_w, 'contrastive': contrastive_w, 'triplet': triplet_w, 'focal': focal_w}
focal_gamma = 0.0
metric_margin = 1.0
train_with_side_labels = True
use_ref_labels = True
clip_grads = True
path = "./m2"
criterion2 = MultiheadLoss(len(label_encoder.classes_),
            metric_margin, HardNegativePairSelector(),
            metric_margin, RandomNegativeTripletSelector(metric_margin),
            use_cosine=False,
            weights=loss_weights,
            focal_gamma=focal_gamma,
            use_side_labels=train_with_side_labels)
E_model2 = EmbeddingModel(network=appearance_network, pooling=pooling, dropout_p=dropout, cont_dims=embedding_dim, pretrained=True)
model2 = MultiheadModel(E_model2, n_classes, train_with_side_labels=train_with_side_labels).to(device)
opt2 = optim.Adam(model.parameters(), lr=1e-4)
lr_scheduler2 = optim.lr_scheduler.ReduceLROnPlateau(opt2, mode='min', factor=0.1, patience=5)
trainer2 = Trainer(device=device, model=model2, dataloaders=dataloaders, clip_gradients=clip_grads, optimizer=opt2, lr_scheduler=lr_scheduler2, criterion=criterion2, writer=writer2, eval_update_type="logit", metric_type="euclidean", simulate_pairs=False, shift_labels=False, path=path)

mps
treat front/back as different classes (first half: front, second half: back), n_classes=9804


In [None]:
clear_directory(log_file_path)
clear_directory("./m2/checkpoints")
trainer2.train(num_epochs=10, checkpoint=3, earlystop_patience=5)

Epoch 0
Running train loop...


 34%|███▍      | 128/376 [00:38<01:14,  3.33it/s]


Running val loop...


 13%|█▎        | 43/329 [00:08<00:53,  5.35it/s]


Loading eval data...


100%|██████████| 330/330 [00:34<00:00,  9.58it/s]


Starting logit eval...
Starting emb eval...
logit_acc_1=9.479571523367143e-05
logit_acc_5=0.0007583657218693715
logit_micro_ap=0.00020253183326028972
logit_mrr=0.0017087015970376214
emb_acc_1=0.025503355704697986
emb_map_1=0.025503355704697986
emb_acc_5=0.0953020134228188
emb_map_5=0.04984340044742729
emb_micro_ap=0.002453326789231751
emb_mrr=0.06285456676357905
Saving model to path ./m2/checkpoints/epoch_0.pth
Best checkpoint: 0, Best value: 0.00020253183326028972
Best Checkpoint index: 0
Epoch 1
Running train loop...


 34%|███▍      | 129/376 [00:39<01:15,  3.28it/s]


Running val loop...


 13%|█▎        | 43/329 [00:07<00:53,  5.39it/s]


Epoch 2
Running train loop...


 34%|███▍      | 129/376 [00:38<01:14,  3.31it/s]


Running val loop...


 13%|█▎        | 43/329 [00:07<00:52,  5.40it/s]


Epoch 3
Running train loop...


 35%|███▍      | 130/376 [00:41<01:18,  3.13it/s]


Running val loop...


 13%|█▎        | 42/329 [00:07<00:52,  5.47it/s]


Loading eval data...


100%|██████████| 330/330 [00:35<00:00,  9.27it/s]


Starting logit eval...
Starting emb eval...
logit_acc_1=0.0
logit_acc_5=0.0006635700066357001
logit_micro_ap=0.00020339404053726102
logit_mrr=0.0016175682656183073
emb_acc_1=0.028187919463087248
emb_map_1=0.028187919463087248
emb_acc_5=0.09395973154362416
emb_map_5=0.049351230425055924
emb_micro_ap=0.0027192369931886198
emb_mrr=0.06275921403425387
Saving model to path ./m2/checkpoints/epoch_3.pth
Best checkpoint: 1, Best value: 0.00020339404053726102
Best Checkpoint index: 1
Epoch 4
Running train loop...


 34%|███▍      | 129/376 [00:38<01:13,  3.34it/s]


Running val loop...


 13%|█▎        | 43/329 [00:07<00:52,  5.42it/s]


Epoch 5
Running train loop...


 34%|███▍      | 129/376 [00:38<01:13,  3.35it/s]


Running val loop...


 13%|█▎        | 43/329 [00:08<00:53,  5.36it/s]


Epoch 6
Running train loop...


 34%|███▍      | 128/376 [00:38<01:14,  3.33it/s]


Running val loop...


 13%|█▎        | 42/329 [00:07<00:53,  5.41it/s]


Loading eval data...


100%|██████████| 330/330 [00:34<00:00,  9.69it/s]


Starting logit eval...
Starting emb eval...
logit_acc_1=0.0
logit_acc_5=0.0010427528675703858
logit_micro_ap=0.00020254200861197244
logit_mrr=0.0017039006441134474
emb_acc_1=0.024161073825503355
emb_map_1=0.024161073825503355
emb_acc_5=0.08993288590604027
emb_map_5=0.04733780760626398
emb_micro_ap=0.0028657554254951194
emb_mrr=0.06054814444814157
Saving model to path ./m2/checkpoints/epoch_6.pth
Best Checkpoint index: 1
Epoch 7
Running train loop...


 12%|█▏        | 46/376 [00:13<01:38,  3.34it/s]