In [None]:
import jax
import jax.numpy as jnp

class OptimizeSeq:
    def __init__(self,featurised_examples, model_runner, fold_input,seed):
        self.example = featurised_examples[0]
        self.seed = fold_input.rng_seeds[0]
        self.rng_key = jax.random.PRNGKey(seed)
        self.logits, batch = model_runner.run_inference(self.example, self.rng_key)


    def compute_stage_one(self, iteration,temperature=1.0):
        """
        lambda = (step + 1) / iterations
        temperature = 1.0
        """
            # (1 - λ) * logits + λ * softmax(logits / temperature)
        lambda_ = (step + 1) / self.iterations_1
        self.logits = (1 - lambda_) * self.logits + lambda_ * jax.nn.softmax(self.logits / temperature)


    def compute_stage_two(self, iteration, temperature_initial=1e-2):
        """
        temperature = (1e-2 + (1 - 1e-2) * (1 - (step + 1) / iterations)^2)
        """
        for step in range(iteration):
            # temperature = (1e-2 + (1 - 1e-2) * (1 - (step + 1) / iterations)^2)
            temperature = temperature_initial + (1 - temperature_initial) * (1 - (step + 1) / self.iterations_2)**2
            self.step(lr=0.1, batch=None, seq_logits=self.logits)
            self.logits = jax.nn.softmax(self.logits / temperature)


    def get_final_sequence(self, softmax_logits):
        # argmax(softmax_logits) and (softmax_logits - softmax_logits).stop_gradient + softmax_logits
        one_hot = jax.nn.one_hot(jnp.argmax(softmax_logits), softmax_logits.shape[-1])
        final_sequence = one_hot - softmax_logits
        return final_sequence.stop_gradient + softmax_logits

    def process(self, step):
        stage_one_output = self.compute_stage_one(step)
        stage_two_output = self.compute_stage_two(step)
        final_softmax = self.get_final_sequence(stage_two_output)

        return final_softmax

