In [23]:
import os.path as osp
import torch
import torch.nn as nn
from PIL import Image, ImageFilter
import numpy as np
import pandas as pd
import os
import csv
import random
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torch.nn import init
import torch.optim as optim
import time
import argparse
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import math
import torch.distributed as dist
from torch.multiprocessing import Process

############################################# 21

parser = argparse.ArgumentParser()
#  Few-shot parameters  #
parser.add_argument('--data_name', default='miniImageNet', help='miniImageNet| tieredImageNet')
parser.add_argument('--method_name', default='KL', help=' Wass | Wass_CMS | KL | KL_CMS | ADM ')
parser.add_argument('--mode', default='train', help='train|val|test')
parser.add_argument('--outf', default='./results/')
parser.add_argument('--workers', type=int, default=0)
parser.add_argument('--way_num', type=int, default=5, help='the number of way/class')
parser.add_argument('--shot_num', type=int, default=1, help='the number of shot')
parser.add_argument('--query_num', type=int, default=15, help='the number of queries')
parser.add_argument('--train_num', type=int, default=10, help='pretrain number, default=10')
#  Few-shot parameters  #
parser.add_argument('--epochs', type=int, default=50, help='the total number of training epoch')
parser.add_argument('--start_epoch', default=0, type=int, help='manual epoch number (useful on restarts)')
parser.add_argument('--lr', type=float, default=0.1, help='learning rate, default=0.005')
parser.add_argument('--lr2', type=float, default=100, help='learning rate, default=0.005')
parser.add_argument('--adam', action='store_true', default=True, help='use adam optimizer')
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--print_freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 100)')
parser.add_argument('-f', type=str, default="读取额外的参数")
parser.add_argument('--freeze-layers', type=bool, default=False)
# 不要改该参数，系统会自动分配
parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')
# 开启的进程数(注意不是线程),在单机中指使用GPU的数量
parser.add_argument('--world-size', default=4, type=int,
                    help='number of distributed processes')
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')

opt = parser.parse_args(args=[])

if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
    opt.rank = int(os.environ["RANK"])
    opt.world_size = int(os.environ['WORLD_SIZE'])
    opt.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
    opt.rank = int(os.environ['SLURM_PROCID'])
    opt.gpu = opt.rank % torch.cuda.device_count()
else:
    print('Not using distributed mode')
    opt.distributed = False


opt.distributed = True

torch.cuda.set_device(opt.gpu)
opt.dist_backend = 'nccl'  # 通信后端，nvidia GPU推荐使用NCCL
print('| distributed init (rank {}): {}'.format(
    opt.rank, opt.dist_url), flush=True)
dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,
                        world_size=opt.world_size, rank=opt.rank)
dist.barrier()

device = torch.device(opt.device)

################################################################

data_dir = ""

class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x


mocoAug = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
#     transforms.RandomApply([
#         transforms.ColorJitter(0.2, 0.2, 0.2, 0.1)  # not strengthened
#     ], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.Resize((84, 84)),
])


supervisedAug = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.Resize((84, 84)),
])

trans_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.Resize((84, 84)),
])

def RGB_loader(path):
    return Image.open(path).convert('RGB')


def load_data(data_path):
    data_dict = {}
    data_list = []
    for folder_name in os.listdir(data_path):
        data_dict[folder_name] = []
        folder_path = os.path.join(data_path, folder_name)
        for filename in os.listdir(folder_path):
            image_path = os.path.join(folder_path, filename)
            data_list.append((image_path, folder_name))
            data_dict[folder_name].append(image_path)
    class_list = data_dict.keys()
    return data_list, data_dict, class_list

