In [1]:
import networkx
import obonet
import numpy as np
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer
import torch
from tqdm import tqdm

from cafa_5.dataload import read_go_code_info_accr_weight_from_txt_as_df

In [2]:
ia_w_df = read_go_code_info_accr_weight_from_txt_as_df("../kaggle/input/cafa-5-protein-function-prediction/IA.txt")
go_codes = sorted(ia_w_df["go_code"].tolist())
go_code_idx_map = {go_code:i for i, go_code in enumerate(go_codes)}
go_code_graph = obonet.read_obo("../kaggle/input/cafa-5-protein-function-prediction/Train/go-basic.obo")
go_codes_incl_childs_mask = torch.from_numpy(networkx.adjacency_matrix(go_code_graph, go_codes, bool).toarray()).to("cuda")
go_codes_incl_childs_mask[list(range(len(go_codes))), list(range(len(go_codes)))] = True
go_codes_reverse_top_sort = list(networkx.topological_sort(go_code_graph))
go_codes_reverse_top_sort.reverse()

In [3]:
from cafa_5.model import FFN

ffn = FFN(
    len(go_codes),
    len(go_codes),
    torch.nn.Sigmoid(),
    num_layers = 1,
    hidden_size = 2048,
    hidden_activation = torch.nn.ReLU(),
    dropout = 0.1,
    residual_connections = True,
    batch_normalization = True
).to("cuda")
ffn.train()
optimizer = torch.optim.SGD(ffn.parameters(), lr=1)
loss = torch.nn.BCELoss()

In [7]:
for step in (pbar := tqdm(range(1000))):

    go_codes_probs = torch.rand((32, len(go_codes))).to("cuda")
    go_codes_new_probs = torch.zeros_like(go_codes_probs)
    for go_code in go_codes_reverse_top_sort:
        go_code_idx = go_code_idx_map[go_code]
        go_codes_new_probs[:, go_code_idx] = go_codes_probs[:, go_codes_incl_childs_mask[go_code_idx]].max(dim=1).values

    optimizer.zero_grad()
    go_codes_infer_probs = ffn(go_codes_probs)
    step_loss = loss(go_codes_infer_probs, go_codes_new_probs)
    step_loss.backward()
    pbar.set_description(f"Loss: {step_loss}")

Loss: 0.7162948846817017:   2%|▏         | 23/1000 [01:59<1:24:44,  5.20s/it]


KeyboardInterrupt: 

In [8]:
go_codes_probs

tensor([[0.3986, 0.9084, 0.4510,  ..., 0.9754, 0.9229, 0.9402],
        [0.2172, 0.8448, 0.6664,  ..., 0.9633, 0.9432, 0.9432],
        [0.2763, 0.9230, 0.9553,  ..., 0.9708, 0.9496, 0.9633],
        ...,
        [0.9279, 0.9889, 0.9889,  ..., 0.9959, 0.9889, 0.9889],
        [0.3017, 0.7298, 0.7817,  ..., 0.9862, 0.9862, 0.9862],
        [0.2457, 0.9363, 0.9363,  ..., 0.9900, 0.9363, 0.9522]],
       device='cuda:0')

In [9]:
go_codes_new_probs

tensor([[0.3986, 0.9084, 0.4510,  ..., 0.9754, 0.9229, 0.9402],
        [0.2172, 0.8448, 0.6664,  ..., 0.9633, 0.9432, 0.9432],
        [0.2763, 0.9230, 0.9553,  ..., 0.9708, 0.9496, 0.9633],
        ...,
        [0.9279, 0.9889, 0.9889,  ..., 0.9959, 0.9889, 0.9889],
        [0.3017, 0.7298, 0.7817,  ..., 0.9862, 0.9862, 0.9862],
        [0.2457, 0.9363, 0.9363,  ..., 0.9900, 0.9363, 0.9522]],
       device='cuda:0')