In [None]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

2.0.1+cu118
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m43.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m28.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone


In [None]:
from itertools import chain
from math import ceil, log2
from copy import deepcopy

import numpy as np
from torch_geometric.nn.aggr import Aggregation, MaxAggregation
import torch
from torch import nn
import torch_geometric.nn as gnn
from torch.utils.data import DataLoader, Dataset

SEED = 42

- impl. of LCM aggregation using PyG interface
- impl. of bitwise emedding layer

In [None]:
class LCMAggregation(Aggregation):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        # self.lin = nn.Linear(in_channels, out_channels)
        # learnable parameter
        self.gru_cell = nn.GRUCell(out_channels, out_channels)


    def reset_parameters(self):
        self.gru_cell.reset_parameters()


    def _bin_op(self, x, y):
        return (self.gru_cell(x, y) + self.gru_cell(y, x)) / 2


    def forward(
        self,
        x,
        index = None,
        ptr = None,
        dim_size = None,
        dim = -2,
        max_num_elements = None,
    ):
        x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim,
                                   max_num_elements=max_num_elements)

        # x = self.lin(x).permute(1,0,2)
        x = x.permute(1,0,2)
        depth = ceil(log2(x.shape[0]))
        # losses = []

        for _ in range(depth):
            x = [
                self._bin_op(x[2*i], x[2*i+1]) if 2*i+1 < len(x) else x[2*i]
                for i in range(ceil(len(x)/2))
            ]

        assert len(x) == 1

        return x[0]


    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels})')



class Emb(nn.Module):
    def __init__(self, num_bits, emb_dim):
        super().__init__()
        self.embs = nn.ModuleList([
            nn.Embedding(2, emb_dim)
            for _ in range(num_bits)
        ])

    def forward(self, bitvecs):
        return torch.stack([emb(b) for emb, b in zip(self.embs, bitvecs.T)]).sum(0)

In [None]:
x = torch.ones(14, 1)
index = torch.tensor([0]*3 + [1]*8 + [2]*2 + [3]*1)

aggr = LCMAggregation(1, 1)
aggr(x, index)

tensor([[3.],
        [8.],
        [2.],
        [1.]])

verify that LCM agg at least accepts and produces tensors of the correct shapes

In [None]:
def test_lcm_aggregation():
    x = torch.randn(6, 16)
    index = torch.tensor([0, 0, 1, 1, 1, 2])

    aggr = LCMAggregation(16, 32)
    assert str(aggr) == 'LCMAggregation(16, 32)'

    out = aggr(x, index)
    assert out.size() == (3, 32)

test_lcm_aggregation()

RuntimeError: ignored

generate training and validation sets + custom dataloaders

In [None]:
torch.manual_seed(SEED)


bitvecs_to_ints = lambda bitvecs: \
  (bitvecs * torch.pow(2, torch.arange(num_bits).flip(0)).reshape(1,-1)).sum(-1)

num_bits = 8

class Random2ndMinimumDataset(Dataset):
    def __init__(self, dataset_sz, num_bits, multiset_sz):
        self.generate_dataset(dataset_sz, num_bits, multiset_sz)

    def generate_dataset(self, dataset_sz, num_bits, multiset_sz):
        self.dataset = []
        for _ in range(dataset_sz):
            # randomly sample multiset size
            if isinstance(multiset_sz, tuple):
              sz = torch.randint(*multiset_sz, (1,))
            else:
              sz = multiset_sz

            # randomly sample a multiset of integers of size `sz`, encoded as bit-vectors
            bitvecs = torch.randint(0, 2, (sz, num_bits))

            # convert bit-vectors to integers
            ints = bitvecs_to_ints(bitvecs)

            # find the second smallest element in the multiset, which is the target
            target_idx = torch.topk(ints, 2).indices[-1]
            target = bitvecs[target_idx].float()

            self.dataset += [(bitvecs, target)]

    def __getitem__(self, i):
        return self.dataset[i]

    def __len__(self):
        return len(self.dataset)


trainset = Random2ndMinimumDataset(2**16, num_bits, (2, 16+1))
validset = Random2ndMinimumDataset(2**10, num_bits, 32)




# dataloaders

def collate_fn(samples):
    x = torch.cat([samp[0] for samp in samples])
    y = torch.stack([samp[1] for samp in samples])
    sizes = torch.tensor([samp[0].shape[0] for samp in samples])
    return x, y, sizes

train_dl = DataLoader(trainset, batch_size=2**5, collate_fn=collate_fn, shuffle=True)
valid_dl = DataLoader(validset, batch_size=len(validset), collate_fn=collate_fn)

In [None]:
for i, (x, y) in enumerate(iter(trainset)):
  print(x.shape, y.shape)
  if i >= 5:
    break


