In [None]:
# -*- coding: utf-8 -*-
import torch
from torch import nn
from torchvision import models
import torch.nn.functional as F
import random
import argparse
import os.path as osp
import os
import math
from os.path import join
from torch.utils.data import Dataset, DataLoader
import torch.utils.model_zoo as model_zoo
import shutil
import time
import pprint
from os.path import basename, dirname, join, exists, isdir
import glob
from PIL import Image
import os.path as osp
from PIL import Image
import numpy as np
import sklearn
import pickle
from sklearn.linear_model import LogisticRegression

In [None]:
#utils
def set_gpu(x):
    os.environ['CUDA_VISIBLE_DEVICES'] = x
    print('using gpu:', x)


def ensure_path(path):
    if os.path.exists(path):
        if input('{} exists, remove? ([y]/n)'.format(path)) != 'n':
            shutil.rmtree(path)
            os.makedirs(path)
    else:
        os.makedirs(path)


class Averager():

    def __init__(self):
        self.n = 0
        self.v = 0

    def add(self, x):
        self.v = (self.v * self.n + x) / (self.n + 1)
        self.n += 1

    def item(self):
        return self.v


def count_acc(logits, label):
    pred = torch.argmax(logits, dim=1)
    # label_acc = torch.argmax(label,dim=1)
    return (pred == label).type(torch.cuda.FloatTensor).mean().item()

# def count_acc2(logits, label):
#     pred = torch.argmax(logits, dim=1)
    
#     return (pred == label).type(torch.cuda.FloatTensor).mean().item()

def dot_metric(a, b):
    return torch.mm(a, b.t())


def euclidean_metric(a, b):
    n = a.shape[0]
    m = b.shape[0]
    a = a.unsqueeze(1).expand(n, m, -1)
    b = b.unsqueeze(0).expand(n, m, -1)
    logits = -((a - b) ** 2).sum(dim=2)
    return logits


class Timer():

    def __init__(self):
        self.o = time.time()

    def measure(self, p=1):
        x = (time.time() - self.o) / p
        x = int(x)
        if x >= 3600:
            return '{:.1f}h'.format(x / 3600)
        if x >= 60:
            return '{}m'.format(round(x / 60))
        return '{}s'.format(x)


_utils_pp = pprint.PrettyPrinter()


def pprint(x):
    _utils_pp.pprint(x)


def l2_loss(pred, label):
    return ((pred - label) ** 2).sum() / len(pred) / 2

#核函数
class KernelBase(nn.Module):
    """Kernel base class for kernels, all kernels inherit from this
    """
    def __add__(self,other):
        return AdditiveKernels(self,other)
    def __mul__(self,other):
        return MultiplicativeKernels(self,other)
class MultiplicativeKernels(KernelBase):
    def __init__(self,k1,k2):
        super().__init__()
        self.k1 = k1
        self.k2 = k2
    def __call__(self,x,y):
        return self.k1(x,y)*self.k2(x,y)
class AdditiveKernels(KernelBase):
    def __init__(self,k1,k2):
        super().__init__()
        self.k1 = k1
        self.k2 = k2
    def __call__(self,x,y):
        return self.k1(x,y)+self.k2(x,y)

class RBF(KernelBase):
    def __init__(self, length=1., scale=1., trainable=False):
        super().__init__()
        if trainable:
            self.length = nn.Parameter(torch.tensor(length)) 
            self.scale = nn.Parameter(torch.tensor(scale))
        else:
            self.length = torch.tensor(length)
            self.scale = torch.tensor(scale)
    def forward(self,x,y):
        """ 
        """
        B,_,D = x.shape
        return self.scale**2*torch.exp(-torch.cdist(x, y, p=2.0)**2/(2*self.length**2*D))

    def __str__(self):
        return "RBF-kernel"


In [None]:
#backbone
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from torchvision import models
# **********************parameter*******************************************************
# depth = 26
path = r"F:\Medical Image Segmentation\code\pretrained_model_weights\resnet18-5c106cde.pth"
# *************************BACKBONES****************************************************
#特征提取模块resnet12
class backbonenetwork(nn.Module):
    def __init__(self):
        super(backbonenetwork, self).__init__()
        self.model = models.resnet18(pretrained=False)
        self.model.load_state_dict(torch.load(path))
        self.model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.model.fc = nn.Sequential()
        
    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)
        # x = x.view(x.size(0),-1)
        # x = self.fc(x)
        
        return x

    
