In [None]:
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

import datetime
import json
import os
import numpy as np
import six
import tensorflow as tf
from tensorflow.python import debug as tf_debug

from fewshot.configs.config_factory import get_config
from fewshot.configs.mini_imagenet_config import *
from fewshot.configs.omniglot_config import *
from fewshot.configs.tiered_imagenet_config import *
from fewshot.data.data_factory import get_concurrent_iterator
from fewshot.data.data_factory import get_dataset
from fewshot.data.episode import Episode
from fewshot.data.mini_imagenet import MiniImageNetDataset
from fewshot.data.omniglot import OmniglotDataset
from fewshot.data.tiered_imagenet import TieredImageNetDataset
from fewshot.models.basic_model import BasicModel
from fewshot.models.kmeans_refine_mask_model import KMeansRefineMaskModel
from fewshot.models.kmeans_refine_model import KMeansRefineModel
from fewshot.models.kmeans_refine_radius_model import KMeansRefineRadiusModel
from fewshot.models.basic_model_VAT import BasicModelVAT
from fewshot.models.measure import batch_apk
from fewshot.models.model_factory import get_model
from fewshot.utils import logger
from fewshot.utils.experiment_logger import ExperimentLogger
from fewshot.utils.lr_schedule import FixedLearnRateScheduler
from tqdm import tqdm


flags = tf.flags
flags.DEFINE_bool("eval", False, "Whether to only run evaluation")
flags.DEFINE_bool("use_test", False, "Use the test set or not")
flags.DEFINE_float("learn_rate", None, "Start learning rate")
flags.DEFINE_integer("nclasses_eval", 5, "Number of classes for testing")
flags.DEFINE_integer("nclasses_train", 5, "Number of classes for training")
flags.DEFINE_integer("nshot", 1, "nshot")
flags.DEFINE_integer("num_eval_episode", 600, "Number of evaluation episodes")
flags.DEFINE_integer("num_test", -1, "Number of test images per episode")
flags.DEFINE_integer("num_unlabel", 5, "Number of unlabeled for training")
flags.DEFINE_integer("steps_per_summary", 100, "Number of steps between summary ops")
flags.DEFINE_integer("seed", 0, "Random seed")
flags.DEFINE_string("dataset", "omniglot", "Dataset name")
flags.DEFINE_string("model", "basic", "Model name")
flags.DEFINE_string("pretrain", None, "Model pretrain path")
flags.DEFINE_string("results", "./results", "Checkpoint save path")

FLAGS = tf.flags.FLAGS

In [None]:
if FLAGS.num_test == -1 and (FLAGS.dataset == "tiered-imagenet" or
														 FLAGS.dataset == 'mini-imagenet'):
	num_test = 5
else:
	num_test = FLAGS.num_test
config = get_config(FLAGS.dataset, FLAGS.model)
nclasses_train = FLAGS.nclasses_train
nclasses_eval = FLAGS.nclasses_eval

# Which training split to use.
train_split_name = 'train'
if FLAGS.use_test:
	log.info('Using the test set')
	test_split_name = 'test'
else:
	log.info('Not using the test set, using val')
	test_split_name = 'val'

log.info('Use split `{}` for training'.format(train_split_name))

# Whether doing 90 degree augmentation.
if 'mini-imagenet' in FLAGS.dataset or 'tiered-imagenet' in FLAGS.dataset:
	_aug_90 = False
else:
	_aug_90 = True

nshot = FLAGS.nshot
meta_train_dataset = get_dataset(
		FLAGS.dataset,
		train_split_name,
		nclasses_train,
		nshot,
		num_test=num_test,
		aug_90=_aug_90,
		num_unlabel=FLAGS.num_unlabel,
		shuffle_episode=False,
		seed=FLAGS.seed)