Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug with template mask for batch inference #197

Closed
jproney opened this issue Aug 10, 2022 · 9 comments
Closed

Bug with template mask for batch inference #197

jproney opened this issue Aug 10, 2022 · 9 comments

Comments

@jproney
Copy link

jproney commented Aug 10, 2022

Hello, my name is James, and I'm working on training a new AlphaFold variant using OpenFold. Thanks for the great tool!

I think I may have found a bug in how the code processes templates for batch sizes larger than 1 (either that or I'm doing something wrong, in which case help would also be appreciated!). Here's a code snippet that reproduces the problem:

import torch
import torch.nn as nn
import numpy as np

from openfold.model.model import AlphaFold
from openfold.config import model_config
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.data import data_transforms

model_name = "model_1_ptm"

conf = model_config(model_name, train=True)
conf.data.common.max_recycling_iters = 0
conf.data.train.subsample_templates = False
conf.data.train.max_msa_clusters = 1
conf.data.train.max_extra_msa = 1
conf.data.train.max_templates = 1

# copied from openfold/test/data_utils.py
def random_template_feats(n_templ, n, batch_size=None):
    b = []
    if batch_size is not None:
        b.append(batch_size)
    batch = {
        "template_mask": np.random.randint(0, 2, (*b, n_templ)),
        "template_pseudo_beta_mask": np.random.randint(0, 2, (*b, n_templ, n)),
        "template_pseudo_beta": np.random.rand(*b, n_templ, n, 3),
        "template_aatype": np.random.randint(0, 22, (*b, n_templ, n)),
        "template_all_atom_mask": np.random.randint(
            0, 2, (*b, n_templ, n, 37)
        ),
        "template_all_atom_positions": 
            np.random.rand(*b, n_templ, n, 37, 3) * 10,
        "template_torsion_angles_sin_cos": 
            np.random.rand(*b, n_templ, n, 7, 2),
        "template_alt_torsion_angles_sin_cos": 
            np.random.rand(*b, n_templ, n, 7, 2),
        "template_torsion_angles_mask": 
            np.random.rand(*b, n_templ, n, 7),
    }
    batch = {k: v.astype(np.float32) for k, v in batch.items()}
    batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
    return batch


def random_extra_msa_feats(n_extra, n, batch_size=None):
    b = []
    if batch_size is not None:
        b.append(batch_size)
    batch = {
        "extra_msa": np.random.randint(0, 22, (*b, n_extra, n)).astype(
            np.int64
        ),
        "extra_has_deletion": np.random.randint(0, 2, (*b, n_extra, n)).astype(
            np.float32
        ),
        "extra_deletion_value": np.random.rand(*b, n_extra, n).astype(
            np.float32
        ),
        "extra_msa_mask": np.random.randint(0, 2, (*b, n_extra, n)).astype(
            np.float32
        ),
    }
    return batch

n_templ = 1
n_res = 256
n_extra_seq = 1
n_seq = 1
bsize = 2

model = AlphaFold(conf).cuda()

batch = {}


tf = torch.randint(conf.model.input_embedder.tf_dim - 1, size=(bsize, n_res))
batch["target_feat"] = nn.functional.one_hot(tf, conf.model.input_embedder.tf_dim).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)

batch["target_feat"] = torch.rand((bsize, n_res, conf.model.input_embedder.tf_dim))
batch["residue_index"] = torch.rand((bsize, n_res))
batch["msa_feat"] = torch.rand((bsize, n_seq, n_res, conf.model.input_embedder.msa_dim))


t_feats = random_template_feats(n_templ, n_res, batch_size=bsize)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})

extra_feats = random_extra_msa_feats(n_extra_seq, n_res, batch_size=bsize)
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()})

batch["msa_mask"] = torch.randint(low=0, high=2, size=(bsize, n_seq, n_res)).float()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(bsize, n_res)).float()
batch.update(data_transforms.make_atom14_masks(batch))

batch["no_recycling_iters"] = torch.tensor(0.)