class FewShotDataSet(Dataset):
    def __init__(self, data_dir, phase='train', loader=RGB_loader):
        super(FewShotDataSet, self).__init__()
        IMAGE_PATH = ""
        self.loader = loader
        if phase == 'train':
            csv_path = ""
        elif phase == 'val':
            csv_path = ""
        else:
            csv_path = ""

            
        data_list = []
        data_dict = {}
        
        lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]

        self.wnids = []

        for l in lines:
            context = l.split(',')
            name = context[0]
            wnid = context[1]
            path = osp.join(IMAGE_PATH, name)
            data_list.append((path, wnid))
            if wnid not in data_dict:
                data_dict[wnid] = []
            data_dict[wnid].append(path)
        class_list = data_dict.keys()
            
        
        self.data_list = data_list
        self.data_dict = data_dict
        self.class_list = sorted(list(class_list))
        self.label2Int = {item: idx for idx, item in enumerate(self.class_list)}
        self.num_cats = len(self.class_list)

    def __getitem__(self, item):
        img_item, class_name = self.data_list[item]
        label = self.label2Int[class_name]
        fn = os.path.join(img_item)
        img = self.loader(fn)
        img = torch.cat((mocoAug(img).unsqueeze(0), mocoAug(img).unsqueeze(0), supervisedAug(img).unsqueeze(0)), dim=0)
        return img, label

    def __len__(self):
        return int(len(self.data_list))

def get_dataloader(opt, mode):
    dataset = FewShotDataSet(data_dir, phase=mode)
    if mode == 'train':
        loader = MetaDataloader(dataset, opt, mode)     # opt.episode_train_num默认值10000
    elif mode == 'val':
        loader = MetaDataloader(dataset, opt, mode)
    elif mode == 'test':
        loader = MetaDataloader(dataset, opt, mode)
    else:
        raise ValueError('Mode ought to be in [train, val, test]')
    return loader

class MetaDataloader(object):
    def __init__(self, dataset, opt, mode):
        self.dataset = dataset
#         self.img_root = dataset.img_path
        self.loader = dataset.loader

        self.way_num = opt.way_num
        self.shot_num = opt.shot_num
        self.query_num = opt.query_num
        # self.batch_size = opt.batch_size
        # self.epoch_size = opt.epoch_size
        self.num_workers = int(opt.workers)
        # self.current_epoch = opt.current_epoch
        if mode == 'train':
            self.shuffle = True
        else:
            self.shuffle = False

    def sampleImageIdsFrom(self, cat_id, sample_size=1):  # 根据类id采样某一类下个数为sample_size大小的样本
        assert (cat_id in self.dataset.data_dict)
        assert (len(self.dataset.data_dict[cat_id]) >= sample_size)
        # Note: random.sample samples elements without replacement.
        return random.sample(self.dataset.data_dict[cat_id], sample_size)

    def sampleCategories(self, sample_size=1):  # 对数据集中的类进行采样
        class_list = self.dataset.class_list
        assert (len(class_list) >= sample_size)
        return random.sample(class_list, sample_size)  # 从class_list中随机获得长度为sample_size的种类

    def sampleSupQuery(self, categories, query_num, shot_num):
        if len(categories) == 0:
            return [], []
        nCategories = len(categories)
        Query_imgs = []
        Support_imgs = []

        for idx in range(len(categories)):
            img_ids = self.sampleImageIdsFrom(
                categories[idx],
                sample_size=(query_num + shot_num)
            )
            imgs_novel = img_ids[:query_num]
            imgs_exemplar = img_ids[query_num:]

            Query_imgs += [(img_id, idx) for img_id in imgs_novel]
            Support_imgs += [(img_id, idx) for img_id in imgs_exemplar]

        assert (len(Query_imgs) == nCategories * query_num)
        assert (len(Support_imgs) == nCategories * shot_num)

        return Query_imgs, Support_imgs

    def sampleEpisode(self):
        categories = self.sampleCategories(self.way_num)
        Query_imgs, Support_imgs = self.sampleSupQuery(categories, self.query_num, self.shot_num)
        return Query_imgs, Support_imgs

    def createExamplesTensorData(self, examples):
        images = torch.stack(
            [trans_val(self.loader(img_name)) for img_name, _ in examples], dim=0)
        labels = torch.tensor([label for _, label in examples])

        return images, labels

    def load_function(self, iter_idx):
        Query_imgs, Support_imgs = self.sampleEpisode()
        Xt, Yt = self.createExamplesTensorData(Query_imgs)
        Xe, Ye = self.createExamplesTensorData(Support_imgs)
        return Xt, Yt, Xe, Ye

    def get_iterator(self, index):
        rand_seed = index
        random.seed(rand_seed)
        np.random.seed(rand_seed)
        Xt, Yt, Xe, Ye = self.load_function(index)

        return Xt, Yt, Xe, Ye

    def __call__(self, index):
        return self.get_iterator(index)

    def __len__(self):
        return 0
    
    
