In [1]:
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def exits(val):
  return val is not None

def default(val, d): 
  return val if(exits(val)) else d

In [None]:
class Residual (nn.Module):
  def __init__(self, fn):
    super().__init__()
    self.fn = fn

  def forward(self, x, **kwargs):
    return self.fn(x, **kwargs) + x

In [None]:
class PreNorm(nn.Module):
  def __init__(self, dim, fn):
    super().__init__()
    self.norm = nn.Layerform(dim)
    self.fn = fn

    def forward(self, x, **kwargs): 
      return self.fn(self.norm(x), **kwargs)

In [None]:
class GEGLU(nn.Module):
  def forward(self, x):
    x, gates =  x.chunk(2, dim =- 1)
    return x * F.gelu(gates)

In [None]:
class FeedForward (nn.Module):
  def __init__(self, dim, mult = 4, dropout = 0.):
    super().__init__()
    self.net = nn.Sequential(
      nn.Linear(dim, dim * mult * 2),
      GEGLU(),
      nn.Dropout (dropout),
      nn.Linear(dim * mult, dim)
    )

  def forward(self, x, **kwargs): 
    return self.net(x)

In [None]:
class Attention (nn.Module): 
  def __init__(
      self,
      dim,
      heads = 8,
      dim_head = 16,
      dropout = 0.
      ):
    super().__init__()
    inner_dim = dim_head*heads
    self.heads = heads
    self.scale = dim_head ** -0.5

    self.to_qkv = nn.Linear (dim, inner_dim * 3, bias = False)
    self.to_out = nn.Linear(inner_dim, dim)

    self.dropout = nn.Dropout (dropout)

    def forward(self, x):
      h = self.heads
      q, k, v = self.to_qkv(x).chunk(3, dim = -1)
      q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 
      sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
      attn = sim.softmax(dim = -1) 
      attn = self.dropout (attn)

      out = einsum('b h i j, b h j d-> b h i d', attn, v)
      out = rearrange(out, 'b h n d-> b n (h d)', h = h)
      return self.to_out(out)

In [None]:
class MLP(nn.Module):
  def __init__(self, dims, act=None): 
    super().__init__()
    dims_pairs = list(zip(dims[:-1], dims[1:]))
    layers = [] 
    for ind, (dim_in, dim_out) in enumerate(dims_pairs):
      is_last = ind >= (len(dims_pairs) - 1) 
      linear = nn.Linear(dim_in, dim_out) 
      layers.append(linear)

      if is_last: 
        continue
        
      act = default(act, nn.ReLU())
      layers.append(act)
      
    self.mlp = nn.Sequential(*layers)

  def forward(self, x):
     return self.mlp(x)

In [None]:
class Transformer (nn.Module):
  def __init__(self, num_tokens, dim, depth, heads, dim_head, attn_dropout, ff_dropout): 
    super().__init__()
    self.embeds = nn.Embedding(num_tokens, dim) 
    self.layers = nn.ModuleList([])
    
    for _ in range(depth):
      self.layers.append(nn.ModuleList([
        Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout))),
        Residual (PreNorm(dim, FeedForward(dim, dropout = ff_dropout))),
      ]))
      
    def forward(self, x):
      x = self.embeds(x)
      
      for attn, ff in self.layers:
        x = attn(x)
        x = ff(x)
      return x

In [None]:
class TabTransformer (nn.Module):
  def __init__(self,
               *,
              categories,
              num_continuous,
              dim,
              depth,
              heads,
              dim_head= 16,
              dim_out = 1, 
              mlp_hidden_mults = (4, 2),
              mlp_act = None,
              num_special_tokens = 2, 
              continuous_mean_std = None,
              attn_dropout=0.,
              ff_dropout = 0.
              ):
      self.num_categories =  len(categories)
      self.num_unique_categories = sum(categories)
      self.num_special_tokens = num_special_tokens
      total_tokens = self.num_unique_categories + self.num_special_tokens
      categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
      categories_offset = categories_offset.cumsum(dim = -1)[:-1]
      self.transformer = Transformer(
        num_tokens = total_tokens,
        dim = dim,
        depth = depth,
        heads = heads, 
        dim_head = dim_head,
        attn_dropout = attn_dropout, 
        ff_dropout = ff_dropout
     )

      input_size = (dim * self.num_categories ) + num_continuous
      l = input_size

      hidden_dimensions = list(map(lambda t:1 * t, mlp_hidden_mults))
      all_dimensions = [input_size, *hidden_dimensions, dim_out] 

      self.mlp = MLP(all_dimensions, act = mlp_act)

  def forward(self, x_categ, x_cont): 
    x_categ += self.categories_offset
    x = self.transformer(x_categ)
    flat_categ= x.flatten(1)

    if exists(self.continuous_mean_std):
      mean, std = self.continuous_mean_std.unbind(dim = -1) 
      x_cont = (x_cont - mean) / std
      
      normed_cont =  self.norm(x_cont) 
      x = torch.cat((flat_categ, normed_cont), dim = -1) 
      return self.mlp(x)