class backbonenetwork2(nn.Module):
    def __init__(self):
        super(backbonenetwork2, self).__init__()
        self.model = models.resnet18(pretrained=False)
        self.model.load_state_dict(torch.load(path))
        self.model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.model.fc = nn.Sequential()
        
    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)
        
        return x 

    
    
    
class mergedbackbone_encoder(nn.Module) :
    def __init__(self,stem1,stem2):
        super(mergedbackbone_encoder,self).__init__()
        self.encoder1 = stem1
        self.encoder2 = stem2
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(in_features=1024,out_features=640)
        
    def forward(self,x):
        x2 = x.clone()
        x1 = self.encoder1(x)
        x2 = self.encoder2(x2)
        out = torch.cat((x1,x2),dim=0)
        out = self.flatten(out)
        out = out.view(x.size(0),-1) 
        out = self.fc(out)
        return out   
    
    
    
def conv_block(in_channels, out_channels):
    bn = nn.BatchNorm2d(out_channels)
    nn.init.uniform_(bn.weight) # for pytorch 1.2 or later
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        bn,
        nn.ReLU(),
        nn.MaxPool2d(2)
    )

class Convnet(nn.Module):

    def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
        super().__init__()
        self.encoder = nn.Sequential(
            conv_block(x_dim, hid_dim),
            conv_block(hid_dim, hid_dim),
            conv_block(hid_dim, hid_dim),
            conv_block(hid_dim, z_dim),
        )
        self.out_channels = 1600

    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0),-1).unsqueeze(0)
        
backbonenetwork = backbonenetwork()
backbonenetwork2 = backbonenetwork2()
for name, parameter in backbonenetwork.named_parameters():  # 冻结network1的全部参数
    parameter.requires_grad = False

for name, parameter in backbonenetwork2.named_parameters():
    parameter.requires_grad = True


In [None]:
#datasets
from torch.utils.data import Dataset
from torchvision import transforms
ROOT_PATH = '/root/autodl-tmp/triplet_mergednet/materials/omniglot'
seed_num = 42

class MiniImageNet(Dataset):

    def __init__(self, setname):
        csv_path = osp.join(ROOT_PATH, setname + '.csv')
        lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]

        data = []
        label = []
        lb = -1

        self.wnids = []

        for l in lines:
            name, wnid = l.split(',')
            path = osp.join(ROOT_PATH, 'images', name)
            if wnid not in self.wnids:
                self.wnids.append(wnid)
                lb += 1
            data.append(path)
            label.append(lb)

        self.data = data
        self.label = label

        self.transform = transforms.Compose([
            transforms.Resize(84),
            transforms.CenterCrop(84),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        path, label = self.data[i], self.label[i]
        image = self.transform(Image.open(path).convert('RGB')) 
        return image, label

class CategoriesSampler():

    def __init__(self, label, n_batch, n_cls, n_per):
        self.n_batch = n_batch
        self.n_cls = n_cls
        self.n_per = n_per

        label = np.array(label)
        self.m_ind = []
        for i in range(max(label) + 1):
            ind = np.argwhere(label == i).reshape(-1)
            ind = torch.from_numpy(ind)
            self.m_ind.append(ind)

    def __len__(self):
        return self.n_batch
    
    def __iter__(self):
        for i_batch in range(self.n_batch):
            batch = []
            classes = torch.randperm(len(self.m_ind))[:self.n_cls]
            for c in classes:
                l = self.m_ind[c]
                pos = torch.randperm(len(l))[:self.n_per]
                batch.append(l[pos])
            batch = torch.stack(batch).t().reshape(-1)
            yield batch



In [None]:
# ************************tripletloss dataset********************************************************************
# ************************tripletloss dataset********************************************************************
# ************************tripletloss dataset********************************************************************
def get_images_to_label():
    train_csv_path = osp.join(ROOT_PATH, 'train' + '.csv')
    lines = open(train_csv_path, 'r').readlines()[1:]

    key_images = {}
    key_lable = {}

    for line in lines:
        parts = line[:-1].split(',')
        label = parts[0]
        images = parts[1]
        key_images[images] = label

        if label in key_lable:
            key_lable[label].append(images)
        else:
            key_lable[label] = [images]

    return key_images, key_lable


def get_train_subset(train_label_to_images, nc, seed_num):
    random.seed(seed_num)
    return {train_label: random.sample(images, nc) for train_label, images in train_label_to_images.items()}


def get_test_subset(test_image_to_label, val_subset, seed_num):
    random.seed(seed_num)
    if val_subset < len(test_image_to_label):
        return {test_image: label for test_image, label in random.sample(test_image_to_label.items(), val_subset)}
    else:
        return test_image_to_label


#######################################
############ triplet stuff ############
#######################################
def generate_triplet(label_to_sentences, sentence_to_labels, mb_size=5):
    transformer = transforms.Compose([
        transforms.Resize(128),
        transforms.CenterCrop(128),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])])
    labels = label_to_sentences.keys()
    label_p, label_n = random.sample(labels, 2)
    value_p = []
    value_n = []
    for label, value in sentence_to_labels.items():
        if str(value[0]) == label_p:
            value_p.append(label)
    for label, value in sentence_to_labels.items():
        if value[0] == label_n:
            value_n.append(label)
    anchor_list = []
    pos_list = []
    neg_list = []
    for _ in range(mb_size):
        anchor_path, pos_path = np.random.choice(value_p, 2)
        neg_path = np.random.choice(value_n, 1)
        anchor = transformer(Image.open(os.path.join(ROOT_PATH, 'images', "".join(anchor_path))).convert('RGB'))
        pos = transformer(Image.open(os.path.join(ROOT_PATH, 'images', "".join(pos_path))).convert('RGB'))
        neg = transformer(Image.open(os.path.join(ROOT_PATH, 'images', "".join(neg_path))).convert('RGB'))
        anchor_list.append(torch.tensor(anchor))
        pos_list.append(torch.tensor(pos))
        neg_list.append(torch.tensor(neg))
        an = torch.tensor([item.cpu().detach().numpy() for item in anchor_list]).cuda()
        po = torch.tensor([item.cpu().detach().numpy() for item in pos_list]).cuda()
        ne = torch.tensor([item.cpu().detach().numpy() for item in neg_list]).cuda()

    return an, po, ne


