In [None]:
from dataset import *
from train import *
from utils import torch_fix_seed

In [None]:
torch_fix_seed(19981303)

In [None]:
top_k_candidates = (20, 30, 40)
k = 7
temps = 0.7
momentums = 0.99
k_value = 64

def get_self_pretrain_model(index=1000):
    base_encoder = torchvision.models.resnet50(weights=None)
    base_encoder.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    origin_dim_mlp = base_encoder.fc.in_features
    base_encoder.fc = None
    temp = torch.load(f'./checkpoints/CEM_ALL_CHECK_{index}_Epoch.pth')['state_dict']
    state_dict = {}
    for k, v in temp.items():
        if 'encoder_q' in k:
            if 'fc' not in k:
                state_dict['.'.join(k.split('.')[1:])] = v
    base_encoder.load_state_dict(state_dict)
    base_encoder.fc = torch.nn.Linear(origin_dim_mlp, 512)
    return base_encoder

In [None]:
parameters = {'pretrain_model': ['self_pretrained']}
train_metrics = HistoryRecorder(['Train Loss', 'Train Acc', 'Val Loss', 'Val Acc'], list(parameters.keys()))

parameters = list(itertools.product(*parameters.values()))

for parameter in parameters:

    ### custom part to get parameters
    pretrain_model = parameter[0]
    ### END
    
    for images in k_fold_train_validation_split(ORIGINAL_IMAGE, TARGET_IMAGE, k):
        train_dataset = SingleChannelNDIDatasetContrastiveLearningWithAug(images, False)
        val_dataset = SingleChannelNDIDatasetContrastiveLearningWithAug(images, True)
        train_iter = DataLoader(train_dataset, batch_size=16, shuffle=True, drop_last=True)
        val_iter = DataLoader(val_dataset, batch_size=len(val_dataset))

        model = get_model(pretrain_model)
        if pretrain_model in ['self_pretrained', 'CEM']:
            base_params = list(filter(lambda kv: 'fc' not in kv[0], model.encoder_q.named_parameters()))
            base_params = [v for k, v in base_params]
            fc_params = list(filter(lambda kv: 'fc' in kv[0], model.encoder_q.named_parameters()))
            fc_params = [v for k, v in fc_params]
            params = [{'params': base_params, 'lr': 0.02 * 0.5}, {'params': fc_params, 'lr': 0.02}]
        else:
            params = [{'params': model.parameters(), 'lr': 0.02}]
            
        device = torch.device('cuda:0')
        criterion = nn.CrossEntropyLoss().cuda(device)
        optimizer = torch.optim.SGD(params=params, momentum=0.9, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
        start_time = time.time()
        metrics = train_moco_return_metrics_top_k(model, train_iter, val_iter, criterion, optimizer, 20, device,
                                                    tested_parameter=parameter, k_candidates=top_k_candidates, scheduler=scheduler)
        end_time = time.time()