torch.Size([14, 8]) torch.Size([8])
torch.Size([8, 8]) torch.Size([8])
torch.Size([5, 8]) torch.Size([8])
torch.Size([7, 8]) torch.Size([8])
torch.Size([6, 8]) torch.Size([8])
torch.Size([13, 8]) torch.Size([8])


In [None]:
dl = DataLoader(trainset, batch_size=3, collate_fn=collate_fn)
x, y, sizes = next(iter(dl))

# print(sizes)
a = [ [i] * sz for i, sz in enumerate(sizes) ]

print(a)
print()
print(chain(*a))
print()
print(list(chain(*a)))
print()
print(torch.tensor(list(chain(*a))))


[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2]]

<itertools.chain object at 0x7f3dae4edc30>

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
        2, 2, 2])


In [None]:
a = [0,0,0]
b = [1,1]

for x in chain(*[a, b]):
  print(x)


[0, 0, 0]
[1, 1]


In [None]:
def f(*args):
  print(args)

f(*(1,2,3))

(1, 2, 3)


create model, optimizer, and loss function

In [None]:
torch.manual_seed(SEED)


h = 128

# n = length of list
# (n, 8) -> (n, 128) -> (1, 128) -> (1, 8)

# label (y): [1, 0, 1, 1]
# pred  (x): [.1, .2, .9, .9]
# BCE loss: for each bit, compute ylogx + (1-y)log(1-x), then sum
# -loss = [ 1log(.1) + 0log(.9) ] + [ 0log(.2) + 1log(.8) ] + [ 1log(.9) + 0log(.1) ] + [ 1log(.9) + 0log(.1) ]
#  loss = 2.74





# class Net(nn.Module):
#   def __init__(self):
#     self.enc = nn.Sequential(
#         Emb(num_bits, h),
#         nn.Linear(h, h),
#         nn.Dropout(.5),
#         nn.GELU()) #, nn.Linear(h, h), nn.GELU())

#     self.agg = LCMAggregation(h, h)

#     self.dec = nn.Sequential(
#         nn.Linear(h, h),
#         nn.Dropout(.5),
#         nn.GELU(),
#         nn.Linear(h, num_bits)) #, nn.Sigmoid())

#   def forward(self, x, index):
#     x = self.enc(x)
#     x = self.agg(x, index)
#     x = self.dec(x)
#     return x




enc = nn.Sequential(
    Emb(num_bits, h),
    nn.Linear(h, h),
    nn.Dropout(.5),
    nn.GELU()) #, nn.Linear(h, h), nn.GELU())

agg = LCMAggregation(h, h)

dec = nn.Sequential(
    nn.Linear(h, h),
    nn.Dropout(.5),
    nn.GELU(),
    nn.Linear(h, num_bits)) #, nn.Sigmoid())

net = gnn.Sequential('x, index', [
    (enc, 'x -> x'),
    (agg, 'x, index -> x'),
    (dec, 'x -> x')
])



# using `BCEWithLogitsLoss` instead of `BCELoss`+Sigmoid for numerical stability
criterion = nn.BCEWithLogitsLoss(reduction='none')

opt = torch.optim.Adam(params=net.parameters(), lr=1e-4)



In [None]:
torch.manual_seed(SEED)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

net.to(device)

best_state_dict = None
best_valid_acc = 0

print('Beginning training.')

for ep in range(1):
    losses = []
    for i, (x, y, sizes) in enumerate(train_dl):
        x = x.to(device)
        y = y.to(device)
        index = torch.tensor(list(chain(*[
            [i] * sz for i, sz in enumerate(sizes)
        ]))).to(device)

        pred = net(x, index)

        # print(pred.shape, y.shape, criterion(pred, y).shape)

        loss = criterion(pred, y).sum(-1).mean()
        print(loss)



        losses += [loss.item()]
        print(loss.item())
        break

        opt.zero_grad()
        loss.backward()
        opt.step()
    # losses = np.array(losses)
    # loss_mean, loss_std = losses.mean(), losses.std()

    # accs = []
    # for x, y, sizes in valid_dl:
    #     x = x.to(device)
    #     y = y.to(device)
    #     index = torch.tensor(list(chain(*[
    #         [i] * sz for i, sz in enumerate(sizes)
    #     ]))).to(device)

    #     pred = nn.functional.sigmoid(net(x, index)).round().squeeze()
    #     accuracy = ((pred != y).sum(-1) == 0).sum() / y.shape[0]
    #     accs += [accuracy.item()]
    # mean_acc = sum(accs) / len(accs)
    # if mean_acc > best_valid_acc:
    #     best_state_dict = deepcopy(net.state_dict())
    #     best_valid_acc = mean_acc

    # print(f'ep {ep+1:04d}: loss=[{loss_mean:.4f}, {loss_std:.4f}], acc={mean_acc*100:.2f}%')

Beginning training.
tensor(5.5676, grad_fn=<MeanBackward0>)
5.567600250244141
