In [1]:
from torchvision.datasets import MNIST
from torchvision import transforms

mean, std = 0.1307, 0.3081

train_dataset = MNIST('../data/MNIST', train=True, download=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((mean,), (std,))
                             ]))
test_dataset = MNIST('../data/MNIST', train=False, download=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((mean,), (std,))
                            ]))
n_classes = 10

In [93]:
train_dataset.train_labels



tensor([5, 0, 4,  ..., 5, 6, 8])

In [3]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import torch
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.autograd import Variable

from trainer import fit
import numpy as np
cuda = torch.cuda.is_available()

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler

import torch.nn as nn
import torch.nn.functional as F

from itertools import combinations

import numpy as np
import torch


VIS_df = pd.read_csv('VIS_features.csv')
NIR_df = pd.read_csv('NIR_features.csv')

In [4]:
VIS_df.drop(['Unnamed: 0'], axis=1, inplace=True)
NIR_df.drop(['Unnamed: 0'], axis=1, inplace=True)

In [5]:
from torch.utils.data import Dataset,DataLoader,Subset
import torch

class CustomDataset(Dataset):        

    def __init__(self, dataset_file, transform=None):
        super().__init__()
        dataset = dataset_file
        self.labels_frame = np.array(pd.DataFrame(dataset['class']), dtype=np.float32).squeeze(1)
        self.features_frame = np.array(dataset.drop(['class'], axis=1), dtype=np.float32)
        self.transform = transform

    def __len__(self):
        return len(self.features_frame)

    def __getitem__(self, idx):
        features = self.features_frame[idx]
        label = self.labels_frame[idx]

        if self.transform:
            features = self.transform(features)

        return features, label


In [6]:
train_size = 0.8
test_size = 1 - train_size
VIS_train_dataset, VIS_test_dataset = train_test_split(VIS_df, test_size=test_size)
NIR_train_dataset, NIR_test_dataset = train_test_split(NIR_df, test_size=test_size)
VIS_train_dataset = CustomDataset(VIS_train_dataset)
VIS_test_dataset = CustomDataset(VIS_test_dataset)
NIR_train_dataset = CustomDataset(NIR_train_dataset)
NIR_test_dataset = CustomDataset(NIR_test_dataset)



# # Create data loaders
# VIS_train_loader = DataLoader(VIS_train_dataset, batch_size=64)
# VIS_test_loader = DataLoader(VIS_test_dataset, batch_size=64)
# NIR_train_loader = DataLoader(NIR_train_dataset, batch_size=64)
# NIR_test_loader = DataLoader(NIR_test_dataset, batch_size=64)



Dataset Sampler

In [7]:
class BalancedBatchSampler(BatchSampler):
    """
    BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples.
    Returns batches of size n_classes * n_samples
    """

    def __init__(self, labels, n_classes, n_samples):
        self.labels = labels
        self.labels_set = list(set(self.labels.numpy()))
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
                                 for label in self.labels_set}
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l])
        self.used_label_indices_count = {label: 0 for label in self.labels_set}
        self.count = 0
        self.n_classes = n_classes
        self.n_samples = n_samples
        self.n_dataset = len(self.labels)
        self.batch_size = self.n_samples * self.n_classes

    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < self.n_dataset:
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
            indices = []
            for class_ in classes:
                indices.extend(self.label_to_indices[class_][
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples])
                self.used_label_indices_count[class_] += self.n_samples
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            yield indices
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return self.n_dataset // self.batch_size
    
#print(type(VIS_dataset))
VIS_train_batch_sampler = BalancedBatchSampler(torch.tensor(VIS_train_dataset.labels_frame), n_classes=16, n_samples=4)
VIS_test_batch_sampler = BalancedBatchSampler(torch.tensor(VIS_test_dataset.labels_frame), n_classes=16, n_samples=4)
NIR_train_batch_sampler = BalancedBatchSampler(torch.tensor(NIR_train_dataset.labels_frame), n_classes=16, n_samples=4)
NIR_test_batch_sampler = BalancedBatchSampler(torch.tensor(NIR_test_dataset.labels_frame), n_classes=16, n_samples=4)

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
VIS_train_loader = torch.utils.data.DataLoader(VIS_train_dataset, batch_sampler=VIS_train_batch_sampler, **kwargs)
VIS_test_loader = torch.utils.data.DataLoader(VIS_test_dataset, batch_sampler=VIS_test_batch_sampler, **kwargs)