#######################################################################################

def conv3x3(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)

def conv1x1(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)

class Conv_block(nn.Module):
    def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
        super(Conv_block, self).__init__()
        self.conv = nn.Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
        self.relu = nn.LeakyReLU(0.2)
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample):
        super(ResBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.2)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = conv3x3(out_channels, out_channels)
        self.bn3 = nn.BatchNorm2d(out_channels)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.downsample = downsample

    def forward(self, x):

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        residual = self.downsample(x)

        out += residual
        out = self.relu(out)
        out = self.maxpool(out)

        return out


class ResNet12(nn.Module):
    def __init__(self, channels):
        super(ResNet12, self).__init__()

        self.inplanes = 3

        self.layer1 = self._make_layer(channels[0])
        self.layer2 = self._make_layer(channels[1])
        self.layer3 = self._make_layer(channels[2])
        self.layer4 = self._make_layer(channels[3])

        self.out_dims = channels[3]

    def _make_layer(self, planes):
        downsample = nn.Sequential(
            conv1x1(self.inplanes, planes),
            nn.BatchNorm2d(planes),
        )
        block = ResBlock(self.inplanes, planes, downsample)
        self.inplanes = planes
        return block

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

def resnet12():
    return ResNet12([64, 128, 256, 512])

def resnet12_wide():
    return ResNet12([64, 160, 320, 640])



##########################################################################################


def weights_init_normal(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1:
        init.normal_(m.weight.data, 0.0, 0.02)  # 网络初始化，normal_实现基于正态分布的初始化参数
    elif classname.find('Linear') != -1:
        init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def get_model(pre_train=False, model_dir=None, num_class=80, dim=128):
    model = featureAugNet(num_class, dim)
    #     model.apply(weights_init_normal)
    if pre_train:
        model.load_state_dict(torch.load(model_dir)['state_dict'])
    return model


class BaseEncoder(nn.Module):
    def __init__(self):
        super(BaseEncoder, self).__init__()
        self.extractor = resnet12_wide()
        self.pooling = nn.AvgPool2d(kernel_size=5, stride=5)
        self.linear = nn.Linear(640, 128)
        self.act = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.extractor(x)
        feature_pool = self.pooling(x)
        x = feature_pool.view(feature_pool.size(0), -1)
        out = self.act(self.linear(x))
        return x, out
    
class BaseEncoderClass(nn.Module):
    def __init__(self):
        super(BaseEncoderClass, self).__init__()
        self.extractor = resnet12_wide()
        self.pooling = nn.AvgPool2d(kernel_size=5, stride=5)

    def forward(self, x):
        x = self.extractor(x)
        feature_pool = self.pooling(x)
        x = feature_pool.view(feature_pool.size(0), -1)
        return x

class distLinear(nn.Module):
    def __init__(self, indim, outdim):
        super(distLinear, self).__init__()
        self.L = nn.Linear(indim, outdim, bias=False)
            
    def forward(self, x):
        self.L.weight.data = nn.functional.normalize(self.L.weight.data, dim=1)
        return 10*self.L(x)
    
class linearC(nn.Module):
    def __init__(self, indim, outdim):
        super(linearC, self).__init__()
        self.L = nn.Linear(indim, outdim)
        
    def forward(self, x):
        return self.L(x)
    
class MoCo(nn.Module):
    def __init__(self, base_encoder, dim=128, K=2048, m=0.999, T=0.1):
        super(MoCo, self).__init__()
        self.K = K
        self.T = T
        self.m = m

        self.encoder_q = base_encoder()
        self.encoder_k = base_encoder()

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False

        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        keys = concat_all_gather(keys)
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).to(device)

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]
        
    def forward(self, im_q, im_k):
        q_high, q = self.encoder_q(im_q)
        q = nn.functional.normalize(q, dim=1)
        # feature = q.clone()
        with torch.no_grad():
            self._momentum_update_key_encoder()
            im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
            _, k = self.encoder_k(im_k)
            k = nn.functional.normalize(k, dim=1)
            k = self._batch_unshuffle_ddp(k, idx_unshuffle)

        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        logits = torch.cat([l_pos, l_neg], dim=1)
        logits /= self.T
        self._dequeue_and_enqueue(k)

        return q_high, logits

