In [None]:
import torch
from tqdm import tqdm
import torch.optim as optim
import os
from sklearn.metrics import top_k_accuracy_score
from torch.utils.tensorboard import SummaryWriter
from benchmark.metrics import MetricsCollection
import torch.nn as nn
from evaluate import LogitsEvaluator, EmbeddingEvaluator
import copy
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader
from dataset import CustomBatchSamplerPillID, PillImages
from benchmark.pillid_datasets import SingleImgPillID, BalancedBatchSamplerPillID, SiamesePillID
import utils
from benchmark.models.multihead_model import MultiheadModel
from benchmark.models.embedding_model import EmbeddingModel
from benchmark.models.losses import MultiheadLoss
from benchmark.metric_utils import HardNegativePairSelector, RandomNegativeTripletSelector
from train import Trainer
import pandas as pd


In [None]:
all_imgs_df, fold_indicies = utils.load_data()
ref_df = all_imgs_df[all_imgs_df.is_ref].reset_index(drop=True)
# unique_classes = all_imgs_df['label'].unique()
unique_classes = ref_df["label"].unique()
all_imgs_df = all_imgs_df[all_imgs_df["label"].isin(unique_classes)].reset_index(drop=True) 
n_classes = len(unique_classes)
label_encoder = LabelEncoder()
label_encoder.fit(unique_classes)
partitions = utils.split_data(all_imgs_df, fold_indicies)
datasets = utils.get_datasets(partitions, ref_df, 'label', False, label_encoder=label_encoder)
dataloaders = {}
for k,v in datasets.items():
    dataloaders[k] = DataLoader(v, batch_sampler=CustomBatchSamplerPillID(v.df, 32, labelcol='label', min_classes=5, min_per_class=2, keep_remainders=True, batch_size_mode=None, debug=False))
eval_dataset = PillImages(pd.concat([partitions['val'], ref_df]), "eval", labelcol="label", label_encoder=label_encoder)
dataloaders["eval"] = DataLoader(eval_dataset, batch_size=32, shuffle=False)

In [None]:
# generated with copilot
def clear_directory(directory):
    """
    Recursively deletes all files and subdirectories in the specified directory using os.walk.
    """
    for root, dirs, files in os.walk(directory, topdown=False):
        for name in files:
            os.remove(os.path.join(root, name))
        for name in dirs:
            os.rmdir(os.path.join(root, name))


In [None]:
# with batch sampler from dataset.py
torch.mps.empty_cache()
log_file_path = "./benchmark_training_logs"
writer = SummaryWriter(log_file_path)
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
print(device)
appearance_network = 'resnet50'
pooling = 'GAvP'
dropout = 0.0
embedding_dim = 2048
ce_w = 1.0
arcface_w = 0.1
contrastive_w = 1.0 
triplet_w = 1.0
focal_w = 0.0
loss_weights = {'ce': ce_w, 'arcface': arcface_w, 'contrastive': contrastive_w, 'triplet': triplet_w, 'focal': focal_w}
focal_gamma = 0.0
metric_margin = 1.0
train_with_side_labels = False
train_with_ref_labels = False
clip_grads = True
simulate_pairs = False
shift_labels = False
path = "./"
criterion = MultiheadLoss(len(label_encoder.classes_),
            metric_margin, HardNegativePairSelector(),
            metric_margin, RandomNegativeTripletSelector(metric_margin),
            use_cosine=False,
            weights=loss_weights,
            focal_gamma=focal_gamma,
            use_side_labels=train_with_side_labels)
E_model = EmbeddingModel(network=appearance_network, pooling=pooling, dropout_p=dropout, cont_dims=embedding_dim, pretrained=True)
model = MultiheadModel(E_model, n_classes, train_with_side_labels=train_with_side_labels).to(device)
opt = optim.Adam(model.parameters(), lr=1e-4)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.1, patience=5)
trainer = Trainer(device=device, model=model, dataloaders=dataloaders, clip_gradients=clip_grads, optimizer=opt, lr_scheduler=lr_scheduler, criterion=criterion, writer=writer, eval_update_type="logit", metric_type="euclidean", simulate_pairs=simulate_pairs, shift_labels=shift_labels, train_with_ref_labels=train_with_ref_labels, plot_metrics_names=["acc_1", "acc_5", "loss", "micro_ap", "map", "mrr"], path=path)

In [None]:
clear_directory(log_file_path)
clear_directory("./checkpoints")
trainer.train(num_epochs=10, checkpoint=3, earlystop_patience=5)

In [None]:
# with batch sampler from benchmark.pillid_dataset.py (https://github.com/usuyama/ePillID-benchmark)
dataloaders={}
train_df = pd.concat([partitions["train"], ref_df])
val_df = pd.concat([partitions["val"], ref_df])
labelcol="label"
train_dataset = PillImages(train_df, "train", labelcol=labelcol, label_encoder=label_encoder)
val_dataset = PillImages(val_df, "val", labelcol=labelcol, label_encoder=label_encoder)
dataloaders["train"] = DataLoader(train_dataset, batch_sampler=BalancedBatchSamplerPillID(train_df, batch_size=32, labelcol=labelcol))
dataloaders["val"] = DataLoader(val_dataset, batch_sampler=BalancedBatchSamplerPillID(val_df, batch_size=32, labelcol=labelcol))
dataloaders["eval"] = DataLoader(val_dataset, batch_size=32, shuffle=False)
torch.mps.empty_cache()
log_file_path = "./benchmark_training_logs2"
writer2 = SummaryWriter(log_file_path)
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
print(device)
appearance_network = 'resnet50'
pooling = 'GAvP'
dropout = 0.0
embedding_dim = 2048
ce_w = 1.0
arcface_w = 0.1
contrastive_w = 1.0 
triplet_w = 1.0
focal_w = 0.0
loss_weights = {'ce': ce_w, 'arcface': arcface_w, 'contrastive': contrastive_w, 'triplet': triplet_w, 'focal': focal_w}
focal_gamma = 0.0
metric_margin = 1.0
train_with_side_labels = False
train_with_ref_labels = False
clip_grads = True
shift_labels = False
simulate_pairs = False
path = "./m2"
criterion2 = MultiheadLoss(len(label_encoder.classes_),
            metric_margin, HardNegativePairSelector(),
            metric_margin, RandomNegativeTripletSelector(metric_margin),
            use_cosine=False,
            weights=loss_weights,
            focal_gamma=focal_gamma,
            use_side_labels=train_with_side_labels)
E_model2 = EmbeddingModel(network=appearance_network, pooling=pooling, dropout_p=dropout, cont_dims=embedding_dim, pretrained=True)
model2 = MultiheadModel(E_model2, n_classes, train_with_side_labels=train_with_side_labels).to(device)
opt2 = optim.Adam(model2.parameters(), lr=1e-4)
lr_scheduler2 = optim.lr_scheduler.ReduceLROnPlateau(opt2, mode='min', factor=0.1, patience=5)
trainer2 = Trainer(device=device, model=model2, dataloaders=dataloaders, clip_gradients=clip_grads, optimizer=opt2, lr_scheduler=lr_scheduler2, criterion=criterion2, writer=writer2, eval_update_type="logit", metric_type="euclidean", simulate_pairs=simulate_pairs, shift_labels=shift_labels, train_with_ref_labels=train_with_ref_labels, plot_metrics_names=["acc_1", "acc_5", "loss", "micro_ap", "map", "mrr"], path=path)

In [None]:
clear_directory(log_file_path)
clear_directory("./m2/checkpoints")
trainer2.train(num_epochs=20, checkpoint=3, earlystop_patience=5)