NIR_train_loader = torch.utils.data.DataLoader(NIR_train_dataset, batch_sampler=NIR_train_batch_sampler, **kwargs)
NIR_test_loader = torch.utils.data.DataLoader(NIR_test_dataset, batch_sampler=NIR_test_batch_sampler, **kwargs)




# NIR_train_batch_sampler = BalancedBatchSampler(NIR_train_dataset.labels_frame, n_classes=16, n_samples=4)
# NIR_test_batch_sampler = BalancedBatchSampler(NIR_test_dataset.labels_frame, n_classes=16, n_samples=4)

Network

In [8]:

class EmbeddingNet(nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__()
        self.fc = nn.Sequential(nn.Linear(512, 256),
                                nn.ReLU(),
                                nn.Linear(256, 128),
                                nn.ReLU(),
                                nn.Linear(128, 64),
                                )

    def forward(self, x):
        output = self.fc(x)
        return output

    def get_embedding(self, x):
        return self.forward(x)

Defining Triplet Loss function

In [9]:
class OnlineTripletLoss(nn.Module):
    """
    Online Triplets loss
    Takes a batch of embeddings and corresponding labels.
    Triplets are generated using triplet_selector object that take embeddings and targets and return indices of
    triplets
    """

    def __init__(self, margin, triplet_selector):
        super(OnlineTripletLoss, self).__init__()
        self.margin = margin
        self.triplet_selector = triplet_selector

    def forward(self, embeddings, target):

        triplets = self.triplet_selector.get_triplets(embeddings, target)

        if embeddings.is_cuda:
            triplets = triplets.cuda()

        ap_distances = (embeddings[triplets[:, 0]] - embeddings[triplets[:, 1]]).pow(2).sum(1)  # .pow(.5)
        an_distances = (embeddings[triplets[:, 0]] - embeddings[triplets[:, 2]]).pow(2).sum(1)  # .pow(.5)
        losses = F.relu(ap_distances - an_distances + self.margin)

        return losses.mean(), len(triplets)

Hard Mining Strategy

In [10]:
def pdist(vectors):
    distance_matrix = -2 * vectors.mm(torch.t(vectors)) + vectors.pow(2).sum(dim=1).view(1, -1) + vectors.pow(2).sum(
        dim=1).view(-1, 1)
    return distance_matrix

class TripletSelector:
    def __init__(self):
        pass

    def get_triplets(self, embeddings, labels):
        raise NotImplementedError


class AllTripletSelector(TripletSelector):
    def __init__(self):
        super(AllTripletSelector, self).__init__()

    def get_triplets(self, embeddings, labels):
        labels = labels.cpu().data.numpy()
        triplets = []
        for label in set(labels):
            label_mask = (labels == label)
            label_indices = np.where(label_mask)[0]
            if len(label_indices) < 2:
                continue
            negative_indices = np.where(np.logical_not(label_mask))[0]
            anchor_positives = list(combinations(label_indices, 2))  # All anchor-positive pairs

            # Add all negatives for all positive pairs
            temp_triplets = [[anchor_positive[0], anchor_positive[1], neg_ind] for anchor_positive in anchor_positives
                             for neg_ind in negative_indices]
            triplets += temp_triplets

        return torch.LongTensor(np.array(triplets))


def hardest_negative(loss_values):
    hard_negative = np.argmax(loss_values)
    return hard_negative if loss_values[hard_negative] > 0 else None


def random_hard_negative(loss_values):
    hard_negatives = np.where(loss_values > 0)[0]
    return np.random.choice(hard_negatives) if len(hard_negatives) > 0 else None


def semihard_negative(loss_values, margin):
    semihard_negatives = np.where(np.logical_and(loss_values < margin, loss_values > 0))[0]
    return np.random.choice(semihard_negatives) if len(semihard_negatives) > 0 else None


class FunctionNegativeTripletSelector(TripletSelector):
    def __init__(self, margin, negative_selection_fn, cpu=True):
        super(FunctionNegativeTripletSelector, self).__init__()
        self.cpu = cpu
        self.margin = margin
        self.negative_selection_fn = negative_selection_fn

    def get_triplets(self, embeddings, labels):
        if self.cpu:
            embeddings = embeddings.cpu()
        distance_matrix = pdist(embeddings)
        distance_matrix = distance_matrix.cpu()

        labels = labels.cpu().data.numpy()
        triplets = []

        for label in set(labels):
            label_mask = (labels == label)
            label_indices = np.where(label_mask)[0]
            if len(label_indices) < 2:
                continue
            negative_indices = np.where(np.logical_not(label_mask))[0]
            anchor_positives = list(combinations(label_indices, 2))  # All anchor-positive pairs
            anchor_positives = np.array(anchor_positives)

            ap_distances = distance_matrix[anchor_positives[:, 0], anchor_positives[:, 1]]
            for anchor_positive, ap_distance in zip(anchor_positives, ap_distances):
                loss_values = ap_distance - distance_matrix[torch.LongTensor(np.array([anchor_positive[0]])), torch.LongTensor(negative_indices)] + self.margin
                loss_values = loss_values.data.cpu().numpy()
                hard_negative = self.negative_selection_fn(loss_values)
                if hard_negative is not None:
                    hard_negative = negative_indices[hard_negative]
                    triplets.append([anchor_positive[0], anchor_positive[1], hard_negative])

        if len(triplets) == 0:
            triplets.append([anchor_positive[0], anchor_positive[1], negative_indices[0]])

        triplets = np.array(triplets)

        return torch.LongTensor(triplets)


def HardestNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin,
                                                                                 negative_selection_fn=hardest_negative,
                                                                                 cpu=cpu)


def RandomNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin,
                                                                                negative_selection_fn=random_hard_negative,
                                                                                cpu=cpu)


def SemihardNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin,
                                                                                  negative_selection_fn=lambda x: semihard_negative(x, margin),
                                                                                  cpu=cpu)

Training the model

In [11]:


from metrics import AverageNonzeroTripletsMetric

margin = 1.
embedding_net = EmbeddingNet()
model = embedding_net
if cuda:
    model.cuda()
loss_fn = OnlineTripletLoss(margin,RandomNegativeTripletSelector(margin))
lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)
n_epochs = 20
log_interval = 50

fit(NIR_train_loader, NIR_test_loader, model, loss_fn, optimizer, scheduler, n_epochs, cuda, log_interval, metrics=[AverageNonzeroTripletsMetric()])




Epoch: 1/20. Train set: Average loss: 0.8928	Average nonzero triplets: 95.6086956521739
Epoch: 1/20. Validation set: Average loss: 0.8422	Average nonzero triplets: 22.4
Epoch: 2/20. Train set: Average loss: 1.0016	Average nonzero triplets: 90.8695652173913
Epoch: 2/20. Validation set: Average loss: 0.7712	Average nonzero triplets: 19.4
Epoch: 3/20. Train set: Average loss: 0.9046	Average nonzero triplets: 87.0
Epoch: 3/20. Validation set: Average loss: 0.7707	Average nonzero triplets: 16.4
Epoch: 4/20. Train set: Average loss: 1.0144	Average nonzero triplets: 83.26086956521739
Epoch: 4/20. Validation set: Average loss: 0.5914	Average nonzero triplets: 19.4
Epoch: 5/20. Train set: Average loss: 0.9990	Average nonzero triplets: 79.30434782608695
Epoch: 5/20. Validation set: Average loss: 0.6605	Average nonzero triplets: 10.8
Epoch: 6/20. Train set: Average loss: 0.9605	Average nonzero triplets: 80.26086956521739
Epoch: 6/20. Validation set: Average loss: 1.0014	Average nonzero triplets: 