batch = tensor_tree_map(lambda t: t.unsqueeze(-1).cuda(), batch)

out = model(batch)

In this code I'm basically just running inference on the model with a batch size of 2, with templates enabled. For this demo I've created dummy inputs using the code from the /openfold/tests/ directory, although I've also had the same problem with a real data pipeline.

The code above crashes with the error: RuntimeError: The size of tensor a (128) must match the size of tensor b (2) at non-singleton dimension 3, which occurs on line 189 of /openfold/openfold/model/model.py.:

t = t * (torch.sum(batch["template_mask"], dim=-1) > 0) 

This line is basically just masking out activations from templates that don't exist according to batch["template_mask"]. However, there seems to be a dimension mismatch. If I print out the dimensions, t has shape [2, 256, 256, 128] and batch["template_mask"] has shape [2]. Based on the PyTorch broadcasting rules (https://pytorch.org/docs/stable/notes/broadcasting.html), those shapes aren't compatible to multiply. If I change the code to the following:

t = t * (torch.sum(batch["template_mask"], dim=-1) > 0).view([-1,1,1,1])

Then everything works fine. Is this a real bug in the code, or have I done something wrong to trigger this error? Thanks! For reference, my environment is the following:

  • Python 3.10.4
  • PyTorch 1.12.1
  • Numpy 1.23.1
  • Cuda 11.1
  • Latest OpenFold commit (6e930a6ca4accb14aa128ae40bd3f27906796589)
@gahdritz
Copy link
Collaborator

Yes, this is a bug. I'll fix this later today.

@jproney
Copy link
Author

jproney commented Aug 11, 2022

Great, thanks! As a heads up, I think there's also a related bug in the pathway where inplace_safe=True, which triggers an error on line 156 (t_pair[..., i, :, :, :] = t) with the message:

RuntimeError: expand(torch.cuda.FloatTensor{[2, 1, 256, 256, 64]}, size=[2, 256, 256, 64]): the number of sizes provided (4) must be greater or equal to the number of dimensions in the tensor (5)

You can reproduce this error by using the above snippet and replacing the last line with:

model.eval()
with torch.no_grad():
    out = model(batch)

Thanks again! Really appreciate your quick response.

@gofreelee
Copy link

Yes, this is a bug. I'll fix this later today.

hi,May I ask when this bug is expected to be fixed?

@gahdritz
Copy link
Collaborator

gahdritz commented Aug 17, 2022

Sorry for the delay here. I just pushed e56b597, which fixes the template masking bug. I was unable to reproduce the second issue. Perhaps a product of your own changes? Could you test the newest version and verify that it works?

@jproney
Copy link
Author

jproney commented Aug 18, 2022

Hi Gustaf. Thanks so much for the change! The first error is gone, and the model works great when inplace_safe=False. However, I can still reproduce the second error, even after rolling back all changes I've made. I think I can create a workaround pretty easily, but in case you want to investigate the issue further, I can provide some more details. Here's a snippet I can run to produce the error:

import torch
import torch.nn as nn
import numpy as np

from openfold.model.model import AlphaFold
from openfold.config import model_config
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.data import data_transforms

model_name = "model_1_ptm"

conf = model_config(model_name, train=True)
conf.data.common.max_recycling_iters = 0
conf.data.train.subsample_templates = False
conf.data.train.max_msa_clusters = 1
conf.data.train.max_extra_msa = 1
conf.data.train.max_templates = 2

