In [16]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.nn import functional as F
from sklearn.model_selection import train_test_split
from scipy.optimize import linear_sum_assignment

## FSPOOL: a differentiable, sorting-based pooling method for variable-size sets

We are given a set of n feature vectors X = [x(1), . . . , x(n)] where each x(i) is a column vector of
dimension d placed in some arbitrary order in the columns of X ∈ Rd×n. From this, the goal is to
produce a single feature vector in a way that is invariant to permutation of the columns in the matrix.
We first sort each of the d features across the elements of the set by numerically sorting within the
rows of X to obtain the matrix of sorted features ~X:
~Xi,j = SORT(Xi,:)j (3)
where Xi,: is the ith row of X and SORT(·) sorts a vector in descending order. 

hile this may appear
strange since the columns of ~X no longer correspond to individual elements of the set, there are good
reasons for this. A transformation (such as with an MLP) prior to the pooling can ensure that the
features being sorted are mostly independent so that little information is lost by treating the features
independently. Also, if we were to sort whole elements by one feature, there would be discontinuities
whenever two elements swap order

Then, we apply a learnable weight matrix W ∈ Rd×n to ~X by elementwise multiplying and summing
over the columns (row-wise dot products).
yi =
n∑
j
Wi,j ~Xi,j (4)
y ∈ Rd is the final pooled representation of ~X. 

In [17]:
from torch.nn.parameter import Parameter


class FSPool(nn.Module):
    def __init__(self, in_channels, n_pieces, relaxed=False):
        super().__init__()
        self.n_pieces = n_pieces
        self.weight = nn.Parameter(torch.zeros(in_channels, n_pieces+1))
        self.relaxed = relaxed

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.weight)

    def forward(self, x, n=None):#FSPool
        assert x.size(1) == self.weight.size(0), "incorrect number of input channels in weight"

        if n is None:
            n = torch.full((x.size(0),),x.size(2), dtype=torch.long, device=x.device)

        sizes, mask = fill_sizes(n, x)
        mask = mask.expand_as(x)

        weight = self.determine_weight(sizes)

        x = x + (1-mask).float() * -99999

        if self.relaxed:
            x, perm = cont_sort(x, temp=self.relaxed)
        else:
            x, perm = x.sort(dim=2, descending=True)

        x = (x*weight*mask.float()).sum(dim=2)

        return x, perm
    
    def forward_transpose(self, x, perm, n=None):#FSUnpool
        if n is None:
            n = x.new(x.size(0)).fill_(perm.size(2)).long()

        sizes, mask = fill_sizes(n)
        mask = mask.expand(mask.size(0), x.size(1), mask.size(2))

        weight = self.determine_weight(sizes)

        x = x.unsqueeze(2) * weight * mask.float()

        if self.relaxed:
            x, _ = cont_sort(x, perm)
        else:
            x = x.scatter(2, perm, x)

        return x, mask
    
    def determine_weight(self, sizes):
        """
        Piecewise Linear Function. Evaluates f at ratios in sizes.
        faster as we know most terms are 0
        
        """
        weight = self.weight.unsqueeze(0)
        weight = weight.expand(sizes.size(0), weight.size(1), weight.size(2))

        index = self.n_pieces * sizes
        index = index.unsqueeze(1)
        index = index.expand(index.size(0), weight.size(1), index.size(2))

        idx = index.long()
        frac = index.frac()
        left = weight.gather(2, idx)
        right = weight.gather(2, (idx+1).clamp(max=self.n_pieces))

        return (1-frac)*left + frac*right
    


def fill_sizes(sizes, x=None):
    """
    Each set size n is turned to [0/(n-1), 1/(n-1), .... ,1, 0, 0, ... ,0]
    """

    if x is not None:
        max_size = x.size(2)
    else:
        max_size = sizes.max()

    size_tensor = sizes.new(sizes.size(0), max_size).float().fill_(-1)

    size_tensor = torch.arange(end=max_size, device=sizes.device, dtype=torch.float32)
    size_tensor = size_tensor.unsqueeze(0) / (sizes.float()-1).clamp(min=1).unsqueeze(1)

    mask = size_tensor<=1
    mask = mask.unsqueeze(1)

    return size_tensor.clamp(max=1), mask.float()

def deterministic_sort(s, tau):

    """
    s:input elements to be sorted shape: b x n x 1
    tau: temperature for relaxation: scalar
    """

    n = s.size()[1]

    one = torch.ones((n,1), dtype=torch.float32, device=s.device)
    A_s = torch.abs(s - s.permute(0, 2, 1))
    B = torch.matmul(A_s, torch.matmul(one, one.transpose(0 , 1)))
    scaling = (n + 1 - 2 * (torch.arange(n, device=s.device) + 1)).type(torch.float32)
    C = torch.matmul(s, scaling.unsqueeze(0))
    P_max = (C-B).permute(0, 2, 1)
    sm = torch.nn.Softmax(-1)
    P_hat = sm(P_max / tau)
    return P_hat

def cont_sort(x, perm=None, temp=1):
    original_size = x.size()

    x = x.view(-1, x.size(2), 1)

    if perm is None:
        perm = deterministic_sort(x, temp)
    else:
        perm = perm.transpose(1,2)

    x = perm.matmul(x)
    x = x.view(original_size)
    return x, perm



In [18]:
class ELementEncoder(nn.Module):
    """
    Encodes each instance into a latent embedding
    """
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, x): #(B, N, D)
        return self.net(x) #(B, N, H)
    
