Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sampling #7

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
99d11e1
add basic sampling code
Jan 13, 2021
8bb3322
Merge remote-tracking branch 'origin/main' into sample
Jan 13, 2021
b9fd039
add prediction input / output fns
Jan 13, 2021
cf76c6c
get sample_autoregressive working
ConnorJL Jan 14, 2021
c7ff6c4
truncate text tokens properly
ConnorJL Jan 14, 2021
4d51fd9
log model params to tensorboard
ConnorJL Jan 14, 2021
346871f
add vae decoding and write to jpeg
kingoflolz Jan 14, 2021
d13c330
unshift image outputs at decode time
kingoflolz Jan 14, 2021
f8a7449
dirty hack to use vae decoder params when training dalle
kingoflolz Jan 14, 2021
ff56d12
Move initialize_vae_weights to after lowering
leogao2 Jan 17, 2021
4c4e0e0
fix vae checkpoint load in training
ConnorJL Jan 17, 2021
2c14bde
fix parameter count logging
ConnorJL Jan 18, 2021
130c26e
fix image vocab size
ConnorJL Jan 18, 2021
42e7677
add missing dep
lucidrains Apr 2, 2021
fcab065
ready tests for rewiring
lucidrains Apr 3, 2021
3858108
no 3.9 for tests
lucidrains Apr 3, 2021
002343a
make DALLE-mtf work with text and image logits created with separate …
lucidrains Apr 3, 2021
037aefb
add axial positional embedding
lucidrains Apr 3, 2021
e203de4
cleanup
lucidrains Apr 3, 2021
5ecb496
shift by text vocab size (not text seq len)
lucidrains Apr 3, 2021
ec5c298
make sure sampling can be forced to never start below a certain minim…
lucidrains Apr 3, 2021
3492d80
add unique pad token ids, which obviates the need to remove attention…
lucidrains Apr 3, 2021
3b76849
fix text_vocab_dim error
sdtblck Apr 4, 2021
ef96537
Update models.py
sdtblck Apr 4, 2021
a1634d4
fix bug
lucidrains Apr 4, 2021
a3d2115
fix bug with shift
lucidrains Apr 4, 2021
5f279ad
Add adam weight decay optimizer
sdtblck Apr 4, 2021
8db0d55
fix args.steps_per_checkpoint
sdtblck Apr 4, 2021
d3112ca
add variable scope
lucidrains Apr 4, 2021
33de203
do an early copy of inputs for labels
lucidrains Apr 4, 2021
4dbf727
tweak
lucidrains Apr 4, 2021
10dc024
fix initial positions at text_seq_len
lucidrains Apr 5, 2021
9dacd35
more cleanup
lucidrains Apr 5, 2021
c073fbc
make sure axial positional embedding is correctly shifted by one due …
lucidrains Apr 6, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Tests

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
build:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |
pytest -s test.py
10 changes: 5 additions & 5 deletions configs/dalle_coco.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@
},
"train_batch_size": 128,
"eval_batch_size": 128,
"predict_batch_size": 128,
"predict_batch_size": 16,
"steps_per_checkpoint": 5000,
"iterations": 1000,
"train_steps": 100000,
"predict_steps": 0,
"eval_steps": 0,
"n_channels": 3,
"bf_16": false,
"bf_16": true,
"recompute_grad": true,
"lr": 0.0001,
"model_path": "gs://neo-models/dalle_coco/",
"model_path": "gs://neo-models/dalle_coco_sample/",
"mesh_shape": "data:16,model:2",
"layout": "batch_dim:data",
"layout": "batch_dim:data,embed_dim:model",
"n_embd": 1024,
"text_vocab_size": 50258,
"image_vocab_size": 512,
"image_vocab_size": 2048,
"text_seq_len": 256,
"n_layers": 12,
"n_heads": 8,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
tensorflow==2.4.0
tensorflow-datasets
mesh_tensorflow==0.1.18
tpunicorn
lm_dataformat
Expand Down
3 changes: 2 additions & 1 deletion src/dalle_mtf/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .models import DALLE, DiscreteVAE
from .models import DALLE, DiscreteVAE
from .sample import sample_autoregressive
198 changes: 144 additions & 54 deletions src/dalle_mtf/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class DALLE:

def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq_len=256, image_seq_len=1024,
n_layers=6, n_heads=8, batch_size=32, bf_16=True, attn_mask=None, mode="train",
is_incremental_inference=False, context=None, loss_fn=None, params=None, eos_token_id=None,
is_incremental_inference=False, context=None, loss_fn=None, text_loss_weight = 0.15, params=None, padding_id=None,
activation_fn=None):

