In [1]:
import numpy as np
import os
import struct

def read_idx(filename):
    with open(filename, 'rb') as f:
        zero, data_type, dims = struct.unpack('>HBB', f.read(4))
        shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(shape)

def load_mnist(image_path, label_path):
    images = read_idx(image_path)
    labels = read_idx(label_path)
    return images, labels

train_image_path = './MNIST/train-images-idx3-ubyte/train-images-idx3-ubyte'
train_label_path = './MNIST/train-labels-idx1-ubyte/train-labels-idx1-ubyte'
test_image_path =  './MNIST/t10k-images-idx3-ubyte/t10k-images-idx3-ubyte'
test_label_path =  './MNIST/t10k-labels-idx1-ubyte/t10k-labels-idx1-ubyte'

In [9]:
train_images, train_labels = load_mnist(train_image_path, train_label_path)
test_images, test_labels = load_mnist(test_image_path, test_label_path)
val_images, val_labels = train_images[50000:], train_labels[50000:]
train_images, train_labels = train_images[:50000], train_labels[:50000]
print(f'Train images shape: {train_images.shape}')
print(f'Train labels shape: {train_labels.shape}')
print(f'Test images shape: {test_images.shape}')
print(f'Test labels shape: {test_labels.shape}')

Train images shape: (50000, 28, 28)
Train labels shape: (50000,)
Test images shape: (10000, 28, 28)
Test labels shape: (10000,)


In [11]:
import numpy as np
import os
import struct
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class MNISTDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label
    
# MNIST transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Create train, validation, and test datasets
train_dataset = MNISTDataset(train_images, train_labels, transform=transform)
val_dataset = MNISTDataset(val_images, val_labels, transform=transform)
test_dataset = MNISTDataset(test_images, test_labels, transform=transform)

In [12]:
# Create the MNISTAdditionDataset
class MNISTAdditionDataset(Dataset):
    stack = False
    def __init__(self, dataset):
        self.dataset = dataset
        self.transform = transform
        self.index_dataset2 = np.random.randint(0, len(self.dataset), len(self.dataset))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        image = self.transform(image)

        image2, label2 = self.dataset[self.index_dataset2[idx]]
        image2 = self.transform(image2)

        if self.stack:
            image = torch.stack((image, image2), dim=0)
        else:
            image = torch.cat((image, image2), dim=-1)

        c_label = torch.zeros(20)
        c_label[label] = 1
        c_label[label2 + 10] = 1

        y_label = label + label2

        return image, c_label, y_label
    
    
# Create the combined datasets
train_addition_dataset = MNISTAdditionDataset(train_dataset)
val_addition_dataset = MNISTAdditionDataset(val_dataset)
test_addition_dataset = MNISTAdditionDataset(test_dataset)

In [2]:
import torchvision
import torch.nn as nn

class ResNet(nn.Module):
    def __init__(self,
                model_name: str = 'resnet18',
                pretrained: bool = True,
                layers_to_freeze: int = -1,
                layers_to_crop: list = [4],
                ):
        super().__init__()
        self.model_name = model_name.lower()
        self.layers_to_freeze = layers_to_freeze
        
        if pretrained:
            # the new naming of pretrained weights, you can change to V2 if desired.
            weights = 'IMAGENET1K_V1'
        else:
            weights = None
            
        self.model = torchvision.models.resnet18(weights=weights)
        self.model.avgpool = None
        self.model.fc = None
        
        if 4 in layers_to_crop:
            self.model.layer4 = None
        if 3 in layers_to_crop:
            self.model.layer3 = None
            
        out_channels = 2048
        if '34' in model_name or '18' in model_name:
            out_channels = 512

        self.out_channels = out_channels // 2 if self.model.layer4 is None else out_channels
        # in our case the output is 256
        self.out_channels = self.out_channels // 2 if self.model.layer3 is None else self.out_channels

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        if self.model.layer3 is not None:
            x = self.model.layer3(x)
        if self.model.layer4 is not None:
            x = self.model.layer4(x)
        return x

In [3]:
import torch.nn.functional as F

class GeMPool(nn.Module):
    """Implementation of GeM as in https://github.com/filipradenovic/cnnimageretrieval-pytorch
    we add flatten and norm so that we can use it as one aggregation layer.
    """

    def __init__(self, p=3, eps=1e-6):
        """
        Initializes the GeM pooling layer with the given parameters.

        Parameters:
            p (float, optional): The exponent value for the power operation in the GeM pooling. Defaults to 3.
            eps (float, optional): The small constant added to the input tensor to avoid division by zero. Defaults to 1e-6.

        Explanation:
            p is a trainable parameter, for default value of 3, p = torch.tensor([3.0], requires_grad=True)
            eps is a small constant used to avoid division by zero (numerical stability).
        """

        super().__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        """
        Performs the forward pass of the GeM pooling layer.

        Parameters:
            x (torch.Tensor): The input tensor to be processed.

        Returns:
            torch.Tensor: The normalized output tensor after applying the GeM pooling operation.

        Explanation:
            x.clamp(min=self.eps) is used to ensure all elements are at least eps.
            x.clamp(min=self.eps).pow(self.p) is used to apply the power operation to the input tensor.
            F.avg_pool2d(...): Applies 2D average pooling over the input. 
            The size for pooling is set to the height and width of the input feature map 
            (x.size(-2) and x.size(-1)), effectively taking the average over all spatial dimensions.
            .pow(1./self.p): Applies the inverse of the power applied earlier, effectively calculating the p-th root of the average of the elevated pixel values, thus computing the generalized mean.

            x.flatten(1): Flattens the pooled tensor starting from the first dimension, converting it into a 1D vector per example in the batch.
            F.normalize(x, p=2, dim=1): Normalizes these vectors to have unit norm, which is a common practice in retrieval systems to measure similarity using cosine distance.
            - p is the degree of the norm, which is set to 2 for L2 normalization.
            - dim=1 specifies the dimension along which the normalization is applied, which is the batch dimension in this case.
        """
        x = F.avg_pool2d(x.clamp(min=self.eps).pow(self.p),
                         (x.size(-2), x.size(-1))).pow(1./self.p)
        x = x.flatten(1)
        return F.normalize(x, p=2, dim=1)


: 

In [None]:
class feature_extract(nn.Module):
    def __init__(self):
        super(feature_extract, self).__init__()
        self.resnet = ResNet()
        self.gem = GeMPool()
        
    def forward(self, x):
        x = self.resnet(x)
        x = self.gem(x)
        return x
    
    