In [79]:
%load_ext autoreload
%autoreload 2
    
import numpy as np
import argparse
import torch
import os
import glob
import random
from evodiff.utils import Tokenizer
import pathlib
from sequence_models.datasets import UniRefDataset
from tqdm import tqdm
from evodiff.plot import aa_reconstruction_parity_plot
import pandas as pd
from evodiff.pretrained import load_sequence_checkpoint
from matplotlib import pyplot as plt
import pkg_resources
from evodiff.utils import Tokenizer


home = str(pathlib.Path.home())

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [151]:
tokenizer = Tokenizer()
model = ByteNetLMTime()
collater = SimpleCollater(tokenizer)

[1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128]


In [152]:
data_train = UniRefDataset('data/uniref50/', 'train', structure=False, max_len=1024)
data_valid = UniRefDataset('data/uniref50/', 'test', structure=False, max_len=1024)

D =10
seqs = [data_train[i] for i in range(D)]
data = collater(seqs)

In [153]:
sum(p.numel() for p in model.parameters())


43683359

In [154]:
model(data[0], torch.zeros(D).float(), torch.zeros(data[0].shape).float()).shape

RuntimeError: The size of tensor a (543) must match the size of tensor b (10) at non-singleton dimension 1

In [82]:
def _pad(tokenized, value, dim=2):
    """
    Utility function that pads batches to the same length.

    tokenized: list of tokenized sequences
    value: pad index
    """
    batch_size = len(tokenized)
    max_len = max(len(t) for t in tokenized)
    if dim == 3: # dim = 3 (one hot)
        categories = tokenized[0].shape[-1]
        output = torch.zeros((batch_size, max_len, categories)) + value
        for row, t in enumerate(tokenized):
            output[row, :len(t), :] = t
    elif dim == 2: # dim = 2 (tokenized)
        output = torch.zeros((batch_size, max_len)) + value
        for row, t in enumerate(tokenized):
            output[row, :len(t)] = t
    else:
        print("padding not supported for dim > 3")
    return output

class SimpleCollater(object):
    def __init__(self, tokenizer=Tokenizer()):
        self.tokenizer = tokenizer

    def __call__(self, sequences):
        tokenized = [torch.tensor(self.tokenizer.tokenize(s)) for s in sequences]
        tokenized = _pad(tokenized, self.tokenizer.pad_id)
        masks = tokenized != self.tokenizer.pad_id
        return tokenized.to(torch.long), masks

In [150]:
from sequence_models.layers import PositionFeedForward
from sequence_models.convolutional import ByteNetBlock
from torch import nn
import torch.nn.functional as F

class ByteNetLMTime(nn.Module):
    """Stacked residual blocks from ByteNet paper defined by n_layers

         Shape:
            Input: (N, L,)
            input_mask: (N, L, 1), optional
            Output: (N, L, d)
    """

    def __init__(self, n_tokens=31, d_embedding=128, d_model=1024, n_layer=16,
                 kernel_size=5, r=128, rank=None, n_frozen_embs=None,
                 padding_idx=None, causal=False, dropout=0.1, slim=True, activation='gelu',
                 schedule_conditioning=True):
        """
        :param n_tokens: number of tokens in token dictionary
        :param d_embedding: dimension of embedding
        :param d_model: dimension to use within ByteNet model, //2 every layer
        :param n_layers: number of layers of ByteNet block
        :param kernel_size: the kernel width
        :param r: used to calculate dilation factor
        :padding_idx: location of padding token in ordered alphabet
        :param causal: if True, chooses MaskedCausalConv1d() over MaskedConv1d()
        :param rank: rank of compressed weight matrices
        :param n_frozen_embs: number of frozen embeddings
        :param slim: if True, use half as many dimensions in the NLP as in the CNN
        :param activation: 'relu' or 'gelu'
        :param down_embed: if True, have lower dimension for initial embedding than in CNN layers
        """
        super().__init__()
        self.time_encoding = TimestepEmbedder(d_embedding) # Timestep encoding
        self.time_mod_layer = nn.Linear(d_embedding, d_model)
        if schedule_conditioning:
            self.s_embed_input = TimestepEmbedder(d_model)
            self.s_embed_block = TimestepEmbedder(d_embedding)
        self.embedder = nn.Embedding(n_tokens, d_model, padding_idx=padding_idx)
        log2 = int(np.log2(r)) + 1
        dilations = [2 ** (n % log2) for n in range(n_layer)]
        print(dilations)
        d_h = d_model
        if slim:
            d_h = d_h // 2
        self.layers = nn.ModuleList([
            ByteNetBlock(d_model, d_h, d_model, kernel_size, dilation=d, causal=causal, rank=rank,
                         activation=activation)
            for d in dilations
        ])
        self.c_mod_layers = nn.ModuleList([nn.Linear(d_embedding, 2*d_model) for d in dilations])
        self.dropout = dropout
        self.decoder = PositionFeedForward(d_model, n_tokens)
        self.last_norm = nn.LayerNorm(d_model)

    def forward(self, x, t, S, input_mask=None):
        """
        :param x: (batch, length)
        :param y: (batch)
        :param input_mask: (batch, length, 1)
        :return: (batch, length,)
        """
        x = self.embedder(x)
        c = F.silu(self.time_encoding(t))[:, None, :]

        x = x + self.time_mod_layer(c)
        if S is not None:
            bs, seq_len = S.shape[0], S.shape[1]
            S_out = F.silu(self.s_embed_input(S.reshape(-1))).reshape(bs, seq_len, -1)
            x = x + S_out
            
            # WIP, this is approximately correct but not thoroughly tested
            S_out = F.silu(self.s_embed_block(S.reshape(-1))).reshape(bs, seq_len, -1)
            c = c[ + S_out

        for layer, c_layer in zip(self.layers, self.c_mod_layers):
            x = layer(x, input_mask=input_mask)
            c_mod = c_layer(c)
            modulate_fused(x, *c_mod.chunk(2, dim=-1))
            if self.dropout > 0.0:
                x = F.dropout(x, self.dropout)
        return self.decoder(self.last_norm(x))


In [129]:
import math

# function overload
def modulate(x: torch.Tensor,
             shift: torch.Tensor,
             scale: torch.Tensor) -> torch.Tensor:
  return x * (1 + scale) + shift

@torch.jit.script
def modulate_fused(x: torch.Tensor,
                   shift: torch.Tensor,
                   scale: torch.Tensor) -> torch.Tensor:
  return modulate(x, shift, scale)

class TimestepEmbedder(nn.Module):
  """
  Embeds scalar timesteps into vector representations.
  """
  def __init__(self, hidden_size, frequency_embedding_size=256):
    super().__init__()
    self.mlp = nn.Sequential(
      nn.Linear(frequency_embedding_size, hidden_size, bias=True),
      nn.SiLU(),
      nn.Linear(hidden_size, hidden_size, bias=True))
    self.frequency_embedding_size = frequency_embedding_size

  @staticmethod
  def timestep_embedding(t, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.
    :param t: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an (N, D) Tensor of positional embeddings.
    """
    # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
    half = dim // 2
    freqs = torch.exp(
      - math.log(max_period)
      * torch.arange(start=0, end=half, dtype=torch.float32)
      / half).to(device=t.device)
    args = t[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
      embedding = torch.cat(
        [embedding,
         torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

  def forward(self, t):
    t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
    t_emb = self.mlp(t_freq)
    return t_emb