<a href="https://colab.research.google.com/github/Ronnypetson/titanic/blob/master/MNIST_Maromba.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import numpy as np
import torch.nn as nn
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.optim import Adam

import pandas as pd
from sklearn.metrics import accuracy_score
import matplotlib.pylab as plt
import time
from IPython import display
from IPython.core.debugger import Pdb

def breakpoint():
    Pdb().set_trace()

device = "cuda:0" if torch.cuda.is_available() else "cpu"

%matplotlib inline

### Dataset

In [None]:
tr = ToTensor()

def _transform(x):
  return (tr(x) * 2.0 - 1.0).reshape(-1)

bsize = 32

MNIST_train_data = MNIST(
    'MNIST_root/',
    download=True,
    train=True,
    transform=_transform,
)
train_data_loader = torch.utils.data.DataLoader(
    MNIST_train_data,
    batch_size=bsize,
    shuffle=True,
    num_workers=1,
)

MNIST_test_data = MNIST(
    'MNIST_root_test/',
    download=True,
    train=False,
    transform=_transform,
)
test_data_loader = torch.utils.data.DataLoader(
    MNIST_test_data,
    batch_size=bsize,
    shuffle=True,
    num_workers=1,
)

In [3]:
def _posenc(shape, d=32):
  """
  3D Positional encodings (sin(row) + cos(col))
  """
  assert len(shape) == 2
  rows, cols = shape
  idx_sin = np.zeros((rows, d))
  idx_cos = np.zeros((cols, d))
  for idx in range(rows):
    _x = (np.arange(0, d) / d) * (4 * np.pi * (1 + idx / rows))
    idx_sin[idx] = np.sin(_x)
  for idx in range(cols):
    _x = (np.arange(0, d) / d) * (4 * np.pi * (1 + idx / cols))
    idx_cos[idx] = np.cos(_x)
  idx_sin = torch.from_numpy(idx_sin)
  idx_cos = torch.from_numpy(idx_cos)
  idx = (
      idx_sin.reshape((rows, 1, d)).repeat(1, cols, 1)
      + idx_cos.reshape((1, cols, d)).repeat(rows, 1, 1)
  )
  idx = idx.reshape(rows * cols, d) / 2.0
  return idx

In [None]:
rows, cols, d = 5, 5, 10
pos = _posenc((rows, cols), d).reshape(rows, cols, d)
fig, axs = plt.subplots(nrows=rows, ncols=cols, layout=None)
for row in range(rows):
  for col in range(cols):
    axs[row][col].plot(range(d), pos[row, col].numpy())
plt.show()
print(pos[0] @ pos[1].T)

In [None]:
x, y = MNIST_train_data[0]
plt.imshow(np.array(x.reshape(28, 28))), y

### Classe Tensor Maromba