# copied from openfold/test/data_utils.py
def random_template_feats(n_templ, n, batch_size=None):
    b = []
    if batch_size is not None:
        b.append(batch_size)
    batch = {
        "template_mask": np.random.randint(0, 2, (*b, n_templ)),
        "template_pseudo_beta_mask": np.random.randint(0, 2, (*b, n_templ, n)),
        "template_pseudo_beta": np.random.rand(*b, n_templ, n, 3),
        "template_aatype": np.random.randint(0, 22, (*b, n_templ, n)),
        "template_all_atom_mask": np.random.randint(
            0, 2, (*b, n_templ, n, 37)
        ),
        "template_all_atom_positions": 
            np.random.rand(*b, n_templ, n, 37, 3) * 10,
        "template_torsion_angles_sin_cos": 
            np.random.rand(*b, n_templ, n, 7, 2),
        "template_alt_torsion_angles_sin_cos": 
            np.random.rand(*b, n_templ, n, 7, 2),
        "template_torsion_angles_mask": 
            np.random.rand(*b, n_templ, n, 7),
    }
    batch = {k: v.astype(np.float32) for k, v in batch.items()}
    batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
    return batch


def random_extra_msa_feats(n_extra, n, batch_size=None):
    b = []
    if batch_size is not None:
        b.append(batch_size)
    batch = {
        "extra_msa": np.random.randint(0, 22, (*b, n_extra, n)).astype(
            np.int64
        ),
        "extra_has_deletion": np.random.randint(0, 2, (*b, n_extra, n)).astype(
            np.float32
        ),
        "extra_deletion_value": np.random.rand(*b, n_extra, n).astype(
            np.float32
        ),
        "extra_msa_mask": np.random.randint(0, 2, (*b, n_extra, n)).astype(
            np.float32
        ),
    }
    return batch

n_templ = 2
n_res = 256
n_extra_seq = 1
n_seq = 1
bsize = 2

model = AlphaFold(conf).cuda()

batch = {}


tf = torch.randint(conf.model.input_embedder.tf_dim - 1, size=(bsize, n_res))
batch["target_feat"] = nn.functional.one_hot(tf, conf.model.input_embedder.tf_dim).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)

batch["residue_index"] = torch.rand((bsize, n_res))
batch["msa_feat"] = torch.rand((bsize, n_seq, n_res, conf.model.input_embedder.msa_dim))


t_feats = random_template_feats(n_templ, n_res, batch_size=bsize)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})

extra_feats = random_extra_msa_feats(n_extra_seq, n_res, batch_size=bsize)
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()})

batch["msa_mask"] = torch.randint(low=0, high=2, size=(bsize, n_seq, n_res)).float()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(bsize, n_res)).float()
batch.update(data_transforms.make_atom14_masks(batch))

batch["no_recycling_iters"] = torch.tensor(0.)

batch = tensor_tree_map(lambda t: t.unsqueeze(-1).cuda(), batch)

model.eval()
with torch.no_grad():
    out = model(batch)

The error occurs on line 156 of model.py, which is the assignment t_pair[..., i, :, :, :] = t. If I print out the dimensions of the tensors involved, I get that t_pair has shape [2, 2, 256, 256, 64] and t has shape [2, 1, 256, 256, 64]. Basically I think t is retaining an extra index along the template dimension, which is causing the dimensions to mismatch during the assignment. If I change the assignment to t_pair[..., i, :, :, :] = t.squeeze() then everything works fine. Not sure why reproducing isn't working on your end... I guess I'd just make sure that the inplace_safe=True path is running.

Thanks again for the other bug fix! Let me know if you want any other info.

@gahdritz
Copy link
Collaborator

Try 2986436.

@jproney
Copy link
Author

jproney commented Aug 19, 2022

HI again! The inplace_safe=True path works great now, but I think the change may have introduced another bug into the inplace_safe=False path. In that path, I'm getting an error on line 168 of model.py (t = self.template_pair_stack(t_pair, ...)). You can reproduce it with the code snippet I posted at the very beginning of this issue. I think the problem is that getting rid of the extra dimension breaks the concatenation operation on line 163 (t_pair = torch.cat(pair_embeds, dim=templ_dim)) in the case where inpace_safe=False. Thanks!

@gahdritz
Copy link
Collaborator

gahdritz commented Aug 19, 2022

I noticed the same thing earlier and fixed it in 349fdbd.

@jproney
Copy link
Author

jproney commented Aug 20, 2022

Everything looks good! Thanks so much for addressing these issues.

@jproney jproney closed this as completed Aug 20, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants