In [None]:
import torch
import torch.nn.functional as F 
import torch.nn as nn
from torch.autograd import Variable,Function
import torchvision.models as models
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
import numpy as np
import scipy.ndimage
import torch.optim as optim
import matplotlib.pyplot as plt
from PIL import Image
import os
import shutil
from scipy.io import loadmat
from scipy.io import savemat
from scipy.spatial.distance import pdist
import cv2
from img_enhance import enh_contrast
import math

# 不显示warning
import warnings
warnings.filterwarnings('ignore')

use_gpu = torch.cuda.is_available() # 检测是否可以使用GPU, use_gpu的值为True则可以使用GPU
print(use_gpu)

In [None]:
# 使重复实验时结果不变
np.random.seed(0)
torch.manual_seed(0)

In [None]:
classNum = len(os.listdir('Periocular_Class/'))
print('classNum:%s'%classNum)

In [None]:
# 用于训练的网络

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
    
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out


class MyNet_Eye(nn.Module):

    def __init__(self, block, layers, num_classes=classNum, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(MyNet_Eye, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(512 * block.expansion, num_classes)
        self.centers = torch.zeros(classNum, 512).type(torch.FloatTensor)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))
        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        x = torch.cat((x, x, x), 1)  # torch.Size([69, 3, 224, 224])
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        self.features = torch.flatten(x, 1)
        return self.features

    def forward(self, x):
        return self._forward_impl(x)

In [None]:
# 用于训练的网络

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
    
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out


class MyNet_Iris(nn.Module):

    def __init__(self, block, layers, num_classes=classNum, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(MyNet_Iris, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(512 * block.expansion, num_classes)
        self.centers = torch.zeros(classNum, 512).type(torch.FloatTensor)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))
        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        x = torch.cat((x, x, x), 1)  # torch.Size([69, 3, 224, 224])
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        self.features = torch.flatten(x, 1)
        return self.features

    def forward(self, x):
        return self._forward_impl(x)

In [None]:
# Iris-Periocular-Attention-Feature-Fusion
class AttentionFusionNet(nn.Module): # 数据格式是(batch_size, channel, height, weight)
    def __init__(self):
        super(AttentionFusionNet, self).__init__()
        self.model_eye = model_eye
        self.model_iris = model_iris
        self.fc1 = nn.Sequential(
            nn.Linear(1024, 1024, bias=False),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Softmax()
        )
        self.gamma = nn.Parameter(torch.zeros(1))
        self.fc2 = nn.Linear(1024, classNum, bias=False)
        self.centers = torch.zeros(classNum, 1024).type(torch.FloatTensor)
            
    def forward(self, eye_image, iris_image):
        eye_feat = self.model_eye(eye_image)
        iris_feat = self.model_iris(iris_image)
        self.features = torch.cat((eye_feat, iris_feat), dim=-1)  # torch.Size([69, 1, 1024])
        weightfeatures = self.features*self.fc1(self.features)
        self.features = self.gamma*weightfeatures + self.features
        self.features = torch.flatten(self.features, 1)  # torch.Size([69, 1024])
        y = self.fc2(self.features)
        return y, self.features
    
    def get_center_loss(self, target, alpha):
        batch_size = target.size(0)
        features_dim = self.features.size(1)
        target_expand = target.view(batch_size,1).expand(batch_size,features_dim)

        centers_var = Variable(self.centers)
        centers_batch = centers_var.gather(0,target_expand).cuda()
        criterion = nn.MSELoss()
        center_loss = criterion(self.features,  centers_batch)

        diff = centers_batch - self.features
        unique_label, unique_reverse, unique_count = np.unique(target.cpu().data.numpy(), return_inverse=True, return_counts=True)
        appear_times = torch.from_numpy(unique_count).gather(0,torch.from_numpy(unique_reverse))
        appear_times_expand = appear_times.view(-1,1).expand(batch_size,features_dim).type(torch.FloatTensor)
        diff_cpu = diff.cpu().data / appear_times_expand.add(1e-6)
        diff_cpu = alpha * diff_cpu

        for i in range(batch_size):
            self.centers[target.data[i]] -= diff_cpu[i].type(self.centers.type())
        return center_loss, self.centers

In [None]:
# 加载训练集和测试集
transformer_IrisImage = transforms.Compose([transforms.Resize(224),transforms.CenterCrop(224),transforms.ToTensor()])

class MyDataset(Dataset):
    def __init__(self, filenames, labels, transform):
        self.filenames = filenames
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        image_iris = cv2.imread(self.filenames[idx], 0)
        image_iris = cv2.resize(image_iris, (512, 64))
        image_iris = enh_contrast(image_iris).astype(np.float32) / 255
        image_iris = torch.unsqueeze(torch.from_numpy(image_iris), 0).to(torch.float32)  # torch.Size([1, 64, 512])
        
        eye_filename = self.filenames[idx].replace('NormalizedIris_Class','Periocular_Class')
        eye_filename = eye_filename.replace('_imno.bmp','.tiff')
        image_eye = Image.open(eye_filename)
        image_eye = self.transform(image_eye)  # torch.Size([1, 224, 224])
        return image_eye, image_iris, self.labels[idx]