In [14]:
class MTensor:
  def __init__(
      self,
      values: torch.Tensor,
      indices: torch.Tensor,
      indexer: nn.Module=nn.Identity(),
    ):
    assert values.shape == indices.shape[:-1]
    self.data = values
    self.idx = indices
    self.indexer = indexer

  def __getitem__(self, idx):
    return MTensor(self.data[idx], self.idx[idx], self.indexer)

  def __setitem__(self, idx, value):
    self.data[idx] = value.data
    self.idx[idx] = value.idx

  def __delitem__(self, idx):
    del self.data[idx]
    del self.idx[idx]

  def __len__(self):
    return len(self.data)

  @staticmethod
  def cat(mts, dim=0):
    values = [mt.data for mt in mts]
    indices = [mt.idx for mt in mts]
    values = torch.cat(values, dim=dim)
    indices = torch.cat(indices, dim=dim)
    mt = MTensor(values, indices)
    return mt

  @staticmethod
  def unsqueeze(mt, dim=0):
    assert dim != -1
    assert dim < len(mt.idx.shape) - 1
    mt.data = mt.data.unsqueeze(dim)
    mt.idx = mt.idx.unsqueeze(dim)
    return mt

  @staticmethod
  def squeeze(mt, dim=0):
    assert dim != -1
    assert dim < len(mt.idx.shape) - 1
    mt.data = mt.data.squeeze(dim)
    mt.idx = mt.idx.squeeze(dim)
    return mt

  def _gbmd(self, u, v, idxu, idxv) -> torch.Tensor:
    """
    'General Batch Maromba Dot'
    Shorter implementation for the 'batch maromba dot' operation.
    u: M x d_u
    v: N x d_v
    idxu: M x d_u x d_idx
    idxv: N x d_v x d_idx
    """
    m, d_u = u.shape
    n, d_v = v.shape
    d_idx = idxu.shape[-1]
    assert (m, d_u, d_idx) == idxu.shape
    assert (n, d_v, d_idx) == idxv.shape
    # uidxu: M x d_idx
    # vidxv: N x d_idx
    uidxu = torch.bmm(u.reshape(m, 1, d_u), idxu).squeeze(1)
    vidxv = torch.bmm(v.reshape(n, 1, d_v), idxv).squeeze(1)
    dot = uidxu @ vidxv.T
    return dot

  def _xor_idx(self, idxu, idxv):
    """
    idxu: M x d_u x d_idx
    idxv: N x d_v x d_idx
    """
    m, d_u, d_idx = idxu.shape
    n, d_v, _ = idxv.shape
    assert d_idx == idxv.shape[-1]
    # idxu: (M * d_u) x d_idx x 1
    # idxv: (N * d_v) x d_idx x 1
    idxu = idxu.reshape(m * d_u, d_idx, 1)
    idxv = idxv.reshape(n * d_v, d_idx, 1)
    # siiT: M x d_idx x d_idx
    # sjjT: N x d_idx x d_idx
    siiT = torch.bmm(idxu, idxu.permute(0, 2, 1))
    siiT = siiT.reshape(m, d_u, d_idx, d_idx).sum(dim=1)
    sjjT = torch.bmm(idxv, idxv.permute(0, 2, 1))
    sjjT = sjjT.reshape(n, d_v, d_idx, d_idx).sum(dim=1) ###
    # siiT: (M * N) x d_idx x d_idx
    # sjjT: (M * N) x d_idx x d_idx
    siiT = siiT.unsqueeze(1).repeat(1, n, 1, 1).reshape(m * n, d_idx, d_idx)
    sjjT = sjjT.unsqueeze(0).repeat(m, 1, 1, 1).reshape(m * n, d_idx, d_idx)
    # si: (M * N) x d_idx x 1
    # sj: (M * N) x d_idx x 1
    si = idxu.reshape(m, d_u, d_idx).sum(dim=1).unsqueeze(1)
    si = si.repeat(1, n, 1).reshape(m * n, d_idx, 1)
    sj = idxv.reshape(n, d_v, d_idx).sum(dim=1).unsqueeze(0)
    sj = sj.repeat(m, 1, 1).reshape(m * n, d_idx, 1)
    diag_siiT_sjjT = torch.diagonal(torch.bmm(siiT, sjjT), dim1=1, dim2=2)
    diag_siiT_sjjT = diag_siiT_sjjT.unsqueeze(-1)
    xor_idx = torch.bmm(siiT, sj) + torch.bmm(sjjT, si) - 2 * diag_siiT_sjjT
    xor_idx = xor_idx.reshape(m, n, d_idx) / d_u
    return xor_idx

  def __matmul__(self, b):
    """
    Useful for computing m-product between a batch of inputs (N x ...) and a
    parameter matrix (m x n).

    self.data: pre_shape(self) x in_dim(self)
    self.data.idx: pre_shape(self) x in_dim(self) x d_idx
    b.data: pre_shape(b) x in_dim(b)
    b.idx: pre_shape(b) x in_dim(b) x d_idx

    Returns 'mdot'
    mdot.data: pre_shape(self) x pre_shape(b)
    mdot.idx: pre_shape(self) x pre_shape(b) x d_idx
    """
    apre = self.data.shape[:-1]
    bpre = b.data.shape[:-1]
    d_idx = self.idx.shape[-1]
    assert d_idx == b.idx.shape[-1]
    aidx = self.idx.reshape(*((-1,) + self.idx.shape[-2:]))
    bidx = b.idx.reshape(*((-1,) + b.idx.shape[-2:]))
    mdot = self._gbmd(
        self.data.reshape(-1, self.data.shape[-1]),
        b.data.reshape(-1, b.data.shape[-1]),
        aidx,
        bidx
    )
    mdot = mdot.reshape(apre + bpre)
    midx = self._xor_idx(aidx, bidx)
    midx = midx.reshape(apre + bpre + (d_idx,))
    mdot = MTensor(mdot, midx, self.indexer)
    return mdot

  def __mul__(self, b):
    """
    self: N x out_dim_a x in_dim_a (x idx_dim)
    b:    N x out_dim_b x in_dim_b (x idx_dim)
    """
    # ...

In [None]:
# a = MTensor(torch.randn(1, 2), torch.randn(1, 2, 3))
# b = MTensor(torch.randn(1, 2), torch.randn(1, 2, 3))
# c = MTensor.cat([a, b])
# print(a.data)
# print(b.data)
# print(c.data)
# print(a.idx)
# print(b.idx)
# print(c.idx)
# print(a.data)
# print(a.idx)
# idx = torch.tensor([1, 2]).long()
# b = a[:, idx]
# print(b.data)
# print(b.idx)

### Classe do Módulo Treinável

