In [1]:
import torch
import numpy as np
import os

# 数据预处理
root_dir = 'D:/数据库/小样本数据库/data/omniglot2/raw'

import torchvision.transforms as transforms
from PIL import Image

In [2]:
#an example of img_items:
#( '0709_17.png',                                                                             #file
# 'Alphabet_of_the_Magi/character01',                                                         #r[-2] + "/" + r[-1]
#'D:\\数据库\\小样本数据库\\data\\omniglot2\\raw\\images_evaluation\\Tibetan\\character32')   #root

In [3]:
#返回一个img_items列表，包含所有图片样本，每个样本都是一个三元组（图片名词，语言/字母类别，图片样本所在的目录）
def find_classes(root_dir):
    img_items = []
    #root='D:\\数据库\\小样本数据库\\data\\omniglot2\\raw\\images_evaluation\\Tibetan\\character32'
    #dirs=[]
    #files=['1585_01.png', '1585_02.png', '1585_03.png',,,,,'1585_20.png']
    for (root, dirs, files) in os.walk(root_dir):
        for file in files:  #遍历某种语言某类字母的20张图片
            if (file.endswith("png")):
                r = root.split('/')
                img_items.append((file, r[-2] + "/" + r[-1], root))
    print("== Found %d items " % len(img_items))   #32460个样本
    return img_items

img_items = find_classes(root_dir)
print(img_items[0])

== Found 32460 items 
('0709_01.png', 'omniglot2/raw\\images_background\\Alphabet_of_the_Magi\\character01', 'D:/数据库/小样本数据库/data/omniglot2/raw\\images_background\\Alphabet_of_the_Magi\\character01')


In [4]:
## 构建一个词典{class:idx}  把所有类（具体到所有语言的所有字母类别）构成一个映射，索引按顺序排列
def index_classes(items):
    class_idx = {}
    count = 0
    for item in items:
        if item[1] not in class_idx:
            class_idx[item[1]] = count
            count += 1
    print('== Found {} classes'.format(len(class_idx)))   #1623类
    return class_idx

class_idx = index_classes(img_items)

== Found 1623 classes


In [5]:
temp = dict()  #把所有样本生成一个字典按类别保存，(类别索引：[图片1，图片2，，，图片20])
for imgname, classes, dirs in img_items:
    img = '{}/{}'.format(dirs, imgname)  #某张图片的地址
    label = class_idx[classes]  #某张图片的类别对应的索引
    transform = transforms.Compose([lambda img: Image.open(img).convert('L'),  #转成灰度图
                                    lambda img: img.resize((28, 28)),  #统一大小
                                    lambda img: np.reshape(img, (28, 28, 1)),
                                    lambda img: np.transpose(img, [2, 0, 1]),   #通道在前
                                    lambda img: img / 255.  #归一化
                                    ])
    img = transform(img)  #图像处理
    if label in temp.keys():
        temp[label].append(img)
    else:
        temp[label] = [img]
        
#print(len(temp))  #1623

In [6]:
#存放所有样本图片[[img1,,,,20张图片],[],[],[],,,,,1623个[]]
img_list = []   
for label, imgs in temp.items():
    img_list.append(np.array(imgs))
img_list = np.array(img_list).astype(np.float)
#print('data shape:{}'.format(img_list.shape))  # (1623, 20, 1, 28, 28)
#print(img_list[0][0].shape)  #一张图片的像素数组  (1, 28, 28)

# 分割训练集和测试集

In [7]:
x_train = img_list[:1200]  #前1200类作为训练集
x_val = img_list[1200:1400]  #后200类作为验证集
x_test=img_list[1400:]  #后223类作为测试集
num_classes = img_list.shape[0]  #训练集+测试集一共有1623类

In [8]:
#把训练集转换成一个个样本存放的形式，并生成对应样本的标签
imgs=[]  #训练集所有图片
labels=[] #imgs对应的标签
for i in range(len(x_train)):
    for j in range(20):
        img=x_train[i][j]
        imgs.append(img)
        labels.append(i)
        
#验证集
val_imgs=[]  
val_labels=[] 
for i in range(len(x_val)):
    for j in range(20):
        img=x_val[i][j]
        val_imgs.append(img)
        val_labels.append(i)
        
#测试集
test_imgs=[]  
test_labels=[] 
for i in range(len(x_test)):
    for j in range(20):
        img=x_test[i][j]
        test_imgs.append(img)
        test_labels.append(i)

In [9]:
import torch.utils.data as data

