In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

In [None]:
dataloaders = (train_dataloader, val_dataloader, test_dataloader)

In [None]:
def set_model(args):
    """Initialize models

    Lineare Modelle, welche später die extrahierten Features übergeben bekommen

    :param args: training arguments
    :return: tuple
        - model: Initialized model
        - criteria_x: Supervised loss function
        - ema_model: Initialized ema model
    """
    if args.dataset.lower() == 'cifar100':
        feature_dim = 1280
    elif args.dataset.lower() == 'nih':
        feature_dim = 512
    else:
        print(f'Dataset {args.dataset} not defined')
        sys.exit()
    model = LinearNN(num_classes=args.n_classes, feature_dim=feature_dim, proj=True)

    model.train()
    model.cuda()  
    
    if args.eval_ema:
        ema_model = LinearNN(num_classes=args.n_classes, feature_dim=feature_dim, proj=True)
        for param_q, param_k in zip(model.parameters(), ema_model.parameters()):
            param_k.data.copy_(param_q.detach().data)  # initialize
            param_k.requires_grad = False  # not update by gradient for eval_net
        ema_model.cuda()  
        ema_model.eval()
    else:
        ema_model = None
        
    criteria_x = nn.CrossEntropyLoss().cuda()
    return model,  criteria_x, ema_model

