# Siamese Networks

Example: Classifying MNIST Images Using A Siamese Network In PyTorch

It's from theb blog of https://becominghuman.ai/siamese-networks-algorithm-applications-and-pytorch-implementation-4ffa3304c18.

But there are several typos in the original blog. Have modifies it to run successfully


In [1]:
import codecs
import errno
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
import random
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import torchvision.datasets.mnist
from torchvision import transforms
from tqdm import tqdm

do_learn = True
save_frequency = 2
batch_size = 16
lr = 0.001
num_epochs = 10
weight_decay = 0.0001


In [2]:
def get_int(b):
    return int(codecs.encode(b, 'hex'), 16)

def read_label_file(path):
    with open(path, 'rb') as f:
        data = f.read()
    assert get_int(data[:4]) == 2049
    length = get_int(data[4:8])
    parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
    return torch.from_numpy(parsed).view(length).long()

def read_image_file(path):
    with open(path, 'rb') as f:
        data = f.read()
    assert get_int(data[:4]) == 2051
    length = get_int(data[4:8])
    num_rows = get_int(data[8:12])
    num_cols = get_int(data[12:16])
    images = []
    parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
    return torch.from_numpy(parsed).view(length, num_rows, num_cols)


Lesson learned:
- Assert. The good coder always add the assert statement.

The most important thing is **implementing the siamese network is to build the images pair**.

This class use the `torchvision.dataset.mnist.py` as the reference, **the key** is to rewrite the function of getting the image pair.

In [56]:
# processed_folder = 'processed'
# training_file = 'training.pt'
# root='.'
# data_train_data,data_train_label=torch.load(os.path.join(root,processed_folder,training_file))

In [79]:
# train_labels_class = []
# for i in range(10):
#     indeces=(data_train_label==i).nonzero().squeeze()
#     train_labels_class.append(torch.index_select(data_train_label,0,indeces))
# train_labels_class

In [3]:
class BalancedMNISTPair(torch.utils.data.Dataset):
    """
    Dataset that on each iteration provides two random pairs of
    MNIST images. 
    
    One pair is of the same number (positive sample), oneis of 
    two different numbers (negative sample).
    
    The code uses the torchvision.dataset as the reference.
    """
    urls = [
      'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
      'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
      'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
      'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
    ]
    raw_folder = 'raw'
    processed_folder = 'processed'
    training_file = 'training.pt'
    test_file = 'test.pt'
   
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
            os.path.join(self.root, self.processed_folder, self.training_file))

            train_labels_class = []
            train_data_class = []
            for i in range(10):
#                 indices = torch.squeeze((self.train_labels == i).nonzero())  # ? maybe it's wrong
                indices=(self.train_labels==i).nonzero().squeeze()
                train_labels_class.append(torch.index_select(self.train_labels, 0, indices))
                train_data_class.append(torch.index_select(self.train_data, 0, indices))

            # generate balanced pairs
            self.train_data = []
            self.train_labels = []
            lengths = [x.shape[0] for x in train_labels_class]  # store the amount of each number image
            for i in range(10):
                for j in range(500): # create 500 pairs
                    rnd_cls = random.randint(0,8) # choose random class that is not the same class
                    if rnd_cls >= i:
                        rnd_cls = rnd_cls + 1
                    rnd_dist = random.randint(0, 100)
                    # why there are three images in the input? (base image, same class image, non-same class image)?
                    self.train_data.append(torch.stack([train_data_class[i][j], train_data_class[i][j+rnd_dist], train_data_class[rnd_cls][j]]))
                    self.train_labels.append([1,0])

            self.train_data = torch.stack(self.train_data)
            self.train_labels = torch.tensor(self.train_labels)

        else:
            self.test_data, self.test_labels = torch.load(
              os.path.join(self.root, self.processed_folder, self.test_file))

            test_labels_class = []
            test_data_class = []
            for i in range(10):
                indices = torch.squeeze((self.test_labels == i).nonzero())
                test_labels_class.append(torch.index_select(self.test_labels, 0, indices))
                test_data_class.append(torch.index_select(self.test_data, 0, indices))

            # generate balanced pairs
            self.test_data = []
            self.test_labels = []
            lengths = [x.shape[0] for x in test_labels_class]
            for i in range(10):
                for j in range(500): # create 500 pairs
                    rnd_cls = random.randint(0,8) # choose random class that is not the same class
                    if rnd_cls >= i:
                        rnd_cls = rnd_cls + 1

                rnd_dist = random.randint(0, 100)

                self.test_data.append(torch.stack([test_data_class[i][j], test_data_class[i][j+rnd_dist], test_data_class[rnd_cls][j]]))
                self.test_labels.append([1,0])

            self.test_data = torch.stack(self.test_data)
            self.test_labels = torch.tensor(self.test_labels)

    def __getitem__(self, index):
        if self.train:
            imgs, target = self.train_data[index], self.train_labels[index]