In [None]:
class ModelRunner:
  """Helper class to run structure prediction stages."""

  def __init__(
      self,
      config: model.Model.Config,
      device: jax.Device,
      model_dir: pathlib.Path,
  ):
    self._model_config = config
    self._device = device
    self._model_dir = model_dir
    self.o = optax.adam(1.0)

  @functools.cached_property
  def model_params(self) -> hk.Params:
    """Loads model parameters from the model directory."""
    return params.get_model_haiku_params(model_dir=self._model_dir)

  # def debug_run(self)
  #   @hk.transform
  #   def forward_fn(batch):
  #     return model.Model(self._model_config)(batch)

  #   return functools.partial(
  #     forward_fn.apply, self.model_params
  #   )

  @functools.cached_property
  def _model(
      self
  ) -> Callable[[jnp.ndarray, features.BatchDict], model.ModelResult]:
    """Loads model parameters and returns a jitted model forward pass."""

    @hk.transform
    def forward_fn(batch):
        return model.Model(self._model_config)(batch)

    return functools.partial(
        jax.jit(forward_fn.apply, device=self._device), self.model_params # haiku apply 第一个参数是params 第二个参数是rng
    )

  @functools.cached_property
  def _designmodel(
      self
  ) -> Callable[[jnp.ndarray, features.BatchDict], model.ModelResult]:
    """Loads model parameters and returns a jitted model forward pass."""

    @hk.transform
    def forward_fn(seq_logits,batch):
        # import pprint
        # pprint.pprint(self.model_params.keys())
        return model.Af3Design(self._model_config)(seq_logits,batch)

    return functools.partial(
        jax.jit(forward_fn.apply, device=self._device),
        self.model_params
    )


  def get_inference_feature(
      self, featurised_example: features.BatchDict, rng_key=jax.random.PRNGKey(42), af3design: bool = True
  ) -> model.ModelResult:
    """Computes a forward pass of the model on a featurised example."""
    featurised_example = jax.device_put(
        jax.tree_util.tree_map(
            jnp.asarray, utils.remove_invalidly_typed_feats(featurised_example)
        ),
        self._device,
    )
    if af3design:
        batch = feat_batch.Batch.from_data_dict(featurised_example)
        seq_logits = jax.nn.one_hot(
              batch.token_features.aatype,
              residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP,
          )
        return seq_logits,featurised_example

    else:
        result = self._model(rng_key, featurised_example)

    # print(result['distogram']['bin_edges'].shape, result['distogram']['contact_probs'].shape) [63,] [256,256]

        result = jax.tree.map(np.asarray, result)
        result = jax.tree.map(
            lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x,
            result,
        )
        result = dict(result)
        identifier = self.model_params['__meta__']['__identifier__'].tobytes()
        result['__identifier__'] = identifier
        return result

  def diffusion_result(self, result):
        result = jax.tree.map(np.asarray, result)
        result = jax.tree.map(
            lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x,
            result,
        )
        result = dict(result)
        identifier = self.model_params['__meta__']['__identifier__'].tobytes()
        result['__identifier__'] = identifier
        return result

  def updata_seq(self, seq_logits, opt):
      step = opt['step']
      iteration = opt['iteration']
      t = opt['t']
      stage = opt['stage']

      def stage_1_fn(seq_logits):
          lambda_ = (step + 1) / iteration
          return (1 - lambda_) * seq_logits + lambda_ * jax.nn.softmax(seq_logits / t)

      def stage_2_fn(seq_logits):
          temperature_initial = 1e-2
          temperature = temperature_initial + (1 - temperature_initial) * (1 - (step + 1) / iteration)**2
          return jax.nn.softmax(seq_logits / temperature)

      def stage_3_fn(seq_logits):
          softmax_logits = jax.nn.softmax(seq_logits)
          final_sequence = jax.nn.one_hot(jnp.argmax(softmax_logits), softmax_logits.shape[-1]) - softmax_logits
          return jax.lax.stop_gradient(final_sequence) + softmax_logits

      seq_logits = lax.cond(
          stage == 1,
          stage_1_fn,
          lambda _: lax.cond(
              stage == 2,
              stage_2_fn,
              lambda _: stage_3_fn(seq_logits),
              seq_logits
          ),
          seq_logits
      )

      return seq_logits


  def get_model(self,example):
      # forward pass
      @remat
      def _model(seq_logits, batch,rng):
        # logits -> sequence representation
        opt = batch['design_opt']
        seq_logits = self.updata_seq(seq_logits,opt)
        result=  self._designmodel(rng, seq_logits, batch)
        plddt_loss = confidence_loss(result['predicted_lddt'], batch['is_ligand'], example)
        dis_loss = contact_loss_dgram(result['distogram']['probs_logits'],
          result['distogram']['bin_edges'],
          batch['entity_id'])

        loss = dis_loss + plddt_loss
        return loss,result
      return jax.value_and_grad(_model, argnums=0,has_aux=True)

  def update_grad(self, grad, params, state):
    updates, new_state = self.o.update(grad, state, params)
    grad = jax.tree_util.tree_map(lambda x:-x, updates)
    return new_state, grad

seq_logits,batch = model_runner.get_inference_feature(example)
state = model_runner.o.init(seq_logits)
optimizer = jax.jit(model_runner.update_grad)
grad_fn = model_runner.get_model(example)
(loss, aux), grad= grad_fn(seq_logits, batch,jax.random.PRNGKey(42))
aux = model_runner.diffusion_result(aux)
state,grad = optimizer(grad, seq_logits, state)
lr =  schedule(step)

seq_logits = jax.tree_util.tree_map(lambda x,g:x-lr*g, seq_logits, grad)

In [13]:
import jax.numpy as jnp

# 假设的 residue_chain 数组
residue_chain = jnp.array([1, 1, 1, 0, 0])

# 创建二维掩码
mask = jnp.where(residue_chain[:, None] == 0, 0, 1.0)
mask = mask * mask.swapaxes(0, 1)

print(mask.shape)


(5, 5)


In [1]:
json_path = '/home/ge/input/test.json'
import json
with open(json_path, 'r') as f:
    json_str = f.read()

# Parse the JSON string, so we can detect its format.
raw_json = json.loads(json_str)