In [1]:
import torch
import pandas as pd
from tqdm import tqdm
import numpy as np
import wandb
import os

In [2]:
wandb.require("service")

In [3]:
!python3 -m wandb login eb7b1964fb84cd81de96b2a273ecf2bb6254aeac

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/alexeyorlov53/.netrc


### Upload config

In [4]:
import yaml

config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
print(config)

{'batch_size': 16, 'warm_up': 10, 'epochs': 100, 'load_graph_model': 'pretrained_gcn', 'save_every_n_epochs': 5, 'fp16_precision': False, 'init_lr': 5e-05, 'weight_decay': '1e-5', 'gpu': 'cuda:2', 'pretrained_roberta_name': 'molberto_ecfp0_2M', 'roberta_model': {'vocab_size': 30522, 'max_position_embeddings': 514, 'hidden_size': 768, 'num_attention_heads': 12, 'num_hidden_layers': 6, 'type_vocab_size': 1}, 'graph_model_type': 'gcn', 'graph_model': {'num_layer': 5, 'emb_dim': 300, 'feat_dim': 512, 'drop_ratio': 0, 'pool': 'mean'}, 'graph_aug': 'node', 'dataset': {'num_workers': 12, 'valid_size': 0.1}, 'ntxent_loss': {'temperature': 0.1, 'use_cosine_similarity': True}, 'loss_params': {'alpha': 1.0, 'beta': 1.0, 'gamma': 1.0}}


In [5]:
print('batch_size =', config['batch_size'])

batch_size = 16


In [6]:
print('running on device:', config['gpu'])
device = torch.device(config['gpu']) if torch.cuda.is_available() else torch.device('cpu')

running on device: cuda:2


In [7]:
def _save_config_file(config, log_dir):
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    with open(os.path.join(log_dir, 'config.yml'), 'w') as outfile:
        yaml.dump(config, outfile, default_flow_style=False, sort_keys=False)

### Upload and Split Dataset

In [8]:
dataframe = pd.read_csv("data_10k.csv")

In [9]:
dataframe = dataframe.drop(columns=['ecfp2', 'ecfp3', 'Molecular Weight', 'Bioactivities', 'AlogP', 'Polar Surface Area', 'CX Acidic pKa', 'CX Basic pKa'])

In [10]:
dataframe