@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor.contiguous(), async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output   


####################################################################################

class classFeatureExtractor(nn.Module):
    def __init__(self, base_enceoder, num_classes=64, dim=640):
        super(classFeatureExtractor, self).__init__()
        self.base_encoder = base_enceoder()
        self.classifier = distLinear(dim, num_classes)

    def forward(self, x):
        x = self.base_encoder(x)
        x = nn.functional.normalize(x)
        out = self.classifier(x)
        return x, out

    
class featureAugNet(nn.Module):
    def __init__(self, num_classes=100, dim=640):
        super(featureAugNet, self).__init__()
        self.num_classes = num_classes
        self.instFeatExt = MoCo(BaseEncoder, dim=128, K=2048, m=0.999, T=0.1)
        self.classFeatExt = classFeatureExtractor(BaseEncoderClass, num_classes=num_classes, dim=dim)
        self.classifier = distLinear(dim, num_classes)
   
    @torch.no_grad()
    def update_classifier(self):
        for param_q, param_k in zip(self.classFeatExt.classifier.parameters(), self.classifier.parameters()):
            param_k.data = param_q.data
    

    def forward(self, im_q, im_k, im_s):
#         batch_size = labels.shape[0]
#         kl_loss = torch.tensor(0.).to(device)
        q_high, logits_u = self.instFeatExt(im_q, im_k)
        q_high = nn.functional.normalize(q_high, dim=1)
        self.update_classifier()
        logits_us = self.classifier(q_high)
        s_high, logits_s = self.classFeatExt(im_s)
        return q_high, logits_u, logits_us, s_high, logits_s


def MoCoModel():
    return MoCo(BaseEncoder, dim=128, K=2048, m=0.999, T=0.1)


###################################################################################################

def adjust_learning_rate(opt, optimizer, epoch, F_txt):
	"""Sets the learning rate to the initial LR decayed by 2 every 10 epoches"""
	lr = opt.lr * (0.2 ** (epoch // 35))
	print('Learning rate: %f' %lr)
	print('Learning rate: %f' %lr, file=F_txt)
	for param_group in optimizer.param_groups:
		param_group['lr'] = lr

# def adjust_learning_rate(opt, optimizer, epoch, F_txt):
#     lr = opt.lr * (0.5 ** (epoch // 10))
#     lr2 = opt.lr2 * (0.5 ** (epoch // 5))
#     print('learning rate: %f' % lr)
#     print('learning rate: %f' % lr2)
#     print('Learning rate: %f' % lr, file=F_txt)
#     optimizer.param_groups[0]['lr'] = lr
#     optimizer.param_groups[1]['lr'] = lr
#     optimizer.param_groups[2]['lr'] = lr2
#     for param_group in optimizer.param_groups:
#         param_group['lr'] = lr


def mean_confidence_interval(data, confidence=0.95):
	a = [1.0*np.array(data[i]) for i in range(len(data))]
	n = len(a)
	m, se = np.mean(a), scipy.stats.sem(a)
	h = se * sp.stats.t._ppf((1+confidence)/2., n-1)
	return m, h


def save_checkpoint(state, filename='checkpoint.pth.tar'):
    torch.save(state, filename)


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        res = []

        for k in topk:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True).mul_(100.0 / batch_size).cpu().detach().numpy()
            res.append(correct_k)
        return res  # 返回topk的准确率


def set_save_path(opt):
    """
    settings of the save path
    """
    opt.outf = 'Shot'

    if not os.path.exists(opt.outf):
        os.makedirs(opt.outf)


    # save the opt and results to txt file
    txt_save_path = os.path.join(opt.outf, 'opt_resutls.txt')
    F_txt = open(txt_save_path, 'a+')

    return opt.outf, F_txt


def set_save_test_path(opt, finetune=False):
    """
    Settings of the save path
    """
    if not os.path.exists(opt.outf):
        os.makedirs(opt.outf)

    # save the opt and results to txt file
    if finetune:
        txt_save_path = os.path.join(opt.outf, 'Test_Finetune_resutls.txt')
    else:
        txt_save_path = os.path.join(opt.outf, 'Test_resutls.txt')
    F_txt_test = open(txt_save_path, 'a+')

    return F_txt_test


def get_resume_file(checkpoint_dir, F_txt):
    if os.path.isfile(checkpoint_dir):
        print("=> loading checkpoint '{}'".format(checkpoint_dir))
        print("=> loading checkpoint '{}'".format(checkpoint_dir), file=F_txt)
        checkpoint = torch.load(checkpoint_dir)
        print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_dir, checkpoint['epoch_index']))
        print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_dir, checkpoint['epoch_index']), file=F_txt)

        return checkpoint
    else:
        print("=> no checkpoint found at '{}'".format(checkpoint_dir))
        print("=> no checkpoint found at '{}'".format(checkpoint_dir), file=F_txt)

        return None
    