class OmniglotDataset(data.Dataset):
    def __init__(self,imgs,labels):
        self.imgs=imgs
        self.labels=labels
    def __getitem__(self, idx):
        return self.imgs[idx], self.labels[idx]
    def __len__(self):
        return len(self.imgs)

In [10]:
class PrototypicalBatchSampler(object):
    def __init__(self, labels, classes_per_it, num_samples, iterations):
        super(PrototypicalBatchSampler, self).__init__()
        self.labels = labels
        self.classes_per_it = classes_per_it
        self.sample_per_class = num_samples  #s+q
        self.iterations = iterations

        self.classes, self.counts = np.unique(self.labels, return_counts=True)
        #print(self.classes, self.counts )   #[0~1199]   [20,,,,,,20]
        self.classes = torch.LongTensor(self.classes)

        self.idxs = range(len(self.labels)) #[0~2399]
        self.indexes = np.empty((len(self.classes), max(self.counts)), dtype=int) * np.nan
        self.indexes = torch.Tensor(self.indexes)
        
        self.numel_per_class = torch.zeros_like(self.classes)
        for idx, label in enumerate(self.labels):
            label_idx = np.argwhere(self.classes == label).item()
            self.indexes[label_idx, np.where(np.isnan(self.indexes[label_idx]))[0][0]] = idx
            self.numel_per_class[label_idx] += 1
        #print(self.indexes[2][15])  #  55     self.indexes ([1200, 20])   表示第n类的第k张图片在数据集中排第几个样本
        #print(self.numel_per_class[0])  #   20    self.numel_per_class([1200])  
    
    #产生一个batch样本的索引
    def __iter__(self):
        spc = self.sample_per_class
        cpi = self.classes_per_it

        for it in range(self.iterations):  #一个batch包含self.iterations个任务,iterations相当于eposides
            batch_size = spc * cpi  #一个任务需要的数据量（支持集+查询集）
            batch = torch.LongTensor(batch_size)
            c_idxs = torch.randperm(len(self.classes))[:cpi]  #从数据集1200个类别中随机选取cpi个类别
            #print(c_idxs)  #([ 335,   56, 1139,  991,  137])
            for i, c in enumerate(self.classes[c_idxs]):   #从被选出的每个类别中随机选取spc个样本
                s = slice(i * spc, (i + 1) * spc) #每个类被选出的所有样本在新生成batch中的存放位置
                #print('c=%d',c)  #c分别为335,   56, 1139,  991,  137
                #print('i=%d',i) #0~4
                #label_idx = torch.arange(len(self.classes)).long()[self.classes == c].item()
                label_idx=c
                sample_idxs = torch.randperm(20)[:spc]  #从某一类20个样本中随机选择spc个样本
                batch[s] = self.indexes[label_idx][sample_idxs]
            batch = batch[torch.randperm(len(batch))]
            yield batch
    def __len__(self):
        return self.iterations

In [11]:
traindata=OmniglotDataset(imgs,labels)
trainsampler=PrototypicalBatchSampler(labels,5,15,28)
dataloader = torch.utils.data.DataLoader(traindata, batch_sampler=trainsampler)

#验证集
valdata=OmniglotDataset(val_imgs,val_labels)
valsampler=PrototypicalBatchSampler(val_labels,5,15,28)
val_dataloader = torch.utils.data.DataLoader(valdata, batch_sampler=valsampler)

#测试集
testdata=OmniglotDataset(test_imgs,test_labels)
testsampler=PrototypicalBatchSampler(test_labels,5,15,28)
test_dataloader = torch.utils.data.DataLoader(testdata, batch_sampler=testsampler)

In [12]:
#len(next(iter(dataloader))) #[75个图片矩阵，75个标签]

# 构建原型网络模型

In [13]:
import torch.nn as nn

def conv_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )


class ProtoNet(nn.Module):
    def __init__(self, x_dim=1, hid_dim=64, z_dim=64):
        super(ProtoNet, self).__init__()
        self.encoder = nn.Sequential(
            conv_block(x_dim, hid_dim),
            conv_block(hid_dim, hid_dim),
            conv_block(hid_dim, hid_dim),
            conv_block(hid_dim, z_dim),
        )

    def forward(self, x):
        x=x.type(torch.cuda.FloatTensor)
        x = self.encoder(x)
        return x.view(x.size(0), -1)
    
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = ProtoNet().to(device)
print(model)

ProtoNet(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (3): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), st

In [14]:
optim=torch.optim.Adam(params=model.parameters(),lr=0.001)
lr_scheduler=torch.optim.lr_scheduler.StepLR(optimizer=optim,gamma=0.5,step_size=20)

