In [2]:
import os
import random
import pandas as pd
import numpy as np
import mxnet as mx
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
from torch.utils.data import Dataset, DataLoader
from pytorch_metric_learning import losses
from einops import rearrange, repeat
import optuna
from optuna.trial import TrialState
from tqdm import tqdm

In [3]:
def file_to_embed(embeds, file):
    emb = []
    for f in file:
        emb.append(embeds[f][0])
    return torch.stack(emb)

In [4]:
MIN_NUM_PATCHES = 16

In [5]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=1)

In [6]:
class AdienceDataset(Dataset):
    def __init__(self, annot_file, img_dir):
        self.img_lbls = pd.read_csv(annot_file, header=None)
        self.img_dir = img_dir
    
    def __len__(self):
        return len(self.img_lbls)
    
    def __getitem__(self, idx):
        img_file = self.img_lbls.iloc[idx, 0]
        img_path = os.path.join(self.img_dir, img_file)
        image = mx.image.imread(img_path)
        if image.shape[1] != 112:
            image = mx.image.resize_short(image, 112)
        image = mx.nd.transpose(image, axes=(2,0,1))
        image = torch.tensor(image.asnumpy()).type(torch.FloatTensor)
        label = self.img_lbls.iloc[idx, 1]

        return image, label, img_file

In [7]:
train_data = AdienceDataset("../train.csv", "../cropped_Adience/")
val_data = AdienceDataset("../val.csv", "../cropped_Adience/")

In [8]:
class CosFace(nn.Module):
    r"""Implement of CosFace (https://arxiv.org/pdf/1801.09414.pdf):
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        device_id: the ID of GPU where the model will be trained by model parallel.
                       if device_id=None, it will be trained on CPU without model parallel.
        s: norm of input feature
        m: margin
        cos(theta)-m
    """

    def __init__(self, in_features, out_features, device_id, s=64.0, m=0.35):
        super(CosFace, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.device_id = device_id
        self.s = s
        self.m = m
        print("self.device_id", self.device_id)
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------

        if self.device_id == None:
            cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        else:
            x = input
            sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
            temp_x = x.cuda(self.device_id[0])
            weight = sub_weights[0].cuda(self.device_id[0])
            cosine = F.linear(F.normalize(temp_x), F.normalize(weight))
            for i in range(1, len(self.device_id)):
                temp_x = x.cuda(self.device_id[i])
                weight = sub_weights[i].cuda(self.device_id[i])
                cosine = torch.cat((cosine, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])),
                                   dim=1)
        phi = cosine - self.m
        # --------------------------- convert label to one-hot ---------------------------
        one_hot = torch.zeros(cosine.size())
        if self.device_id != None:
            one_hot = one_hot.cuda(self.device_id[0])
        # one_hot = one_hot.cuda() if cosine.is_cuda else one_hot

        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + (
                    (1.0 - one_hot) * cosine)  # you can use torch.where if your torch.__version__ is 0.4
        output *= self.s

        return output

    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + 'in_features = ' + str(self.in_features) \
               + ', out_features = ' + str(self.out_features) \
               + ', s = ' + str(self.s) \
               + ', m = ' + str(self.m) + ')'

In [9]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

In [10]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

In [11]:
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

In [12]:
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask = None):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        mask_value = -torch.finfo(dots.dtype).max
        #embed()
        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, mask_value)
            del mask

        attn = dots.softmax(dim=-1)

        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)

        return out

In [13]:
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
            ]))
    def forward(self, x, mask = None):
        for attn, ff in self.layers:
            x = attn(x, mask = mask)
            #embed()
            x = ff(x)
        return x

In [14]:
class ViT_face(nn.Module):
    def __init__(self, *, loss_type, GPU_ID, num_class, image_size, patch_size, dim, depth, heads, mlp_dim, pool = 'mean', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2
        assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.patch_size = patch_size

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
        )
        self.loss_type = loss_type
        self.GPU_ID = GPU_ID
        if self.loss_type == 'None':
            print("no loss for vit_face")
        else:
            if self.loss_type == 'CosFace':
                self.loss = CosFace(in_features=dim, out_features=num_class, device_id=self.GPU_ID)

    def forward(self, img, label=None, mask=None):
        p = self.patch_size
        
        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
        x = self.patch_to_embedding(x)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        x = self.transformer(x, mask)

        y = x[:, 0]
        z = x[:, 1:].mean(dim = 1)

        y = self.to_latent(y)
        emb_y = self.mlp_head(y)
        z = self.to_latent(z)
        emb_z = self.mlp_head(z)
        emb = torch.cat((emb_y, emb_z), dim=1)
        if label is not None:
            x = self.loss(emb, label)
            return x, emb
        else:
            return emb