In [None]:
def main():
    args = {
        "root", "", #Dataset direcotry
        "wresnet-k", 2, #width factor of wide resnet
        "wresnet-n", 28, #depth of wide resnet
        "dataset", "nih",
        "n-classes", 2, #number of classes in dataset
        "n-labeled", 40, #number of labeled samples for training
        "n-epoches", 10, #number of training epoches
        "batchsize", 16, #train batch size of labeled samples
        "mu", 7, #factor of train batch size of unlabeled samples
        "n-imgs-per-epoch", 32768, #number of training images for each epoch
        "eval-ema", True, #whether to use ema model for evaluation
        "ema-m", 0.999, #
        "lam-u", 1., #coefficient of unlabeled loss
        "lr", 0.03, #learning rate for training
        "weight-decay", 5e-4, #weight decay
        "momentum", 0.9, #momentum for optimizer
        "seed", 1, #seed for random behaviors, no seed if negtive
        "temperature", 0.2, #softmax temperature
        "low-dim", 64, #
        "lam-c", 1, #coefficient of contrastive loss
        "contrast-th", 0.8, #pseudo label graph threshold
        "thr", 0.95, #pseudo label threshold
        "alpha", 0.9, #
        "queue-batch", 5, #number of batches stored in memory bank
        "exp-dir", "EmbeddingCM_bin", #experiment id
        "ex_strength", 4295342357, #Strength of the expert
        "", , #
    }

    #Setzt Logger fest
    logger, output_dir = setup_default_logging(args)
    logger.info(dict(args._get_kwargs()))
    
    tb_logger = SummaryWriter(output_dir)

    #Seed init
    if args.seed >= 0:
        set_seed(args.seed)

    n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize  # 1024
    n_iters_all = n_iters_per_epoch * args.n_epoches  # 1024 * 200

    path = "../../../Datasets/NIH/"

    logger.info("***** Running training *****")
    logger.info(f"  Task = {args.dataset}@{args.n_labeled}")

    #Erstellt das Modell
    model, criteria_x, ema_model = set_model(args)
    #Lädt das trainierte eingebettete Modell
    emb_model = EmbeddingModel(os.getcwd(), args.dataset)
    logger.info("Total params: {:.2f}M".format(
        sum(p.numel() for p in model.parameters()) / 1e6))

    if 'cifar' in args.dataset.lower():
        expert = CIFAR100Expert(20, int(args.ex_strength), 1, 0, 123)
        dltrain_x, dltrain_u = cifar.get_train_loader(
            args.dataset, expert, args.batchsize, args.mu, n_iters_per_epoch, L=args.n_labeled, root=args.root,
            method='comatch')
        dlval = cifar.get_val_loader(args.dataset, expert, batch_size=64, num_workers=2)
    elif 'nih' in args.dataset.lower(): #Erstellt den Experten mit seiner ID
        expert = NIHExpert(int(args.ex_strength), 2)
        dltrain_x, dltrain_u = nih.get_train_loader( 
            expert, args.batchsize, args.mu, n_iters_per_epoch, L=args.n_labeled, method='comatch')
        dlval = nih.get_val_loader(expert, batch_size=64, num_workers=2)

    wd_params, non_wd_params = [], []
    for name, param in model.named_parameters():
        if 'bn' in name:
            non_wd_params.append(param)  
        else:
            wd_params.append(param)
    param_list = [
        {'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}]
    optim = torch.optim.SGD(param_list, lr=args.lr, weight_decay=args.weight_decay,
                            momentum=args.momentum, nesterov=True)

    lr_schdlr = WarmupCosineLrScheduler(optim, n_iters_all, warmup_iter=0)
    
    model, ema_model, optim, lr_schdlr, start_epoch, metrics, prob_list, queue = \
        load_from_checkpoint(output_dir, model, ema_model, optim, lr_schdlr)

    # memory bank
    args.queue_size = args.queue_batch*(args.mu+1)*args.batchsize
    if queue is not None:
        queue_feats = queue['queue_feats']
        queue_probs = queue['queue_probs']
        queue_ptr = queue['queue_ptr']
    else:
        queue_feats = torch.zeros(args.queue_size, args.low_dim).cuda()
        queue_probs = torch.zeros(args.queue_size, args.n_classes).cuda()
        queue_ptr = 0

    train_args = dict(
        model=model,
        ema_model=ema_model,
        emb_model=emb_model,
        prob_list=prob_list,
        criteria_x=criteria_x,
        optim=optim,
        lr_schdlr=lr_schdlr,
        dltrain_x=dltrain_x,
        dltrain_u=dltrain_u,
        args=args,
        n_iters=n_iters_per_epoch,
        logger=logger
    )
    
    best_acc = -1
    best_epoch = 0

    if metrics is not None:
        best_acc = metrics['best_acc']
        best_epoch = metrics['best_epoch']
    logger.info('-----------start training--------------')
    for epoch in range(start_epoch, args.n_epoches):
        
        loss_x, loss_u, loss_c, mask_mean, num_pos, guess_label_acc, queue_feats, queue_probs, queue_ptr, prob_list = \
        train_one_epoch(epoch, **train_args, queue_feats=queue_feats,queue_probs=queue_probs,queue_ptr=queue_ptr)

        top1, ema_top1 = evaluate(model, ema_model, emb_model, dlval)


        tb_logger.add_scalar('loss_x', loss_x, epoch)
        tb_logger.add_scalar('loss_u', loss_u, epoch)
        tb_logger.add_scalar('loss_c', loss_c, epoch)
        tb_logger.add_scalar('guess_label_acc', guess_label_acc, epoch)
        tb_logger.add_scalar('test_acc', top1, epoch)
        tb_logger.add_scalar('test_ema_acc', ema_top1, epoch)
        tb_logger.add_scalar('mask', mask_mean, epoch)
        tb_logger.add_scalar('num_pos', num_pos, epoch)

        if best_acc < top1:
            best_acc = top1
            best_epoch = epoch

        logger.info("Epoch {}. Acc: {:.4f}. Ema-Acc: {:.4f}. best_acc: {:.4f} in epoch{}".
                    format(epoch, top1, ema_top1, best_acc, best_epoch))
        
        save_obj = {
            'model': model.state_dict(),
            'ema_model': ema_model.state_dict(),
            'optimizer': optim.state_dict(),
            'lr_scheduler': lr_schdlr.state_dict(),
            'prob_list': prob_list,
            'queue': {'queue_feats':queue_feats, 'queue_probs':queue_probs, 'queue_ptr':queue_ptr},
            'metrics': {'best_acc': best_acc, 'best_epoch': best_epoch},
            'epoch': epoch,
        }
        torch.save(save_obj, os.path.join(output_dir, 'ckp.latest'))
    _, _ = evaluate(model, ema_model, emb_model, dlval)
    if 'cifar' in args.dataset.lower():
        predictions = predict_cifar(model, ema_model, emb_model, dltrain_x, dltrain_u, dlval)
    elif 'nih' in args.dataset.lower():
        predictions = predict_nih(model, ema_model, emb_model, dltrain_x, dltrain_u, dlval)

    logger.info("***** Generate Predictions *****")
    if not os.path.exists('./artificial_expert_labels/'):
        os.makedirs('./artificial_expert_labels/')
    pred_file = f'{args.exp_dir}_{args.dataset.lower()}_expert{args.ex_strength}.{args.seed}@{args.n_labeled}_predictions.json'
    with open(f'artificial_expert_labels/{pred_file}', 'w') as f:
        json.dump(predictions, f)
    with open(os.getcwd()[:-len('Embedding-Semi-Supervised')]+f'Learning-to-Defer-Algs/artificial_expert_labels/{pred_file}', 'w') as f:
        json.dump(predictions, f)