In [None]:
import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
import os

from fvcore.nn import FlopCountAnalysis
from pathlib import Path
from collections import OrderedDict

from timm.data.mixup import Mixup
from timm.models import create_model
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.utils import ModelEma
from optim_factory import create_optimizer, get_parameter_groups, LayerDecayValueAssigner

from datasets import build_dataset, build_dataset_self_define
from engine_for_finetuning import train_one_epoch, evaluate
from utils import NativeScalerWithGradNormCount as NativeScaler
import utils
from scipy import interpolate
# import modeling_finetune
import modeling_branchViT
import modeling_finetune_MOE

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
def get_args():
    parser = argparse.ArgumentParser('MAE fine-tuning and evaluation script for image classification', add_help=False)

    # Model parameters
    parser.add_argument('--model', default='moe_mlp_vit_base_patch16_224', type=str, metavar='MODEL',
                        help='Name of model to train')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--nb_classes', default=3, type=int)
    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
                        help='Dropout rate (default: 0.)')
    return parser.parse_args()

def main(args):

    device = torch.device(args.device)
    args.model = "vit_base_patch16_224"
    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    # seed = 666
    print("seed:",seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    cudnn.benchmark = True

    model = create_model(
        args.model,
        pretrained=False,
        num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_block_rate=None
    )

    model.to(device)
    testTensor = torch.FloatTensor(1, 3, 224, 224).to(device)
    flops = FlopCountAnalysis(model, testTensor)
    gflops = flops.total() / 1e9
    print("Model:",args.model,f", FLOPs: {flops.total()}",f", GFLOPs: {gflops}")

if __name__ == '__main__':
    opts = get_args()
    main(opts)