In [17]:
class MModule(nn.Module):
  def __init__(self, n_params=600, idx_dim=32, samples=32, sets=64, device="cpu"):
    super().__init__()
    self.idx_dim = idx_dim
    self.samples = samples
    self.sets = sets
    self.device = device
    self.n_params = n_params
    self.W = nn.Parameter(torch.randn((1, n_params), device=device))
    self.W_idx = nn.Parameter(torch.randn((1, n_params, idx_dim), device=device))
    self.MW = MTensor(self.W, self.W_idx)
    self.activation = nn.ReLU()

  def _msample(self, x: MTensor, n_sets, n_samples):
    """
    x.data: N x in_dim
    x.idx: N x in_dim x idx_dim

    Returns
    x_sets: N x n_sets x n_samples
    """
    n, in_dim, idx_dim = x.idx.shape
    assert x.data.shape == (n, in_dim)
    x_sets = []
    for _ in range(n_sets):
      idx = np.random.choice(in_dim, n_samples, replace=True)
      idx = torch.tensor(idx).long()
      # x_sampled.data: N x 1 x n_samples
      x_sampled = MTensor.unsqueeze(x[:, idx], dim=1)
      x_sets.append(x_sampled)
    # x_sets.data: N x n_sets x n_samples
    x_sets = MTensor.cat(x_sets, dim=1)
    return x_sets

  def _W_step(self, x: MTensor):
    """
    x.data: N x in_dim
    x.idx: N x in_dim x idx_dim
    """
    n, in_dim, idx_dim = x.idx.shape
    assert x.data.shape == (n, in_dim)
    # Put 1 into x
    one = MTensor(
        torch.ones(n, 1).to(self.device),
        torch.zeros(n, 1, idx_dim).to(self.device),
    )
    x = MTensor.cat([x, one], dim=1)
    # Sample W
    W_sets = []
    for _ in range(self.sets):
      idx = np.random.choice(self.n_params, self.samples, replace=True)
      idx = torch.tensor(idx).long()
      W_sets.append(self.MW[:, idx])
    W_sets = MTensor.cat(W_sets, dim=0)
    # mdot: N x sets
    mdot = x @ W_sets
    mdot.data = self.activation(mdot.data)
    return mdot

  def _pool_step(self, x: MTensor):
    """
    x.data: N x in_dim
    x.idx: N x in_dim x idx_dim
    """
    n, in_dim, idx_dim = x.idx.shape
    assert x.data.shape == (n, in_dim)
    # x0: N x samples
    # x1: N x sets x samples
    x0 = self._msample(x, 1, self.samples)
    x0 = MTensor.squeeze(x0, 1)
    x1 = self._msample(x, self.sets, self.samples)
    # mdot.data: N x sets
    mdot = x0 @ x1
    mdot.data = self.activation(mdot.data)
    return mdot

  def forward(self, x: MTensor, n_steps=3):
    """
    x.data: N x in_dim
    x.idx: N x in_dim x idx_dim
    """
    # pool: N x sets
    pool = self._W_step(x)
    for step in range(n_steps):
      print(pool.data.shape)
      # pool: N x ((step + 2) * sets)
      if step % 2 == 0:
        pool_new = self._pool_step(pool)
      else:
        pool_new = self._W_step(pool)
      print(pool_new.data.shape)
      pool = MTensor.cat([pool, pool_new], dim=1)
    return pool

In [18]:
model = MModule(n_params=50, idx_dim=16, samples=16, sets=16, device="cpu")
x = MTensor(torch.randn(10, 15), torch.randn(10, 15, 16))
pool = model.forward(x)
print(pool.data.shape)
print(pool.idx.shape)

torch.Size([10, 16])
torch.Size([10, 10, 16])


RuntimeError: ignored

### Função de Custo

In [None]:
def maromba_loss(y_true, y_pred, true_index, pred_index, debug=False):
  """
  y_true: N x d_out
  y_pred: N x d_out
  true_index: N x d_out x d_index
  pred_index: N x d_out x d_index
  """
  n, d_out = y_true.shape
  assert y_true.shape == y_pred.shape
  assert true_index.shape == pred_index.shape
  # index_match: N x d_out x d_out
  index_match = torch.bmm(pred_index, true_index.permute(0, 2, 1))
  # y_true_match: N x 1 x d_out
  # y_pred_match: N x 1 x d_out
  y_true_match = torch.bmm(y_true.unsqueeze(1), index_match.permute(0, 2, 1))
  y_pred_match = torch.bmm(y_pred.unsqueeze(1), index_match)
  huber = nn.HuberLoss()
  match_loss_lr = huber(y_pred, y_true_match.squeeze(1))
  match_loss_rl = huber(y_true, y_pred_match.squeeze(1))
  loss = match_loss_lr + match_loss_rl
  return loss

### -------