class SetEncoder(nn.Module):
    """"
    Encodes complete set into permutation invariant vector using FSPool
    """
    def __init__(self, hidden_dim, n_pieces=4):
        super().__init__()
        self.pool = FSPool(in_channels=hidden_dim, n_pieces=n_pieces)
        self.out = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        pooled, _ = self.pool(x.transpose(1, 2)) #(B, H)
        return self.out(pooled)
    
class SetDecoder(nn.Module):
    """
    Predicts an output set (or refined set)
    """
    def __init__(self, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, pred_set, set_embed):
        set_embed = set_embed.unsqueeze(1).expand_as(pred_set)
        return pred_set + self.net(torch.cat([pred_set, set_embed], dim=-1))
    

class DSPN_FSPool(nn.Module):
    """
    Full DSPN Pipeline
    """
    def __init__(self, in_dim, hidden_dim, set_size, num_classes, n_pieces=4, num_iters=5):
        super().__init__()
        self.encoder = ELementEncoder(in_dim, hidden_dim)
        self.set_encoder = SetEncoder(hidden_dim, n_pieces)
        self.decoder = SetDecoder(hidden_dim)
        self.num_iters = num_iters
        self.set_size = set_size


        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, input_set):
        """
        input_set: (B, N, D)
        return: class logits, predicted set
        """

        B, N, D = input_set.shape

        encoded = self.encoder(input_set) #(B, N, H)

        pred_set = torch.zeros(B, self.set_size, encoded.size(-1), device=input_set.device)
        target_embed = self.set_encoder(encoded) #(B, H)

        for _ in range(self.num_iters):
            pred_embed = self.set_encoder(pred_set)
            pred_set = self.decoder(pred_set, target_embed-pred_embed)

        final_embed = self.set_encoder(pred_set)
        logits = self.classifier(final_embed)

        return logits, pred_set
    
    def compute_loss(self, input_set, labels):
        logits, pred_set = self.forward(input_set)
        cls_loss = F.cross_entropy(logits, labels)

        encoded = self.encoder(input_set)
        target_embed = self.set_encoder(encoded)
        pred_embed = self.set_encoder(pred_set)
        recon_loss = F.mse_loss(pred_embed, target_embed)

        return cls_loss + 0.1 * recon_loss, cls_loss, recon_loss

In [24]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./mnist', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./mnist', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)


In [25]:
def image_to_patches(x, patch_size=4):
    B, C, H, W = x.shape
    x = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    x = x.contiguous().view(B, C, -1, patch_size, patch_size) #(B, 1, N, p, p)
    x = x.permute(0, 2, 1, 3, 4) #(B, N, C, p, p)
    x = x.reshape(B, x.size(1), -1) #(B, N, p)
    return x

In [26]:
patch_size = 4
set_size = (28 // patch_size) ** 2  # 49
patch_dim = patch_size * patch_size  # 16
hidden_dim = 128
num_classes = 10

model = DSPN_FSPool(in_dim=patch_dim, hidden_dim=hidden_dim, set_size=set_size,
                    num_classes=num_classes, n_pieces=8, num_iters=5)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [41]:
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
print(f'using device: {device}')

for epoch in range(25):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    loop = tqdm(train_loader, desc=f"Epoch {epoch}", leave = False)

    for images, labels in loop:
        images = images.to(device)
        labels = labels.to(device)

        # Convert images to patch sets
        x = image_to_patches(images, patch_size=patch_size)

        loss, cls_loss, recon_loss = model.compute_loss(x, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        preds = model.forward(x)[0].argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += images.size(0)

    print(f"Epoch {epoch}: loss={total_loss/total_samples:.4f}, acc={total_correct/total_samples:.4f}")


using device: cuda


                                                            

Epoch 0: loss=0.4895, acc=0.8663


                                                            

Epoch 1: loss=0.4695, acc=0.8719


                                                            

Epoch 2: loss=0.4638, acc=0.8741


                                                            

Epoch 3: loss=0.4437, acc=0.8810


                                                            

Epoch 4: loss=0.4366, acc=0.8834


                                                            

Epoch 5: loss=0.4227, acc=0.8865


                                                            

Epoch 6: loss=0.4195, acc=0.8878


                                                            

Epoch 7: loss=0.4089, acc=0.8917


                                                            

Epoch 8: loss=0.4062, acc=0.8926


                                                            

Epoch 9: loss=0.3902, acc=0.8982


                                                             

Epoch 10: loss=0.3932, acc=0.8964


                                                             

Epoch 11: loss=0.3823, acc=0.9006


                                                             

Epoch 12: loss=0.3800, acc=0.9027


                                                             

Epoch 13: loss=0.3844, acc=0.9018


                                                             

Epoch 14: loss=0.3689, acc=0.9045


                                                             

Epoch 15: loss=0.3675, acc=0.9056


                                                             

Epoch 16: loss=0.3623, acc=0.9071


                                                             

Epoch 17: loss=0.3652, acc=0.9080


                                                             

Epoch 18: loss=0.3463, acc=0.9116


                                                             

Epoch 19: loss=0.3472, acc=0.9123


                                                             

Epoch 20: loss=0.3522, acc=0.9114


                                                             

Epoch 21: loss=0.3414, acc=0.9140


                                                             

Epoch 22: loss=0.3320, acc=0.9150


                                                             

Epoch 23: loss=0.3341, acc=0.9164


                                                             

Epoch 24: loss=0.3311, acc=0.9164