##############################################################################################
def train(train_loader, model, criterions, optimizer, batch_size, epoch_index, dataset_length, F_txt, device, rank):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1_s = AverageMeter()
    top1_u = AverageMeter()
    end = time.time()
    train_iterator = iter(train_loader)
    mse_criterion = criterions[1]
    entropy_criterion = criterions[0]
    iter_length = int(dataset_length / batch_size)
    sup_align_loss = torch.tensor(0.).to(device)
    sup_align_loss.requires_grad = True
    recon_loss = torch.tensor(0.).to(device)
    recon_loss.requires_grad = True
    for iter_num in range(iter_length):     # 300 = len(dataset) / batchsize
        try:
            imgs, labels = next(train_iterator)
        except StopIteration:
            train_iterator = iter(train_loader)
            imgs, labels = next(train_iterator)
        imgs = imgs.to(device)
        labels = labels.to(device)
        batch_size = imgs.shape[0]
        im_q, im_k, im_s = imgs[:, 0, ...], imgs[:, 1, ...], imgs[:, 2, ...]
        data_time.update(time.time() - end)
        feature_u, logits_u, logits_us, feature_s, logits_s = model(im_q, im_k, im_s)
        labels_u = torch.zeros(logits_u.shape[0], dtype=torch.long).to(device)
        # 计算有监督无监督特征对齐损失
        feature_u_dnorm = nn.functional.normalize(feature_u-torch.mean(feature_u, dim=1, keepdim=True), dim=1)
        feature_s_dnorm = nn.functional.normalize(feature_s-torch.mean(feature_s, dim=1, keepdim=True), dim=1)
        u_s_crossalign_loss = -1.5*torch.mean(torch.einsum('nc,nc->n', [feature_u_dnorm, feature_s_dnorm])) - 0.5*torch.mean(torch.einsum('nc,nc->n', [feature_u, feature_s]))
        # 特征重建损失
#         if epoch_index > 35:
#             recon_features = nn.functional.normalize(recon_features, dim=1)
#             recon_loss = 10*mse_criterion(recon_features, feature_u) - 10*torch.mean(torch.einsum('nc,nc->n', [recon_features, feature_u.detach()]))
        # entropy loss
        entropy_loss = entropy_criterion(logits_s, labels)
        entropy_loss_u = entropy_criterion(logits_us, labels)
#         entropy_loss_recon = entropy_criterion(logits_recon, labels)
        # infoNCE
        infoNCELoss = entropy_criterion(logits_u, labels_u)
        # all loss
        loss = entropy_loss + infoNCELoss + entropy_loss_u + u_s_crossalign_loss
        prec1_s = accuracy(logits_s, labels, topk=(1,))
        prec1_u = accuracy(logits_us, labels, topk=(1,))
        losses.update(loss.item(), batch_size)
        top1_s.update(prec1_s[0].item(), batch_size)
        top1_u.update(prec1_u[0].item(), batch_size)
        optimizer.zero_grad()
#         print("######################################################before####################################")
#         print(model.module.intra_classMean.grad)
        loss.backward()
#         print("###################################################### after ####################################")
#         print(model.module.intra_classInfo.grad)
        optimizer.step()
