In [None]:
# HuiduRep Trainer

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import torch
import torch.optim as optim
import numpy as np
import os

from model.CMAES import CMAES
from train.CMAES_trainer import CMAESTrainer
from utils.CMAES_utils import initialize_weights, load_model
from utils.scheduler_utils import get_param_groups, WarmupCosineLR, LARS

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
random.seed(42)
torch.manual_seed(42)
np.random.seed(42)
print(device)

In [None]:
huidu = CMAES(T=0.2,
              mask_ratio=0.0,
              use_embedding=True,
              n_heads=4,
              m=0.99,
              use_avg_pool=True,
              K=4096,
              embedding_dim=64,
              ff_dim=128,
              num_layers=3,
              dropout=0.1,
              moco_v3=True)
huidu = load_model(huidu, './resources/checkpoint/HuiduRep.pt')
for param, name in huidu.named_parameters():
    print(param)

In [None]:
# dataset is about 5GB
train_labels = None
test_labels = None
train_dataset = None
test_dataset = None
target = [11, 13, 16, 69, 84, 89, 277, 267, 332, 343]

train_units = np.unique(train_labels)
samples = [i for i in range(len(train_units))]
samples = random.sample(samples, len(train_units) // 4)
train_units = train_units[samples]
mask = torch.isin(torch.tensor(test_labels), torch.tensor(train_units))
indices = torch.nonzero(mask, as_tuple=False).squeeze(dim=1)
train_dataset = train_dataset[indices]
train_labels = train_labels[indices]
print(train_dataset.shape)

In [None]:
ita = 4e-4
weight_decay = 1e-4 #0.05
optimizer = optim.AdamW(huidu.parameters(), weight_decay=weight_decay, lr=ita)
epochs = 32

batch_size = int(1024 * 1)

# train_steps = epochs * (len(train_dataset) // batch_size)
# warmup_steps = int(0.1 * train_steps)

# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs * 65, eta_min=1e-5)
scheduler = WarmupCosineLR(optimizer, warmup_epochs=4 * 16 // 1, max_epochs=epochs * 16 // 1, warmup_start_lr=1e-5, eta_min=5e-6)
trainer = CMAESTrainer(huidu,
                       optimizer=optimizer,
                       train_dataset=train_dataset,
                       val_dataset=test_dataset,
                       train_labels=train_labels,
                       val_labels=test_labels,
                       test_labels=target,
                       eval_epochs=1,
                       batch_size=batch_size,
                       epochs=epochs,
                       tensorboard=True,
                       tensorboard_steps=10,
                       save_path='./resources/chec',
                       scheduler=scheduler,
                       total_steps=epochs * 16,
                       valid=False)

In [None]:
trainer.train()