# 定义损失函数

In [15]:
from torch.nn import functional as F

#欧式距离
def euclidean_dist(x, y):
    # x: N x D
    # y: M x D
    n = x.size(0)  #查询集的样本数=原型的数*n_query
    m = y.size(0)  #原型的数
    d = x.size(1)  #每个样本的嵌入网络输出维度=每个原型的嵌入维度
    if d != y.size(1):
        raise Exception

    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)

    return torch.pow(x - y, 2).sum(2)  #返回（n,m）

def prototypical_loss(input, target, n_support):
    target_cpu = target.to('cpu')   #75个样本的类别
    input_cpu = input.to('cpu')  #75个样本全部网络嵌入后的输出结果

    #返回出类别c选出的支持集样本索引
    def supp_idxs(c):
        # FIXME when torch will support where as np
        return target_cpu.eq(c).nonzero()[:n_support].squeeze(1)  #从类别c的所有样本中取n_support个作为支持集

    # FIXME when torch.unique will be available on cuda too
    classes = torch.unique(target_cpu)  #一共有哪些类别
    n_classes = len(classes)  #一共有多少给类别
    # FIXME when torch will support where as np
    # assuming n_query, n_target constants
    n_query = target_cpu.eq(classes[0].item()).sum().item() - n_support  #每个类别各有多少个样本组成查询集

    support_idxs = list(map(supp_idxs, classes))   #全部支持集样本的索引[[类别1]，,,[类别n_classes]]

    prototypes = torch.stack([input_cpu[idx_list].mean(0) for idx_list in support_idxs])   #计算出支持集每个类别的原型
    # FIXME when torch will support where as np
    #全部查询集样本的索引
    query_idxs = torch.stack(list(map(lambda c: target_cpu.eq(c).nonzero()[n_support:], classes))).view(-1)

    query_samples = input.to('cpu')[query_idxs]     #取出所有支持集样本的嵌入向量
    dists = euclidean_dist(query_samples, prototypes)  #返回（查询集样本数，原型类别数）的距离矩阵

    log_p_y = F.log_softmax(-dists, dim=1).view(n_classes, n_query, -1)

    target_inds = torch.arange(0, n_classes)
    target_inds = target_inds.view(n_classes, 1, 1)
    target_inds = target_inds.expand(n_classes, n_query, 1).long()

    loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
    _, y_hat = log_p_y.max(2)  #把log_p_y最小的原型对应的类别定为查询集样本的类别
    acc_val = y_hat.eq(target_inds.squeeze()).float().mean()

    return loss_val,  acc_val

# 训练模型

In [16]:
from tqdm import tqdm

train_loss = []
train_acc = []

val_loss = []
val_acc = []

best_acc = 0
best_model_path='C:/Users/user/Desktop/proto_net/best_protonet/best_model.pth'
last_model_path='C:/Users/user/Desktop/proto_net/best_protonet/last_model.pth'

#遍历epoch训练
for epoch in range(5):
    print('=== Epoch: {} ==='.format(epoch))
    #取一个batch的训练数据
    tr_iter = iter(dataloader)
    model.train()
    for batch in tqdm(tr_iter):  #遍历每个batch   #len(next(iter(dataloader))) #[75个图片矩阵，75个标签]
        optim.zero_grad()  #梯度清零
        x, y = batch
        #print(len(x))  #75
        #print(x.shape)  #([75, 1, 28, 28])
        x, y = x.to(device), y.to(device)
        model_output = model(x)  #前传
        loss, acc = prototypical_loss(model_output, target=y,n_support=5)  #计算loss
        loss.backward()  #反传
        optim.step()  #参数更新
        train_loss.append(loss.item())   #加入这个batch的计算loss
        train_acc.append(acc.item())  #加入这个batch的计算acc
    avg_loss = np.mean(train_loss[-100:])  #计算本次epoch的所有batch迭代的平均损失
    avg_acc = np.mean(train_acc[-100:])   #计算本次epoch的所有batch迭代的平均准确率
    print('Avg Train Loss: {}, Avg Train Acc: {}'.format(avg_loss, avg_acc))
    lr_scheduler.step()   #更新学习率
    
    #一边训练一边验证
    val_iter = iter(val_dataloader)
    model.eval()
    for batch in val_iter:
        x, y = batch
        x, y = x.to(device), y.to(device)
        model_output = model(x)
        loss, acc = prototypical_loss(model_output, target=y,n_support=5)
        val_loss.append(loss.item())
        val_acc.append(acc.item())

    avg_loss = np.mean(val_loss[-100:])
    avg_acc = np.mean(val_acc[-100:])

    #如果当前模型的验证效果比先前好，则把当前模型记为最佳模型
    postfix = ' (Best)' if avg_acc >= best_acc else ' (Best: {})'.format(best_acc)

    print('Avg Val Loss: {}, Avg Val Acc: {}{}'.format(  #输出本次验证的平均损失和准确率
            avg_loss, avg_acc, postfix))

    #保存某个epoch内所有batch得到的最佳模型、准确率
    if avg_acc >= best_acc:
        torch.save(model.state_dict(), best_model_path)
        best_acc = avg_acc
        best_state = model.state_dict()

