<a href="https://colab.research.google.com/github/Momilijaz96/AlphaFold-V1-PyTorch/blob/main/Network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Import libraries
import numpy as np
import keras
import tensorflow as tf
from keras.models import Model
# Activation and Regularization
from keras.regularizers import l2
from keras.activations import softmax
from keras import backend as K

# Keras layers
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers import Dense, Dropout, Flatten, Input, BatchNormalization, Activation, Add
#Imoprt Tensforflow and json for configuration
import json
import tensorflow as tf

# Avoid python call depth errors
import sys
sys.setrecursionlimit(5000)


In [None]:
#Function for matching shape of weights
def map_weights(keras_weight,alphafold_weight):
  try:
    assert keras_weight.shape == alphafold_weight.shape
    return True
  except AssertionError as e:
    e.args += (keras_weight.shape, alphafold_weight.shape)
    raise
    


def load_alphafold_ckpt(ckpt_path,model_weights):
  """
  Map alphafold weights to keras model
  Arguments:
    ckpt_path= tf model ckpt path,ending in .ckpt
    model= keras model
  Returns:
    List of numpy arrays of model weights
  """

  modules=['Deep2D','Deep2DExtra','position_specific_bias']
  kM_w={w.name.split(':')[0] : w for w in model_weights} #keras model weights
    
  for tf_name, tf_shape in tf.train.list_variables(str(ckpt_path)):
    tf_var = tf.train.load_variable(str(ckpt_path), tf_name) #Get value of the tf_name var from ckpt file
    
    main_module,*others=tf_name.split('/') #split name of var on '/' and get remaining address in others pointer
    if main_module in modules:  
      
      if main_module in ['Deep2D','Deep2DExtra']:

        if others[0].startswith('conv'):
          #Get bias and kernel of conv layer of main_module
          if others[1]=='b':
            map_name=main_module+'/'+others[0]+'/conv2d/bias'
          elif others[1]=='w':
            map_name=main_module+'/'+others[0]+'/conv2d/kernel'

        elif others[0].startswith('output_reshape'):
          if others[1]=='w':
            map_name=main_module+'/output_reshape/conv2d/kernel'
          elif others[1]==others[0]: #batchnorm
            if others[2]=='beta':
              map_name=main_module+'/output_reshape/batch_norm/beta'
            if others[2]=='moving_mean':
              map_name=main_module+'/output_reshape/batch_norm/moving_mean'
            if others[2]=='moving_variance':
              map_name=main_module+'/output_reshape/batch_norm/moving_variance'
              
        elif others[0].startswith('res'):
          #Get weights of a single residual block setup
          res_others=others[0].split('_')
          
          if len(res_others)==1:
            #First Batchnorm layer of residual block
            if others[1]=='beta':
              map_name=main_module+'/'+others[0]+'/batch_norm/beta'
          else:
            
            if res_others[1]=='1x1': 
              #set weights for c_up layer of this block
              if others[1]=='b': #set bias
                map_name=main_module+'/'+res_others[0]+'/c_up/conv2d/bias'
              elif others[1]=='w':#set weights
                map_name=main_module+'/'+res_others[0]+'/c_up/conv2d/kernel'

            elif res_others[1]=='1x1h':
              #set weights for c_down layer of this block
              if others[1]=='b': #set bias
                map_name=main_module+'/'+res_others[0]+'/c_down/conv2d/bias'
              elif others[1]=='w':#set weights
                map_name=main_module+'/'+res_others[0]+'/c_down/conv2d/kernel'
              elif others[1]==others[0]: #batchnorm of c_down
                if others[2]=='beta':
                  map_name=main_module+'/'+res_others[0]+'/c_down/batch_norm/beta'
                if others[2]=='moving_mean':
                  map_name=main_module+'/'+res_others[0]+'/c_down/batch_norm/moving_mean'
                if others[2]=='moving_variance':
                  map_name=main_module+'/'+res_others[0]+'/c_down/batch_norm/moving_variance'

            elif res_others[1]=='3x3h':
              #set weights for c_dialayer of this block
              if others[1]=='w': #set bias
                map_name=main_module+'/'+res_others[0]+'/c_dia/conv2d/kernel'
              
              elif others[1]==others[0]:#batchnorm of 3x3h
                if others[2]=='beta':
                  map_name=main_module+'/'+res_others[0]+'/c_dia/batch_norm/beta'
                if others[2]=='moving_mean':
                  map_name=main_module+'/'+res_others[0]+'/c_dia/batch_norm/moving_mean'
                if others[2]=='moving_variance':
                  map_name=main_module+'/'+res_others[0]+'/c_dia/batch_norm/moving_variance'
      
      elif main_module=='position_specific_bias':
        map_name=main_module+'/b'
      
      kM_w[map_name]=tf_var if map_weights(kM_w[map_name],tf_var) else None   

      
  return list(kM_w.values())