self.n_embd = n_embd
Expand All @@ -154,10 +154,11 @@ def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq
self.n_layers = n_layers
self.n_heads = n_heads
self.attn_mask = attn_mask
self.total_tokens = text_vocab_size + image_vocab_size + 1 # extra for EOS
self.eos_token_id = self.total_tokens - 1 if eos_token_id is None else eos_token_id
self.total_tokens = text_vocab_size + text_seq_len + image_vocab_size # (this is the order of the embeddings as well [pad] [text tokens] [text padding tokens] [image tokens])

self.padding_id = 0 if padding_id is None else padding_id
self.dimensions = {"embed_dim": mtf.Dimension("embed_dim", n_embd),
"text_vocab_dim": mtf.Dimension("vocab_dim", text_vocab_size),
"text_vocab_dim": mtf.Dimension("vocab_dim", text_vocab_size + text_seq_len),
"image_vocab_dim": mtf.Dimension("vocab_dim", image_vocab_size),
"final_vocab_dim": mtf.Dimension("vocab_dim", self.total_tokens),
"total_seq_dim": mtf.Dimension("total_seq_dim", self.total_seq_dim),
Expand All @@ -174,15 +175,23 @@ def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq
if loss_fn is None:
loss_fn = mtf.layers.softmax_cross_entropy_with_logits
self.loss_fn = loss_fn
self.text_loss_weight = text_loss_weight
if activation_fn is None:
activation_fn = mtf.relu
self.activation_fn = activation_fn
if self.is_incremental_inference:
assert self.context is not None, "must have context in incremental inference"
assert self.context['mode'] == 'incremental'
if params is None: # extra params
params = {}
self.params = defaultdict(lambda: None, params)

def shift_image_tokens(self, image_tokens):
return image_tokens + self.text_seq_len + self.dimensions['text_vocab_dim'].size

def unshift_image_tokens(self, image_tokens):
return image_tokens - (self.text_seq_len + self.dimensions['text_vocab_dim'].size)

def embedding(self, x, name):
embd_dim = self.dimensions["embed_dim"]
vocab_dim = self.dimensions["final_vocab_dim"]
Expand All @@ -200,23 +209,58 @@ def embedding(self, x, name):
x = mtf.dropout(x, rate=embed_dropout, name="wte_dropout")
return x

def positional_embedding(self, x, name):
def axial_positional_embedding(self, mesh, name):
with tf.variable_scope(name):
axial_dim_side = int(sqrt(self.image_seq_len))

embd_dim = self.dimensions["embed_dim"]
axial_dim = mtf.Dimension("axial_dim", self.image_seq_len)

dim_axials = [mtf.Dimension(f"axial_dim_{i}", t) for i, t in enumerate((axial_dim_side, axial_dim_side))]

axial_wpe_1 = mtf.get_variable(mesh, "axial_wpe_1", mtf.Shape([dim_axials[0], embd_dim]),
initializer=tf.random_normal_initializer(stddev=0.01),
master_dtype=self.variable_dtype.master_dtype,
slice_dtype=self.variable_dtype.slice_dtype,
activation_dtype=self.variable_dtype.activation_dtype)

axial_wpe_2 = mtf.get_variable(mesh, "axial_wpe_2", mtf.Shape([dim_axials[1], embd_dim]),
initializer=tf.random_normal_initializer(stddev=0.01),
master_dtype=self.variable_dtype.master_dtype,
slice_dtype=self.variable_dtype.slice_dtype,
activation_dtype=self.variable_dtype.activation_dtype)

axial_wpe_1, axial_wpe_2 = map(lambda t: mtf.broadcast(t, [dim_axials[0], dim_axials[1], embd_dim]),
(axial_wpe_1, axial_wpe_2))
wpe = (axial_wpe_1 + axial_wpe_2) / 2

wpe = mtf.reshape(wpe, [axial_dim, embd_dim])
wpe = pad(wpe, [self.text_seq_len + 1, 0], axial_dim.name)
wpe = mtf.slice(wpe, 0, self.total_seq_dim, axial_dim.name)
wpe = mtf.replace_dimensions(wpe, wpe.shape[0], self.dimensions["embed_seq_dim"])
return wpe


def absolute_positional_embedding(self, mesh, name):
with tf.variable_scope(name):
# Positional embedding
wpe = mtf.get_variable(x.mesh, "wpe",
wpe = mtf.get_variable(mesh, "wpe",
mtf.Shape([self.dimensions["embed_seq_dim"], self.dimensions["embed_dim"]]),
initializer=tf.random_normal_initializer(stddev=0.01),
master_dtype=self.variable_dtype.master_dtype,
slice_dtype=self.variable_dtype.slice_dtype,
activation_dtype=self.variable_dtype.activation_dtype)
position_indices = mtf.range(x.mesh, self.dimensions["total_seq_dim"], tf.int64) if not \
self.is_incremental_inference else (self.context.position - 1)
pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0])
embed_dropout = self.params.get("embed_dropout", 0)
if embed_dropout > 0 and self.mode == "train":
pos_emb = mtf.dropout(pos_emb, rate=embed_dropout, name="wte_dropout")
x += pos_emb
return x
return wpe