In [15]:
class ViT_plus(nn.Module):
    def __init__(self):
        super(ViT_plus, self).__init__()
        
        self.fc1 = nn.Linear(in_features=1024, out_features=1024)
        self.fc2 = nn.Linear(in_features=1024, out_features=2)
        
    def forward(self, x):
        x = self.fc1(x)
        x_cosface = x
        x_classification = self.fc2(x)
        
        return x_cosface, x_classification

In [16]:
model = ViT_face(
            image_size=112,
            patch_size=8,
            loss_type='CosFace',
            GPU_ID= [device],
            num_class=93431,
            dim=512,
            depth=20,
            heads=8,
            mlp_dim=2048,
            dropout=0.1,
            emb_dropout=0.1
        ).to(device)
model.load_state_dict(
    torch.load("../Backbone_VIT_Epoch_2_Batch_20000_Time_2021-01-12-16-48_checkpoint.pth", map_location=device)
)

self.device_id [device(type='cuda', index=1)]


<All keys matched successfully>

In [17]:
for param in model.parameters():
    param.requires_grad = False

In [18]:
embeds = {}
model.eval()

with torch.no_grad():
    for img, label, file in train_data:
        img = img.to(device)
        embeds[file] = model(torch.unsqueeze(img, 0))

    for img, label, file in val_data:
        img = img.to(device)
        embeds[file] = model(torch.unsqueeze(img, 0))