def generate_triplet_batch(key_lable, train_image_to_embedding, device, mb_size=5):
    anchor_list = []
    pos_list = []
    neg_list = []
    for _ in range(mb_size):
        anchor, pos, neg = generate_triplet(key_lable)
        anchor_list.append(train_image_to_embedding[anchor])
        pos_list.append(train_image_to_embedding[pos])
        neg_list.append(train_image_to_embedding[neg])

    anchor_embeddings = torch.tensor(anchor_list)
    pos_embeddings = torch.tensor(pos_list)
    neg_embeddings = torch.tensor(neg_list)

    return anchor_embeddings.to(device), pos_embeddings.to(device), neg_embeddings.to(device)


In [None]:
#Distribution Calibration
def distribution_calibration(query, base_means, base_cov, k,alpha=0.21):
    dist = []
    for i in range(len(base_means)):
        dist.append(np.linalg.norm(query-base_means[i]))
    index = np.argpartition(dist, k)[:k]
    mean = np.concatenate([np.array(base_means)[index], query[np.newaxis, :]])
    calibrated_mean = np.mean(mean, axis=0)
    calibrated_cov = np.mean(np.array(base_cov)[index], axis=0)+alpha

    return calibrated_mean, calibrated_cov

In [None]:
#Base feature
base_means = []
base_cov = []
base_features_path = "/root/autodl-tmp/triplet_mergednet/materials/miniImagenet-20230719T122414Z-001/miniImagenet/base_features.plk"
val_features_path = "/root/autodl-tmp/triplet_mergednet/materials/miniImagenet-20230719T122414Z-001/miniImagenet/val_features.plk"
with open(val_features_path, 'rb') as f:
    data = pickle.load(f)
    for key in data.keys():
        feature = np.array(data[key])
        mean = np.mean(feature, axis=0)
        cov = np.cov(feature.T)
        base_means.append(mean)
        base_cov.append(cov)