Unnamed: 0,Smiles,ecfp1
0,COc1cc(C2(C)CCCc3nc(SCc4ncccn4)n(-c4ccc(F)cc4)...,"['2246728737', '864674487', '3217380708', '321..."
1,COC(=O)c1sc(NC(=O)C2c3ccccc3Oc3ccccc32)c(C(=O)...,"['2246728737', '864674487', '2246699815', '864..."
2,CC[C@H]1OC(=O)C[C@@H](O)[C@H](C)[C@@H](O[C@@H]...,"['2246728737', '2245384272', '2976033787', '31..."
3,Cc1cccc(-n2cc(C(=O)N3CCC[C@@H]([n+]4cc[nH]c4)C...,"['2246728737', '3217380708', '3218693969', '32..."
4,CCOC(=O)[C@H](C1CC1)N1C(=O)[C@@H](CC(=O)O)C[C@...,"['2246728737', '2245384272', '864674487', '224..."
...,...,...
9995,CCN1CCN(CC(O)c2ccc(Br)cc2)CC1,"['2246728737', '2245384272', '2092489639', '29..."
9996,O=C(O)CNC(=O)CNC(=O)CNC(=O)CSC(=O)c1ccccc1,"['864942730', '2246699815', '864662311', '2245..."
9997,O=C(N[C@@]12CCC[C@@](C#Cc3ccccn3)(CC1)C2)c1ccc...,"['864942730', '2246699815', '847961216', '2976..."
9998,CCOc1ccccc1-c1cc(C(=O)N2CCOCC2)c2ccccc2n1,"['2246728737', '2245384272', '864674487', '321..."


In [11]:
# this because pandas thinks columns with arrays are strings
def preprocess_data_dataset(df, column):
    for row in tqdm(range(len(df))):
        str_ints = eval(df.iloc[row][column])
        str_fingerprint = ' '.join(str_ints)
        df.at[row, column] = str_fingerprint

In [12]:
preprocess_data_dataset(dataframe, 'ecfp1')

100%|█████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:00<00:00, 10004.93it/s]


### Create Molecule Dataset
##### It will generate torch_geometric.data.Data objects for both bert and GIN/GCN models.

In [13]:
from rdkit import Chem

ATOM_LIST = list(range(1,119))
CHIRALITY_LIST = [
    Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
    Chem.rdchem.ChiralType.CHI_OTHER
]
BOND_LIST = [
    Chem.rdchem.BondType.SINGLE, 
    Chem.rdchem.BondType.DOUBLE, 
    Chem.rdchem.BondType.TRIPLE, 
    Chem.rdchem.BondType.AROMATIC
]
BONDDIR_LIST = [
    Chem.rdchem.BondDir.NONE,
    Chem.rdchem.BondDir.ENDUPRIGHT,
    Chem.rdchem.BondDir.ENDDOWNRIGHT
]

In [14]:
import random
import math
from copy import deepcopy
from torch_geometric.data import Data, Dataset

class MoleculeDataset(Dataset):
    def __init__(self, dataset: pd.DataFrame, tokenizer, node_mask_percent=0.25, edge_mask_percent=0.25):
        super(Dataset, self).__init__()
        self.dataset = dataset
        self.node_mask_percent = node_mask_percent
        self.edge_mask_percent = edge_mask_percent

        self.tokenizer = tokenizer
        self.tokenizer.model_max_len = 512

    def get_graph_from_smiles(self, smiles):
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return torch.tensor([[], []], dtype=torch.long), \
                    torch.tensor(np.array([]), dtype=torch.long), \
                    torch.tensor(np.array([]), dtype=torch.long), \
                    0
    
        N = mol.GetNumAtoms()
        M = mol.GetNumBonds()
    
        type_idx = []
        chirality_idx = []
        atomic_number = []
        
        for atom in mol.GetAtoms():
            type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
            chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
            atomic_number.append(atom.GetAtomicNum())
        
        x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
        x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
        node_feat = torch.cat([x1, x2], dim=-1)
    
        row, col, edge_feat = [], [], []
        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            row += [start, end]
            col += [end, start]
            
            edge_feat.append([
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ])
            edge_feat.append([
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ])
    
        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.long)
        num_nodes = N
        num_edges = M
        return node_feat, edge_index, edge_attr, num_nodes, num_edges

    def get_augmented_graph_copy(self, node_feat, edge_index, edge_attr, N, M):
        num_mask_nodes = max([1, math.floor(self.node_mask_percent * N)])
        num_mask_edges = max([0, math.floor(self.edge_mask_percent * M)])
        
        mask_nodes = random.sample(list(range(N)), num_mask_nodes)
        mask_edges_single = random.sample(list(range(M)), num_mask_edges)
        mask_edges = [2*i for i in mask_edges_single] + [2*i+1 for i in mask_edges_single]

        node_feat_new = deepcopy(node_feat)
        for atom_idx in mask_nodes:
            node_feat_new[atom_idx, :] = torch.tensor([len(ATOM_LIST), 0])
        edge_index_new = torch.zeros((2, 2*(M - num_mask_edges)), dtype=torch.long)
        edge_attr_new = torch.zeros((2*(M - num_mask_edges), 2), dtype=torch.long)
        count = 0
        for bond_idx in range(2*M):
            if bond_idx not in mask_edges:
                edge_index_new[:, count] = edge_index[:, bond_idx]
                edge_attr_new[count, :] = edge_attr[bond_idx, :]
                count += 1
        return Data(x=node_feat_new, edge_index=edge_index_new, edge_attr=edge_attr_new)

    def tokenize(self, item):
        return self.tokenizer(item, truncation=True, max_length=512, padding='max_length')

    def mlm(self, tensor):
        rand = torch.rand(tensor.shape)
        # mask random 15% where token is not 0 <s>, 1 <pad>, or 2 <s/>
        mask_arr = (rand < .15) * (tensor != 0) * (tensor != 1) * (tensor != 2)
        selection = torch.flatten(mask_arr.nonzero()).tolist()
        # mask tensor, token == 4 is our mask token
        tensor[selection] = 4
        return tensor

    def apply_mlm(self, sample):
        labels = torch.tensor(sample.input_ids)
        attention_mask = torch.tensor(sample.attention_mask)
        input_ids = self.mlm(labels.detach().clone())
        return Data(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

    def __getitem__(self, index):
        node_feat, edge_index, edge_attr, num_nodes, num_edges = self.get_graph_from_smiles(self.dataset['Smiles'][index])

        data_i = self.get_augmented_graph_copy(node_feat, edge_index, edge_attr, num_nodes, num_edges)
        data_j = self.get_augmented_graph_copy(node_feat, edge_index, edge_attr, num_nodes, num_edges)

        ecfp = self.dataset['ecfp1'][index]
        data_for_bert = self.apply_mlm(self.tokenize(ecfp))
        return data_for_bert, data_i, data_j

    def __len__(self):
        return len(self.dataset)

    def get(self):
        pass
    def len(self):
        pass

In [15]:
from transformers import AutoTokenizer

model_name_bert = 'molberto_ecfp0_2M'
tokenizer = AutoTokenizer.from_pretrained(model_name_bert)
dataset = MoleculeDataset(dataframe, tokenizer)

In [16]:
from torch_geometric.loader import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_train = len(dataset)
indices = list(range(num_train))
np.random.shuffle(indices)

split = int(np.floor(config['dataset']['valid_size'] * num_train))
train_idx, valid_idx = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_dataloader = DataLoader(
    dataset, batch_size=config['batch_size'], sampler=train_sampler,
    num_workers=config['dataset']['num_workers'], drop_last=True
)

eval_dataloader = DataLoader(
    dataset, batch_size=config['batch_size'], sampler=valid_sampler,
    num_workers=config['dataset']['num_workers'], drop_last=True
)

### Create Transformer Model

In [17]:
import torch
import numpy as np


class NTXentLoss(torch.nn.Module):

    def __init__(self, device, batch_size, temperature, use_cosine_similarity):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.device = device
        self.softmax = torch.nn.Softmax(dim=-1)
        self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
        self.similarity_function = self._get_similarity_function(use_cosine_similarity)
        self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")

    def _get_similarity_function(self, use_cosine_similarity):
        if use_cosine_similarity:
            self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
            return self._cosine_simililarity
        else:
            return self._dot_simililarity

    def _get_correlated_mask(self):
        diag = np.eye(2 * self.batch_size)
        l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
        l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
        mask = torch.from_numpy((diag + l1 + l2))
        mask = (1 - mask).type(torch.bool)
        return mask.to(self.device)

    @staticmethod
    def _dot_simililarity(x, y):
        v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
        # x shape: (N, 1, C)
        # y shape: (1, C, 2N)
        # v shape: (N, 2N)
        return v

    def _cosine_simililarity(self, x, y):
        # x shape: (N, 1, C)
        # y shape: (1, 2N, C)
        # v shape: (N, 2N)
        v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def forward(self, zis, zjs):
        representations = torch.cat([zjs, zis], dim=0)

        similarity_matrix = self.similarity_function(representations, representations)

        # filter out the scores from the positive samples
        l_pos = torch.diag(similarity_matrix, self.batch_size)
        r_pos = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
        negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)

        logits = torch.cat((positives, negatives), dim=1)
        logits = logits.abs() + 0.0001
        logits = torch.log(logits)
        logits /= self.temperature
        
        labels = torch.zeros(2 * self.batch_size).to(self.device).long()
        loss = self.criterion(logits, labels)

        return loss / (2 * self.batch_size)

In [18]:
from transformers import RobertaForMaskedLM
from transformers import RobertaConfig

if config['graph_model_type'] == 'gin':
    from MolCLR.models.ginet_molclr import GINet as GraphModel
elif config['graph_model_type'] == 'gcn':
    from MolCLR.models.gcn_molclr import GCN as GraphModel
else:
    raise ValueError('GNN model is not defined in config.')

class MolecularBertGraph(torch.nn.Module):
    def __init__(self):
        super(MolecularBertGraph, self).__init__()
        self.batch_size = config['batch_size']

        roberta_config = RobertaConfig(**config['roberta_model'])
        self.bert = RobertaForMaskedLM(roberta_config).from_pretrained(config['pretrained_roberta_name'], 
                                                                       config=roberta_config)
        self.graph_model = GraphModel(**config['graph_model'])
        # self.graph_model = self._load_graph_pretrained_weights(self.graph_model)

        self.out_graph_linear = torch.nn.Linear(2 * config['graph_model']['feat_dim'], 
                                                config['roberta_model']['hidden_size'], bias=True)
        # contrastive loss for MolCLR
        self.nt_xent_criterion = NTXentLoss(device, self.batch_size, **config['ntxent_loss'])
        # cosine distance as loss between models
        self.cosine_sim = torch.nn.CosineSimilarity(dim=-1)

    def forward(self, bert_batch, graph_batch1, graph_batch2):
        bert_output = self.bert(input_ids=bert_batch['input_ids'].view(self.batch_size, -1), 
                                 attention_mask=bert_batch['attention_mask'].view(self.batch_size, -1),
                                 labels=bert_batch['labels'].view(self.batch_size, -1), output_hidden_states=True)
        bert_loss = bert_output.loss
        bert_emb = bert_output.hidden_states[0][:, 0, :] # take emb for CLS token

        graph_loss, hidden_states_1, hidden_states_2 = self.graph_step(graph_batch1, graph_batch2)
        graph_emb = self.out_graph_linear(torch.cat((hidden_states_1, hidden_states_2), dim=-1))

        # bimodal_loss = ((1 - self.cosine_sim(bert_emb, graph_emb))**2).mean()
        bimodal_loss = self.nt_xent_criterion(bert_emb, graph_emb)
        return bert_loss, graph_loss, bimodal_loss

    def graph_step(self, xis, xjs):
        # get the representations and the projections
        ris, zis = self.graph_model(xis)  # [N,C]
    
        # get the representations and the projections
        rjs, zjs = self.graph_model(xjs)  # [N,C]
    
        # normalize projection feature vectors
        zis = torch.nn.functional.normalize(zis, dim=1)
        zjs = torch.nn.functional.normalize(zjs, dim=1)

        loss = self.nt_xent_criterion(zis, zjs)
        return loss, ris, rjs
        
    def _load_graph_pretrained_weights(self, model):
        try:
            checkpoints_folder = os.path.join('MolCLR', 'ckpt', config['load_graph_model'], 'checkpoints')
            print(os.path.join(checkpoints_folder, 'model.pth'))
            state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'))
            
            model.load_state_dict(state_dict)
            print("Loaded pre-trained model with success.")
        except FileNotFoundError:
            print("Pre-trained weights not found. Training from scratch.")

        return model

In [19]:
model = MolecularBertGraph().to(device)

In [20]:
print(model)

MolecularBertGraph(
  (bert): RobertaForMaskedLM(
    (roberta): RobertaModel(
      (embeddings): RobertaEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=1)
        (position_embeddings): Embedding(514, 768, padding_idx=1)
        (token_type_embeddings): Embedding(1, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): RobertaEncoder(
        (layer): ModuleList(
          (0-5): 6 x RobertaLayer(
            (attention): RobertaAttention(
              (self): RobertaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): RobertaSelfOutput(
                (dense): Linear(in_fe

### Define utils

In [21]:
num_epoch = config['epochs']

optimizer = torch.optim.Adam(
    model.parameters(), config['init_lr'], 
    weight_decay=eval(config['weight_decay'])
)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
#     optimizer, T_max=config['epochs']-config['warm_up'], 
#     eta_min=0, last_epoch=-1
# )

In [22]:
wandb.init(
    project="efcp_transformer",
    name="Pretrained RobertaForMaskedLM + pretrained MolCLR (GCN) 10k",
    config=config
)

[34m[1mwandb[0m: Currently logged in as: [33morlov-aleksei53[0m ([33mmoleculary-ai[0m). Use [1m`wandb login --relogin`[0m to force relogin


### Training (with validation)

In [23]:
alpha = config['loss_params']['alpha']
beta = config['loss_params']['beta']
gamma = config['loss_params']['gamma']

In [24]:
epoch_counter = 0

In [25]:
def train_loop():
    train_tqdm = tqdm(train_dataloader, unit="batch")
    train_tqdm.set_description(f'Epoch {epoch_counter}')
    bert_loss_sum, graph_model_loss_sum, bimodal_loss_sum, loss_sum = 0, 0, 0, 0
    
    model.train()
    for (bert_batch, graph_batch1, graph_batch2) in train_tqdm:
        optimizer.zero_grad()

        bert_batch = bert_batch.to(device)
        graph_batch1 = graph_batch1.to(device)
        graph_batch2 = graph_batch2.to(device)

        bert_loss, graph_loss, bimodal_loss = model(bert_batch, graph_batch1, graph_batch2)

        loss = alpha * bert_loss + beta * graph_loss + gamma * bimodal_loss
        loss.backward()

        bert_loss_sum += bert_loss.item()
        graph_model_loss_sum += graph_loss.item()
        bimodal_loss_sum += bimodal_loss.item()
        loss_sum += loss.item()

        optimizer.step()
        train_tqdm.set_postfix(loss=loss.item(), bert_loss=bert_loss.item(), graph_loss=graph_loss.item(), bimodal_loss=bimodal_loss.item())
    return bert_loss_sum / len(train_dataloader), graph_model_loss_sum / len(train_dataloader), bimodal_loss_sum / len(train_dataloader), loss_sum / len(train_dataloader)

In [26]:
def eval_loop():
    eval_tqdm = tqdm(eval_dataloader, unit="batch")
    eval_tqdm.set_description(f'Epoch {epoch_counter}')
    bert_loss_sum, graph_model_loss_sum, bimodal_loss_sum, loss_sum = 0, 0, 0, 0
    
    model.eval()
    for (bert_batch, graph_batch1, graph_batch2) in eval_tqdm:
        optimizer.zero_grad()

        bert_batch = bert_batch.to(device)
        graph_batch1 = graph_batch1.to(device)
        graph_batch2 = graph_batch2.to(device)

        with torch.no_grad():
            bert_loss, graph_loss, bimodal_loss = model(bert_batch, graph_batch1, graph_batch2)

        loss = alpha * bert_loss + beta * graph_loss + gamma * bimodal_loss

        bert_loss_sum += bert_loss.item()
        graph_model_loss_sum += graph_loss.item()
        bimodal_loss_sum += bimodal_loss.item()
        loss_sum += loss.item()

        eval_tqdm.set_postfix(loss=loss.item(), bert_loss=bert_loss.item(), graph_loss=graph_loss.item(), bimodal_loss=bimodal_loss.item())
    return bert_loss_sum / len(eval_dataloader), graph_model_loss_sum / len(eval_dataloader), bimodal_loss_sum / len(eval_dataloader), loss_sum / len(eval_dataloader)

In [27]:
bert_loss, graph_loss, bimodal_loss, loss = eval_loop()

Epoch 0: 100%|█████████████████████████| 62/62 [00:20<00:00,  3.08batch/s, bert_loss=0.0132, bimodal_loss=54.8, graph_loss=3.39, loss=58.2]


In [28]:
print('bert_loss =', bert_loss)
print('graph_loss = ', graph_loss)
print('bimodal_loss =', bimodal_loss)
print('sum of losses =', loss)

bert_loss = 0.012725239319186057
graph_loss =  3.3833395934874013
bimodal_loss = 52.559547855008034
sum of losses = 55.955612613308816


In [29]:
from datetime import datetime

model_checkpoints_folder = os.path.join('ckpts')
dir_name = datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = os.path.join(model_checkpoints_folder, dir_name)
_save_config_file(config, log_dir)

In [None]:
n_iter = 0
valid_n_iter = 0
best_valid_loss = np.inf

for epoch_counter in range(num_epoch):
    bert_loss, graph_loss, bimodal_loss, loss = train_loop()
    print('train', bert_loss, graph_loss, bimodal_loss, loss)
    
    wandb.log({"bert_loss/train": bert_loss}, step=epoch_counter)
    wandb.log({"graph_loss/train": graph_loss}, step=epoch_counter)
    wandb.log({"bimodal_loss/train": bimodal_loss}, step=epoch_counter)
    wandb.log({"loss/train": loss}, step=epoch_counter)

    bert_loss, graph_loss, bimodal_loss, loss = eval_loop()
    print('eval', bert_loss, graph_loss, bimodal_loss, loss)

    wandb.log({"bert_loss/eval": bert_loss}, step=epoch_counter)
    wandb.log({"graph_loss/eval": graph_loss}, step=epoch_counter)
    wandb.log({"bimodal_loss/eval": bimodal_loss}, step=epoch_counter)
    wandb.log({"loss/eval": loss}, step=epoch_counter)
    
    if loss < best_valid_loss:
        best_valid_loss = loss
        torch.save(model.state_dict(), os.path.join(log_dir, 'model.pth'))
    
    if (epoch_counter + 1) % config['save_every_n_epochs'] == 0:
        torch.save(model.state_dict(), os.path.join(log_dir, 'model_{}.pth'.format(str(epoch_counter))))

    # # warmup for the first few epochs
    # if epoch_counter >= config['warm_up']:
        # wandb.log({"cosine_lr_decay": scheduler.get_last_lr()[0]}, step=epoch_counter)
        # scheduler.step()

Epoch 0: 100%|██████████████████████| 562/562 [06:26<00:00,  1.45batch/s, bert_loss=0.00656, bimodal_loss=9.81, graph_loss=1.13, loss=10.9]


train 0.007989358557951621 1.6952791869428234 11.319391691811992 13.022660201978853


Epoch 0: 100%|██████████████████████████| 62/62 [00:19<00:00,  3.23batch/s, bert_loss=0.00743, bimodal_loss=13, graph_loss=1.53, loss=14.5]


eval 0.0073110865459086434 1.6785333098903779 9.3007484943636 10.986592908059396


Epoch 1: 100%|███████████████████████| 562/562 [06:29<00:00,  1.44batch/s, bert_loss=0.0086, bimodal_loss=4.65, graph_loss=1.11, loss=5.77]


train 0.007956480351773843 1.4875733537393956 6.564184535864833 8.059714372471982


Epoch 1: 100%|█████████████████████████| 62/62 [00:19<00:00,  3.19batch/s, bert_loss=0.0063, bimodal_loss=5.53, graph_loss=1.36, loss=6.89]


eval 0.007376911505425889 1.5299454648648538 5.768517832602224 7.305840246139034


Epoch 2: 100%|██████████████████████| 562/562 [06:30<00:00,  1.44batch/s, bert_loss=0.00842, bimodal_loss=4.38, graph_loss=2.09, loss=6.48]


train 0.008025767536853768 1.4425862434707926 5.416331642038881 6.866943648701461


Epoch 2: 100%|████████████████████████| 62/62 [00:17<00:00,  3.49batch/s, bert_loss=0.00977, bimodal_loss=4.89, graph_loss=1.41, loss=6.31]


eval 0.008072700862201953 1.4841395020484924 5.9083679106927685 7.4005801062430105


Epoch 3: 100%|██████████████████████| 562/562 [06:27<00:00,  1.45batch/s, bert_loss=0.00544, bimodal_loss=4.02, graph_loss=2.29, loss=6.31]


train 0.008125185062926515 1.3760550672261316 4.7827509722251484 6.166931237190219


Epoch 3: 100%|██████████████████████████| 62/62 [00:19<00:00,  3.26batch/s, bert_loss=0.00544, bimodal_loss=4.9, graph_loss=1.89, loss=6.8]


eval 0.007343835280006451 1.4862409214819632 5.226063032304087 6.719647784386912


Epoch 4: 100%|██████████████████████| 562/562 [06:29<00:00,  1.44batch/s, bert_loss=0.00615, bimodal_loss=4.03, graph_loss=1.72, loss=5.76]


train 0.007965413582166557 1.336401222333365 4.613786784355327 5.958153417526191


Epoch 4: 100%|████████████████████████| 62/62 [00:19<00:00,  3.23batch/s, bert_loss=0.00538, bimodal_loss=3.95, graph_loss=1.55, loss=5.51]


eval 0.007787514598138871 1.715786793539601 4.667580719917051 6.391155019883187


Epoch 5: 100%|███████████████████████| 562/562 [06:29<00:00,  1.44batch/s, bert_loss=0.0118, bimodal_loss=3.73, graph_loss=1.95, loss=5.69]


train 0.007985612127707866 1.4346057683547622 4.449056195194611 5.891647583225019


Epoch 5: 100%|█████████████████████████| 62/62 [00:16<00:00,  3.74batch/s, bert_loss=0.0065, bimodal_loss=4.12, graph_loss=1.81, loss=5.94]


eval 0.008077590662475315 1.6444563211933259 4.20133001189078 5.853863969925912


Epoch 6: 100%|█████████████████████████| 562/562 [06:27<00:00,  1.45batch/s, bert_loss=0.00559, bimodal_loss=3.87, graph_loss=1.12, loss=5]


train 0.007996226302961226 1.3905536666035228 4.270889707731607 5.669439612758541


Epoch 6: 100%|█████████████████████████| 62/62 [00:19<00:00,  3.24batch/s, bert_loss=0.0068, bimodal_loss=3.64, graph_loss=1.25, loss=4.91]


eval 0.007593946734202966 1.4398000874826986 4.353418104110226 5.800812159815142


Epoch 7: 100%|██████████████████████| 562/562 [06:28<00:00,  1.45batch/s, bert_loss=0.00719, bimodal_loss=3.69, graph_loss=1.28, loss=4.97]


train 0.008034338787681956 1.3743587332476077 4.451794335850617 5.834187406662096


Epoch 7: 100%|████████████████████████| 62/62 [00:19<00:00,  3.21batch/s, bert_loss=0.00584, bimodal_loss=3.51, graph_loss=1.03, loss=4.54]


eval 0.0077572131056278465 1.395942476487929 3.917857620023912 5.321557283401489


Epoch 8: 100%|██████████████████████| 562/562 [06:26<00:00,  1.45batch/s, bert_loss=0.00561, bimodal_loss=3.61, graph_loss=1.04, loss=4.66]


train 0.008037233571488795 1.3364662709397352 3.905933326249445 5.2504368337447955


Epoch 8: 100%|████████████████████████| 62/62 [00:16<00:00,  3.68batch/s, bert_loss=0.00615, bimodal_loss=3.73, graph_loss=2.13, loss=5.87]


eval 0.007644908780592584 1.4090487735886728 3.8433801474109774 5.260073838695403


Epoch 9: 100%|██████████████████████| 562/562 [06:29<00:00,  1.44batch/s, bert_loss=0.00722, bimodal_loss=6.22, graph_loss=1.62, loss=7.84]


train 0.008067078420683556 1.3113595122120134 4.356950473955093 5.676377052938387


Epoch 9: 100%|████████████████████████| 62/62 [00:19<00:00,  3.21batch/s, bert_loss=0.00446, bimodal_loss=3.76, graph_loss=1.56, loss=5.32]


eval 0.00760008372937239 1.4487391191144143 4.097002321673978 5.553341480993455


Epoch 10: 100%|██████████████████████| 562/562 [06:30<00:00,  1.44batch/s, bert_loss=0.0103, bimodal_loss=3.54, graph_loss=1.98, loss=5.53]


train 0.008194234022486199 1.3265256417179447 3.925053265595351 5.259773150033374


Epoch 10: 100%|███████████████████████| 62/62 [00:19<00:00,  3.24batch/s, bert_loss=0.00806, bimodal_loss=3.46, graph_loss=1.41, loss=4.88]


eval 0.007675332280116216 1.3811359126721658 3.703485065890897 5.092296315777686


Epoch 11: 100%|█████████████████████| 562/562 [06:26<00:00,  1.45batch/s, bert_loss=0.00833, bimodal_loss=3.75, graph_loss=1.09, loss=4.86]


train 0.008030255245595004 1.3319068326220393 3.737643939320303 5.077581028072859


Epoch 11: 100%|███████████████████████| 62/62 [00:19<00:00,  3.22batch/s, bert_loss=0.00575, bimodal_loss=3.45, graph_loss=1.61, loss=5.06]


eval 0.007619160625542845 1.4469382935954678 3.7553550774051296 5.209912538528442


Epoch 12: 100%|█████████████████████| 562/562 [06:30<00:00,  1.44batch/s, bert_loss=0.00807, bimodal_loss=4.32, graph_loss=1.35, loss=5.68]


train 0.008035762982071771 1.3755684946779678 3.9149762387801745 5.298580489548924


Epoch 12: 100%|███████████████████████| 62/62 [00:19<00:00,  3.21batch/s, bert_loss=0.00795, bimodal_loss=4.62, graph_loss=1.82, loss=6.45]


eval 0.00788802181130215 1.4859285498819044 3.9472135151586225 5.44103007162771


Epoch 13: 100%|█████████████████████| 562/562 [06:29<00:00,  1.44batch/s, bert_loss=0.00651, bimodal_loss=3.99, graph_loss=1.44, loss=5.44]


train 0.008127855315134989 1.320514210708625 3.870976698780399 5.199618756134739


Epoch 13: 100%|███████████████████████| 62/62 [00:19<00:00,  3.25batch/s, bert_loss=0.00749, bimodal_loss=3.64, graph_loss=1.41, loss=5.06]


eval 0.007226349887317947 1.4491123403272321 3.73506558710529 5.19140427343307


Epoch 14: 100%|█████████████████████| 562/562 [06:25<00:00,  1.46batch/s, bert_loss=0.00675, bimodal_loss=3.49, graph_loss=1.14, loss=4.64]


train 0.008051353368371247 1.2871967807571234 3.739602294680911 5.034850431930977


Epoch 14: 100%|██████████████████████| 62/62 [00:19<00:00,  3.25batch/s, bert_loss=0.00661, bimodal_loss=4.63, graph_loss=0.968, loss=5.61]


eval 0.007451166894527212 1.3944765713907057 3.6545000114748554 5.056427763354394


Epoch 15: 100%|█████████████████████| 562/562 [06:29<00:00,  1.44batch/s, bert_loss=0.00534, bimodal_loss=3.57, graph_loss=0.52, loss=4.09]


train 0.008051816115961195 1.3167755774543803 3.9068199056747543 5.231647293762804


Epoch 15: 100%|█████████████████████████| 62/62 [00:19<00:00,  3.21batch/s, bert_loss=0.0109, bimodal_loss=3.44, graph_loss=1.5, loss=4.96]


eval 0.00782375683587405 1.4066423648788082 3.7702989270610194 5.1847650235699065


Epoch 16: 100%|██████████████████████| 562/562 [06:27<00:00,  1.45batch/s, bert_loss=0.00906, bimodal_loss=3.5, graph_loss=3.07, loss=6.58]


train 0.008057041051449566 3.61256102151718 3.9942186345409243 7.614836715718606


Epoch 16: 100%|███████████████████████| 62/62 [00:19<00:00,  3.24batch/s, bert_loss=0.00879, bimodal_loss=3.69, graph_loss=3.46, loss=7.16]


eval 0.007590710208000195 3.5794558909631546 3.8927287863146876 7.479775382626441


Epoch 17: 100%|██████████████████████| 562/562 [06:26<00:00,  1.45batch/s, bert_loss=0.00837, bimodal_loss=4.01, graph_loss=4.1, loss=8.12]


train 0.008006201065609437 3.5445899004613803 3.856518708514149 7.409114823222585


Epoch 17: 100%|███████████████████████| 62/62 [00:19<00:00,  3.19batch/s, bert_loss=0.00451, bimodal_loss=3.48, graph_loss=3.31, loss=6.79]


eval 0.008009641108313394 3.4074941950459636 3.699527963515251 7.1150317884260605


Epoch 18: 100%|█████████████████████| 562/562 [06:26<00:00,  1.45batch/s, bert_loss=0.00599, bimodal_loss=4.55, graph_loss=3.02, loss=7.57]


train 0.00792925346069356 3.4145818104523356 3.746811584645743 7.169322638324996


Epoch 18: 100%|███████████████████████| 62/62 [00:19<00:00,  3.24batch/s, bert_loss=0.00486, bimodal_loss=3.67, graph_loss=3.17, loss=6.84]


eval 0.007415258325636387 3.4232115514816774 3.7325107474480905 7.163137543585993


Epoch 19: 100%|██████████████████████| 562/562 [08:02<00:00,  1.17batch/s, bert_loss=0.0115, bimodal_loss=4.01, graph_loss=4.07, loss=8.09]


train 0.007946137604436231 3.5553493007646337 3.971539313682882 7.534834756545749


Epoch 19: 100%|████████████████████████| 62/62 [00:21<00:00,  2.85batch/s, bert_loss=0.00669, bimodal_loss=3.87, graph_loss=3.9, loss=7.77]


eval 0.007670880242761585 4.714435800429313 4.0676272492254935 8.78973392517336


Epoch 20:  19%|███▉                 | 106/562 [01:38<07:19,  1.04batch/s, bert_loss=0.00667, bimodal_loss=3.57, graph_loss=5.61, loss=9.19]

In [None]:
wandb.finish()