In [1]:
import torch
import torchvision
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import torchvision.utils
import numpy as np
import random
from PIL import Image
import PIL.ImageOps 

In [2]:
def show_img(img0,title=None):
    img=img0.numpy()
    plt.axis("off")
    if title:
        plt.text(75, 8, text, style='italic',fontweight='bold',bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    plt.imshow(np.transpose(img,(1,2,0)))
    plt.show()
                                                                     
def plot_loss(iteration,loss):
    plt.plot(iteration,loss)        
    plt.show()

# 定义dataset和dataloader

In [3]:
class SiameseNetworkDataset(Dataset):
    def __init__(self,imageFolderDataset,transform=None,should_invert=True):
        self.imageFolderDataset=imageFolderDataset
        self.transform=transform
        self.should_invert=should_invert
        
    def __getitem__(self,index):  #获取图片1、图片2、标签
        #锚样本
        img1_folder=random.choice(self.imageFolderDataset.imgs)
        #正例
        while True:
            img2_folder=random.choice(self.imageFolderDataset.imgs)
            if img1_folder[1]==img2_folder[1]:
                break
        #负例
        while True:
            img3_folder=random.choice(self.imageFolderDataset.imgs)
            if img1_folder[1]!=img3_folder[1]:
                break
                    
        #装换成灰度图            
        img1=Image.open(img1_folder[0])
        img2=Image.open(img2_folder[0])
        img3=Image.open(img3_folder[0])

        img1=img1.convert('L')
        img2=img2.convert('L')
        img3=img3.convert('L')

        
        if self.should_invert:
            img1=Image.ImageOps.invert(img1)
            img2=Image.ImageOps.invert(img2) 
            img3=Image.ImageOps.invert(img3) 

            
        if self.transform is not None:
            img1=self.transform(img1)
            img2=self.transform(img2) 
            img3=self.transform(img3)  
                               
        return img1,img2,img3
    
    def __len__(self):
        return len(self.imageFolderDataset.imgs)
    
#定义训练集地址
train_dir='D:/数据库/palmdata/iitd'
batch_size=5

folder_dataset = torchvision.datasets.ImageFolder(root=train_dir)

transform=transforms.Compose([transforms.Resize((100,100)),transforms.ToTensor()])

siamese_dataset=SiameseNetworkDataset(imageFolderDataset=folder_dataset,transform=transform,should_invert=False)

train_dataloader=DataLoader(siamese_dataset,shuffle=True,batch_size=batch_size)

# 可视化几个训练图片

In [4]:
visual_dataloader=DataLoader(siamese_dataset,shuffle=True,batch_size=6)
x1,x2,x3=next(iter(visual_dataloader))
concatenated = torch.cat((x1,x2,x3),0) 
show_img(torchvision.utils.make_grid(concatenated, nrow=6))

AttributeError: 'int' object has no attribute 'read'

# 构建模型

In [None]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn=nn.Sequential(
            nn.ReflectionPad2d(1),   #padding
            nn.Conv2d(1, 4, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(4),
            
            nn.ReflectionPad2d(1),
            nn.Conv2d(4, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),

            nn.ReflectionPad2d(1),
            nn.Conv2d(8, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),
        )
        self.fc=nn.Sequential(
            nn.Linear(8*100*100, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 5)
        )
    def output(self,x):
        output=self.cnn(x)
        output=output.view(output.size()[0],-1)
        output=self.fc(output)
        return output
    def forward(self,x1,x2,x3):
        output1=self.output(x1)
        output2=self.output(x2) 
        output3=self.output(x3) 

        return output1,output2,output3

# 训练模型

In [None]:
net = SiameseNetwork().cuda() #定义模型且移至GPU
criterion = nn.TripletMarginLoss(margin=1.0, p=2) #定义损失函数
optimizer = optim.Adam(net.parameters(), lr = 0.0005) #定义优化器

epochs=5
counter = []
loss_history = [] 
iteration_number = 0


#开始训练
for epoch in range(0, epochs):
    for i, data in enumerate(train_dataloader, 0):
        img1, img2 , img3 = data
        #img0维度为torch.Size([32, 1, 100, 100])，32是batch，label为torch.Size([32, 1])
        img1, img2 , img3 = img1.cuda(), img2.cuda(), img3.cuda() #数据移至GPU
        optimizer.zero_grad()
        output1,output2,output3 = net(img1, img2,img3)
        loss_triplet = criterion(output1, output2, output3)
        loss_triplet.backward()
        optimizer.step()
        if i % 10 == 0 :
            iteration_number +=10
            counter.append(iteration_number)
            loss_history.append(loss_triplet.item())
    print("Epoch number: {} , Current loss: {:.4f}\n".format(epoch,loss_triplet.item()))

plot_loss(counter, loss_history)

# 测试模型

In [None]:
losses=[]
for i, data in enumerate(train_dataloader, 0):
    img1, img2 ,img3 = data
    #img0维度为torch.Size([32, 1, 100, 100])，32是batch，label为torch.Size([32, 1])
    img1, img2 , img3 = img1.cuda(), img2.cuda(), img3.cuda() #数据移至GPU
    output1,output2,output3=net(img1,img2,img3)
    loss_ = criterion(output1,output2,output3)
    if i % 10 == 0 :
            losses.append(loss_.item())
print(sum(losses)/len(losses))

In [None]:
test_dataloader=DataLoader(siamese_dataset,shuffle=True,batch_size=6)
x1,x2,x3=next(iter(test_dataloader))
concatenated = torch.cat((x1,x2,x3),0) 
show_img(torchvision.utils.make_grid(concatenated, nrow=6))

      
print('pre')
x1,x2,x3=x1.cuda(),x2.cuda(),x3.cuda()
print(x1.shape)
out1,out2,out3 = net.forward(x1,x2,x3)
print(out1.shape)
# oushi1=torch.dist(out1, out2, p=2)
# oushi2=torch.dist(out1, out3, p=2)
oushi1=torch.sqrt(((out1-out2)**2))
oushi2=torch.sqrt(((out1-out3)**2))

for i in range(6):
    d1=sum(oushi1[i])/5
    d2=sum(oushi2[i])/5
    print('p:',d1)
    print('n:',d2)    