#             print(len(imgs))  # the output is 3
        else:
            imgs, target = self.test_data[index], self.test_labels[index]

        img_ar = []
        for i in range(len(imgs)):
            img = Image.fromarray(imgs[i].numpy(), mode='L')
            if self.transform is not None:
                img = self.transform(img)
            img_ar.append(img)

        if self.target_transform is not None:
            target = self.target_transform(target)
        return img_ar, target
   
    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)
      
    def _check_exists(self):
        return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
         os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
   
    def download(self):
        """
        Download the MNIST data if it doesn't exist in processed_folder already.
        """
        from six.moves import urllib
        import gzip

        if self._check_exists():
            return

        # download files
        # In torchvision, this is from util.py
        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
            os.makedirs(os.path.join(self.root, self.processed_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
        else:
            raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            with open(file_path.replace('.gz', ''), 'wb') as out_f, \
                gzip.GzipFile(file_path) as zip_f:
                out_f.write(zip_f.read())
                os.unlink(file_path)

        # process and save as torch files
        print('Processing...')

        training_set = (
         read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
         read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
        )
        test_set = (
         read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
         read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
        )
        with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
            torch.save(training_set, f)
        with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
            torch.save(test_set, f)

        print('Done!')

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        tmp = 'train' if self.train is True else 'test'
        fmt_str += '    Split: {}\n'.format(tmp)
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str


Build the Siamese network

In [4]:
class Net(nn.Module):
    '''
    A simple network is built for this task
    '''
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 64, 7)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(64, 128, 5)
        self.conv3 = nn.Conv2d(128, 256, 5)
        self.linear1 = nn.Linear(2304, 512)

        self.linear2 = nn.Linear(512, 2)

    def forward(self, data):
        res = []
#         print(len(data))
        for i in range(2): # Siamese nets; sharing weights
            x = data[i]
            x = self.conv1(x)
            x = F.relu(x)
            x = self.pool1(x)
            x = self.conv2(x)
            x = F.relu(x)
            x = self.conv3(x)
            x = F.relu(x)

            x = x.view(x.shape[0], -1)
            x = self.linear1(x)
            res.append(F.relu(x))

        res = torch.abs(res[1] - res[0])
        res = self.linear2(res) # calculate the residual
        return res


In [5]:
siamese_net=Net()
siamese_net

Net(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1))
  (conv3): Conv2d(128, 256, kernel_size=(5, 5), stride=(1, 1))
  (linear1): Linear(in_features=2304, out_features=512, bias=True)
  (linear2): Linear(in_features=512, out_features=2, bias=True)
)

Set the functino of training and testing.

In [6]:
def train(model, device, train_loader, epoch, optimizer):
    model.train()
   
    for batch_idx, (data, target) in enumerate(train_loader):
        for i in range(len(data)):   # len(data) == 1
            data[i] = data[i].to(device)
         
        optimizer.zero_grad()
        output_positive = model(data[:2])  # first two images
        output_negative = model(data[0:3:2]) # the first and third image
      
        target = target.type(torch.LongTensor).to(device)
        target_positive = torch.squeeze(target[:,0])
        target_negative = torch.squeeze(target[:,1])
       
        loss_positive = F.cross_entropy(output_positive, target_positive)
        loss_negative = F.cross_entropy(output_negative, target_negative)
      
        loss = loss_positive + loss_negative
        loss.backward()
      
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx*batch_size, len(train_loader.dataset), 100. * batch_idx*batch_size / len(train_loader.dataset),
            loss.item()))

