In [1]:
from transformers import AlbertConfig, AlbertModel
from transformers.modeling_albert import AlbertMLMHead
import torch
from torch import nn
import pytorch_lightning as pl

## Defining custom Albert Config

Refer to [this](https://github.com/huggingface/transformers/blob/48cc224703a8dd8d03d2721c8651fea8704d994b/src/transformers/models/albert/configuration_albert.py#L33) link to understand the meaning of the parameters

In [8]:
# Hyperparameters
vocab_size = 10
embedding_size = 16
hidden_size = 768
num_attention_heads = 12
intermediate_size = 3072


In [9]:
# Defining Custom Albert Model config
custom_config = AlbertConfig(
    vocab_size=vocab_size, # A, C, T, G, U, UNK, MASK, PAD, CLS, SEP
    embedding_size=embedding_size, #this will be scaled to 32 and 64 for ablation experiments
    hidden_size=hidden_size,
    num_attention_heads=num_attention_heads,
    intermediate_size=intermediate_size,
)

In [10]:
custom_model = AlbertModel(custom_config) # custom model

In [11]:
#downloading from huggingface the pretrained model
pretrained_model = AlbertModel.from_pretrained('albert-base-v2', return_dict=True)
# pretrained_model.save_pretrained('./albert_base_v2')

In [14]:
# masked language modelling head
# this MLM head will be put on top our custom config albert
# and trained on miRNA and mRNA sequences separately.
mlm_head = AlbertMLMHead(custom_config) 

In [5]:
pretrained_model

AlbertModel(
  (embeddings): AlbertEmbeddings(
    (word_embeddings): Embedding(30000, 128, padding_idx=0)
    (position_embeddings): Embedding(512, 128)
    (token_type_embeddings): Embedding(2, 128)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0, inplace=False)
  )
  (encoder): AlbertTransformer(
    (embedding_hidden_mapping_in): Linear(in_features=128, out_features=768, bias=True)
    (albert_layer_groups): ModuleList(
      (0): AlbertLayerGroup(
        (albert_layers): ModuleList(
          (0): AlbertLayer(
            (full_layer_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (attention): AlbertAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (attention_dropout): Dropout(p=0, inplace=False)
      

In [10]:
custom_model

AlbertModel(
  (embeddings): AlbertEmbeddings(
    (word_embeddings): Embedding(10, 16, padding_idx=0)
    (position_embeddings): Embedding(512, 16)
    (token_type_embeddings): Embedding(2, 16)
    (LayerNorm): LayerNorm((16,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0, inplace=False)
  )
  (encoder): AlbertTransformer(
    (embedding_hidden_mapping_in): Linear(in_features=16, out_features=768, bias=True)
    (albert_layer_groups): ModuleList(
      (0): AlbertLayerGroup(
        (albert_layers): ModuleList(
          (0): AlbertLayer(
            (full_layer_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (attention): AlbertAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (attention_dropout): Dropout(p=0, inplace=False)
              

In [24]:
# Since the pretrained model has a vocabulary of 30k
# and embedding size of 128, we need to downscale it
# to our requirements
# Thus, here the embeddings are not pretrained but
# only the main model is. Objective of this is to 
# leverage the latent space of pretrained model.

pretrained_model.resize_token_embeddings(10)
pretrained_model.set_input_embeddings(nn.Embedding(10,16, padding_idx=0))
pretrained_model

AlbertModel(
  (embeddings): AlbertEmbeddings(
    (word_embeddings): Embedding(10, 16, padding_idx=0)
    (position_embeddings): Embedding(512, 128)
    (token_type_embeddings): Embedding(2, 128)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0, inplace=False)
  )
  (encoder): AlbertTransformer(
    (embedding_hidden_mapping_in): Linear(in_features=128, out_features=768, bias=True)
    (albert_layer_groups): ModuleList(
      (0): AlbertLayerGroup(
        (albert_layers): ModuleList(
          (0): AlbertLayer(
            (full_layer_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (attention): AlbertAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (attention_dropout): Dropout(p=0, inplace=False)
          

In [13]:
mlm_head

AlbertMLMHead(
  (LayerNorm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (dense): Linear(in_features=768, out_features=16, bias=True)
  (decoder): Linear(in_features=16, out_features=10, bias=True)
)