In [23]:
best_accu = 0.9446246027946472
def objective(trial):
    model_xtr = ViT_plus().to(device)
    
    loss_lr = trial.suggest_float("loss_learning_rate", 1e-4, 1e-2, log=True)
    arc_margin = losses.ArcFaceLoss(2, 1024).to(device)
    loss_optimizer = opt.AdamW(arc_margin.parameters(), lr=loss_lr)
    
    lr = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True)
    wd = trial.suggest_float('weight_decay', 1e-4, 1e-2, log=True)
    eps = trial.suggest_float("epsilon", 1e-9, 1e-7, log=True)
    optimizer = opt.AdamW(model_xtr.parameters(), lr=lr, eps=eps, weight_decay=wd)
    
    criterion = nn.CrossEntropyLoss()
    
    batch_size = trial.suggest_int('batch_size', 50, 300)
    num_epochs = trial.suggest_int('epochs', 10, 100)
    
    print("Learning rate for Loss: "+ str(loss_lr))
    print("Learning rate: "+ str(lr))
    print("Weight decay: "+ str(wd))
    print("Epsilon: "+ str(eps))
    print("Batch size: "+ str(batch_size))
    print("Number of epochs: "+ str(num_epochs))
    
    for epoch in tqdm(range(num_epochs), desc="Epochs"):
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
        val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=4)
        
        # training loop
        running_loss = []
        running_accu = []
        
        model_xtr.train()
        for img, label, file in tqdm(train_loader, desc="Training", leave=False):
            img, label = img.to(device), label.to(device)

            x = file_to_embed(embeds, file)
            
            optimizer.zero_grad()
            embed, output = model_xtr(x)
            
            pred = torch.argmax(output, 1)
            accuracy = torch.eq(pred, label).sum() / len(img)

            class_loss = criterion(output, label)
            arc_loss = arc_margin(embed, label)
            loss = class_loss + arc_loss
            loss.backward()
            loss_optimizer.step()
            optimizer.step()

            running_accu.append(accuracy.cpu().detach().numpy())
            running_loss.append(loss.cpu().detach().numpy())
        print("Epoch: {}/{} - Loss: {:.4f} - Accuracy: {:.4f}".format(epoch+1, num_epochs, np.mean(running_loss), np.mean(running_accu)))
        
        # validation loop
        val_loss = []
        val_accu = []

        model_xtr.eval()
        with torch.no_grad():
            for img, label, file in tqdm(val_loader):
                img, label = img.to(device), label.to(device)
                
                x = file_to_embed(embeds, file)
                
                embed, output = model_xtr(x)
                
                pred = torch.argmax(output, 1)
                accuracy = torch.eq(pred, label).sum() / len(img)
                
                class_loss = criterion(output, label)
                arc_loss = arc_margin(embed, label)
                loss = class_loss + arc_loss
                
                val_accu.append(accuracy.cpu().detach().numpy())
                val_loss.append(loss.cpu().detach().numpy())
        val_accu = np.mean(val_accu)
        val_loss = np.mean(val_loss)
        print("Val Loss: {:.4f} - Val Accuracy: {:.4f}".format(val_loss, val_accu))
        
        trial.report(val_accu, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
    
    global best_accu
    if val_accu > best_accu:
        best_accu = val_accu
        print("Saving best model...")
        torch.save(model_xtr.state_dict(), "../vit_8-8_arcface_mean.pt")
            
    return val_accu

In [24]:
study = optuna.create_study(direction='maximize',
                            study_name='arcface-8-8-mean-vit-study',
                            storage='sqlite:///study1.db',
                            load_if_exists=True)
study.optimize(objective, n_trials=10)

pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

# Display the study statistics
print("\nStudy statistics: ")
print("  Number of finished trials: ", len(study.trials))
print("  Number of pruned trials: ", len(pruned_trials))
print("  Number of complete trials: ", len(complete_trials))

[32m[I 2023-12-07 20:51:59,530][0m Using an existing study with name 'arcface-8-8-mean-vit-study' instead of creating a new one.[0m


Learning rate for Loss: 0.0009233611225501734
Learning rate: 0.03828681316753371
Weight decay: 0.0024370304518440133
Epsilon: 3.433742523312084e-09
Batch size: 56
Number of epochs: 100


Epochs:   0%|          | 0/100 [00:00<?, ?it/s]
Training:   0%|          | 0/250 [00:00<?, ?it/s][A
Training:   0%|          | 1/250 [00:00<02:06,  1.97it/s][A
Training:   2%|▏         | 5/250 [00:00<00:24,  9.92it/s][A
Training:   4%|▎         | 9/250 [00:00<00:15, 15.36it/s][A
Training:   5%|▍         | 12/250 [00:00<00:13, 17.13it/s][A
Training:   6%|▌         | 15/250 [00:01<00:12, 19.34it/s][A
Training:   7%|▋         | 18/250 [00:01<00:11, 20.74it/s][A
Training:   9%|▉         | 22/250 [00:01<00:09, 23.06it/s][A
Training:  10%|█         | 26/250 [00:01<00:08, 25.12it/s][A
Training:  12%|█▏        | 30/250 [00:01<00:08, 26.05it/s][A
Training:  14%|█▎        | 34/250 [00:01<00:07, 27.00it/s][A
Training:  15%|█▌        | 38/250 [00:01<00:07, 27.55it/s][A
Training:  17%|█▋        | 42/250 [00:01<00:07, 27.12it/s][A
Training:  18%|█▊        | 46/250 [00:02<00:07, 27.01it/s][A
Training:  20%|██        | 50/250 [00:02<00:07, 27.27it/s][A
Training:  22%|██▏       | 54/250 

Epoch: 1/100 - Loss: 24.7942 - Accuracy: 0.8443



  0%|          | 0/32 [00:00<?, ?it/s][A
  3%|▎         | 1/32 [00:00<00:11,  2.64it/s][A
 16%|█▌        | 5/32 [00:00<00:02, 10.87it/s][A
 28%|██▊       | 9/32 [00:00<00:01, 16.24it/s][A
 41%|████      | 13/32 [00:00<00:00, 19.75it/s][A
 53%|█████▎    | 17/32 [00:00<00:00, 22.11it/s][A
 66%|██████▌   | 21/32 [00:01<00:00, 23.38it/s][A
 78%|███████▊  | 25/32 [00:01<00:00, 24.25it/s][A
100%|██████████| 32/32 [00:01<00:00, 21.20it/s][A
Epochs:   0%|          | 0/100 [00:11<?, ?it/s]
[32m[I 2023-12-07 20:52:11,376][0m Trial 10 pruned. [0m


Val Loss: 21.7157 - Val Accuracy: 0.8927
Learning rate for Loss: 0.00820579089904839
Learning rate: 0.00015364919241506037
Weight decay: 0.0002767023750603174
Epsilon: 1.810503395532833e-09
Batch size: 234
Number of epochs: 34


Epochs:   0%|          | 0/34 [00:00<?, ?it/s]
Training:   0%|          | 0/60 [00:00<?, ?it/s][A
Training:   2%|▏         | 1/60 [00:00<00:50,  1.16it/s][A
Training:   8%|▊         | 5/60 [00:01<00:14,  3.93it/s][A
Training:  15%|█▌        | 9/60 [00:02<00:09,  5.19it/s][A
Training:  22%|██▏       | 13/60 [00:02<00:08,  5.81it/s][A
Training:  28%|██▊       | 17/60 [00:03<00:07,  6.11it/s][A
Training:  35%|███▌      | 21/60 [00:03<00:06,  6.45it/s][A
Training:  42%|████▏     | 25/60 [00:04<00:05,  6.62it/s][A
Training:  48%|████▊     | 29/60 [00:04<00:04,  6.81it/s][A
Training:  55%|█████▌    | 33/60 [00:05<00:03,  6.84it/s][A
Training:  62%|██████▏   | 37/60 [00:06<00:03,  6.95it/s][A
Training:  68%|██████▊   | 41/60 [00:06<00:02,  7.01it/s][A
Training:  75%|███████▌  | 45/60 [00:07<00:02,  7.05it/s][A
Training:  82%|████████▏ | 49/60 [00:07<00:01,  7.08it/s][A
Training:  88%|████████▊ | 53/60 [00:08<00:00,  7.10it/s][A
Training:  95%|█████████▌| 57/60 [00:08<00:00,  7.

Epoch: 1/34 - Loss: 14.3772 - Accuracy: 0.7995



  0%|          | 0/8 [00:00<?, ?it/s][A
 12%|█▎        | 1/8 [00:00<00:06,  1.12it/s][A
 50%|█████     | 4/8 [00:01<00:00,  5.05it/s][A
100%|██████████| 8/8 [00:01<00:00,  4.92it/s][A
Epochs:   0%|          | 0/34 [00:10<?, ?it/s]
[32m[I 2023-12-07 20:52:22,263][0m Trial 11 pruned. [0m


Val Loss: 6.6593 - Val Accuracy: 0.9095
Learning rate for Loss: 0.002010369672178752
Learning rate: 0.008272134907077715
Weight decay: 0.0018539875985396422
Epsilon: 1.0393375964254468e-09
Batch size: 226
Number of epochs: 42


Epochs:   0%|          | 0/42 [00:00<?, ?it/s]
Training:   0%|          | 0/62 [00:00<?, ?it/s][A
Training:   2%|▏         | 1/62 [00:00<00:50,  1.20it/s][A
Training:   8%|▊         | 5/62 [00:01<00:13,  4.15it/s][A
Training:  15%|█▍        | 9/62 [00:01<00:09,  5.36it/s][A
Training:  21%|██        | 13/62 [00:02<00:08,  6.00it/s][A
Training:  27%|██▋       | 17/62 [00:03<00:06,  6.47it/s][A
Training:  34%|███▍      | 21/62 [00:03<00:06,  6.76it/s][A
Training:  40%|████      | 25/62 [00:04<00:05,  6.93it/s][A
Training:  47%|████▋     | 29/62 [00:04<00:04,  7.09it/s][A
Training:  53%|█████▎    | 33/62 [00:05<00:04,  7.15it/s][A
Training:  60%|█████▉    | 37/62 [00:05<00:03,  7.25it/s][A
Training:  66%|██████▌   | 41/62 [00:06<00:02,  7.19it/s][A
Training:  73%|███████▎  | 45/62 [00:06<00:02,  7.18it/s][A
Training:  79%|███████▉  | 49/62 [00:07<00:01,  7.23it/s][A
Training:  85%|████████▌ | 53/62 [00:07<00:01,  7.31it/s][A
Training:  92%|█████████▏| 57/62 [00:08<00:00,  7.

Epoch: 1/42 - Loss: 7.6307 - Accuracy: 0.8309



  0%|          | 0/8 [00:00<?, ?it/s][A
 12%|█▎        | 1/8 [00:00<00:06,  1.14it/s][A
100%|██████████| 8/8 [00:01<00:00,  5.01it/s][A
Epochs:   2%|▏         | 1/42 [00:10<07:24, 10.84s/it]

Val Loss: 3.2179 - Val Accuracy: 0.9212



Training:   0%|          | 0/62 [00:00<?, ?it/s][A
Training:   2%|▏         | 1/62 [00:00<00:50,  1.21it/s][A
Training:   8%|▊         | 5/62 [00:01<00:13,  4.20it/s][A
Training:  15%|█▍        | 9/62 [00:01<00:09,  5.40it/s][A
Training:  21%|██        | 13/62 [00:02<00:08,  6.07it/s][A
Training:  27%|██▋       | 17/62 [00:03<00:06,  6.50it/s][A
Training:  34%|███▍      | 21/62 [00:03<00:06,  6.78it/s][A
Training:  40%|████      | 25/62 [00:04<00:05,  6.97it/s][A
Training:  47%|████▋     | 29/62 [00:04<00:04,  7.08it/s][A
Training:  53%|█████▎    | 33/62 [00:05<00:04,  7.17it/s][A
Training:  60%|█████▉    | 37/62 [00:05<00:03,  7.18it/s][A
Training:  66%|██████▌   | 41/62 [00:06<00:02,  7.14it/s][A
Training:  73%|███████▎  | 45/62 [00:06<00:02,  7.16it/s][A
Training:  79%|███████▉  | 49/62 [00:07<00:01,  7.27it/s][A
Training:  85%|████████▌ | 53/62 [00:07<00:01,  7.31it/s][A
Training:  92%|█████████▏| 57/62 [00:08<00:00,  7.32it/s][A
Training:  98%|█████████▊| 61/62 [0

Epoch: 2/42 - Loss: 2.6093 - Accuracy: 0.9289



  0%|          | 0/8 [00:00<?, ?it/s][A
 12%|█▎        | 1/8 [00:00<00:06,  1.14it/s][A
100%|██████████| 8/8 [00:01<00:00,  5.03it/s][A
Epochs:   2%|▏         | 1/42 [00:21<14:47, 21.65s/it]
[32m[I 2023-12-07 20:52:44,162][0m Trial 12 pruned. [0m


Val Loss: 2.5575 - Val Accuracy: 0.9086
Learning rate for Loss: 0.0005254730934824125
Learning rate: 0.0001343483811328187
Weight decay: 0.0002718073491237901
Epsilon: 4.49837148852805e-09
Batch size: 62
Number of epochs: 81


Epochs:   0%|          | 0/81 [00:00<?, ?it/s]
Training:   0%|          | 0/226 [00:00<?, ?it/s][A
Training:   0%|          | 1/226 [00:00<01:29,  2.52it/s][A
Training:   2%|▏         | 5/226 [00:00<00:20, 11.01it/s][A
Training:   4%|▍         | 9/226 [00:00<00:13, 16.32it/s][A
Training:   6%|▌         | 13/226 [00:00<00:10, 19.61it/s][A
Training:   8%|▊         | 17/226 [00:00<00:09, 21.72it/s][A
Training:   9%|▉         | 21/226 [00:01<00:08, 22.92it/s][A
Training:  11%|█         | 25/226 [00:01<00:08, 24.00it/s][A
Training:  13%|█▎        | 29/226 [00:01<00:07, 24.94it/s][A
Training:  15%|█▍        | 33/226 [00:01<00:07, 25.52it/s][A
Training:  16%|█▋        | 37/226 [00:01<00:07, 25.52it/s][A
Training:  18%|█▊        | 41/226 [00:01<00:07, 25.75it/s][A
Training:  20%|█▉        | 45/226 [00:02<00:06, 26.40it/s][A
Training:  22%|██▏       | 49/226 [00:02<00:06, 26.33it/s][A
Training:  23%|██▎       | 53/226 [00:02<00:06, 25.41it/s][A
Training:  25%|██▌       | 57/226 [

Epoch: 1/81 - Loss: 9.8352 - Accuracy: 0.8604



  0%|          | 0/29 [00:00<?, ?it/s][A
  3%|▎         | 1/29 [00:00<00:11,  2.46it/s][A
 17%|█▋        | 5/29 [00:00<00:02, 10.32it/s][A
 31%|███       | 9/29 [00:00<00:01, 14.91it/s][A
 45%|████▍     | 13/29 [00:00<00:00, 17.68it/s][A
 59%|█████▊    | 17/29 [00:01<00:00, 19.62it/s][A
 72%|███████▏  | 21/29 [00:01<00:00, 20.78it/s][A
100%|██████████| 29/29 [00:01<00:00, 18.98it/s][A
Epochs:   0%|          | 0/81 [00:10<?, ?it/s]
[32m[I 2023-12-07 20:52:54,995][0m Trial 13 pruned. [0m


Val Loss: 4.8549 - Val Accuracy: 0.9050
Learning rate for Loss: 0.004725291999976461
Learning rate: 0.005358794491244823
Weight decay: 0.0003556283031120871
Epsilon: 5.672565099667101e-09
Batch size: 207
Number of epochs: 45


Epochs:   0%|          | 0/45 [00:00<?, ?it/s]
Training:   0%|          | 0/68 [00:00<?, ?it/s][A
Training:   1%|▏         | 1/68 [00:00<00:50,  1.32it/s][A
Training:   7%|▋         | 5/68 [00:01<00:13,  4.56it/s][A
Training:  13%|█▎        | 9/68 [00:01<00:09,  5.99it/s][A
Training:  19%|█▉        | 13/68 [00:02<00:08,  6.68it/s][A
Training:  25%|██▌       | 17/68 [00:02<00:07,  7.13it/s][A
Training:  31%|███       | 21/68 [00:03<00:06,  7.36it/s][A
Training:  37%|███▋      | 25/68 [00:03<00:05,  7.53it/s][A
Training:  43%|████▎     | 29/68 [00:04<00:05,  7.60it/s][A
Training:  49%|████▊     | 33/68 [00:04<00:04,  7.71it/s][A
Training:  54%|█████▍    | 37/68 [00:05<00:04,  7.71it/s][A
Training:  60%|██████    | 41/68 [00:05<00:03,  7.77it/s][A
Training:  66%|██████▌   | 45/68 [00:06<00:02,  7.81it/s][A
Training:  71%|███████   | 48/68 [00:06<00:02,  9.53it/s][A
Training:  74%|███████▎  | 50/68 [00:06<00:02,  7.88it/s][A
Training:  78%|███████▊  | 53/68 [00:07<00:02,  7.

Epoch: 1/45 - Loss: 5.1322 - Accuracy: 0.8617



  0%|          | 0/9 [00:00<?, ?it/s][A
 11%|█         | 1/9 [00:00<00:06,  1.20it/s][A
 56%|█████▌    | 5/9 [00:01<00:00,  4.14it/s][A
100%|██████████| 9/9 [00:01<00:00,  5.37it/s][A
Epochs:   0%|          | 0/45 [00:10<?, ?it/s]
[32m[I 2023-12-07 20:53:06,051][0m Trial 14 pruned. [0m


Val Loss: 2.6232 - Val Accuracy: 0.9161
Learning rate for Loss: 0.0018172682948593003
Learning rate: 0.05545058398688582
Weight decay: 0.001401980381031568
Epsilon: 8.834252536741359e-08
Batch size: 257
Number of epochs: 41


Epochs:   0%|          | 0/41 [00:00<?, ?it/s]
Training:   0%|          | 0/55 [00:00<?, ?it/s][A
Training:   2%|▏         | 1/55 [00:00<00:50,  1.08it/s][A
Training:   9%|▉         | 5/55 [00:01<00:13,  3.68it/s][A
Training:  16%|█▋        | 9/55 [00:02<00:09,  4.75it/s][A
Training:  24%|██▎       | 13/55 [00:02<00:07,  5.38it/s][A
Training:  31%|███       | 17/55 [00:03<00:06,  5.68it/s][A
Training:  38%|███▊      | 21/55 [00:04<00:05,  5.90it/s][A
Training:  45%|████▌     | 25/55 [00:04<00:04,  6.03it/s][A
Training:  53%|█████▎    | 29/55 [00:05<00:04,  6.16it/s][A
Training:  60%|██████    | 33/55 [00:05<00:03,  6.25it/s][A
Training:  67%|██████▋   | 37/55 [00:06<00:02,  6.30it/s][A
Training:  75%|███████▍  | 41/55 [00:07<00:02,  6.29it/s][A
Training:  82%|████████▏ | 45/55 [00:07<00:01,  6.35it/s][A
Training:  89%|████████▉ | 49/55 [00:08<00:00,  6.33it/s][A
Training:  96%|█████████▋| 53/55 [00:09<00:00,  6.31it/s][A
                                                  

Epoch: 1/41 - Loss: 124.0211 - Accuracy: 0.8041



  0%|          | 0/7 [00:00<?, ?it/s][A
 14%|█▍        | 1/7 [00:00<00:05,  1.01it/s][A
100%|██████████| 7/7 [00:01<00:00,  3.89it/s][A
Epochs:   0%|          | 0/41 [00:11<?, ?it/s]
[32m[I 2023-12-07 20:53:17,354][0m Trial 15 pruned. [0m


Val Loss: 10.7497 - Val Accuracy: 0.9063
Learning rate for Loss: 0.0005216276745841489
Learning rate: 5.50166763437879e-05
Weight decay: 0.00017626042980704932
Epsilon: 2.3430079328995787e-08
Batch size: 93
Number of epochs: 19


Epochs:   0%|          | 0/19 [00:00<?, ?it/s]
Training:   0%|          | 0/151 [00:00<?, ?it/s][A
Training:   1%|          | 1/151 [00:00<01:06,  2.26it/s][A
Training:   3%|▎         | 5/151 [00:00<00:16,  8.77it/s][A
Training:   6%|▌         | 9/151 [00:00<00:11, 12.13it/s][A
Training:   9%|▊         | 13/151 [00:01<00:09, 14.11it/s][A
Training:  11%|█▏        | 17/151 [00:01<00:08, 15.28it/s][A
Training:  14%|█▍        | 21/151 [00:01<00:08, 15.74it/s][A
Training:  17%|█▋        | 25/151 [00:01<00:07, 16.28it/s][A
Training:  19%|█▉        | 29/151 [00:02<00:07, 16.82it/s][A
Training:  22%|██▏       | 33/151 [00:02<00:06, 17.25it/s][A
Training:  25%|██▍       | 37/151 [00:02<00:06, 17.40it/s][A
Training:  27%|██▋       | 41/151 [00:02<00:06, 17.26it/s][A
Training:  30%|██▉       | 45/151 [00:02<00:06, 17.44it/s][A
Training:  32%|███▏      | 49/151 [00:03<00:05, 17.41it/s][A
Training:  35%|███▌      | 53/151 [00:03<00:05, 17.45it/s][A
Training:  38%|███▊      | 57/151 [

Epoch: 1/19 - Loss: 17.7219 - Accuracy: 0.7609



  0%|          | 0/19 [00:00<?, ?it/s][A
  5%|▌         | 1/19 [00:00<00:08,  2.03it/s][A
 26%|██▋       | 5/19 [00:00<00:01,  7.87it/s][A
 47%|████▋     | 9/19 [00:00<00:00, 10.99it/s][A
 68%|██████▊   | 13/19 [00:01<00:00, 12.41it/s][A
100%|██████████| 19/19 [00:01<00:00, 12.07it/s][A
Epochs:   0%|          | 0/19 [00:10<?, ?it/s]
[32m[I 2023-12-07 20:53:28,154][0m Trial 16 pruned. [0m


Val Loss: 10.1930 - Val Accuracy: 0.9003
Learning rate for Loss: 0.0002705105262705772
Learning rate: 0.0034362578330984555
Weight decay: 0.00037861412975971495
Epsilon: 2.400120132265128e-09
Batch size: 191
Number of epochs: 73


Epochs:   0%|          | 0/73 [00:00<?, ?it/s]
Training:   0%|          | 0/74 [00:00<?, ?it/s][A
Training:   1%|▏         | 1/74 [00:00<00:53,  1.37it/s][A
Training:   7%|▋         | 5/74 [00:01<00:14,  4.85it/s][A
Training:  12%|█▏        | 9/74 [00:01<00:10,  6.31it/s][A
Training:  18%|█▊        | 13/74 [00:02<00:08,  7.12it/s][A
Training:  23%|██▎       | 17/74 [00:02<00:07,  7.65it/s][A
Training:  28%|██▊       | 21/74 [00:03<00:06,  7.96it/s][A
Training:  34%|███▍      | 25/74 [00:03<00:05,  8.20it/s][A
Training:  39%|███▉      | 29/74 [00:03<00:05,  8.32it/s][A
Training:  45%|████▍     | 33/74 [00:04<00:04,  8.46it/s][A
Training:  50%|█████     | 37/74 [00:04<00:04,  8.41it/s][A
Training:  55%|█████▌    | 41/74 [00:05<00:03,  8.47it/s][A
Training:  61%|██████    | 45/74 [00:05<00:03,  8.48it/s][A
Training:  66%|██████▌   | 49/74 [00:06<00:02,  8.61it/s][A
Training:  72%|███████▏  | 53/74 [00:06<00:02,  8.49it/s][A
Training:  77%|███████▋  | 57/74 [00:07<00:01,  8.

Epoch: 1/73 - Loss: 6.0735 - Accuracy: 0.8649



  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [00:00<00:06,  1.34it/s][A
 50%|█████     | 5/10 [00:01<00:01,  4.48it/s][A
100%|██████████| 10/10 [00:01<00:00,  5.34it/s][A
Epochs:   0%|          | 0/73 [00:11<?, ?it/s]
[32m[I 2023-12-07 20:53:39,560][0m Trial 17 pruned. [0m


Val Loss: 3.5009 - Val Accuracy: 0.9158
Learning rate for Loss: 0.0008856753006887652
Learning rate: 0.017677771848393038
Weight decay: 0.0036388531317124856
Epsilon: 6.620230122167153e-09
Batch size: 151
Number of epochs: 45


Epochs:   0%|          | 0/45 [00:00<?, ?it/s]
Training:   0%|          | 0/93 [00:00<?, ?it/s][A
Training:   1%|          | 1/93 [00:00<00:56,  1.63it/s][A
Training:   5%|▌         | 5/93 [00:00<00:14,  5.92it/s][A
Training:  10%|▉         | 9/93 [00:01<00:10,  7.88it/s][A
Training:  14%|█▍        | 13/93 [00:01<00:08,  8.95it/s][A
Training:  18%|█▊        | 17/93 [00:02<00:07,  9.65it/s][A
Training:  23%|██▎       | 21/93 [00:02<00:07, 10.10it/s][A
Training:  27%|██▋       | 25/93 [00:02<00:06, 10.40it/s][A
Training:  31%|███       | 29/93 [00:03<00:06, 10.59it/s][A
Training:  35%|███▌      | 33/93 [00:03<00:05, 10.71it/s][A
Training:  40%|███▉      | 37/93 [00:03<00:05, 10.73it/s][A
Training:  44%|████▍     | 41/93 [00:04<00:04, 10.79it/s][A
Training:  48%|████▊     | 45/93 [00:04<00:04, 10.73it/s][A
Training:  53%|█████▎    | 49/93 [00:05<00:04, 10.79it/s][A
Training:  57%|█████▋    | 53/93 [00:05<00:03, 10.82it/s][A
Training:  61%|██████▏   | 57/93 [00:05<00:03, 10.

Epoch: 1/45 - Loss: 11.5889 - Accuracy: 0.8312



  0%|          | 0/12 [00:00<?, ?it/s][A
  8%|▊         | 1/12 [00:00<00:07,  1.56it/s][A
 42%|████▏     | 5/12 [00:01<00:01,  5.55it/s][A
100%|██████████| 12/12 [00:01<00:00,  7.78it/s][A
Epochs:   0%|          | 0/45 [00:10<?, ?it/s]
[32m[I 2023-12-07 20:53:50,318][0m Trial 18 pruned. [0m


Val Loss: 3.8049 - Val Accuracy: 0.9197
Learning rate for Loss: 0.0019908238676840553
Learning rate: 0.0005623733773938865
Weight decay: 0.0011554051156482828
Epsilon: 1.8467391861919478e-09
Batch size: 262
Number of epochs: 90


Epochs:   0%|          | 0/90 [00:00<?, ?it/s]
Training:   0%|          | 0/54 [00:00<?, ?it/s][A
Training:   2%|▏         | 1/54 [00:00<00:50,  1.06it/s][A
Training:   9%|▉         | 5/54 [00:01<00:13,  3.66it/s][A
Training:  17%|█▋        | 9/54 [00:02<00:09,  4.75it/s][A
Training:  24%|██▍       | 13/54 [00:02<00:07,  5.31it/s][A
Training:  31%|███▏      | 17/54 [00:03<00:06,  5.66it/s][A
Training:  39%|███▉      | 21/54 [00:04<00:05,  5.83it/s][A
Training:  46%|████▋     | 25/54 [00:04<00:04,  5.96it/s][A
Training:  54%|█████▎    | 29/54 [00:05<00:04,  6.09it/s][A
Training:  61%|██████    | 33/54 [00:06<00:03,  6.13it/s][A
Training:  69%|██████▊   | 37/54 [00:06<00:02,  6.15it/s][A
Training:  76%|███████▌  | 41/54 [00:07<00:02,  6.20it/s][A
Training:  83%|████████▎ | 45/54 [00:07<00:01,  6.22it/s][A
Training:  91%|█████████ | 49/54 [00:08<00:00,  6.26it/s][A
Training:  98%|█████████▊| 53/54 [00:09<00:00,  6.25it/s][A
                                                  

Epoch: 1/90 - Loss: 9.1214 - Accuracy: 0.8553



  0%|          | 0/7 [00:00<?, ?it/s][A
 14%|█▍        | 1/7 [00:00<00:05,  1.01it/s][A
100%|██████████| 7/7 [00:01<00:00,  3.90it/s][A
Epochs:   0%|          | 0/90 [00:11<?, ?it/s]
[32m[I 2023-12-07 20:54:01,727][0m Trial 19 pruned. [0m


Val Loss: 3.9975 - Val Accuracy: 0.9189

Study statistics: 
  Number of finished trials:  20
  Number of pruned trials:  13
  Number of complete trials:  7


In [25]:
print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

Best trial:
  Value:  0.9446246027946472
  Params: 
    batch_size: 127
    epochs: 35
    epsilon: 3.601242341054396e-08
    learning_rate: 0.001556997542805148
    loss_learning_rate: 0.0011138215118992293
    weight_decay: 0.00021625482614416372


In [None]:
# ViT P8-S8 ArcFace Mean

Best trial:
Value:  0.9446246027946472
Params: 
batch_size: 127
epochs: 35
epsilon: 3.601242341054396e-08
learning_rate: 0.001556997542805148
loss_learning_rate: 0.0011138215118992293
weight_decay: 0.00021625482614416372