In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset

import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

import lightning_trainable.metrics as M

from adaptive_dg.models import InvariantLinear, MultiheadInvariantLinear

In [2]:
torch.set_printoptions(precision=2)

In [3]:
class SetDataset(Dataset):
    def __init__(self, set_data: torch.Tensor, set_labels: torch.Tensor, set_size: int):
        self.set_data = set_data
        self.set_labels = set_labels
        self.set_size = set_size

    def __len__(self):
        return len(self.set_data)

    def __getitem__(self, item):
        data = self.set_data[item]
        label = self.set_labels[item]
        set_data = self.set_data[torch.all(self.set_labels == label, dim=-1)]
        set_data = set_data[torch.randperm(len(set_data))[:self.set_size]]

        return data, set_data, label

In [4]:
set_size = 1024
set_means = [10, 10]
set_stds = [0.5, 1]

data = []
labels = []
for i, (mean, std) in enumerate(zip(set_means, set_stds)):
    set_data = torch.randn(10_000, 1) * std + mean
    set_label = F.one_hot(torch.tensor(i), num_classes=len(set_means)).repeat(10_000, 1)
    data.append(set_data)
    labels.append(set_label)

data = torch.cat(data)
labels = torch.cat(labels)

In [5]:
class SetNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.set_network = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU(),
            nn.Linear(16, 32),
            nn.ReLU(),
            MultiheadInvariantLinear(num_heads=16, in_shape=(set_size, 32), out_shape=(32, 64)),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            MultiheadInvariantLinear(num_heads=16, in_shape=(32, 256), out_shape=(16, 128)),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            MultiheadInvariantLinear(num_heads=16, in_shape=(16, 32), out_shape=(1, 16)),
        )

        self.network = nn.Sequential(
            nn.Linear(1 + 16, 16),
            nn.ReLU(),
            nn.Linear(16, 2),
        )

    def forward(self, x: torch.Tensor, xs: torch.Tensor):
        embedding = self.set_network(xs)
        embedding = torch.mean(embedding, dim=1)
        return self.network(torch.cat([x, embedding], dim=1))

In [6]:
network = SetNetwork().cuda()

In [7]:
watch = network.set_network[-1]

In [8]:
dataset = SetDataset(data, labels, set_size=set_size)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

In [9]:
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)

In [None]:
weights = []
grads = []
initial_weight = list(watch.parameters())[0].clone().detach()

for epoch in tqdm(range(1000)):
    weights.append(list(watch.parameters())[0].clone().detach())

    for batch in dataloader:
        optimizer.zero_grad()

        batch_data, batch_set_data, batch_labels = batch

        logits = network(batch_data.cuda(), batch_set_data.cuda()).cpu()

        cross_entropy = F.cross_entropy(logits, batch_labels.float())
        accuracy = M.accuracy(logits, batch_labels.float())
        print(f"\rLogits: {logits[0]}, Cross Entropy: {cross_entropy.item():.02f}, Accuracy: {accuracy.item():.02f}", end="")

        cross_entropy.backward()
        optimizer.step()

    grads.append(list(watch.parameters())[0].grad.clone().detach())

  0%|          | 0/1000 [00:00<?, ?it/s]

Logits: tensor([0.06, 0.07], grad_fn=<SelectBackward0>), Cross Entropy: 0.69, Accuracy: 0.5750cy: 0.51

In [None]:
weights = torch.stack(weights)

torch.allclose(weights[-1], initial_weight)

In [None]:
grads = torch.stack(grads)

torch.allclose(grads, torch.zeros_like(grads))

In [None]:
grads.mean()

In [None]:
for i in range(weights.shape[1]):
    for j in range(weights.shape[2]):
        plt.plot(weights[:, i, j].detach().numpy())