In [None]:
 class AlphaFoldConvLayer(tf.keras.layers.Layer):
  """Creates a convolution layer followed by a batchnorm
  and elu layer, which can be turned off by setting corresponding bool to false."""
  def __init__(self, num_filters,
                    kernel_size,
                    non_linearity=True,
                    batch_norm=False,
                    atrou_rate=1,
                    name=None):
    super(AlphaFoldConvLayer, self).__init__()
    if name is not None:
      self._name=name

    if batch_norm: #Check BN layer and decide bias addition
      use_bias=False
    else:
      use_bias=True

    padding='same'

    self.batch_norm = BatchNormalization(scale=False, momentum=0.999, fused=True, name='batch_norm') if batch_norm else None
    self.elu = Activation('elu') if non_linearity else None
    self.conv = Conv2D(num_filters,kernel_size,strides=1,padding=padding,
      data_format='channels_last',kernel_initializer='random_normal',
      kernel_regularizer=l2(1e-4),dilation_rate=atrou_rate,use_bias=use_bias,name="conv2d")

  def call(self, x):
    x = self.conv(x)

    if self.batch_norm:
      x = self.batch_norm(x)

    if self.elu:
      x = self.elu(x)

    return x


In [None]:
class AlphaFoldResBlock(tf.keras.layers.Layer):
  def __init__(self,
              num_filters,
              kernel_size,
              batch_norm=False,
              atrou_rate=1,
              dropout_keep_prob=1.0,
              name=None):
    """ Make a residual block
    Arguments:
        num_filters (int): Conv2D number of filters, same as channels of input/output of block
        kernel_size (int): Conv2D square kernel dimensions
        batch_norm (bool): whether to include batch normalization
        atrou_rate (int): dilation rate for the main(3x3 dilated) conv layer of block
    Return:
        A residual block output tensor
    """
    super(AlphaFoldResBlock, self).__init__()
    if name is not None:
      self._name = name

    self.batch_norm = BatchNormalization(scale=False, momentum=0.999, fused=True, name='batch_norm') if batch_norm else None
    self.elu = Activation('elu')

    #Downsize to half using a 1x1 conv
    self.conv_down = AlphaFoldConvLayer(num_filters//2,1,non_linearity=True,batch_norm=True,name='c_down')

    #3x3 dilated convolution layer
    self.conv_dilated = AlphaFoldConvLayer(num_filters//2,kernel_size,non_linearity=True,batch_norm=True,name='c_dia')

    #Upsize to half using a 1x1 conv
    #Note: We use TransposeConv2D for upsampling in Keras
    #x=Conv2DTranspose(num_filters,1,padding='same')(x)
    self.conv_up = AlphaFoldConvLayer(num_filters, 1, False,name='c_up')

    #Dropout
    self.dropout = Dropout(1-dropout_keep_prob) if dropout_keep_prob<1.0 else None

    #Skip connection
    self.skip_connect = Add()

  def call(self, input_node):
    x = input_node

    if self.batch_norm:
      self.batch_norm(x)

    x = self.elu(x)
    x = self.conv_down(x)
    x = self.conv_dilated(x)
    x = self.conv_up(x)

    if self.dropout:
      x = self.dropout(x)

    x = self.skip_connect([x,input_node])
    return x

In [None]:
class AlphaFoldResBlockStack(tf.keras.layers.Layer):
  def __init__(self,
    num_features=40,
    num_predictions=1,
    num_channels=32,
    num_blocks=2,
    filter_size=3,
    batch_norm=False,
    atrou_rates=None,
    #channel_multiplier=0,
    #divide_channels_by=2,
    dropout_keep_prob=1.0,
    name=None):
    """
      Make a stack of residual blocks with a conv layer at start and end.
    Arguments:
      input_node (tensor): from previous layer or input
      num_features (int): number of input channels
      num_predictions (int):number of channels of final output layer
      num_channels (int):Input and output number of channels of 
                          a single residual block
      num_blocks (int):number of residual blocks to stack + 2 conv layers
      filter_size (int):size of filter for main conv layer of each residual block
      batch_norm (bool): wether to use batch norm in a block or not
      atrou_rates (int): dilation rates for each subsequent residual block
      dropout_keep_prob (double)= 1 - drop_rate for an optional dropout layer at end of each block
      resize_features_with_1x1 (bool): Make start and end conv layer 1x1 or not
    Returns:
      Output of num_blocks stacked residual blocks
    """
    super(AlphaFoldResBlockStack, self).__init__()

    if name is not None:
      self._name=name

    if atrou_rates is None: atrou_rates = [1]
    non_linearity=True
    num_filters=num_channels

    #Loop over num blocks to stack
    self.blocks = []
    for i_block in range(0,num_blocks):
      #Get the current block's dilation rate
      curr_atrou_rate=atrou_rates[i_block % len(atrou_rates)]
      
      #Add a conv layer for first and last block
      is_first_block = (i_block==0)
      is_last_block = (i_block==num_blocks-1)
      if is_first_block or is_last_block:
        #For last block set the output channel size
        num_filters=num_predictions if is_last_block else num_channels
        self.blocks.append(AlphaFoldConvLayer(num_filters,filter_size,non_linearity=non_linearity,atrou_rate=curr_atrou_rate,name=f'conv{i_block+1}'))
      #Add middle residual blocks
      else:
        self.blocks.append(AlphaFoldResBlock(num_filters,filter_size,batch_norm=batch_norm,
                        atrou_rate=curr_atrou_rate,
                        dropout_keep_prob=dropout_keep_prob, name=f'res{i_block+1}'))
    
  def call(self, x):
    for block in self.blocks:
      x = block(x)

    return x


In [None]:
class PositionBias(tf.keras.layers.Layer):
  def __init__(self, bias_size):
    super(PositionBias, self).__init__()
    self.bias_size = bias_size
    self._name='position_specific_bias'

  def build(self, input_shape):
    main_input_shape = input_shape[0]
    self.crop_size_x = main_input_shape[1]
    self.crop_size_y = main_input_shape[2]
    self.num_bins = main_input_shape[3]

    b_init = tf.zeros_initializer()
    self.b = tf.Variable(initial_value=b_init(shape=(self.bias_size, self.num_bins), dtype=tf.float32), trainable=True,name='b')


  def call(self, inputs):
    x, crop_x, crop_y = inputs

    # These are required because all inputs are feed in as floats (at least with build())
    crop_x = tf.cast(crop_x, tf.int32)
    crop_y = tf.cast(crop_x, tf.int32)
  
    # First pad the biases with a copy of the final value to the maximum length.
    max_off_diag = tf.reduce_max(tf.maximum(
      tf.abs(crop_x[:, 1] - crop_y[:, 0]), 
      tf.abs(crop_y[:, 1] - crop_x[:, 0])))
    padded_bias_size = tf.maximum(self.bias_size, max_off_diag)
    biases = tf.concat([self.b, tf.tile(self.b[-1:, :], [padded_bias_size - self.bias_size, 1])], axis=0)
    # Now prepend a mirror image (excluding 0th elt) for below-diagonal.
    biases = tf.concat([tf.reverse(biases[1:, :], axis=[0]), biases], axis=0)

    # Which diagonal of the full matrix each crop starts on (top left):
    start_diag = crop_x[:, 0:1] - crop_y[:, 0:1]  # B x 1

    # Relative offset of each row within a crop:
    # (off-diagonal decreases as y increases)
    increment = tf.expand_dims(-tf.range(0, self.crop_size_y), 0)  # 1 x crop_size_y

    # Index of diagonal of first element of each row, flattened.
    row_offsets = tf.reshape(start_diag + increment, [-1])  # B*crop_size_y

    # Make it relative to the start of the biases array. (0-th diagonal is in
    # the middle at position padded_bias_size - 1)
    row_offsets += padded_bias_size - 1

    # Map_fn to build the individual rows.
    # B*cropsizey x cropsizex x num_bins
    cropped_biases = tf.map_fn(lambda i: biases[i:i+self.crop_size_x, :], elems=row_offsets, fn_output_signature=tf.float32)
    cropped_biases = tf.reshape(cropped_biases, [-1, self.crop_size_y, self.crop_size_x, self.num_bins])

    return x + cropped_biases


In [None]:
class AlphaFoldNetwork(tf.keras.Model):
  def __init__(self, config):
    """
      Go from input features to the distance predictions.
      Arguments:
        input_shape(3d input shape tuple): Get input shape to initialize placeholders
        config(python nested distionary): Network architecture configuration file
      Output:
        Model
                              
    """
    super(AlphaFoldNetwork, self).__init__()

    #Get model's configuration file
    network_2d_deep = config['network_2d_deep']
    output_dimension = config['num_bins']
    num_features = 1878

    #Get position specific bias size
    self.position_specific_bias_size=config['position_specific_bias_size']

    ##### LET'S START ASSEMBLING THE MODEL #####
    #220 Residual blocks with dilated convolution, with dilation rates of 4 subsequent 
    #blocks as [1,2,4,8], we are calling these four stacked blocks as 4-block-group.
    #Making 7 4-group-blocks at start of network with channels size of a residual blocks as 256
    self.Deep2DExtra = AlphaFoldResBlockStack(
                                      num_features=num_features,
                                      num_predictions=2*network_2d_deep['num_filters'],
                                      num_channels= 2*network_2d_deep['num_filters'],
                                      num_blocks=network_2d_deep['extra_blocks'] * network_2d_deep['num_layers_per_block'],
                                      filter_size=3,
                                      batch_norm=network_2d_deep['use_batch_norm'],
                                      atrou_rates = [1,2,4,8],
                                      dropout_keep_prob=1.0,
                                      name='Deep2DExtra'
    ) if network_2d_deep['extra_blocks'] else None
        
    #doble input feature size for next half of the network
    num_features = 2 * network_2d_deep['num_filters']

    #Making 48 4-group-blocks at start of network with channels size of a residual blocks as 128
    self.Deep2D = AlphaFoldResBlockStack(
                                      num_features=num_features,
                                      num_predictions=network_2d_deep['num_filters'] if config['reshape_layer'] else output_dimension,
                                      num_channels= network_2d_deep['num_filters'],
                                      num_blocks=network_2d_deep['num_blocks'] * network_2d_deep['num_layers_per_block'],
                                      filter_size=3,
                                      batch_norm=network_2d_deep['use_batch_norm'],
                                      atrou_rates = [1,2,4,8],
                                      dropout_keep_prob=1.0,
                                      name='Deep2D'
                                      )
    #Add a 1x1 conv layer to resize the output contact_pre_logits
    #if config.reshape_layer was true then the contact_pre_logits output
    #is network_2d_deep.num_filters size sized so change it to num_bins
    self.reshape = AlphaFoldConvLayer(
          num_filters=output_dimension,
          kernel_size=1,
          non_linearity=False,
          batch_norm=network_2d_deep['use_batch_norm'],
          name='Deep2D/output_reshape'
    ) if config['reshape_layer'] else None

    #Position Specific Biases
    if self.position_specific_bias_size:
      #self.position_specific_bias=K.variable(np.zeros((config['position_specific_bias_size'],output_dimension),dtype='float32'),name='position_specific_bias')
      self.position_specific_bias=PositionBias(config['position_specific_bias_size'])
  
    
  def call(self, input):
    x,crop_x,crop_y=input

    if self.Deep2DExtra:
      x = self.Deep2DExtra(x)

    x = self.Deep2D(x)

    if self.reshape:
      x = self.reshape(x)
    
    if self.position_specific_bias_size:
      biases = self.position_specific_bias([x,crop_x,crop_y])
      x += biases # BxDxLxL
    return x

  def load_weights(self,tf_ckpt):
    new_w=load_alphafold_ckpt(tf_ckpt,self.weights)
    for i in range(len(new_w)):
      if type(new_w[i]).__module__ !='numpy':
        new_w[i]=new_w[i].numpy()

    self.set_weights(new_w)


    