def test(model, device, test_loader):
     model.eval()
   
     with torch.no_grad():
        accurate_labels = 0
        all_labels = 0
        loss = 0
        for batch_idx, (data, target) in enumerate(test_loader):
            for i in range(len(data)):
                data[i] = data[i].to(device)
            
            output_positive = model(data[:2])
            output_negative = model(data[0:3:2])
            
            target = target.type(torch.LongTensor).to(device)
            target_positive = torch.squeeze(target[:,0])
            target_negative = torch.squeeze(target[:,1])
            
            loss_positive = F.cross_entropy(output_positive, target_positive)
            loss_negative = F.cross_entropy(output_negative, target_negative)
            
            loss = loss + loss_positive + loss_negative
            
            accurate_labels_positive = torch.sum(torch.argmax(output_positive, dim=1) == target_positive).cpu()
            accurate_labels_negative = torch.sum(torch.argmax(output_negative, dim=1) == target_negative).cpu()
            
            accurate_labels = accurate_labels + accurate_labels_positive + accurate_labels_negative
            all_labels = all_labels + len(target_positive) + len(target_negative)
      
        accuracy = 100. * accurate_labels / all_labels
        print('Test accuracy: {}/{} ({:.3f}%)\tLoss: {:.6f}'.format(accurate_labels, all_labels, accuracy, loss))


In [7]:
def oneshot(model, device, data):
    model.eval()

    with torch.no_grad():
        for i in range(len(data)):
            data[i] = data[i].to(device)
      
        output = model(data)
        return torch.squeeze(torch.argmax(output, dim=1)).cpu().item()


In [29]:
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
   
    model = Net().to(device)
   
    if do_learn: # training mode
        train_loader = torch.utils.data.DataLoader(BalancedMNISTPair('.', train=True, download=True, transform=trans), batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(BalancedMNISTPair('.', train=False, download=True, transform=trans), batch_size=batch_size, shuffle=False)
      
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        for epoch in range(num_epochs):
            train(model, device, train_loader, epoch, optimizer)
            test(model, device, test_loader)
            if epoch & save_frequency == 0:  
                # & is a bit-wise operator. 4&2 ==0 since binary for 4 is 100 and binary for 2 is 10
                torch.save(model, 'siamese_{:03}.pt'.format(epoch))
                
    else: # prediction
        prediction_loader = torch.utils.data.DataLoader(BalancedMNISTPair('.', train=False, download=True, transform=trans), batch_size=1, shuffle=True)
#         model.load_state_dict(torch.load(load_model_path))
        model=torch.load(load_model_path)
        data = []
#         data.extend(next(iter(prediction_loader))[0][:3:2]) # this line tests different number
        data.extend(next(iter(prediction_loader))[0][:2]) # this line tests the same number 
        same = oneshot(model, device, data)
        if same > 0:
            print('These two images are of the same number')
        else:
            print('These two images are not of the same number')

Something wrong with the original code. 

To save and load model, there are two approaches:
- save `model.state_dict` 

```python
# save
torch.save(model.state_dict(), PATH)

# load
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

```

- save entire model

```python
# save entire model
torch.save(model,PATH)

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
```

## To debug

There is something wrong with the `dataloader`. 

What's the return of `dataloader`?

数据集的构造的时候出了问题。取每个数据的时候似乎没有取到三个图片。

In [9]:
main()

Test accuracy: 19/20 (95.000%)	Loss: 0.457100


  "type " + obj.__name__ + ". It won't be checked "


Test accuracy: 19/20 (95.000%)	Loss: 0.204832
Test accuracy: 19/20 (95.000%)	Loss: 0.208629
Test accuracy: 20/20 (100.000%)	Loss: 0.045101
Test accuracy: 20/20 (100.000%)	Loss: 0.059529
Test accuracy: 20/20 (100.000%)	Loss: 0.008774


Test accuracy: 19/20 (95.000%)	Loss: 0.183692
Test accuracy: 19/20 (95.000%)	Loss: 0.172109
Test accuracy: 20/20 (100.000%)	Loss: 0.015457
Test accuracy: 20/20 (100.000%)	Loss: 0.000724


In [28]:
do_learn=False
load_model_path='./siamese_009.pt'
main()

These two images are of the same number


## Reference
1. Siamese Networks: Algorithm, Applications And PyTorch Implementation https://becominghuman.ai/siamese-networks-algorithm-applications-and-pytorch-implementation-4ffa3304c18