#保存最终所有epoch得到的最佳模型
torch.save(model.state_dict(), last_model_path)

print('最后的结果： best_acc=%d, train_loss=%d, train_acc=%d, val_loss=%d, val_acc=%d' % (best_acc, train_loss[-1], train_acc[-1], val_loss[-1], val_acc[-1]))


  0%|                                                                                           | 0/28 [00:00<?, ?it/s]

=== Epoch: 0 ===


100%|██████████████████████████████████████████████████████████████████████████████████| 28/28 [00:02<00:00, 11.65it/s]


Avg Train Loss: 0.4625422563403845, Avg Train Acc: 0.8357142827340535


  7%|█████▉                                                                             | 2/28 [00:00<00:01, 18.52it/s]

Avg Val Loss: 1.0961973326546806, Avg Val Acc: 0.7800000054495675 (Best)
=== Epoch: 1 ===


100%|██████████████████████████████████████████████████████████████████████████████████| 28/28 [00:01<00:00, 18.66it/s]


Avg Train Loss: 0.398562701685088, Avg Train Acc: 0.8639285670859473


 11%|████████▉                                                                          | 3/28 [00:00<00:01, 22.05it/s]

Avg Val Loss: 0.7758758286280292, Avg Val Acc: 0.8110714278050831 (Best)
=== Epoch: 2 ===


100%|██████████████████████████████████████████████████████████████████████████████████| 28/28 [00:01<00:00, 19.15it/s]


Avg Train Loss: 0.34443896461189505, Avg Train Acc: 0.8811904724155154


 11%|████████▉                                                                          | 3/28 [00:00<00:01, 21.41it/s]

Avg Val Loss: 0.6436975156622273, Avg Val Acc: 0.8330952383223034 (Best)
=== Epoch: 3 ===


100%|██████████████████████████████████████████████████████████████████████████████████| 28/28 [00:01<00:00, 19.42it/s]


Avg Train Loss: 0.25318576232530177, Avg Train Acc: 0.9151999974250793


 11%|████████▉                                                                          | 3/28 [00:00<00:01, 22.24it/s]

Avg Val Loss: 0.5097465296834707, Avg Val Acc: 0.8537999975681305 (Best)
=== Epoch: 4 ===


100%|██████████████████████████████████████████████████████████████████████████████████| 28/28 [00:01<00:00, 19.96it/s]


Avg Train Loss: 0.1871448717173189, Avg Train Acc: 0.9335999995470047
Avg Val Loss: 0.3491128814406693, Avg Val Acc: 0.8827999967336655 (Best)
最后的结果： best_acc=0, train_loss=0, train_acc=0, val_loss=0, val_acc=0


# 测试模型

In [17]:
avg_acc = list()
for epoch in range(10):
    test_iter = iter(test_dataloader)
    for batch in test_iter:
        x, y = batch
        x, y = x.to(device), y.to(device)
        model_output = model(x)
        _, acc = prototypical_loss(model_output, target=y,n_support=5)
        avg_acc.append(acc.item())
avg_acc = np.mean(avg_acc)
print('Test Acc: {}'.format(avg_acc))

Test Acc: 0.9102857112884521


In [18]:
model=None
#加载训练最好的模型参数
model = ProtoNet().to(device)
model.load_state_dict(best_state)
print('Testing with best model..')
#测试模型
avg_acc = list()
for epoch in range(10):
    test_iter = iter(test_dataloader)
    for batch in test_iter:
        x, y = batch
        x, y = x.to(device), y.to(device)
        model_output = model(x)
        _, acc = prototypical_loss(model_output, target=y,n_support=5)
        avg_acc.append(acc.item())
avg_acc = np.mean(avg_acc)
print('Test Acc: {}'.format(avg_acc))

Testing with best model..
Test Acc: 0.9153571414096014
