In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader

# ENCODER

In [2]:
class SSE(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, pooling='avg'):
        super(SSE, self).__init__()
        self.phi = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size)
        )
        self.rho = nn.Sequential(
            nn.Linear(hidden_size, output_size),
            nn.ReLU(),
        )
        self.pooling = pooling

    def forward(self, x):
        x = self.phi(x)
        if self.pooling == 'avg':
            x = torch.mean(x, dim=1)
        elif self.pooling == 'max':
            x, _ = torch.max(x, dim=1)
        elif self.pooling == 'sum':
            x = torch.sum(x, dim=1)
        
        return self.rho(x)


In [3]:
batch = [
    torch.randn(32, 100),
    torch.randn(16, 100),
]

X = nn.utils.rnn.pad_sequence(batch, batch_first=True)
mask = torch.arange(X.size(1))[None, :] < torch.tensor([len(b) for b in batch])[:, None]
X = X * mask.unsqueeze(-1)

In [4]:
model = SSE(input_size=100, hidden_size=256, output_size=128, pooling='avg')
output = model(X)
print(output.shape)

torch.Size([2, 128])


# DECODER

In [5]:
class SSD(nn.Module):
    def __init__(self, z_dim, element_dim, max_elements):
        super(SSD, self).__init__()
        self.max_elements = max_elements
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, z_dim*2),
            nn.ReLU(),
            nn.Linear(z_dim*2, element_dim * max_elements),
        )
        self.element_dim = element_dim

    def forward(self, x):
        out = self.decoder(x)
        out = out.view(-1, self.max_elements, self.element_dim)
        return out

# S2S

In [6]:
class S2S(nn.Module):
    def __init__(self, input_size, hidden_size, z_dim, element_dim, max_elements, output_size):
        super(S2S, self).__init__()
        self.encoder = SSE(input_size, hidden_size, z_dim, pooling='avg')
        self.decoder = SSD(z_dim, element_dim, max_elements)

    def forward(self, x):
        z = self.encoder(x)
        out = self.decoder(z)
        return out

# DATASET

In [7]:
class SyntheticSetDataset(Dataset):
    def __init__(self, num_samples=1000, min_len=3, max_len=8, elem_dim=100):
        self.num_samples = num_samples
        self.min_len = min_len
        self.max_len = max_len
        self.elem_dim = elem_dim

        # Pre-generate sets
        self.data = []
        for _ in range(num_samples):
            # Ensure n and m are Python ints (not PyTorch Number/float) so torch.randn accepts them
            n = int(torch.randint(min_len, max_len+1, (1,)).item())
            m = int(torch.randint(min_len, max_len+1, (1,)).item())
            A = torch.randn(n, elem_dim)           # Modality A
            # Create B by sampling rows from A so B is correlated and has shape (m, elem_dim).
            # Use without-replacement when m <= n, otherwise sample with replacement.
            if m <= n:
                indices = torch.randperm(n)[:m]
            else:
                indices = torch.randint(0, n, (m,))
            B = A[indices] + 0.1 * torch.randn(m, elem_dim)
            self.data.append((A, B))

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch, max_elements_decoder=8):
    """
    Pads variable-length sets to:
      - input: max length in batch
      - output: decoder max_elements
    Returns:
      A_padded: [B, max_n, D]
      B_padded: [B, max_elements_decoder, D]
      A_mask, B_mask: boolean masks
    """
    As, Bs = zip(*batch)
    B_size = max_elements_decoder
    D = As[0].size(1)
    max_n = max([a.size(0) for a in As])

    B_padded = torch.zeros(len(batch), B_size, D)
    A_padded = torch.zeros(len(batch), max_n, D)
    A_mask = torch.zeros(len(batch), max_n).bool()
    B_mask = torch.zeros(len(batch), B_size).bool()

    for i, (a, b) in enumerate(zip(As, Bs)):
        # Encode input set
        A_padded[i, :a.size(0), :] = a
        A_mask[i, :a.size(0)] = 1

        # Pad output set to decoder size
        num_b = min(b.size(0), B_size)
        B_padded[i, :num_b, :] = b[:num_b]
        B_mask[i, :num_b] = 1

    return A_padded, B_padded, A_mask, B_mask


# CHAMFER LOSS

In [8]:
def chamfer_loss(pred, target, pred_mask, target_mask):
    """
    pred: [B, N, D]  (decoder output)
    target: [B, M, D] (padded/truncated to N=M=decoder max)
    pred_mask, target_mask: boolean masks [B, N]
    """
    # Compute pairwise distances
    diff = pred.unsqueeze(2) - target.unsqueeze(1)  # [B, N, N, D]
    dist = torch.norm(diff, dim=-1)                 # [B, N, N]

    # Mask invalid positions
    pred_mask = pred_mask.unsqueeze(2)              # [B, N, 1]
    target_mask = target_mask.unsqueeze(1)          # [B, 1, N]
    valid_mask = pred_mask & target_mask
    dist_masked = dist.clone()
    dist_masked[~valid_mask] = float('inf')

    # Chamfer distance: nearest neighbor
    min_dist_pred = dist_masked.min(dim=2)[0]       # [B, N]
    min_dist_target = dist_masked.min(dim=1)[0]     # [B, N]

    # Average only valid elements
    loss = (min_dist_pred[pred_mask.squeeze(2)].mean() +
            min_dist_target[target_mask.squeeze(1)].mean())
    return loss


# DEFINING THE MODEL

In [9]:
input_size = 100
hidden_size = 64
z_dim = 128
element_dim = 100
max_elements = 8  # max output set size
output_size = element_dim  # same as element dim

model = S2S(input_size, hidden_size, z_dim, element_dim, max_elements, output_size)
model = model.cuda()


# TRAINING

In [12]:
# Dataset
dataset = SyntheticSetDataset(num_samples=1000, min_len=3, max_len=8, elem_dim=100)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, 
                        collate_fn=lambda batch: collate_fn(batch, max_elements_decoder=8))

# Model
input_size = 100
hidden_size = 256
z_dim = 256
element_dim = 100
max_elements = 8
output_size = element_dim

model = S2S(input_size, hidden_size, z_dim, element_dim, max_elements, output_size).cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-2)
epochs = 10

for epoch in range(epochs):
    total_loss = 0
    for A, B, A_mask, B_mask in dataloader:
        A, B = A.cuda(), B.cuda()
        A_mask, B_mask = A_mask.cuda(), B_mask.cuda()
        A = (A-A.mean(dim=1, keepdim=True))/A.std(dim=1, keepdim=True)
        B = (B-B.mean(dim=1, keepdim=True))/B.std(dim=1, keepdim=True)

        optimizer.zero_grad()
        B_hat = model(A)  # [B, max_elements, element_dim]

        # Prediction mask: all True since decoder always outputs max_elements
        pred_mask = torch.ones(B_hat.size(0), B_hat.size(1), dtype=torch.bool, device=B_hat.device)
        loss = chamfer_loss(B_hat, B, pred_mask=pred_mask, target_mask=B_mask)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")


Epoch 1/10, Loss: 20.8733
Epoch 2/10, Loss: 20.6866
Epoch 3/10, Loss: 20.6844
Epoch 4/10, Loss: 20.6649
Epoch 5/10, Loss: 20.6808
Epoch 6/10, Loss: 20.6679
Epoch 7/10, Loss: 20.6699
Epoch 8/10, Loss: 20.6899
Epoch 9/10, Loss: 20.6688
Epoch 10/10, Loss: 20.6840


# THIS IS A DUMB S2S MAPPER