#         ls = [name for name,para in model.named_parameters() if para.grad==None]
#         print(ls)
        batch_time.update(time.time() - end)
        end = time.time()

        # ============== print the intermediate results ==============#
        if rank == 0:
            if iter_num % opt.print_freq == 0 and iter_num != 0:
                print('Eposide-({0}): [{1}/{2}]\t'
                      'entropy infonce u&s uc {3} {4} {5} {6}\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
                      'Prec1 {top1_s.avg:.3f} ({top1_u.avg:.3f})\t'.format(epoch_index, iter_num, iter_length,
                                                                      round(entropy_loss.item(), 3), round(infoNCELoss.item(), 3),
                                                                      round(u_s_crossalign_loss.item(), 3), round(entropy_loss_u.item(), 3),
                                                                      batch_time=batch_time,
                                                                      data_time=data_time, loss=losses, top1_s=top1_s, top1_u=top1_u))

                print('Eposide-({0}): [{1}/{2}]\t'
                      'entropy infonce u&s uc {3} {4} {5} {6}\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
                      'Prec1 {top1_s.avg:.3f} ({top1_u.avg:.3f})\t'.format(epoch_index, iter_num, iter_length,
                                                                      round(entropy_loss.item(), 3), round(infoNCELoss.item(), 3),
                                                                      round(u_s_crossalign_loss.item(), 3), round(entropy_loss_u.item(), 3),
                                                                      batch_time=batch_time,
                                                                      data_time=data_time, loss=losses, top1_s=top1_s, top1_u=top1_u), file=F_txt)

    return losses

###########################################################################
def val_train(train_num, data, train_model, model, criterion, optimizer):
    x, y = data[0], data[1]
    # loss = torch.tensor(0.).to(device)
    for i in range(train_num):
        feat, logits = train_model.module.classFeatExt(x)       # 这里的优化器优化的模型参数应为未冻结的模型参数

        logits = model(feat)
        loss = criterion(logits, y)

        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
#         if i % 100 ==0:
#             for param_group in optimizer.param_groups:
#                 param_group['lr'] = param_group['lr']/5

    return model


def val(train_model, criterion, epoch, rank):
    top1_val = AverageMeter()
    for task_num in range(1):
        model = distLinear(indim=640, outdim=5).to(device)
        checkpoint_path = "dist_initial_weights.pt"
        if rank == 0:
            torch.save(model.state_dict(), checkpoint_path)
        dist.barrier()
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[opt.gpu])
        val_optimizer = optim.Adam(model.parameters())
        test_loader = get_dataloader(opt, 'val')
        X_q, Y_q, X_s, Y_s = test_loader(10*epoch + task_num)     # 0 is rand seed
        X_q, Y_q, X_s, Y_s = X_q.to(device), Y_q.to(device), X_s.to(device), Y_s.to(device)
        query_nums = Y_q.shape[-1]
        model.train()
#         print("############################  Before Train  ################################")
#         print(model.L.weight)
        model = val_train(30, [X_s, Y_s], train_model, model, criterion, val_optimizer)
#         print("############################  After Train  ################################")
#         print(model.L.weight)
        model.eval()
        feats, _ = train_model.module.classFeatExt(X_q)
        logit = model(feats)
        prec1 = accuracy(logit, Y_q, topk=(1,))
        top1_val.update(prec1[0].item(), query_nums)
#         top1.update(prec1[0].item(), query_nums)
        del val_optimizer
        del model
        if rank == 0:
            print('Eposide-({0}): [{1}/{2}]\t'
                  'prec1 is {3}\t'.format(epoch, task_num, 10, prec1[0].item()))
    
    if rank == 0:               
#         print("测试结果为" + str(top1.avg))
        print("第%d代验证结果为"%epoch_item + str(top1_val.avg))
    return top1_val.avg
    
    
def test(train_model, criterion, epoch, rank):
    top1_val = AverageMeter()
    for task_num in range(1):
        model = distLinear(indim=640, outdim=5).to(device)
        checkpoint_path = "dist_initial_weights.pt"
        if rank == 0:
            torch.save(model.state_dict(), checkpoint_path)
        dist.barrier()
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[opt.gpu])
        val_optimizer = optim.Adam(model.parameters())
        test_loader = get_dataloader(opt, 'test')
        X_q, Y_q, X_s, Y_s = test_loader(10*epoch + task_num)     # 0 is rand seed
        X_q, Y_q, X_s, Y_s = X_q.to(device), Y_q.to(device), X_s.to(device), Y_s.to(device)
        query_nums = Y_q.shape[-1]
        model.train()
