In [None]:
import os.path as osp
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import os
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
import random


seed = 43
torch.manual_seed(seed) # 为CPU设置随机种子
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU，为所有GPU设置随机种子
np.random.seed(seed)  # Numpy module.
random.seed(seed)  # Python random module.	
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

DATA_DIR = ""  # 数据集地址
DATA_SPLIT = ""  # 数据集划分文件地址
device = "cuda"

class CUB(Dataset):

    def __init__(self, setname, args):
        IMAGE_PATH = os.path.join(args.data_dir)
        SPLIT_PATH = os.path.join(args.data_split)

        csv_path = osp.join(SPLIT_PATH, setname + '.csv')
        lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]

        data = []
        label = []
        lb = -1

        self.wnids = []

        for l in lines:
            context = l.split(',')
            name = context[0]
            wnid = context[1]
            path = osp.join(IMAGE_PATH, name)
            if wnid not in self.wnids:
                self.wnids.append(wnid)
                lb += 1
            data.append(path)
            label.append(lb)

        self.data = data  # data path of all data
        self.label = label  # label of all data
        self.num_class = len(set(label))

        if setname == 'val' or setname == 'test':
            image_size = 84
            self.transform = transforms.Compose([
                transforms.Resize([84, 84]),
                transforms.ToTensor(),
                transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
                                     np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))])
        elif setname == 'train':
            image_size = 84
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
                                     np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        path, label = self.data[i], self.label[i]
        image = self.transform(Image.open(path).convert('RGB'))
        return image, label

class CategoriesSampler():

    def __init__(self, label, n_batch, n_cls, n_per):
        self.n_batch = n_batch  # the number of iterations in the dataloader
        self.n_cls = n_cls
        self.n_per = n_per

        label = np.array(label)  # all data label
        self.m_ind = []  # the data index of each class
        for i in range(max(label) + 1):
            ind = np.argwhere(label == i).reshape(-1)  # all data index of this class
            ind = torch.from_numpy(ind)
            self.m_ind.append(ind)  # 列表，0位置表示类0的元素所有索引

    def __len__(self):
        return self.n_batch

    def __iter__(self):
        for i_batch in range(self.n_batch):
            batch = []
            classes = torch.randperm(len(self.m_ind))[:self.n_cls]  # random sample num_class indexs,e.g. 5
            for c in classes:
                l = self.m_ind[c]  # all data indexs of this class
                pos = torch.randperm(len(l))[:self.n_per]  # sample n_per data index of this class
                batch.append(l[pos])
            batch = torch.stack(batch).t().reshape(-1)
            yield batch


In [None]:
import torch.nn as nn
import torch.nn.functional as F

def conv3x3(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)   # 之前写的bias=False

def conv1x1(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)

class Conv_block(nn.Module):
    def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
        super(Conv_block, self).__init__()
        self.conv = nn.Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
        self.relu = nn.LeakyReLU(0.2)
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample, is_pooling=True, drop_rate=0.0):
        super(ResBlock, self).__init__()
        self.is_pooling = is_pooling
        self.conv1 = conv3x3(in_channels, out_channels)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.2)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = conv3x3(out_channels, out_channels)
        self.bn3 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        self.drop_rate = drop_rate

    def forward(self, x):

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        residual = self.downsample(x)

        out += residual
        out = self.relu(out)
        if self.is_pooling:
            self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
            out = self.maxpool(out)
        if self.drop_rate > 0:
            out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True)
        return out


class ResNet12(nn.Module):
    def __init__(self, channels):
        super(ResNet12, self).__init__()
        self.feature_dim = 640
        self.inplanes = 3

        self.layer1 = self._make_layer(channels[0])
        self.layer2 = self._make_layer(channels[1])
        self.layer3 = self._make_layer(channels[2])
        self.layer4 = self._make_layer(channels[3])

        self.out_dims = channels[3]

    def _make_layer(self, planes):
        downsample = nn.Sequential(
            conv1x1(self.inplanes, planes),
            nn.BatchNorm2d(planes),
        )
        block = ResBlock(self.inplanes, planes, downsample)
        self.inplanes = planes
        return block

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