In [None]:
# ********************************tripletloss*******************************************************************
class TripletLoss(nn.Module):
    """
    Triplet loss
    Takes embeddings of an anchor sample, a positive sample and a negative sample
    """

    def __init__(self, margin=0.4, distance_type="C", account_for_nonzeros=False):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.distance_type = distance_type.lower().strip()
        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        self.account_for_nonzeros = account_for_nonzeros

    def forward(self, anchor, positive, negative):

        if self.distance_type == "c":
            # cosine distance
            distance_positive = -self.cos(anchor, positive)
            distance_negative = -self.cos(anchor, negative)
            losses = F.relu(distance_positive - distance_negative + self.margin)

        elif self.distance_type == "e":

            # this is using Euclidean distance
            distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
            distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
            losses = F.relu(distance_positive - distance_negative + self.margin)
        else:
            raise Exception('please specify distance_type as C or E')

        semi_hard_indexes = [i for i in range(len(losses)) if losses[i] > 0]
        percent_activated = len(semi_hard_indexes) / len(losses)
        if self.account_for_nonzeros:
            loss = losses.sum() / len(semi_hard_indexes)
        else:
            loss = losses.mean()

        return loss, percent_activated


tripletloss = TripletLoss()

In [None]:
#Train
import argparse
import os.path as osp
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--max-epoch', type=int, default=100)
    parser.add_argument('--proto-epoch',type=int,default=100)
    parser.add_argument('--save-epoch', type=int, default=30)
    parser.add_argument('--shot', type=int, default=5)
    parser.add_argument('--query', type=int, default=15)
    parser.add_argument('--train-way', type=int, default=30)
    parser.add_argument('--test-way', type=int, default=5)
    parser.add_argument('--theta',type=float,default=0.3)
    parser.add_argument('--save-path', default='./save-theta0.3/omniglot-5')
    parser.add_argument('--gpu', default='0')
    args = parser.parse_args(args=[])
    pprint(vars(args))

    set_gpu(args.gpu)

    acc_list=[]
    trainset = MiniImageNet('train')
    train_sampler = CategoriesSampler(trainset.label, 50,
                                      args.train_way, args.shot + args.query)
    train_loader = DataLoader(dataset=trainset, batch_sampler=train_sampler,
                              num_workers=0, pin_memory=True)

    valset = MiniImageNet('val')
    val_sampler = CategoriesSampler(valset.label, 20,
                                    args.test_way, args.shot + args.query)
    val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler,
                            num_workers=0, pin_memory=True)
  
    train_key_lable, train_key_images = get_images_to_label()
    
    model = mergedbackbone_encoder(backbonenetwork,backbonenetwork2).cuda()
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    def save_model(name):
        torch.save(model.state_dict(), osp.join(args.save_path, name + '.pth'))
    
    trlog = {}
    trlog['args'] = vars(args)
    trlog['train_loss'] = []
    trlog['val_loss'] = []
    trlog['train_acc'] = []
    trlog['val_acc'] = []
    trlog['max_acc'] = 0.0
    percent_activated_list = []
    timer = Timer()

    for epoch in range(1, args.max_epoch + 1):
        if epoch < args.proto_epoch+1:
            lr_scheduler.step()

            model.train()

            tl = Averager()
            ta = Averager()
            for i, batch in enumerate(train_loader, 1):
                anchor, pos, neg = generate_triplet(train_key_lable, train_key_images)
                anchor_encoding = model(anchor.cuda())
                pos_encoding = model(pos.cuda())
                neg_encoding = model(neg.cuda())
                triplet_loss, percent_activated = tripletloss(anchor_encoding, pos_encoding, neg_encoding)
                percent_activated_list.append(percent_activated)
                
                data, _ = [_.cuda() for _ in batch]
                p = args.shot * args.train_way
                data_shot, data_query = data[:p], data[p:]

                proto = model(data_shot)
                proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0)
                data_shot_proto = model(data_shot)
                data_query_proto = model(data_query)
                support_data = data_shot_proto.detach().cpu().numpy()
                support_label = label[:p].detach().cpu().numpy()
                query_data = data_query_proto.detach().cpu().numpy()
                query_label = label[p:].detach().cpu().numpy()
        
                sampled_data = []
                sampled_label = []
                num_sampled = int(750/args.shot)
                label = torch.arange(args.train_way).repeat(args.query)
                label = label.type(torch.cuda.LongTensor)

                logits = euclidean_metric(model(data_query), proto)
                
                total_loss = (1-args.theta) * loss + args.theta * triplet_loss
                
                for i in range(p):
                    mean, cov = distribution_calibration(support_data[i], base_means, base_cov, k=2)
                    sampled_data.append(np.random.multivariate_normal(mean=mean, cov=cov, size=num_sampled))
                    sampled_label.extend([support_label[i]]*num_sampled)
                sampled_data = np.concatenate([sampled_data[:]]).reshape(args.test_way * args.shot * num_sampled, -1)
                X_aug = np.concatenate([support_data, sampled_data])
                Y_aug = np.concatenate([support_label, sampled_label])
            # ---- train classifier
                classifier = LogisticRegression(max_iter=800).fit(X=X_aug, y=Y_aug)

                predicts = classifier.predict(query_data)
                acc = np.mean(predicts == query_label)
                acc_list.append(acc) 
                print('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}'
                  .format(epoch, i, len(train_loader), total_loss.item(), acc))

                tl.add(total_loss.item())
                ta.add(acc)
                model.zero_grad
                optimizer.zero_grad()
            
                total_loss.backward()
                optimizer.step()
           
            proto = None; logits = None; total_loss = None; loss = None
            print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.proto_epoch)))
            tl = tl.item()
            ta = ta.item()
        
            if acc > trlog['max_acc']:
                trlog['max_acc'] = acc
                save_model('max-acc')

            trlog['train_loss'].append(tl)
            trlog['train_acc'].append(ta)

            torch.save(trlog, osp.join(args.save_path, 'trlog'))

            save_model('epoch-last')

            if epoch % args.save_epoch == 0:
                save_model('epoch-{}'.format(epoch))


In [None]:
#test classifier
import argparse
import os.path as osp
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--max-epoch', type=int, default=200)
    parser.add_argument('--proto-epoch',type=int,default=200)
    parser.add_argument('--save-epoch', type=int, default=30)
    parser.add_argument('--load',default='/root/autodl-tmp/triplet_mergednet/save-theta0.3/omniglot-5/epoch-last.pth')
    parser.add_argument('--batch',type=int,default=500)
    parser.add_argument('--shot', type=int, default=5)
    parser.add_argument('--query', type=int, default=15)
    parser.add_argument('--train-way', type=int, default=20)
    parser.add_argument('--test-way', type=int, default=5)
    parser.add_argument('--theta',type=float,default=0.3)
    parser.add_argument('--save-path', default='./save-theta0.3/test-omniglot-5')
    parser.add_argument('--gpu', default='0')
    args = parser.parse_args(args=[])
    pprint(vars(args))

    set_gpu(args.gpu)
   

    testset = MiniImageNet('test')
    test_sampler = CategoriesSampler(testset.label,
                                args.batch, args.test_way, args.shot + args.query)
    test_loader = DataLoader(testset, batch_sampler=test_sampler,
                        num_workers=0, pin_memory=True)
    
    model = mergedbackbone_encoder(backbonenetwork,backbonenetwork2).cuda()
    
    test_acc = Averager()
    trlog = {}
    trlog['args'] = vars(args)
    trlog['train_loss'] = []
    trlog['val_loss'] = []
    trlog['train_acc'] = []
    trlog['val_acc'] = []
    trlog['max_acc'] = 0.0
    percent_activated_list = []
    timer = Timer()
    model.load_state_dict(torch.load(args.load))

    model.eval()
            
    acc_list = []
    for i, batch in enumerate(test_loader, 1):
        data, label = [_.cuda() for _ in batch]
        p = args.shot * args.test_way
        data_shot, data_query = data[:p], data[p:]
        data_shot_proto = model(data_shot)
        data_query_proto = model(data_query)
        support_data = data_shot_proto.detach().cpu().numpy()
        support_label = label[:p].detach().cpu().numpy()
        query_data = data_query_proto.detach().cpu().numpy()
        query_label = label[p:].detach().cpu().numpy()
        
        sampled_data = []
        sampled_label = []
        num_sampled = int(750/args.shot)
        for i in range(p):
            mean, cov = distribution_calibration(support_data[i], base_means, base_cov, k=2)
            sampled_data.append(np.random.multivariate_normal(mean=mean, cov=cov, size=num_sampled))
            sampled_label.extend([support_label[i]]*num_sampled)
        sampled_data = np.concatenate([sampled_data[:]]).reshape(args.test_way * args.shot * num_sampled, -1)
        X_aug = np.concatenate([support_data, sampled_data])
        Y_aug = np.concatenate([support_label, sampled_label])
            # ---- train classifier
        classifier = LogisticRegression(max_iter=800).fit(X=X_aug, y=Y_aug)

        predicts = classifier.predict(query_data)
        acc = np.mean(predicts == query_label)
        acc_list.append(acc)
  
        print('%s %d way %d shot  ACC : %f'%(test_loader,args.test_way,args.shot,float(np.mean(acc_list))))
        trlog['val_acc'].append(acc)
       