#         print("############################  Before Train  ################################")
#         print(model.L.weight)
        model = val_train(30, [X_s, Y_s], train_model, model, criterion, val_optimizer)
#         print("############################  After Train  ################################")
#         print(model.L.weight)
        model.eval()
        feats, _ = train_model.module.classFeatExt(X_q)
        logit = model(feats)
        prec1 = accuracy(logit, Y_q, topk=(1,))
        top1_val.update(prec1[0].item(), query_nums)
        del val_optimizer
        del model
        if rank == 0:
            print('Eposide-({0}): [{1}/{2}]\t'
                  'prec1 is {3}\t'.format(epoch, task_num, 10, prec1[0].item()))
    if rank == 0:        
        print("第%d代测试结果为"%epoch_item + str(top1_val.avg))
    del top1_val
##############################################################################


if __name__ == "__main__":
    rank = opt.rank
    opt.lr *= opt.world_size
    batch_size = opt.batch_size
    top1 = AverageMeter()
    opt.outf, F_txt = set_save_path(opt)

    global best_prec1, epoch_index
    best_prec1 = 0
    epoch_index = 0
    
    dataset = FewShotDataSet(data_dir)
    dataset_length = len(dataset)
    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, 64, drop_last=True)
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_sampler=train_batch_sampler,
                                               pin_memory=True,
                                               num_workers=nw)
    
    
    model = featureAugNet().to(device)
    checkpoint_path = "initial_weights.pt"
    if rank == 0:
        torch.save(model.state_dict(), checkpoint_path)
    dist.barrier()
    pretrain_path = ""
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[opt.gpu])
#     model.load_state_dict(pretrain_state_dict)
    criterion1 = nn.CrossEntropyLoss()
    criterion2 = nn.MSELoss(reduction='sum')
    
    optimizer = optim.Adam(model.parameters(), lr=opt.lr)
#     optimizer.load_state_dict(torch.load(pretrain_path, map_location=device)['optimizer'])
    print(opt)
    print(opt, file=F_txt)
#     print(model)
    print(model, file=F_txt)

    # ============================================ Training phase ========================================
    print('===================================== Training on the train set =====================================')
    print('===================================== Training on the train set =====================================',
          file=F_txt)
    print('Learning rate: %f' % opt.lr)
    print('Learning rate: %f' % opt.lr, file=F_txt)

    Train_losses = []
    Val_losses = []
    Test_losses = []

    for epoch_item in range(1, 201):  # 0:50
        print('==================== Epoch %d ====================' % epoch_item)
        print('==================== Epoch %d ====================' % epoch_item, file=F_txt)
        # ======================================= Loaders of Datasets =======================================
        opt.current_epoch = epoch_item
        train_sampler.set_epoch(epoch_item)
        
#         if epoch_item >20 and epoch_item < 26:
#             optimizer.param_groups[0]['lr'] = 0.0001
#             optimizer.param_groups[1]['lr'] = 0.01
            
        model.train()
        train_loss = train(train_loader, model, [criterion1, criterion2], optimizer, 128, epoch_item, dataset_length, F_txt, device, rank)
        Train_losses.append(train_loss)
        print("################################### val ############################")
        model.eval()
        acc = val(model, criterion1, epoch_item, rank)
        print("################################### test ############################")
        test(model, criterion1, epoch_item, rank)
        adjust_learning_rate(opt, optimizer, epoch_item, F_txt)
        save_checkpoint(
                {
                    'epoch_index': epoch_item,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, os.path.join(opt.outf, 'epoch_tmp.pth.tar'))
        if acc > best_prec1:
            save_checkpoint(
                {
                    'epoch_index': epoch_item,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, os.path.join(opt.outf, 'best_model.pth.tar'))
        if epoch_item % 20 == 0:
            filename = os.path.join(opt.outf, 'save_epoch_%d.pth.tar' % epoch_item)
            save_checkpoint(
                {
                    'epoch_index': epoch_item,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, filename)

    print('======================================== Training is END ========================================\n')
    print('======================================== Training is END ========================================\n',
          file=F_txt)
    F_txt.close()
    if rank == 0:
        if os.path.exists(checkpoint_path) is True:
            os.remove(checkpoint_path)

    dist.destroy_process_group()

Overwriting cv_example.py


### 