def apply_positional_embedding(self, x, wpe):
position_indices = mtf.range(x.mesh, self.dimensions["total_seq_dim"], tf.int64) if not \
self.is_incremental_inference else (self.context.position - 1)
pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0])
embed_dropout = self.params.get("embed_dropout", 0)
if embed_dropout > 0 and self.mode == "train":
pos_emb = mtf.dropout(pos_emb, rate=embed_dropout, name="wte_dropout")
x += pos_emb
return x

def get_attn_mask(self, mesh, nd, ns):
if not exists(self.attn_mask):
Expand All @@ -227,8 +271,13 @@ def get_attn_mask(self, mesh, nd, ns):
return self.attn_mask

def attention(self, x, n_state, mask, attention_type="global", name="attn"):
# x :: [batch, seq, n_embd]
batch_dim, seq_dim, embd_dim = x_shape = x.shape
if not self.is_incremental_inference:
# x :: [batch, seq, n_embd]
batch_dim, seq_dim, embd_dim = x_shape = x.shape
else:
batch_dim, embd_dim = x_shape = x.shape
seq_dim = self.dimensions['total_seq_dim']

assert n_state.size % self.n_heads == 0, "n_state must be divisible by n_heads"
with tf.variable_scope(name):
# Compute attention inputs
Expand All @@ -254,25 +303,7 @@ def attention(self, x, n_state, mask, attention_type="global", name="attn"):
self.context.record_new_states([k, v])

with tf.variable_scope("attention"):
if attention_type == "local":
# `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights.
radius = self.params.get("local_attention_radius", 256)
if self.is_incremental_inference:
q *= one_hot
a = mtf_transformer.attention.local_attention_1d(
q, k, v,
length_dim=k.shape[1],
key_dim=self.dimensions["kv_dim"],
value_dim=self.dimensions["kv_dim"],
radius=radius,
length_dim_num_splits=1,
fully_autoregressive=True,
attention_kwargs={},
)
if self.is_incremental_inference:
a = mtf.gather(a, self.context.position - 1, seq_dim)