In [20]:
class EmbeddingNetmnist(nn.Module):
    def __init__(self):
        super(EmbeddingNetmnist, self).__init__()
        self.convnet = nn.Sequential(nn.Conv2d(1, 32, 5), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2),
                                     nn.Conv2d(32, 64, 5), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2))

        self.fc = nn.Sequential(nn.Linear(64 * 4 * 4, 256),
                                nn.PReLU(),
                                nn.Linear(256, 256),
                                nn.PReLU(),
                                nn.Linear(256, 2)
                                )

    def forward(self, x):
        output = self.convnet(x)
        output = output.view(output.size()[0], -1)
        output = self.fc(output)
        return output

    def get_embedding(self, x):
        return self.forward(x)

In [21]:

train_batch_sampler = BalancedBatchSampler(train_dataset.train_labels, n_classes=10, n_samples=25)
test_batch_sampler = BalancedBatchSampler(test_dataset.test_labels, n_classes=10, n_samples=25)

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
online_train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=train_batch_sampler, **kwargs)
online_test_loader = torch.utils.data.DataLoader(test_dataset, batch_sampler=test_batch_sampler, **kwargs)

# Set up the network and training parameters

from metrics import AverageNonzeroTripletsMetric

margin = 1.
embedding_net = EmbeddingNetmnist()
model = embedding_net
if cuda:
    model.cuda()
loss_fn = OnlineTripletLoss(margin, SemihardNegativeTripletSelector(margin))
lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)
n_epochs = 20
log_interval = 50



In [23]:
fit(online_train_loader, online_test_loader, model, loss_fn, optimizer, scheduler, n_epochs, cuda, log_interval, metrics=[AverageNonzeroTripletsMetric()])

Epoch: 1/20. Train set: Average loss: 0.4940	Average nonzero triplets: 429.3389121338912
Epoch: 1/20. Validation set: Average loss: 0.4844	Average nonzero triplets: 200.82051282051282
Epoch: 2/20. Train set: Average loss: 0.5000	Average nonzero triplets: 118.14644351464435
Epoch: 2/20. Validation set: Average loss: 0.4895	Average nonzero triplets: 49.666666666666664
Epoch: 3/20. Train set: Average loss: 0.5012	Average nonzero triplets: 22.556485355648537
Epoch: 3/20. Validation set: Average loss: 0.4663	Average nonzero triplets: 9.538461538461538
Epoch: 4/20. Train set: Average loss: 0.4900	Average nonzero triplets: 6.2887029288702925
Epoch: 4/20. Validation set: Average loss: 290.7225	Average nonzero triplets: 1.794871794871795
Epoch: 5/20. Train set: Average loss: 1026.7666	Average nonzero triplets: 2.1757322175732217
Epoch: 5/20. Validation set: Average loss: 0.3323	Average nonzero triplets: 1.4615384615384615
Epoch: 6/20. Train set: Average loss: 0.3470	Average nonzero triplets: 1.

In [230]:
from sklearn.metrics.pairwise import cosine_similarity
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
    transforms.ToTensor(),
])
def get_embeddings(model, image_loader):
    model.eval()
    embeddings = []
    with torch.no_grad():
        for images, _ in image_loader:
            images = images.to(device)
            outputs = model(images)
            embeddings.append(outputs.cpu().numpy())
    embeddings = np.vstack(embeddings)
    return embeddings

def calculate_cosine_similarity(embeddings1, embeddings2):
    return cosine_similarity(embeddings1, embeddings2)

# Define a function to check if similarity is within a threshold
def check_similarity(similarity, threshold):
    return similarity >= threshold




test_embeddings = get_embeddings(model, test_loader)

# Take two images from test dataset
image1, label1 = test_dataset[67]
image2, label2 = test_dataset[575]

print(image1.shape)


# Get embeddings for the two images
image1_embedding = model((image1.to(device).unsqueeze(0))).cpu().detach().numpy()
image2_embedding = model((image2.to(device).unsqueeze(0))).cpu().detach().numpy()

# Calculate cosine similarity
similarity = calculate_cosine_similarity(image1_embedding, image2_embedding)

# Define threshold
print(similarity)
print(label1,label2)
# Check if similarity is within threshold
if check_similarity(similarity, threshold=0.1):
    print("Genuine pair")
else:
    print("Fraud pair")

torch.Size([1, 28, 28])
[[1.]]
4 9
Genuine pair