In [None]:
test_filenames = []
data_dir = 'NormalizedIris_Class'
ratio = [0.6, 0.2, 0.2]
dataset = ImageFolder(data_dir)

character = [[] for i in range(len(dataset.classes))]
for x, y in dataset.samples:
    character[y].append(x)
    character[y].sort()
    character[y].reverse()

train_inputs, val_inputs, test_inputs = [], [], []
train_labels, val_labels, test_labels = [], [], []
for i, data in enumerate(character):
    num_sample_train = int(len(data) * ratio[0])
    num_sample_val = int(len(data) * ratio[1])
    num_val_index = num_sample_train + num_sample_val

    for x in data[:num_sample_train]:
        train_inputs.append(str(x))
        train_labels.append(i)
    for x in data[num_sample_train:num_val_index]:
        val_inputs.append(str(x))
        val_labels.append(i)
    for x in data[num_val_index:]:
        test_inputs.append(str(x))
        a = str(x).split('/')
        test_filenames.append(a[2].replace('_imno.bmp','.mat'))
        test_labels.append(i)

train_data = MyDataset(train_inputs, train_labels, transformer_IrisImage)
train_dataloader = DataLoader(train_data, batch_size=89, shuffle=True)
val_data = MyDataset(val_inputs, val_labels, transformer_IrisImage)
val_dataloader = DataLoader(val_data, batch_size=89, shuffle=False)
test_data = MyDataset(test_inputs, test_labels, transformer_IrisImage)
test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False)

loader = {}
loader['train'] = train_dataloader
loader['val'] = val_dataloader
loader['test'] = test_dataloader

print(len(train_data))
print(len(train_dataloader))

In [None]:
# 网络对test的预测准确率
model = AttentionFusionNet()
model.load_state_dict(torch.load('models/ND_Eye_Iris_Recognition_Resnet18.pkl'))
model = model.cuda()
model.eval()

correct = 0
with torch.no_grad():
    for data in loader['test']:
        image_eye, image_iris, labels = data
        image_eye, image_iris, labels = image_eye.cuda(), image_iris.cuda(), labels.cuda()
        outs, fused_feature = model(image_eye, image_iris)
        predictions = torch.argmax(outs, 1)
        for i in range(len(predictions)):
            if predictions[i] == labels[i]:
                correct += 1
    print('test_accuracy: %.3f'%(correct/len(test_data)))

In [None]:
# 测试网络得到特征feat
model = AttentionFusionNet()
model.load_state_dict(torch.load('models/ND_Eye_Iris_Recognition_Resnet18.pkl'))
model = model.cuda()
model.eval()

i = 0
with torch.no_grad():
    for data in loader['test']:
        image_eye, image_iris, labels = data
        image_eye, image_iris, labels = image_eye.cuda(), image_iris.cuda(), labels.cuda()
        outs, fused_feature = model(image_eye, image_iris)
        fused_feature = fused_feature.cpu().detach().numpy()
        savemat('test/%s'%test_filenames[i],{'fused_feature':fused_feature})
        print('%d test/%s'%(i, test_filenames[i]))
        i+=1

In [None]:
# 比较不同特征之间的欧式距离
filenames = os.listdir('test/')
filenames.sort()
for i in range(len(filenames)):
    print(i,filenames[i])
    feat1 = loadmat('test/%s'%filenames[i])['fused_feature']
    hd1 = []
    hd2 = []
    for j in range(len(filenames)):
        feat2 = loadmat('test/%s'%filenames[j])['fused_feature']
        dist = np.sqrt(np.sum(np.square(np.subtract(feat1, feat2))))
        if filenames[i][0:6] == filenames[j][0:6]:
            hd1.append(dist)
        else:
            hd2.append(dist)
    savemat('CASIAV3_test_same/%s'%filenames[i],{'hd1':hd1})
    savemat('CASIAV3_test_diff/%s'%filenames[i],{'hd2':hd2})

In [None]:
# max-min normalization
filenames = os.listdir('CASIAV3_test_same/')
filenames.sort()
feats1=[]
feats2=[]
for i in range(len(filenames)):
    intra_class = loadmat('CASIAV3_test_same/%s'%filenames[i])['hd1']
    inter_class = loadmat('CASIAV3_test_diff/%s'%filenames[i])['hd2']
    total = np.hstack((intra_class, inter_class))
    intra_class = (intra_class - np.min(total)) / (np.max(total) - np.min(total))
    inter_class = (inter_class - np.min(total)) / (np.max(total) - np.min(total))
    for dd in intra_class:
        for j in dd:
            feats1.append(j)
    for dd in inter_class:
        for j in dd:
            feats2.append(j)
savemat('CASIAV3_test_same.mat',{'feats1':feats1})
savemat('CASIAV3_test_diff.mat',{'feats2':feats2})