class ConV4(nn.Module):
    def __init__(self, channels):
        super(ConV4, self).__init__()
        self.in_channels = 3
        self.feature_dim = 64
        self.layer1 = nn.Sequential(
            conv3x3(self.in_channels, channels[0]),
            nn.BatchNorm2d(channels[0]),
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.layer2 = nn.Sequential(
            conv3x3(channels[0], channels[1]),
            nn.BatchNorm2d(channels[1]),
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.layer3 = nn.Sequential(
            conv3x3(channels[1], channels[2]),
            nn.BatchNorm2d(channels[2]),
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.layer4 = nn.Sequential(
            conv3x3(channels[2], channels[3]),
            nn.BatchNorm2d(channels[3]),
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d(kernel_size=2, stride=2))

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x


def resnet12():
    return ResNet12([64, 160, 320, 640])    # 512x21x21

def conv4():
    return ConV4([64, 64, 64, 64])      # 64x21x21

def resnet12_wide():
    return ResNet12([64, 160, 320, 640])

In [None]:
import torch.nn as nn

class LinearClassifier(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(LinearClassifier, self).__init__()
        self.L = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        return self.L(x)

class CosineClassifier(nn.Module):
    def __init__(self, in_dim, out_dim, gain):
        super(CosineClassifier, self).__init__()
        self.gain = gain
        self.L = nn.Linear(in_dim, out_dim, bias=False)

    def forward(self, x):
        self.L.weight.data = nn.functional.normalize(self.L.weight.data, dim=1)
        x = nn.functional.normalize(x)
        return self.gain*self.L(x)

class FewShotModel(nn.Module):
    def __init__(self, args, pretrain_model=None):
        super(FewShotModel, self).__init__()
        self.args = args
        self.training = True
        self.class_info = None
        if pretrain_model:
            self.encoder = nn.Sequential(
			*list(pretrain_model.encoder.children())
			)
            self.feature_dim = pretrain_model.encoder.feature_dim
        else:
            assert args.backbone in ["ConV4", "ResNet12"]
            if args.backbone == "ConV4":
                self.encoder = conv4()              # output: 64x21x21
            else:
                self.encoder = resnet12()           # output: 512x21x21
            self.feature_dim = self.encoder.feature_dim


        if args.classifier == "Linear":
            self.classifier = LinearClassifier(self.feature_dim, args.num_class)
        elif args.classifier == "Cosine":
            self.classifier = CosineClassifier(self.feature_dim, args.num_class, 10)

    def forward(self, input, bbox=None):
        return self.pretrain_classify_forward(input)
   

    def pretrain_classify_forward(self, x):
        x = self.encode(x, dense=False).squeeze(-1).squeeze(-1)
        return self.classifier(x)


    def encode(self, x, dense=False):
        if x.shape.__len__() == 5:  # batch of image patches
            num_data, num_patch = x.shape[:2]
            x = x.reshape(-1, x.shape[2], x.shape[3], x.shape[4])
            x = self.encoder(x)
            x = F.adaptive_avg_pool2d(x, 1)
            x = x.reshape(num_data, num_patch, x.shape[1], x.shape[2], x.shape[3])
            x = x.permute(0, 2, 1, 3, 4)
            x = x.squeeze(-1)
            return x

        else:
            x = self.encoder(x)
            if dense == False:
                x = F.adaptive_avg_pool2d(x, 1)
                return x
            if self.args.feature_pyramid is not None:
                x = self.build_feature_pyramid(x)   # 多尺度
        return x
    
    def build_feature_pyramid(self, feature):
        feature_list = []
        for size in self.args.feature_pyramid:
            feature_list.append(F.adaptive_avg_pool2d(feature, size).view(feature.shape[0], feature.shape[1], 1, -1))
        feature_list.append(feature.view(feature.shape[0], feature.shape[1], 1, -1))
        out = torch.cat(feature_list, dim=-1)
        return out

In [None]:
import random
import math
import scipy as sp
import scipy.stats

def bbox2grid(bbox, w=21, h=21):
    res = []
    bbox_num = len(bbox)        # 一张图像中的目标数
    for i in range(bbox_num):
        x_min, y_min, x_max, y_max = bbox[i]
        x_min, y_min, x_max, y_max = math.floor(x_min*w), math.floor(y_min*h), math.ceil(x_max*w), math.ceil(y_max*h)
        for m in range(x_min, x_max):
            for n in range(y_min, y_max):
                res.append((n, m))
    res = list(set(res))
    return res

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def save_checkpoint(state, filename='checkpoint.pth.tar'):
	torch.save(state, filename)

class Averager(object):
	def __init__(self):
		self.reset()

	def reset(self):
		self.val = 0
		self.avg = 0
		self.sum = 0
		self.count = 0

	def update(self, val, n=1):
		self.val = val
		self.sum += val * n
		self.count += n
		self.avg = self.sum / self.count


def count_acc(logits, label):
    pred = torch.argmax(logits, dim=1)
    if torch.cuda.is_available():
        return (pred == label).type(torch.cuda.FloatTensor).mean().item()
    else:
        return (pred == label).type(torch.FloatTensor).mean().item()
    
def ostu(data):
    num = len(data)
    sort_data, _ = torch.sort(data)
    vr = sort_data[-1] - sort_data[0]
    max_var = 0
    thr = 0
    s1 = 0
    s2 = torch.sum(data)
    for i, th in enumerate(sort_data):
        p1 = (i+1)/num
        p2 = 1-p1
        s1 += th
        s2 -= th
        m1 = s1 / (i+1)
        m2 = s2 / (num-i-1)
        var = p1*p2*(m1-m2)**2
        if var > max_var:
            thr = th
            max_var = var
    return thr

def mean_confidence_interval(data, confidence=0.95):
    a = [1.0*np.array(data[i]) for i in range(len(data))]
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * sp.stats.t._ppf((1+confidence)/2., n-1)

    return m, h

In [None]:
# 第一种方法
import random
from scipy.spatial.distance import euclidean


def kmeans(data, k, iterations):
    res_cent = []
    min_dis = 100000
    res_centroids = random.sample(list(data), k)
    res_ind = []
    for mr in range(30):
        dis = 0
        inds = [[] for _ in range(k)]
        centroids = random.sample(list(data), k)
        new_centroids = centroids.copy()
        for _ in range(iterations):
            dis = 0
            clusters = [[] for _ in range(k)]
            inds = [[] for _ in range(k)]
            # 分配数据点到最近的聚类中心
            for ind, point in enumerate(data):
                distances = [euclidean(point, centroid) for centroid in centroids]
                closest_centroid_index = np.argmin(distances)
                clusters[closest_centroid_index].append(point)
                inds[closest_centroid_index].append(ind)
                dis += distances[closest_centroid_index]

            # 更新聚类中心
            for i in range(k):
                if clusters[i]:
                    new_centroids[i] = np.mean(clusters[i], axis=0)
            flag = 0
            for i in range(k):
                if euclidean(new_centroids[i], centroids[i]) > 0.01:
                    flag+=1
            if flag > 0:
                centroids = new_centroids
            else:
                break
        if dis<min_dis:
            min_dis = dis
            res_centroids = centroids.copy()
            res_ind = inds.copy()
    return res_ind
        

def few_shot_eval(query, support): 
    feat = support.view(support.shape[0], support.shape[1], -1)
    feat_s = feat.permute(0, 2, 1)
    shot_num = feat_s.shape[0]
    s_ind = []
    for i in range(shot_num):
        shot_feat = feat_s[i]
        s_ind.append(kmeans(shot_feat.cpu().detach().numpy(), 2, 25))

    feat = query.view(query.shape[0], query.shape[1], -1)
    feat = feat.permute(0, 2, 1)
    query_num = feat.shape[0]
    logits = torch.empty((query_num, shot_num)).to(device)
#     tmp_feat = torch.empty((3, feat.shape[-1])).to(device)
    for i in range(query_num):
        q_feat = feat[i]
        q_ind = kmeans(q_feat.cpu().detach().numpy(), 2, 25)
        for c in range(shot_num):
            c_feat = feat_s[c]
            logit = 0
            for m in range(2):
                for n in range(2):
                    weight_q = torch.empty(25).to(device)
                    weight_q[q_ind[m]] = 0.6
                    weight_q[q_ind[1-m]] = 0.4
                    weight_q = weight_q / torch.sum(weight_q)
                    q_avg = torch.matmul(weight_q.unsqueeze(0), q_feat)
                    weight_s = torch.empty(25).to(device)
                    weight_s[s_ind[c][n]] = 0.6
                    weight_s[s_ind[c][1-n]] = 0.4
                    weight_s = weight_s / torch.sum(weight_s)
                    s_avg = torch.matmul(weight_s.unsqueeze(0), c_feat)
                    logit_tmp = torch.matmul(F.normalize(q_avg, dim=-1), F.normalize(s_avg, dim=-1).t())
                    if logit_tmp>logit:
                        logit = logit_tmp
            logits[i, c] = logit
    if shot_num > 5:
        new_logits = torch.empty((query_num, 5)).to(device)
        label = torch.arange(5).repeat(5)
        for i in range(query_num):
            for c in range(5):
                ind = (label==c).nonzero().squeeze(-1)
                new_logits[i, c] = torch.sum(logits[i][ind])
        return new_logits
        
    return logits

In [None]:
# 第二种方法
import random
from scipy.spatial.distance import euclidean


def kmeans(data, k, iterations):
    res_cent = []
    min_dis = 100000
    res_centroids = random.sample(list(data), k)
    res_ind = []
    for mr in range(50):
        dis = 0
        inds = [[] for _ in range(k)]
        centroids = random.sample(list(data), k)
        new_centroids = centroids.copy()
        for _ in range(iterations):
            dis = 0
            clusters = [[] for _ in range(k)]
            inds = [[] for _ in range(k)]
            # 分配数据点到最近的聚类中心
            for ind, point in enumerate(data):
                distances = [euclidean(point, centroid) for centroid in centroids]
                closest_centroid_index = np.argmin(distances)
                clusters[closest_centroid_index].append(point)
                inds[closest_centroid_index].append(ind)
                dis += distances[closest_centroid_index]

            # 更新聚类中心
            for i in range(k):
                if clusters[i]:
                    new_centroids[i] = np.mean(clusters[i], axis=0)
            flag = 0
            for i in range(k):
                if euclidean(new_centroids[i], centroids[i]) > 0.01:
                    flag+=1
            if flag > 0:
                centroids = new_centroids
            else:
                break
        if dis<min_dis:
            min_dis = dis
            res_centroids = centroids.copy()
            res_ind = inds.copy()
    return res_ind
        

def few_shot_eval(query, support): 
    feat = support.view(support.shape[0], support.shape[1], -1)
    feat = feat.permute(0, 2, 1)
    label = torch.arange(5).repeat(2)
    feat_s = torch.empty((5, 2*25, 64)).to(device)
    for i in range(5):
        ind = (label==i).nonzero().squeeze(-1)
        ff = feat[ind]
        ff = ff.reshape(-1, ff.shape[-1])
        feat_s[i] = ff
#     print(feat_s.shape)
#     print(feat_s)
    shot_num = feat_s.shape[0]
    s_ind = []
    for i in range(shot_num):
        shot_feat = feat_s[i]
        s_ind.append(kmeans(shot_feat.cpu().detach().numpy(), 2, 25))
    feat = query.view(query.shape[0], query.shape[1], -1)
    feat = feat.permute(0, 2, 1)
    query_num = feat.shape[0]
    logits = torch.empty((query_num, shot_num)).to(device)
#     tmp_feat = torch.empty((3, feat.shape[-1])).to(device)
    for i in range(query_num):
        q_feat = feat[i]
        q_ind = kmeans(q_feat.cpu().detach().numpy(), 2, 25)
        for c in range(shot_num):
            c_feat = feat_s[c]
            logit = 0
            for m in range(2):
                for n in range(2):
                    weight_q = torch.empty(25).to(device)
                    weight_q[q_ind[m]] = 0.6
                    weight_q[q_ind[1-m]] = 0.4
                    weight_q = weight_q / torch.sum(weight_q)
                    q_avg = torch.matmul(weight_q.unsqueeze(0), q_feat)
                    weight_s = torch.empty(25*2).to(device)
                    weight_s[s_ind[c][n]] = 0.6
                    weight_s[s_ind[c][1-n]] = 0.4
                    weight_s = weight_s / torch.sum(weight_s)
                    s_avg = torch.matmul(weight_s.unsqueeze(0), c_feat)
                    logit_tmp = torch.matmul(F.normalize(q_avg, dim=-1), F.normalize(s_avg, dim=-1).t())
                    if logit_tmp>logit:
                        logit = logit_tmp
            logits[i, c] = logit
        
    return logits

In [None]:
import argparse
import os
import time
from datetime import datetime
import tqdm
import torchvision
import matplotlib.pyplot as plt
device = "cuda"


parser = argparse.ArgumentParser()
# about dataset and network
parser.add_argument('-dataset', type=str, default='miniimagenet', choices=['miniimagenet', 'cub','tieredimagenet','fc100','tieredimagenet_yao','cifar_fs'])
parser.add_argument('-data_dir', type=str, default=DATA_DIR, help='dataset path')
parser.add_argument('-data_split', type=str, default=DATA_SPLIT)
# about pre-training
parser.add_argument('-num_class', type=int, default=100)
parser.add_argument('-max_epoch', type=int, default=200)
parser.add_argument('-lr', type=float, default=0.1)
parser.add_argument('-step_size', type=int, default=30)
parser.add_argument('-gamma', type=float, default=0.2)
parser.add_argument('-bs', type=int, default=128)
parser.add_argument('-backbone', type=str, default="ResNet12", choices=["ConV4", "ResNet12"])
parser.add_argument('-classifier', type=str, default="Cosine", choices=["Linear", "Cosine"])

# about validation
parser.add_argument('-set', type=str, default='val', choices=['val', 'test'], help='the set for validation')
parser.add_argument('-way', type=int, default=5)
parser.add_argument('-shot', type=int, default=5)
parser.add_argument('-query', type=int, default=15)
parser.add_argument('-temperature', type=float, default=12.5)
parser.add_argument('-metric', type=str, default='cosine')
parser.add_argument('-simi_metric', type=str, default='cosine')
parser.add_argument('-num_episode', type=int, default=400)
parser.add_argument('-random_val_task', action='store_true', default='no', help='random samples tasks for validation in each epoch')
parser.add_argument('-feature_pyramid', type=str, default=None)
# about training
parser.add_argument('-seed', type=int, default=1)
parser.add_argument('-resume_dir', type=str, default=None)
parser.add_argument('-print_freq', type=int, default=10)
args = parser.parse_args(args=[])

# set_seed(args.seed)

dataset_name = args.dataset

if not args.resume_dir:
    now = datetime.now()
    date_string = now.strftime("%Y-%m-%d")
    train_input_info = dataset_name + '-' + args.backbone + '-' + date_string

    args.save_path = train_input_info
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
else:
    args.save_path = args.resume_dir


# train info txt
txt_save_path = os.path.join(args.save_path, 'opt_resutls.txt')
F_txt = open(txt_save_path, 'a+')


Dataset = CUB
trainset = Dataset('train', args)
train_loader = DataLoader(dataset=trainset, batch_size=args.bs, shuffle=True, num_workers=0, pin_memory=True)

valset = Dataset('val', args)
val_sampler = CategoriesSampler(valset.label, args.num_episode, args.way, args.shot + args.query)
val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, num_workers=0, pin_memory=True)

testset = Dataset('test', args)
test_sampler = CategoriesSampler(testset.label, args.num_episode, args.way, args.shot + args.query)
test_loader = DataLoader(dataset=testset, batch_sampler=test_sampler, num_workers=8, pin_memory=True)
if not args.random_val_task:
    print('fix val set for all epochs')
    val_loader = [x for x in val_loader]


model = FewShotModel(args).to(device)


# label of query images.
label = torch.arange(args.way, dtype=torch.int8).repeat(args.query)  # shape[75]:012340123401234...
label = label.type(torch.LongTensor)
label = label.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0001)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)

for param_group in optimizer.param_groups:
    print(param_group['lr'])

print(args)
print(args, file=F_txt)
print(model)
print(model, file=F_txt)


global_count = 0

best_prec1 = 0.

result_list = [args.save_path]
for epoch in range(1, args.max_epoch + 1):
    losses = Averager()
    vq_losses = Averager()
    top1 = Averager()
    val_losses = Averager()
    prec1_val = Averager()
    val_emd_losses = Averager()
    prec1_emd_val = Averager()
    data_time = Averager()
    batch_time = Averager()
    val_batch_time = Averager()
    val_emd_batch_time = Averager()
    
    start_time = time.time()
    model.train()
    train_loader_length = len(train_loader)
    train_iterator = iter(train_loader)
    end = time.time()
    print('===================================== Training on the train set =====================================')
    print('===================================== Training on the train set =====================================', file=F_txt)
    for i in range(1, train_loader_length+1):
        global_count = global_count + 1
        batch = next(train_iterator)
        data, train_label = [_.to(device) for _ in batch]
        data_time.update(time.time() - end)

        logits = model(data)
        loss = F.cross_entropy(logits, train_label)
        acc = count_acc(logits, train_label)
        total_loss = loss
        losses.update(total_loss.item(), args.bs)
#         vq_losses.update(vq_loss.item(), args.bs)
        top1.update(acc, args.bs)
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Eposide-({0}): [{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
#                   'VQ Loss {vq_loss.val:.3f} ({vq_loss.avg:.3f})\t'
                  'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                epoch, i, train_loader_length, batch_time=batch_time, data_time=data_time, loss=losses,
                top1=top1))

            print('Eposide-({0}): [{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                epoch, i, train_loader_length, batch_time=batch_time, data_time=data_time, loss=losses,
                top1=top1), file=F_txt)

    print('===================================== Validation on the val set =====================================')
    print('===================================== validation on the val set =====================================',
          file=F_txt)

    model.eval()

    val_iterator = iter(val_loader)
    val_print_freq = 10 #args.num_episode / 20 -1
    for i in range(1, args.num_episode+1):
        batch = next(val_iterator)
        data, _ = [_.to(device) for _ in batch]
        k = args.way * args.shot
        data_shot, data_query = data[:k], data[k:]
        #episode learning
        with torch.no_grad():
            feat_s, feat_q = model.encode(data_shot, dense=True), model.encode(data_query, dense=True)

        logits = few_shot_eval(feat_q, feat_s)
        loss = F.cross_entropy(logits, label)
        acc = count_acc(logits, label)

        val_losses.update(loss, args.shot * args.query)
        prec1_val.update(acc, args.shot * args.query)
        val_batch_time.update(time.time()-end)
        end = time.time()
        if i % val_print_freq == 0:
            print('Eposide-({0}): [{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                epoch, i, args.num_episode, batch_time=val_batch_time, loss=val_losses,
                top1=prec1_val))

            print('Eposide-({0}): [{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                epoch, i, args.num_episode, batch_time=val_batch_time, loss=val_losses,
                top1=prec1_val), file=F_txt)


    if prec1_val.avg > best_prec1:
        best_prec1 = prec1_val.avg
        save_checkpoint(
            {
                'epoch_index': epoch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict()
            }, os.path.join(args.save_path, 'model_best.pth.tar'))

    if epoch % 10 == 0:
        filename = os.path.join(args.save_path, 'epoch_%d.pth.tar' % epoch)
        save_checkpoint(
            {
                'epoch_index': epoch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict()
            }, filename)
    save_checkpoint(
        {
            'epoch_index': epoch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict()
        }, os.path.join(args.save_path, 'tmp.pth.tar'))

    lr_scheduler.step()
F_txt.close()