In [1]:
import torch

from cafa_5.dataset import CAFA5Dataset, collate_data_dict
from cafa_5.model import CAFA5EmbeddingsFFN

device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
cafa_5_train_data.prots_amino_acids[0]

'MNSVTVSHAPYTITYHDDWEPVMSQLVEFYNEVASWLLRDETSPIPDKFFIQLKQPLRNKRVCVCGIDPYPKDGTGVPFESPNFTKKSIKEIASSISRLTGVIDYKGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPAARDRQFEKDRSFEIINVLLELDNKVPINWAQGFIY'

In [2]:
cafa_5_train_data = CAFA5Dataset(
    prots_amino_acids_fasta_path = "../kaggle/input/cafa-5-protein-function-prediction/Train/train_sequences.fasta",
    prots_go_codes_tsv_path = "../kaggle/input/cafa-5-protein-function-prediction/Train/train_terms.tsv",
    go_codes_info_accr_weights_txt_path = "../kaggle/input/cafa-5-protein-function-prediction/IA.txt",
    go_code_graph_obo_path = "../kaggle/input/cafa-5-protein-function-prediction/Train/go-basic.obo",
    prots_t5_embeds_npy_path = "../kaggle/input/t5embeds/train_embeds.npy",
    prots_protbert_embeds_npy_path = "../kaggle/input/protbert-embeddings-for-cafa5/train_embeddings.npy",
    prots_esm2_embeds_npy_path = "../kaggle/input/4637427/train_embeds_esm2_t36_3B_UR50D.npy"
)
cafa_5_train_data[0]

{'id': 'P20536',
 'data': {'t5_embeddings': tensor([[ 0.0495, -0.0329,  0.0325,  ..., -0.0435,  0.0965,  0.0731]]),
  'protbert_embeddings': tensor([[ 0.1554,  0.0354,  0.0897,  ..., -0.0395, -0.0736,  0.0459]]),
  'esm2_embeddings': tensor([[-0.0069,  0.0079,  0.0027,  ...,  0.0257, -0.0288, -0.0095]])},
 'go_codes': tensor([[0, 0, 0,  ..., 0, 0, 0]])}

In [3]:
cafa_5_model = CAFA5EmbeddingsFFN(
    n_go_codes = len(cafa_5_train_data.go_codes),
    num_layers = 4,
    hidden_size = 2048,
    hidden_activation = torch.nn.ReLU(),
    dropout = 0.1,
    batch_normalization = True,
    residual_connections = True
)
cafa_5_model.to(device)
display(cafa_5_model)
n_params = 0
for params in cafa_5_model.parameters():
    n_params += params.numel()
print("# of parameters:", n_params)

if True:
    cafa_5_model.fit(
        cafa_5_train_data,
        epochs=64,
        batch_size=32,
        collate_fn = collate_data_dict,
        loss_fn = torch.nn.BCELoss(),
        optimizer_type = torch.optim.Adam,
        optimizer_kwargs = {"amsgrad": True},
        validation_size = 0.1,
        verbose = True,
        checkpoint_save_folder_path = f"../kaggle/weights/{4}_ffn/"
)

CAFA5EmbeddingsFFN(
  (ffn): FFN(
    (hidden_activation): ReLU()
    (output_activation): Sigmoid()
    (ffn): ModuleDict(
      (linear_0): Linear(in_features=4608, out_features=2048, bias=True)
      (batch_norm_0): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation_0): ReLU()
      (dropout_0): Dropout(p=0.1, inplace=False)
      (linear_1): Linear(in_features=2048, out_features=2048, bias=True)
      (batch_norm_1): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation_1): ReLU()
      (dropout_1): Dropout(p=0.1, inplace=False)
      (linear_2): Linear(in_features=2048, out_features=2048, bias=True)
      (batch_norm_2): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation_2): ReLU()
      (dropout_2): Dropout(p=0.1, inplace=False)
      (linear_3): Linear(in_features=2048, out_features=2048, bias=True)
      (batch_norm_3): BatchNorm1d(2048

# of parameters: 110669040


- Epoch: 0, Mode: train, Loss: 0.004203 (0.003027),F-score : 0.205661 (0.345917): 100%|██████████| 4001/4001 [03:08<00:00, 21.24it/s]
- Epoch: 0, Mode: validation, Loss: 0.003339, F-score : 0.261915: 100%|██████████| 445/445 [03:29<00:00,  2.12it/s]  
- Epoch: 1, Mode: train, Loss: 0.003406 (0.002312),F-score : 0.235848 (0.353965):  18%|█▊        | 732/4001 [00:57<04:16, 12.75it/s]


KeyboardInterrupt: 