elif attention_type == "global":
if attention_type == "global":
if exists(mask):
if not self.is_incremental_inference:
broadcasted_mask = mtf.broadcast(mask,
Expand Down Expand Up @@ -314,6 +345,7 @@ def attention(self, x, n_state, mask, attention_type="global", name="attn"):
a = mtf.dropout(a, rate=residual_dropout, name="res_dropout")
return a


def mlp(self, x, n_state, name="mlp"):
residual_dropout = self.params.get("residual_dropout", 0)
with tf.variable_scope(name):
Expand Down Expand Up @@ -343,12 +375,17 @@ def transformer(self, x, mask):
x = mtf.recompute_grad(block_fn, [x])
else:
x = block_fn(x)
return x
return self.layer_norm(x)

def _loss(self, logits, labels):
def _loss(self, text_logits, image_logits, text_labels, image_labels):
with tf.variable_scope("loss_final"):
loss_batch = self.loss_fn(logits=logits, targets=labels,
vocab_dim=logits.shape[-1], z_loss=0.0)
text_loss_batch = self.loss_fn(logits=text_logits, targets=text_labels,
vocab_dim=text_logits.shape[-1], z_loss=0.0)

image_loss_batch = self.loss_fn(logits=image_logits, targets=image_labels,
vocab_dim=image_logits.shape[-1], z_loss=0.0)

loss_batch = text_loss_batch * self.text_loss_weight + image_loss_batch

with tf.variable_scope("reduce_mean_final"):
loss = mtf.reduce_mean(loss_batch)
Expand Down Expand Up @@ -388,29 +425,82 @@ def layer_norm(self, x, name="layer_norm", axis=None, epsilon=1e-5):
x = x * g + b
return x

def to_logits(self, x):
with tf.variable_scope("to_logits"):
logits = self.linear(self.layer_norm(x), self.dimensions["final_vocab_dim"], name="linear_out")
def to_image_logits(self, x):
with tf.variable_scope("to_image_logits"):
if not self.is_incremental_inference:
x = mtf.slice(x, begin = self.text_seq_len, size = self.image_seq_len, slice_dim_name = x.shape[1].name)

image_logits = self.linear(x, self.dimensions["image_vocab_dim"], name="linear_image_out")

# Go to full precision for the logits
return mtf.cast(logits, tf.float32)
image_logits = mtf.cast(image_logits, tf.float32)
return image_logits

def to_text_logits(self, x):
with tf.variable_scope("to_text_logits"):
text_tokens = mtf.slice(x, begin = 0, size = self.text_seq_len, slice_dim_name = x.shape[1].name)
text_logits = self.linear(text_tokens, self.dimensions["text_vocab_dim"], name="linear_text_out")

# Go to full precision for the logits
text_logits = mtf.cast(text_logits, tf.float32)
return text_logits

def forward(self, features, return_loss=True, return_logits=False):
inputs = features["tokens"]
tokens = self.positional_embedding(self.embedding(inputs, "embedding"), "positional_embedding")
mesh = inputs.mesh

# make sure all padding gets turned into unique padding tokens

mask = self.get_attn_mask(tokens.mesh, tokens.shape[1], self.dimensions["memory_len_dim"])
input_range = mtf.range(mesh, self.dimensions['total_seq_dim'], tf.int32)

pad_mask = mtf.logical_and(mtf.less(input_range, self.text_seq_len), mtf.equal(inputs, 0)) # only mask in the positions less than text sequence length, and where the input is 0
pad_token_ids = input_range + self.text_seq_len # shift to the range of pad token ids, which come after text token ids, and before image token ids

inputs = mtf.where(pad_mask, pad_token_ids, inputs)

# save original inputs to be used as labels

orig_inputs = mtf.slice(inputs, begin = 0, size = self.total_seq_dim, slice_dim_name = inputs.shape[1].name)

if self.is_incremental_inference:
# reshape inputs if in inference mode
inputs = mtf.gather(inputs, self.context.position - 1, self.dimensions['total_seq_dim'])
inputs = mtf.reshape(inputs, [self.dimensions['batch_dim']])
else:
# add a <bos> to the inputs, and then remove the last token
inputs = pad(inputs, [1, 0], dim_name = inputs.shape[1].name, pad_value = self.padding_id)
inputs = mtf.slice(inputs, begin = 0, size = self.total_seq_dim, slice_dim_name = inputs.shape[1].name)

# embed text and image tokens jointly and add positional embeds

inputs = self.embedding(inputs, "embedding")

abs_pos_emb = self.absolute_positional_embedding(mesh, "positional_embedding")
axial_pos_emb = self.axial_positional_embedding(mesh, "axial_positional_embedding")

inputs = self.apply_positional_embedding(inputs, abs_pos_emb)
tokens = self.apply_positional_embedding(inputs, axial_pos_emb)

mask = self.get_attn_mask(mesh, self.dimensions["total_seq_dim"], self.dimensions["memory_len_dim"])
out = self.transformer(tokens, mask=mask)
logits = self.to_logits(out)

image_logits = self.to_image_logits(out)

if not return_loss:
return logits
image_logits = mtf.cast(image_logits, self.variable_dtype.master_dtype)
return image_logits # we only care about image logits, text logits will be used for loss and never used otherwise

text_logits = self.to_text_logits(out)

labels = orig_inputs # a <bos> is prepended, so the labels it the same as the original input now

text_labels = mtf.slice(labels, begin = 0, size = self.text_seq_len, slice_dim_name = labels.shape[1].name)
image_labels = mtf.slice(labels, begin = self.text_seq_len, size = self.image_seq_len, slice_dim_name = labels.shape[1].name)

loss, loss_batch = self._loss(text_logits, image_logits, text_labels, image_labels)

labels = pad(inputs, [0, 1], dim_name="total_seq_dim", pad_value=self.eos_token_id)
indices = mtf.range(labels.mesh, mtf.Dimension("range", labels.shape[1].size - 1), tf.int32, name="labels_indices") + 1
labels = mtf.gather(labels, indices, dim=labels.shape[1])
labels = mtf.rename_dimension(labels, "range", "total_seq_dim")
loss, loss_batch = self._loss(logits, labels)
if return_logits and return_loss:
# Cast back to checkpoint dtype
logits = mtf.cast(logits, self.variable_dtype.master_dtype)
return loss, loss_batch, logits
image_logits = mtf.cast(image_logits, self.variable_dtype.master_dtype)
return loss, loss_batch, image_logits # we only care about image logits, text logits will be used for loss and never used otherwise
return loss, loss_batch
Loading