# Pre-training experiments
This notebook contains various experiments with pre-training pipeline, inspired by MOP paper https://arxiv.org/abs/2109.04810

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import numpy as np
import pandas as pd

## Prepare small graph

In [None]:
DATASET_DIR = "/content/drive/MyDrive/Code/mop/kg_dir/old/S20Rel"
PART_ENTITIES = -1

In [None]:
cut_partitions = []
mapping = {}
with open(os.path.join(DATASET_DIR, "partition_20.txt")) as f:
  lines = f.readlines()
  for i in range(len(lines)):
    cut_partitions.append(lines[i].strip().split("\t")[:PART_ENTITIES])
    for idx in cut_partitions[-1]:
      mapping[idx] = i

In [None]:
entities = []
with open(os.path.join(DATASET_DIR, "entity2id.txt")) as f:
  for line in f.readlines():
    for c in line.strip().split("\t"):
      if c.strip() in mapping:
        entities.append(line)

In [None]:
len(entities)

In [None]:
d = {}
with open(os.path.join(DATASET_DIR, "train2id.txt")) as f:
  count = int(f.readline().strip())
  print(count)
  for line in f.readlines():
    dat = line.strip().split("\t")
    if not dat[0] in d:
      d[dat[0]] = np.zeros(20)
    d[dat[0]][int(dat[2])] += 1

In [None]:
np.mean(list(map(lambda x: np.max(x), d.values())))

In [None]:
np.median(list(map(lambda x: np.max(x), d.values())))

In [None]:
train_dummy = []
counts = np.zeros(len(cut_partitions))
with open(os.path.join(DATASET_DIR, "train2id.txt")) as f:
  lines = f.readlines()
  for i in range(1, len(lines)):
    ids = lines[i].strip().split("\t")
    if ids[0] in mapping and ids[1] in mapping and mapping[ids[0]] == mapping[ids[1]]:
      train_dummy.append(lines[i])
      counts[mapping[ids[0]]] += 1

In [None]:
counts

In [None]:
with open("entity2id.txt", "w") as f:
  f.write(str(len(entities)) + '\n' + "".join(entities))

with open("train2id.txt", "w") as f:
  f.write(str(len(train_dummy)) + '\n' + "".join(train_dummy))

In [None]:
for i in range(len(cut_partitions)):
  cut_partitions[i] = "\t".join(cut_partitions[i])


with open("partition_20.txt", "w") as f:
  f.write("\n".join(cut_partitions))

## Dataset preparation

In [None]:
from tqdm import tqdm
from cprint import cprint
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
import torch
from transformers import (
    AdamW,
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
)
from datasets import Dataset

In [None]:
# TODO: Simplify everything here, there is probably no need for both custom collate_fn and generator_fn
"""
Generator function for link prediction (for now, no negative sampling)
"""
def generator_fn_lp(data, id2ent, id2rel, entity_list):
  label = list(id2rel.keys()).index(data[2])
  return ((id2ent[data[0]], id2ent[data[1]]), label)

"""
Generator function for entity prediction
"""
def generator_fn_ep(data, id2ent, id2rel, entity_list):
  label = entity_list.index(data[1])
  return ((id2ent[data[0]], id2rel[data[2]]), label)


In [None]:
"""
generator_fn(data, id2ent, id2rel, ent2part, partition_counts) - function that generates textual input and label from relation
Input:
- data - list [entityid_1, entityid_2, relationid]
- id2ent - dict {entityid: textual_description}
- id2rel - dict {relationshipid: textual_description}
- ent2part - dict {entityid: (partitionid, entityid_within_partition)}
- partition_counts - list, number of entities per partition
Output:
tuple (input_text, label)
"""
def prepare_data(dataset_path, tokenizer, generator_fn=None, include_border_rels=True, num_workers=2, max_seq_length=64):

  id2ent = {}
  cprint.info("Processing entities for graph")
  with open(os.path.join(dataset_path, "entity2id.txt")) as f:
    total_entities = int(f.readline().strip())
    for line in tqdm(f.readlines()):
      id2ent[int(line.strip().split("\t")[1])] = line.strip().split("\t")[0]

  id2rel = {}
  cprint.info("Processing relations for graph")
  with open(os.path.join(dataset_path, "relation2id.txt")) as f:
    total_relations = int(f.readline().strip())
    for line in tqdm(f.readlines()):
      id2rel[int(line.strip().split("\t")[1])] = line.strip().split("\t")[0]

  ent2part = {}
  total_groups = 0
  cprint.info("Processing partitions for graph")
  with open(os.path.join(dataset_path, "partition_20.txt")) as f:
    for group_idx, line in tqdm(enumerate(f.readlines())):
      for entid in line.strip().split("\t"):
        if not int(entid) in id2ent:
            cprint.err(f"Partition №{group_idx} has entity with invalid id {id}. Skipping...")
            continue
        if int(entid) in ent2part:
            cprint.err(f"Entity with ID {id} belongs to multiple partitions. Skipping...")
            continue
        ent2part[int(entid)] = group_idx
      total_groups += 1

  entities = [set([]) for i in range(total_groups)]
  cprint.info(f"Getting entities for partitions")
  with open(os.path.join(dataset_path, "train2id.txt")) as f:
      count = int(f.readline())
      for line in tqdm(f.readlines()):
        data = list(map(int, line.strip().split("\t")))
        if not data[0] in ent2part or not data[1] in ent2part:
            cprint.err(f"No partition data available for triplet {data}. Skipping...")
            continue
        elif ent2part[data[0]] != ent2part[data[1]]:
            if include_border_rels:
                entities[ent2part[data[0]]].add(data[0])
                entities[ent2part[data[0]]].add(data[1])
                entities[ent2part[data[1]]].add(data[0])
                entities[ent2part[data[1]]].add(data[1])
            else:
                #cprint.err(f"Inter-partition link encountered: {data}. Skipping...")
                continue
        else:
            entities[ent2part[data[0]]].add(data[0])
            entities[ent2part[data[0]]].add(data[1])
            
  partition_counts = [len(entities[i]) for i in range(len(entities))]
  entity_list = [list(entities[i]) for i in range(len(entities))]  
  examples = [[] for i in range(total_groups)]
  cprint.info(f"Splitting edges into partitions. include_border_edges={include_border_rels}") 
  with open(os.path.join(dataset_path, "train2id.txt")) as f:
    count = int(f.readline())
    for line in tqdm(f.readlines()):
      data = list(map(int, line.strip().split("\t")))
      if not data[0] in id2ent or not data[1] in id2ent:
        #cprint.err(f"No entity data available for triplet {data}. Skipping...")
        continue
      elif not data[0] in ent2part or not data[1] in ent2part:
        #cprint.err(f"No partition data available for triplet {data}. Skipping...")
        continue
      elif not data[2] in id2rel:
        #cprint.err(f"Invalid relation id encountered for triplet {data}. Skipping...")
        continue
      elif ent2part[data[0]] != ent2part[data[1]]:
        if include_border_rels:
          examples[ent2part[data[1]]].append(generator_fn(data, id2ent, id2rel, entity_list[ent2part[data[1]]]))
        else:
          #cprint.err(f"Inter-partition link encountered: {data}. Skipping...")
          continue
      examples[ent2part[data[0]]].append(generator_fn(data, id2ent, id2rel, entity_list[ent2part[data[0]]]))

  datasets = []
  cprint.info("Preparing datasets")
  for group in tqdm(examples):
      batch_pair_text = [example[0] for example in group]
      text_features = tokenizer.batch_encode_plus(
        batch_pair_text,
        padding="max_length",
        max_length=max_seq_length,
        return_tensors="pt",
        truncation=True,
      )
      labels = torch.as_tensor([example[1] for example in group], dtype=torch.long)
      dataset = Dataset.from_dict({
        "input_ids": text_features.input_ids,
        "attention_mask": text_features.attention_mask,
        "token_type_ids": text_features.token_type_ids,
        "labels": labels,
      })
      dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
      datasets.append(dataset)
  return datasets, partition_counts, total_relations



In [None]:
def prepare_data_full_lp(dataset_path, tokenizer, num_workers=2, max_seq_length=64):

  id2ent = {}
  cprint.info("Processing entities for graph")
  with open(os.path.join(dataset_path, "entity2id.txt")) as f:
    total_entities = int(f.readline().strip())
    for line in tqdm(f.readlines()):
      id2ent[int(line.strip().split("\t")[1])] = line.strip().split("\t")[0]

  id2rel = {}
  cprint.info("Processing relations for graph")
  with open(os.path.join(dataset_path, "relation2id.txt")) as f:
    total_relations = int(f.readline().strip())
    for line in tqdm(f.readlines()):
      id2rel[int(line.strip().split("\t")[1])] = line.strip().split("\t")[0]

  examples = []
  with open(os.path.join(dataset_path, "train2id.txt")) as f:
    count = int(f.readline())
    for line in tqdm(f.readlines()):
      data = list(map(int, line.strip().split("\t")))
      if not data[0] in id2ent or not data[1] in id2ent:
        #cprint.err(f"No entity data available for triplet {data}. Skipping...")
        continue
      elif not data[2] in id2rel:
        #cprint.err(f"Invalid relation id encountered for triplet {data}. Skipping...")
        continue
      examples.append(((id2ent[data[0]], id2ent[data[1]]), list(id2rel.keys()).index(data[2])))

  batch_pair_text = [example[0] for example in examples]
  text_features = tokenizer.batch_encode_plus(
        batch_pair_text,
        padding="max_length",
        max_length=max_seq_length,
        return_tensors="pt",
        truncation=True,
  )
  labels = torch.as_tensor([example[1] for example in examples], dtype=torch.long)
  dataset = Dataset.from_dict({
        "input_ids": text_features.input_ids,
        "attention_mask": text_features.attention_mask,
        "token_type_ids": text_features.token_type_ids,
        "labels": labels,
  })
  dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
  return dataset, total_relations

In [None]:
def prepare_data_ep_mod(dataset_path, tokenizer, include_border_rels=True, num_workers=2, max_seq_length=64):

  id2ent = {}
  cprint.info("Processing entities for graph")
  with open(os.path.join(dataset_path, "entity2id.txt")) as f:
    total_entities = int(f.readline().strip())
    for line in tqdm(f.readlines()):
      id2ent[int(line.strip().split("\t")[1])] = line.strip().split("\t")[0]

  id2rel = {}
  cprint.info("Processing relations for graph")
  with open(os.path.join(dataset_path, "relation2id.txt")) as f:
    total_relations = int(f.readline().strip())
    for line in tqdm(f.readlines()):
      id2rel[int(line.strip().split("\t")[1])] = line.strip().split("\t")[0]

  ent2part = {}
  total_groups = 0
  cprint.info("Processing partitions for graph")
  with open(os.path.join(dataset_path, "partition_20.txt")) as f:
    for group_idx, line in tqdm(enumerate(f.readlines())):
      for entid in line.strip().split("\t"):
        if not int(entid) in id2ent:
            cprint.err(f"Partition №{group_idx} has entity with invalid id {id}. Skipping...")
            continue
        if int(entid) in ent2part:
            cprint.err(f"Entity with ID {id} belongs to multiple partitions. Skipping...")
            continue
        ent2part[int(entid)] = group_idx
      total_groups += 1

  entities = [set([]) for i in range(total_groups)]
  cprint.info(f"Getting entities for partitions")
  with open(os.path.join(dataset_path, "train2id.txt")) as f:
      count = int(f.readline())
      for line in tqdm(f.readlines()):
        data = list(map(int, line.strip().split("\t")))
        if not data[0] in ent2part or not data[1] in ent2part:
            cprint.err(f"No partition data available for triplet {data}. Skipping...")
            continue
        elif ent2part[data[0]] != ent2part[data[1]]:
            if include_border_rels:
                entities[ent2part[data[0]]].add(data[0])
                entities[ent2part[data[0]]].add(data[1])
                entities[ent2part[data[1]]].add(data[0])
                entities[ent2part[data[1]]].add(data[1])
            else:
                #cprint.err(f"Inter-partition link encountered: {data}. Skipping...")
                continue
        else:
            entities[ent2part[data[0]]].add(data[0])
            entities[ent2part[data[0]]].add(data[1])
            
  partition_counts = [len(entities[i]) for i in range(len(entities))]
  entity_list = [list(entities[i]) for i in range(len(entities))]
  examples = [{} for i in range(total_groups)]  
  cprint.info(f"Splitting edges into partitions. include_border_edges={include_border_rels}") 
  with open(os.path.join(dataset_path, "train2id.txt")) as f:
    count = int(f.readline())
    for line in tqdm(f.readlines()):
      data = list(map(int, line.strip().split("\t")))
      if not data[0] in id2ent or not data[1] in id2ent:
        #cprint.err(f"No entity data available for triplet {data}. Skipping...")
        continue
      elif not data[0] in ent2part or not data[1] in ent2part:
        #cprint.err(f"No partition data available for triplet {data}. Skipping...")
        continue
      elif not data[2] in id2rel:
        #cprint.err(f"Invalid relation id encountered for triplet {data}. Skipping...")
        continue
      elif ent2part[data[0]] != ent2part[data[1]]:
        if include_border_rels:
          partition = ent2part[data[1]]  
          if not (data[0], data[2]) in examples[partition]:
               examples[partition][(data[0], data[2])] = np.zeros(partition_counts[partition])
          examples[partition][(data[0], data[2])][entity_list[partition].index(data[1])] = 1
        else:
          #cprint.err(f"Inter-partition link encountered: {data}. Skipping...")
          continue
      partition = ent2part[data[0]]  
      if not (data[0], data[2]) in examples[partition]:
          examples[partition][(data[0], data[2])] = np.zeros(partition_counts[partition])
      examples[partition][(data[0], data[2])][entity_list[partition].index(data[1])] = 1

  datasets = []
  cprint.info("Preparing datasets")
  for group in tqdm(examples):
      batch_pair_text = [(id2ent[example[0]], id2ent[example[1]]) for example in group.keys()]
      labels = torch.as_tensor([group[example] for example in group.keys()], dtype=torch.long)
      text_features = tokenizer.batch_encode_plus(
        batch_pair_text,
        padding="max_length",
        max_length=max_seq_length,
        return_tensors="pt",
        truncation=True,
      )
      dataset = Dataset.from_dict({
        "input_ids": text_features.input_ids,
        "attention_mask": text_features.attention_mask,
        "token_type_ids": text_features.token_type_ids,
        "labels": labels,
      })
      dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
      datasets.append(dataset)
      break
  return datasets, partition_counts, total_relations



In [None]:
group_datasets_ep_mod_pubmedbert, part_ent_counts_ep, total_rel = prepare_data_ep_mod("/vol/data/kg_dir/S20Rel", tokenizer=tokenizer_pubmedbert, include_border_rels=True, num_workers=4)

[92mProcessing entities for graph[0m


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 232147/232147 [00:00<00:00, 1056305.58it/s]


[92mProcessing relations for graph[0m


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 132521.45it/s]


[92mProcessing partitions for graph[0m


20it [00:00, 132.99it/s]


[92mGetting entities for partitions[0m


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1125828/1125828 [00:02<00:00, 557001.50it/s]


[92mSplitting edges into partitions. include_border_edges=True[0m


 76%|███████████████████████████████████████████████████████████████████████████████████▊                          | 857710/1125828 [01:51<02:34, 1736.09it/s]

In [11]:
#full_lp_rel_scibert, total_rel = prepare_data_rel_lp("/vol/data/kg_dir/S20Rel", tokenizer=tokenizer_scibert)
full_lp_rel_biobert, total_rel = prepare_data_rel_lp("/vol/data/kg_dir/S20Rel", tokenizer=tokenizer_biobert)

[92mProcessing entities for graph[0m


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 232147/232147 [00:00<00:00, 1216471.09it/s]


[92mProcessing relations for graph[0m


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 220173.44it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1125828/1125828 [00:33<00:00, 34066.09it/s]


[92mPreparing datasets[0m


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [34:14<00:00, 102.74s/it]


In [10]:
full_lp_rel_pubmedbert, total_rel = prepare_data_rel_lp("/vol/data/kg_dir/S20Rel", tokenizer=tokenizer_pubmedbert)

[92mProcessing entities for graph[0m


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 232147/232147 [00:00<00:00, 1050678.29it/s]


[92mProcessing relations for graph[0m


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 157680.60it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1125828/1125828 [00:43<00:00, 25953.63it/s]


[92mPreparing datasets[0m


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [35:18<00:00, 105.93s/it]


In [10]:
full_lp_scibert, total_rel = prepare_data_full_lp("/vol/data/kg_dir/S20Rel", tokenizer=tokenizer_scibert)
full_lp_biobert, total_rel = prepare_data_full_lp("/vol/data/kg_dir/S20Rel", tokenizer=tokenizer_biobert)
full_lp_pubmedbert, total_rel = prepare_data_full_lp("/vol/data/kg_dir/S20Rel", tokenizer=tokenizer_pubmedbert)

[92mProcessing entities for graph[0m


100%|██████████████████████████████| 232147/232147 [00:00<00:00, 1202160.98it/s]


[92mProcessing relations for graph[0m


100%|███████████████████████████████████████| 20/20 [00:00<00:00, 227951.30it/s]
100%|█████████████████████████████| 1125828/1125828 [00:02<00:00, 399722.42it/s]


[92mProcessing entities for graph[0m


100%|██████████████████████████████| 232147/232147 [00:00<00:00, 1179385.89it/s]


[92mProcessing relations for graph[0m


100%|███████████████████████████████████████| 20/20 [00:00<00:00, 233665.96it/s]
100%|█████████████████████████████| 1125828/1125828 [00:02<00:00, 394320.93it/s]


[92mProcessing entities for graph[0m


100%|██████████████████████████████| 232147/232147 [00:00<00:00, 1152992.31it/s]


[92mProcessing relations for graph[0m


100%|███████████████████████████████████████| 20/20 [00:00<00:00, 227333.55it/s]
100%|█████████████████████████████| 1125828/1125828 [00:02<00:00, 400559.56it/s]


In [7]:
def prepare_data_rel_lp(dataset_path, tokenizer, num_workers=2, max_seq_length=64):

  id2ent = {}
  cprint.info("Processing entities for graph")
  with open(os.path.join(dataset_path, "entity2id.txt")) as f:
    total_entities = int(f.readline().strip())
    for line in tqdm(f.readlines()):
      id2ent[int(line.strip().split("\t")[1])] = line.strip().split("\t")[0]

  id2rel = {}
  cprint.info("Processing relations for graph")
  with open(os.path.join(dataset_path, "relation2id.txt")) as f:
    total_relations = int(f.readline().strip())
    for line in tqdm(f.readlines()):
      id2rel[int(line.strip().split("\t")[1])] = line.strip().split("\t")[0]

            
  examples = [[] for i in range(total_relations)]
  with open(os.path.join(dataset_path, "train2id.txt")) as f:
    count = int(f.readline())
    for line in tqdm(f.readlines()):
      data = list(map(int, line.strip().split("\t")))
      if not data[0] in id2ent or not data[1] in id2ent:
        #cprint.err(f"No entity data available for triplet {data}. Skipping...")
        continue
      elif not data[2] in id2rel:
        #cprint.err(f"Invalid relation id encountered for triplet {data}. Skipping...")
        continue
      for i in range(total_relations):
          examples[i].append(((id2ent[data[0]], id2ent[data[1]]), int(i == list(id2rel.keys()).index(data[2]))))

  datasets = []
  cprint.info("Preparing datasets")
  for group in tqdm(examples):
      batch_pair_text = [example[0] for example in group]
      text_features = tokenizer.batch_encode_plus(
        batch_pair_text,
        padding="max_length",
        max_length=max_seq_length,
        return_tensors="pt",
        truncation=True,
      )
      labels = torch.as_tensor([example[1] for example in group], dtype=torch.long)
      dataset = Dataset.from_dict({
        "input_ids": text_features.input_ids,
        "attention_mask": text_features.attention_mask,
        "token_type_ids": text_features.token_type_ids,
        "labels": labels,
      })
      dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
      datasets.append(dataset)
  return datasets, total_relations



In [9]:
tokenizer_pubmedbert = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") 
tokenizer_biobert = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1") 
tokenizer_scibert = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")



In [None]:
group_datasets_ep_pubmedbert, part_ent_counts_ep, total_rel = prepare_data("/vol/data/kg_dir/S20Rel", 
                                generator_fn=generator_fn_ep,
                                tokenizer=tokenizer_pubmedbert,
                                num_workers=4, 
                                max_seq_length=64, 
                                include_border_rels=True)

In [None]:
group_datasets_ep_biobert, part_ent_counts_ep, total_rel = prepare_data("/vol/data/kg_dir/S20Rel", 
                                generator_fn=generator_fn_ep,
                                tokenizer=tokenizer_biobert,
                                num_workers=4, 
                                max_seq_length=64, 
                                include_border_rels=True)

In [None]:
group_datasets_ep_scibert, part_ent_counts_ep, total_rel = prepare_data("/vol/data/kg_dir/S20Rel", 
                                generator_fn=generator_fn_ep,
                                tokenizer=tokenizer_scibert,
                                num_workers=4, 
                                max_seq_length=64, 
                                include_border_rels=True)

In [None]:
group_datasets_lp_no_border_biobert, part_ent_counts_lp_nb, total_rel = prepare_data("/vol/data/kg_dir/S20Rel", 
                                generator_fn=generator_fn_lp,
                                tokenizer=tokenizer_biobert,
                                num_workers=4, 
                                max_seq_length=64, 
                                include_border_rels=False)

In [None]:
group_datasets_lp_no_border_scibert, part_ent_counts_lp_nb, total_rel = prepare_data("/vol/data/kg_dir/S20Rel", 
                                generator_fn=generator_fn_lp,
                                tokenizer=tokenizer_scibert,
                                num_workers=4, 
                                max_seq_length=64, 
                                include_border_rels=False)

In [None]:
group_datasets_ep_no_border_biobert, part_ent_counts_ep_nb, total_rel = prepare_data("/vol/data/kg_dir/S20Rel", 
                                generator_fn=generator_fn_ep,
                                tokenizer=tokenizer_biobert,
                                num_workers=4, 
                                max_seq_length=64, 
                                include_border_rels=False)

In [None]:
group_datasets_ep_no_border_scibert, part_ent_counts_ep_nb, total_rel = prepare_data("/vol/data/kg_dir/S20Rel", 
                                generator_fn=generator_fn_ep,
                                tokenizer=tokenizer_scibert,
                                num_workers=4, 
                                max_seq_length=64, 
                                include_border_rels=False)

In [None]:
group_datasets_lp_biobert, part_ent_counts_lp, total_rel = prepare_data("/vol/data/kg_dir/S20Rel", 
                                generator_fn=generator_fn_lp,
                                tokenizer=tokenizer_biobert,
                                num_workers=4, 
                                max_seq_length=64, 
                                include_border_rels=True)

In [None]:
group_datasets_lp_scibert, part_ent_counts_lp, total_rel = prepare_data("/vol/data/kg_dir/S20Rel", 
                                generator_fn=generator_fn_lp,
                                tokenizer=tokenizer_scibert,
                                num_workers=4, 
                                max_seq_length=64, 
                                include_border_rels=True)

## Pre-training

In [12]:
import torch
import random
from adapters import SeqBnConfig
from adapters import AutoAdapterModel
from transformers import TrainingArguments, EvalPrediction
from adapters import AdapterTrainer
import wandb
import datetime
from torch.nn import CrossEntropyLoss



In [13]:
os.environ["WANDB_PROJECT"] = "Pre-Training" 
os.environ["WANDB_LOG_MODEL"] = ""

In [14]:
GLOBAL_ACC = []

In [15]:
class WeightedLossAdapterTrainer(AdapterTrainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        
        class_weights = torch.tensor([0.1, 0.9]).to(logits.device)  # Adjust based on your dataset
        loss_fct = CrossEntropyLoss(weight=class_weights)
        
        loss = loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss

In [16]:
def init_model(num_labels, seed, adapter_name, model_name):
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  base_config = AutoConfig.from_pretrained(model_name)
  model = AutoAdapterModel.from_pretrained(model_name, config=base_config)
  adapter_config = SeqBnConfig(reduction_factor=8)
  model.add_adapter(adapter_name, config=adapter_config)
  model.add_classification_head(adapter_name, num_labels=num_labels)
  model.train_adapter(adapter_name)
  model.set_active_adapters(adapter_name)
  model.to(device)
  return model

def compute_metrics(p: EvalPrediction):
    preds = np.argmax(p.predictions, axis=1)
    return {"acc": (preds == p.label_ids).mean(), }

def train_group(group_dataset, num_labels, training_args, adapter_name, model_name, seed=1):
    model = init_model(num_labels=num_labels, seed=seed, adapter_name=adapter_name, model_name=model_name)
    if num_labels == 2:
        trainer = WeightedLossAdapterTrainer(
            model=model,
            args=training_args,
            train_dataset=group_dataset,
            eval_dataset=group_dataset,
            compute_metrics=compute_metrics,
        )
    else:
        trainer = AdapterTrainer(
            model=model,
            args=training_args,
            train_dataset=group_dataset,
            eval_dataset=group_dataset,
            compute_metrics=compute_metrics,
        )
    train_res = trainer.train()
    wandb.log({"train_res": train_res})
    print(train_res)
    eval_res = trainer.evaluate()
    wandb.log({"eval_res": eval_res})
    print(eval_res)
    return model

def train_dataset(save_path, datasets, adapter_name, model_name, training_args, num_classes, seed=1, start_idx=0):
    assert len(datasets) == len(num_classes)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    for idx in tqdm(range(start_idx, len(datasets))):
        partition = datasets[idx]
        cprint.info(f"Training adapter for partition {idx}")
        training_args.run_name = adapter_name + f"_partition_{idx}_epoch_{training_args.num_train_epochs}_" + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        model = train_group(group_dataset=partition, 
                            num_labels=num_classes[idx], 
                            training_args=training_args, 
                            adapter_name=adapter_name, 
                            model_name=model_name, 
                            seed=seed)
        cprint.info(f"Saving adapter to {os.path.join(save_path, training_args.run_name)}")
        model.save_adapter(os.path.join(save_path, training_args.run_name), adapter_name)
        wandb.finish()
        del model
        torch.cuda.empty_cache()

In [17]:
training_args_scibert_pubmedbert = TrainingArguments(
        learning_rate=1e-4,
        num_train_epochs=1,
        logging_steps=20,
        output_dir="/vol/data/mkonov_output",
        overwrite_output_dir=True,
        weight_decay=0.01,
        per_device_train_batch_size=256,
        per_device_eval_batch_size=256,
        #report_to=[]
        report_to="wandb",
    )

training_args_biobert = TrainingArguments(
        learning_rate=1e-4,
        num_train_epochs=2,
        logging_steps=20,
        output_dir="/vol/data/mkonov_output",
        overwrite_output_dir=True,
        weight_decay=0.01,
        per_device_train_batch_size=256,
        per_device_eval_batch_size=256,
        #report_to=[]
        report_to="wandb",
    )

In [18]:
torch.cuda.empty_cache()

In [19]:
wandb.finish()

In [38]:
wandb.login(relogin=True)

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········


[34m[1mwandb[0m: [32m[41mERROR[0m API key must be 40 characters long, yours was 48
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········


[34m[1mwandb[0m: [32m[41mERROR[0m API key must be 40 characters long, yours was 48
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/ubuntu/.netrc


True

In [19]:
train_dataset(save_path="/vol/data/adapters/PubMedBERT/S20Rel_LP_full", 
              datasets=[full_lp_pubmedbert], 
              num_classes=[total_rel],
              training_args=training_args_scibert_pubmedbert, 
              adapter_name="pubmedbert_S20Rel_LP_full", 
              model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")



[92mTraining adapter for partition 0[0m


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
[34m[1mwandb[0m: Currently logged in as: [33mkonovma[0m ([33mSoSe2024-NLP-Lab[0m). Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
20,2.7852
40,2.4488
60,1.8741
80,1.3283
100,0.9391
120,0.6581
140,0.4881
160,0.3851
180,0.3339
200,0.2978


TrainOutput(global_step=4398, training_loss=0.14737696604817604, metrics={'train_runtime': 2048.0233, 'train_samples_per_second': 549.714, 'train_steps_per_second': 2.147, 'total_flos': 3.832717137972941e+16, 'train_loss': 0.14737696604817604, 'epoch': 1.0})


{'eval_loss': 0.06044953688979149, 'eval_acc': 0.9729914338602345, 'eval_runtime': 1040.5788, 'eval_samples_per_second': 1081.925, 'eval_steps_per_second': 4.226, 'epoch': 1.0}
[92mSaving adapter to /vol/data/adapters/PubMedBERT/S20Rel_LP_full/pubmedbert_S20Rel_LP_full_partition_0_epoch_1_2024-07-10 00:06:12[0m


0,1
eval/acc,▁
eval/loss,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/grad_norm,▄▅█▄▃▃▆▄▄▄▃▃▂▂▃▃▆▂▃▃▂▄▁▃▃▃▂▁▂▄▃▂▃▂▃▁▅▁▁▄
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval/acc,0.97299
eval/loss,0.06045
eval/runtime,1040.5788
eval/samples_per_second,1081.925
eval/steps_per_second,4.226
total_flos,3.832717137972941e+16
train/epoch,1.0
train/global_step,4398.0
train/grad_norm,0.62627
train/learning_rate,0.0


100%|███████████████████████████████████████████| 1/1 [51:38<00:00, 3098.27s/it]


In [20]:
train_dataset(save_path="/vol/data/adapters/BioBERT/S20Rel_LP_full", 
              datasets=[full_lp_biobert], 
              num_classes=[total_rel],
              training_args=training_args_biobert, 
              adapter_name="biobert_S20Rel_LP_full", 
              model_name="dmis-lab/biobert-v1.1")



[92mTraining adapter for partition 0[0m


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss
20,2.7416
40,2.3392
60,1.6069
80,0.9945
100,0.6356
120,0.4544
140,0.3715
160,0.3084
180,0.2949
200,0.2703


TrainOutput(global_step=8796, training_loss=0.09819799975380675, metrics={'train_runtime': 4096.2044, 'train_samples_per_second': 549.693, 'train_steps_per_second': 2.147, 'total_flos': 7.611597721383322e+16, 'train_loss': 0.09819799975380675, 'epoch': 2.0})


{'eval_loss': 0.046076368540525436, 'eval_acc': 0.9779362389281488, 'eval_runtime': 1042.819, 'eval_samples_per_second': 1079.601, 'eval_steps_per_second': 4.217, 'epoch': 2.0}
[92mSaving adapter to /vol/data/adapters/BioBERT/S20Rel_LP_full/biobert_S20Rel_LP_full_partition_0_epoch_2_2024-07-10 00:57:51[0m


0,1
eval/acc,▁
eval/loss,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/grad_norm,▅▅▅▅▄▄▄▅▂▃▃▄▂▃▃▃▂▄▃▂▃▁▂█▄▃▄▃▁▂▂▃▂▂▃▄▁▂▃▂
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval/acc,0.97794
eval/loss,0.04608
eval/runtime,1042.819
eval/samples_per_second,1079.601
eval/steps_per_second,4.217
total_flos,7.611597721383322e+16
train/epoch,2.0
train/global_step,8796.0
train/grad_norm,0.30144
train/learning_rate,0.0


100%|█████████████████████████████████████████| 1/1 [1:25:47<00:00, 5147.75s/it]


In [None]:
train_dataset(save_path="/vol/data/adapters/SciBERT/S20Rel_LP_full", 
              datasets=[full_lp_scibert], 
              num_classes=[total_rel],
              training_args=training_args_scibert_pubmedbert, 
              adapter_name="scibert_S20Rel_LP_full", 
              model_name="allenai/scibert_scivocab_uncased")



[92mTraining adapter for partition 0[0m


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss
20,2.5895
40,1.8361
60,1.0722
80,0.6386
100,0.4307
120,0.3326
140,0.2826
160,0.2499
180,0.2241
200,0.2133


In [20]:
train_dataset(save_path="/vol/data/adapters/PubMedBERT/S20Rel_LP_Rel", 
              datasets=full_lp_rel_pubmedbert, 
              num_classes=[2] * total_rel,
              training_args=training_args_scibert_pubmedbert, 
              adapter_name="pubmedbert_S20Rel_LP_Rel", 
              model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
              start_idx=9)



[92mTraining adapter for partition 9[0m


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
[34m[1mwandb[0m: Currently logged in as: [33mkonovma[0m ([33mSoSe2024-NLP-Lab[0m). Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
20,0.6269
40,0.4749
60,0.3086
80,0.208
100,0.1533
120,0.1082
140,0.1137
160,0.1118
180,0.0798
200,0.0853


TrainOutput(global_step=4398, training_loss=0.040862478739351185, metrics={'train_runtime': 2046.6109, 'train_samples_per_second': 550.094, 'train_steps_per_second': 2.149, 'total_flos': 3.832118723463782e+16, 'train_loss': 0.040862478739351185, 'epoch': 1.0})


{'eval_loss': 0.013923414051532745, 'eval_acc': 0.9938916068884412, 'eval_runtime': 1041.644, 'eval_samples_per_second': 1080.818, 'eval_steps_per_second': 4.222, 'epoch': 1.0}
[92mSaving adapter to /vol/data/adapters/PubMedBERT/S20Rel_LP_Rel/pubmedbert_S20Rel_LP_Rel_partition_9_epoch_1_2024-07-10 21:49:34[0m


0,1
eval/acc,▁
eval/loss,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/grad_norm,▄▃▃▅▂▃▂▂▇▁▂▄▃▂▂▂▄█▃▂▁▁▁▁▂▃▁▂▆▁▁▂▃▁▂▁▂▁▁▁
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval/acc,0.99389
eval/loss,0.01392
eval/runtime,1041.644
eval/samples_per_second,1080.818
eval/steps_per_second,4.222
total_flos,3.832118723463782e+16
train/epoch,1.0
train/global_step,4398.0
train/grad_norm,0.38157
train/learning_rate,0.0




[92mTraining adapter for partition 10[0m


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss
20,0.515
40,0.312
60,0.1702
80,0.1217
100,0.0808
120,0.0255
140,0.0438
160,0.0373
180,0.0356
200,0.0225


TrainOutput(global_step=4398, training_loss=0.016734446771188776, metrics={'train_runtime': 2045.2049, 'train_samples_per_second': 550.472, 'train_steps_per_second': 2.15, 'total_flos': 3.832118723463782e+16, 'train_loss': 0.016734446771188776, 'epoch': 1.0})


{'eval_loss': 0.0036929999478161335, 'eval_acc': 0.9989589884067549, 'eval_runtime': 1040.6933, 'eval_samples_per_second': 1081.806, 'eval_steps_per_second': 4.226, 'epoch': 1.0}
[92mSaving adapter to /vol/data/adapters/PubMedBERT/S20Rel_LP_Rel/pubmedbert_S20Rel_LP_Rel_partition_10_epoch_1_2024-07-10 22:41:12[0m


0,1
eval/acc,▁
eval/loss,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/grad_norm,▃▂▂▂█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▂▁▂▂▁▁▁▂▂▁▁▁▁▁▁▁▁▁
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval/acc,0.99896
eval/loss,0.00369
eval/runtime,1040.6933
eval/samples_per_second,1081.806
eval/steps_per_second,4.226
total_flos,3.832118723463782e+16
train/epoch,1.0
train/global_step,4398.0
train/grad_norm,0.00893
train/learning_rate,0.0




[92mTraining adapter for partition 11[0m


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss
20,0.4049
40,0.1746
60,0.0598
80,0.0229
100,0.0108
120,0.0042
140,0.0035
160,0.004
180,0.0106
200,0.0039


TrainOutput(global_step=4398, training_loss=0.0036743453358664365, metrics={'train_runtime': 2049.5303, 'train_samples_per_second': 549.31, 'train_steps_per_second': 2.146, 'total_flos': 3.832118723463782e+16, 'train_loss': 0.0036743453358664365, 'epoch': 1.0})


{'eval_loss': 6.067147842259146e-05, 'eval_acc': 0.9999857882376348, 'eval_runtime': 1044.0039, 'eval_samples_per_second': 1078.375, 'eval_steps_per_second': 4.213, 'epoch': 1.0}
[92mSaving adapter to /vol/data/adapters/PubMedBERT/S20Rel_LP_Rel/pubmedbert_S20Rel_LP_Rel_partition_11_epoch_1_2024-07-10 23:32:47[0m


0,1
eval/acc,▁
eval/loss,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/grad_norm,█▂▁▁▃▁▁▃▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval/acc,0.99999
eval/loss,6e-05
eval/runtime,1044.0039
eval/samples_per_second,1078.375
eval/steps_per_second,4.213
total_flos,3.832118723463782e+16
train/epoch,1.0
train/global_step,4398.0
train/grad_norm,2e-05
train/learning_rate,0.0




[92mTraining adapter for partition 12[0m


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss
20,0.3679
40,0.0904
60,0.0661
80,0.0382
100,0.0221
120,0.0324
140,0.0246
160,0.0163
180,0.0223
200,0.018


TrainOutput(global_step=4398, training_loss=0.015208293170943049, metrics={'train_runtime': 2046.3052, 'train_samples_per_second': 550.176, 'train_steps_per_second': 2.149, 'total_flos': 3.832118723463782e+16, 'train_loss': 0.015208293170943049, 'epoch': 1.0})


{'eval_loss': 0.004600654821842909, 'eval_acc': 0.9965500946858667, 'eval_runtime': 1040.1317, 'eval_samples_per_second': 1082.39, 'eval_steps_per_second': 4.228, 'epoch': 1.0}
[92mSaving adapter to /vol/data/adapters/PubMedBERT/S20Rel_LP_Rel/pubmedbert_S20Rel_LP_Rel_partition_12_epoch_1_2024-07-11 00:24:30[0m


0,1
eval/acc,▁
eval/loss,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/grad_norm,▂█▂▁▃▂▄▁▂▂▁▁▁▁▂▂▂▂▂▂▃▁▁▁▂▁▂▂▂▁▁▁▂▁▁▁▁▁▁▁
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▃▂▂▂▂▁▂▁▁▁▁▁▁▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval/acc,0.99655
eval/loss,0.0046
eval/runtime,1040.1317
eval/samples_per_second,1082.39
eval/steps_per_second,4.228
total_flos,3.832118723463782e+16
train/epoch,1.0
train/global_step,4398.0
train/grad_norm,0.04687
train/learning_rate,0.0




[92mTraining adapter for partition 13[0m


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss
20,0.6063
40,0.3966
60,0.24
80,0.1824
100,0.1471
120,0.1358
140,0.123
160,0.1117
180,0.0953
200,0.1095


TrainOutput(global_step=4398, training_loss=0.0404094222792445, metrics={'train_runtime': 2050.9841, 'train_samples_per_second': 548.921, 'train_steps_per_second': 2.144, 'total_flos': 3.832118723463782e+16, 'train_loss': 0.0404094222792445, 'epoch': 1.0})


{'eval_loss': 0.014768562279641628, 'eval_acc': 0.9939751009923363, 'eval_runtime': 1042.9674, 'eval_samples_per_second': 1079.447, 'eval_steps_per_second': 4.217, 'epoch': 1.0}
[92mSaving adapter to /vol/data/adapters/PubMedBERT/S20Rel_LP_Rel/pubmedbert_S20Rel_LP_Rel_partition_13_epoch_1_2024-07-11 01:16:06[0m


0,1
eval/acc,▁
eval/loss,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/grad_norm,▄▄▅▆▃▂▂▃▁▂▂▁▄▁▁▁▄▂█▂▃▁▁▁▂▁▁▂▁▂▁▁▄▁▂▂▁▂▂▁
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▃▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval/acc,0.99398
eval/loss,0.01477
eval/runtime,1042.9674
eval/samples_per_second,1079.447
eval/steps_per_second,4.217
total_flos,3.832118723463782e+16
train/epoch,1.0
train/global_step,4398.0
train/grad_norm,1.91641
train/learning_rate,0.0




[92mTraining adapter for partition 14[0m


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss
20,0.377
40,0.1048
60,0.0607
80,0.0341
100,0.0423
120,0.0292
140,0.0256
160,0.0229
180,0.0303
200,0.0283


TrainOutput(global_step=4398, training_loss=0.015502432919445771, metrics={'train_runtime': 2049.4898, 'train_samples_per_second': 549.321, 'train_steps_per_second': 2.146, 'total_flos': 3.832118723463782e+16, 'train_loss': 0.015502432919445771, 'epoch': 1.0})


{'eval_loss': 0.007528392132371664, 'eval_acc': 0.9964683770522673, 'eval_runtime': 1041.8976, 'eval_samples_per_second': 1080.555, 'eval_steps_per_second': 4.221, 'epoch': 1.0}
[92mSaving adapter to /vol/data/adapters/PubMedBERT/S20Rel_LP_Rel/pubmedbert_S20Rel_LP_Rel_partition_14_epoch_1_2024-07-11 02:07:51[0m


0,1
eval/acc,▁
eval/loss,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/grad_norm,▇▃▁▁▄▂█▁▂▂▁▂▁▁▂▃▂▂▂▂▂▂▂▁▃▂▁▂▂▂▂▁▂▁▂▂▁█▁▂
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▃▂▂▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▂▁▂▁▁▁▁▁▁▁▁▁▁▂▁▁▁▂▁▁▁

0,1
eval/acc,0.99647
eval/loss,0.00753
eval/runtime,1041.8976
eval/samples_per_second,1080.555
eval/steps_per_second,4.221
total_flos,3.832118723463782e+16
train/epoch,1.0
train/global_step,4398.0
train/grad_norm,0.03698
train/learning_rate,0.0




[92mTraining adapter for partition 15[0m


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss
20,0.5391
40,0.3485
60,0.1695
80,0.0812
100,0.0647
120,0.0387
140,0.0297
160,0.0442
180,0.0563
200,0.0143


TrainOutput(global_step=4398, training_loss=0.016954366793655275, metrics={'train_runtime': 2048.7245, 'train_samples_per_second': 549.526, 'train_steps_per_second': 2.147, 'total_flos': 3.832118723463782e+16, 'train_loss': 0.016954366793655275, 'epoch': 1.0})


{'eval_loss': 0.002545964205637574, 'eval_acc': 0.9989598766419027, 'eval_runtime': 1042.1891, 'eval_samples_per_second': 1080.253, 'eval_steps_per_second': 4.22, 'epoch': 1.0}
[92mSaving adapter to /vol/data/adapters/PubMedBERT/S20Rel_LP_Rel/pubmedbert_S20Rel_LP_Rel_partition_15_epoch_1_2024-07-11 02:59:32[0m


0,1
eval/acc,▁
eval/loss,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/grad_norm,▄▄▂▃▃▁█▁▇▄▂▃▂▂▄▂▂▂▁▂▁▅▁█▁▄▂▁▁▁▁▁▂▁▃▁▁▂▂▁
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval/acc,0.99896
eval/loss,0.00255
eval/runtime,1042.1891
eval/samples_per_second,1080.253
eval/steps_per_second,4.22
total_flos,3.832118723463782e+16
train/epoch,1.0
train/global_step,4398.0
train/grad_norm,0.13026
train/learning_rate,0.0




[92mTraining adapter for partition 16[0m


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss
20,0.6225
40,0.4185
60,0.1983
80,0.0904
100,0.0688
120,0.0512
140,0.0458
160,0.0453
180,0.0451
200,0.0525


TrainOutput(global_step=4398, training_loss=0.020305528221190274, metrics={'train_runtime': 2049.5492, 'train_samples_per_second': 549.305, 'train_steps_per_second': 2.146, 'total_flos': 3.832118723463782e+16, 'train_loss': 0.020305528221190274, 'epoch': 1.0})


{'eval_loss': 0.005877979565411806, 'eval_acc': 0.9983532120359415, 'eval_runtime': 1040.8113, 'eval_samples_per_second': 1081.683, 'eval_steps_per_second': 4.226, 'epoch': 1.0}
[92mSaving adapter to /vol/data/adapters/PubMedBERT/S20Rel_LP_Rel/pubmedbert_S20Rel_LP_Rel_partition_16_epoch_1_2024-07-11 03:51:11[0m


0,1
eval/acc,▁
eval/loss,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/grad_norm,▅▃▂▃▁▂▁▂▂▁▁█▂▂▁▅▁▁▃▁▁▂▄▂▃▁▃▁▁▂▁▁▁▇▁▁▁▁▁▁
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval/acc,0.99835
eval/loss,0.00588
eval/runtime,1040.8113
eval/samples_per_second,1081.683
eval/steps_per_second,4.226
total_flos,3.832118723463782e+16
train/epoch,1.0
train/global_step,4398.0
train/grad_norm,0.0182
train/learning_rate,0.0




[92mTraining adapter for partition 17[0m


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss
20,0.3989
40,0.0947
60,0.0058
80,0.0035
100,0.002
120,0.0004
140,0.0004
160,0.0002
180,0.0001
200,0.0002


TrainOutput(global_step=4398, training_loss=0.0023831287777394176, metrics={'train_runtime': 2047.3087, 'train_samples_per_second': 549.906, 'train_steps_per_second': 2.148, 'total_flos': 3.832118723463782e+16, 'train_loss': 0.0023831287777394176, 'epoch': 1.0})


{'eval_loss': 8.597752412242698e-07, 'eval_acc': 1.0, 'eval_runtime': 1042.7521, 'eval_samples_per_second': 1079.67, 'eval_steps_per_second': 4.218, 'epoch': 1.0}
[92mSaving adapter to /vol/data/adapters/PubMedBERT/S20Rel_LP_Rel/pubmedbert_S20Rel_LP_Rel_partition_17_epoch_1_2024-07-11 04:42:50[0m


0,1
eval/acc,▁
eval/loss,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/grad_norm,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval/acc,1.0
eval/loss,0.0
eval/runtime,1042.7521
eval/samples_per_second,1079.67
eval/steps_per_second,4.218
total_flos,3.832118723463782e+16
train/epoch,1.0
train/global_step,4398.0
train/grad_norm,4e-05
train/learning_rate,0.0




[92mTraining adapter for partition 18[0m


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss
20,0.4361
40,0.2165
60,0.0519
80,0.0147
100,0.0023
120,0.0053
140,0.0092
160,0.0026
180,0.0006
200,0.0009


TrainOutput(global_step=4398, training_loss=0.0037257085705447433, metrics={'train_runtime': 2047.1098, 'train_samples_per_second': 549.96, 'train_steps_per_second': 2.148, 'total_flos': 3.832118723463782e+16, 'train_loss': 0.0037257085705447433, 'epoch': 1.0})


{'eval_loss': 4.134516711928882e-05, 'eval_acc': 0.9999857882376348, 'eval_runtime': 1042.837, 'eval_samples_per_second': 1079.582, 'eval_steps_per_second': 4.217, 'epoch': 1.0}
[92mSaving adapter to /vol/data/adapters/PubMedBERT/S20Rel_LP_Rel/pubmedbert_S20Rel_LP_Rel_partition_18_epoch_1_2024-07-11 05:34:28[0m


0,1
eval/acc,▁
eval/loss,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/grad_norm,█▂▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval/acc,0.99999
eval/loss,4e-05
eval/runtime,1042.837
eval/samples_per_second,1079.582
eval/steps_per_second,4.217
total_flos,3.832118723463782e+16
train/epoch,1.0
train/global_step,4398.0
train/grad_norm,0.00018
train/learning_rate,0.0




[92mTraining adapter for partition 19[0m


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss
20,0.534
40,0.1899
60,0.0823
80,0.0428
100,0.0335
120,0.0297
140,0.0143
160,0.0239
180,0.0133
200,0.0182


TrainOutput(global_step=4398, training_loss=0.008720280402778339, metrics={'train_runtime': 2051.7582, 'train_samples_per_second': 548.714, 'train_steps_per_second': 2.144, 'total_flos': 3.832118723463782e+16, 'train_loss': 0.008720280402778339, 'epoch': 1.0})


{'eval_loss': 0.0012482249876484275, 'eval_acc': 0.9994901530251513, 'eval_runtime': 1045.3945, 'eval_samples_per_second': 1076.941, 'eval_steps_per_second': 4.207, 'epoch': 1.0}
[92mSaving adapter to /vol/data/adapters/PubMedBERT/S20Rel_LP_Rel/pubmedbert_S20Rel_LP_Rel_partition_19_epoch_1_2024-07-11 06:26:05[0m


0,1
eval/acc,▁
eval/loss,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/grad_norm,▃▂▂▂▃█▁▂▂▂▃▁▁▁▁▁▁▁▂▁▁▁▁▂▁▁▂▁▁▁▁▁▁▁▁▁▃▁▁▁
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▂▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁

0,1
eval/acc,0.99949
eval/loss,0.00125
eval/runtime,1045.3945
eval/samples_per_second,1076.941
eval/steps_per_second,4.207
total_flos,3.832118723463782e+16
train/epoch,1.0
train/global_step,4398.0
train/grad_norm,0.00211
train/learning_rate,0.0


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [9:28:17<00:00, 3099.74s/it]


In [None]:
train_dataset(save_path="/vol/data/adapters/BioBERT/S20Rel_LP_Rel", 
              datasets=full_lp_rel_biobert, 
              num_classes=[2] * total_rel,
              training_args=training_args_biobert, 
              adapter_name="biobert_S20Rel_LP_Rel", 
              model_name="dmis-lab/biobert-v1.1")



[92mTraining adapter for partition 0[0m


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss
20,0.3773
40,0.1905
60,0.0661
80,0.0134
100,0.0049
120,0.0056
140,0.0011
160,0.0003
180,0.0016
200,0.0006


TrainOutput(global_step=8796, training_loss=0.0016352417855824168, metrics={'train_runtime': 4101.1415, 'train_samples_per_second': 549.032, 'train_steps_per_second': 2.145, 'total_flos': 7.610400892365005e+16, 'train_loss': 0.0016352417855824168, 'epoch': 2.0})


{'eval_loss': 6.924245212758251e-07, 'eval_acc': 1.0, 'eval_runtime': 1043.2603, 'eval_samples_per_second': 1079.144, 'eval_steps_per_second': 4.216, 'epoch': 2.0}
[92mSaving adapter to /vol/data/adapters/BioBERT/S20Rel_LP_Rel/biobert_S20Rel_LP_Rel_partition_0_epoch_2_2024-07-11 07:17:51[0m


0,1
eval/acc,▁
eval/loss,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/grad_norm,▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval/acc,1.0
eval/loss,0.0
eval/runtime,1043.2603
eval/samples_per_second,1079.144
eval/steps_per_second,4.216
total_flos,7.610400892365005e+16
train/epoch,2.0
train/global_step,8796.0
train/grad_norm,1e-05
train/learning_rate,0.0




[92mTraining adapter for partition 1[0m


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss
20,0.3912
40,0.1601
60,0.0583
80,0.022
100,0.0158
120,0.0142
140,0.0099
160,0.0018
180,0.0012
200,0.0021


TrainOutput(global_step=8796, training_loss=0.001784430009855478, metrics={'train_runtime': 4108.6248, 'train_samples_per_second': 548.032, 'train_steps_per_second': 2.141, 'total_flos': 7.610400892365005e+16, 'train_loss': 0.001784430009855478, 'epoch': 2.0})


{'eval_loss': 2.5048602765309624e-05, 'eval_acc': 0.9999911176485218, 'eval_runtime': 1048.8362, 'eval_samples_per_second': 1073.407, 'eval_steps_per_second': 4.193, 'epoch': 2.0}
[92mSaving adapter to /vol/data/adapters/BioBERT/S20Rel_LP_Rel/biobert_S20Rel_LP_Rel_partition_1_epoch_2_2024-07-11 08:43:44[0m


0,1
eval/acc,▁
eval/loss,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/grad_norm,▅▂▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/learning_rate,███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▂▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval/acc,0.99999
eval/loss,3e-05
eval/runtime,1048.8362
eval/samples_per_second,1073.407
eval/steps_per_second,4.193
total_flos,7.610400892365005e+16
train/epoch,2.0
train/global_step,8796.0
train/grad_norm,4e-05
train/learning_rate,0.0




[92mTraining adapter for partition 2[0m


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss
20,0.4921
40,0.1434
60,0.0324
80,0.0321
100,0.0169
120,0.0156
140,0.0067
160,0.0173
180,0.0221
200,0.013


TrainOutput(global_step=8796, training_loss=0.0038998152203057975, metrics={'train_runtime': 4102.9596, 'train_samples_per_second': 548.788, 'train_steps_per_second': 2.144, 'total_flos': 7.610400892365005e+16, 'train_loss': 0.0038998152203057975, 'epoch': 2.0})


In [None]:
train_dataset(save_path="/vol/data/adapters/SciBERT/S20Rel_LP_Rel", 
              datasets=full_lp_rel_scibert, 
              num_classes=[2] * total_rel,
              training_args=training_args_scibert_pubmedbert, 
              adapter_name="scibert_S20Rel_LP_Rel", 
              model_name="allenai/scibert_scivocab_uncased")

In [None]:
train_dataset(save_path="/vol/data/adapters/PubMedBERT/S20Rel_EP", 
              datasets=group_datasets_ep_pubmedbert, 
              num_classes=part_ent_counts_ep,
              training_args=training_args_scibert_pubmedbert, 
              adapter_name="pubmedbert_S20Rel_EP", 
              model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")

In [None]:
train_dataset(save_path="/vol/data/adapters/BioBERT/S20Rel_EP", 
              datasets=group_datasets_ep_biobert, 
              num_classes=part_ent_counts_ep,
              training_args=training_args_biobert, 
              adapter_name="biobert_S20Rel_EP", 
              model_name="dmis-lab/biobert-v1.1")

In [None]:
train_dataset(save_path="/vol/data/adapters/SciBERT/S20Rel_EP", 
              datasets=group_datasets_ep_scibert, 
              num_classes=part_ent_counts_ep,
              training_args=training_args_scibert_pubmedbert, 
              adapter_name="scibert_S20Rel_EP", 
              model_name="allenai/scibert_scivocab_uncased")

In [None]:
train_dataset(save_path="/vol/data/adapters/BioBERT/S20Rel_LP", 
              datasets=group_datasets_lp_biobert, 
              num_classes=part_ent_counts_lp,
              training_args=training_args_biobert, 
              adapter_name="biobert_S20Rel_LP", 
              model_name="dmis-lab/biobert-v1.1")

In [None]:
train_dataset(save_path="/vol/data/adapters/SciBERT/S20Rel_LP", 
              datasets=group_datasets_lp_scibert, 
              num_classes=part_ent_counts_lp,
              training_args=training_args_scibert, 
              adapter_name="scibert_S20Rel_LP", 
              model_name="allenai/scibert_scivocab_uncased")

In [None]:
train_dataset(save_path="/vol/data/adapters/BioBERT/S20Rel_LP_NB", 
              datasets=group_datasets_lp_no_border_biobert, 
              num_classes=part_ent_counts_lp_nb,
              training_args=training_args_biobert, 
              adapter_name="biobert_S20Rel_LP_NB", 
              model_name="dmis-lab/biobert-v1.1")

In [None]:
train_dataset(save_path="/vol/data/adapters/SciBERT/S20Rel_LP_NB", 
              datasets=group_datasets_lp_no_border_scibert, 
              num_classes=part_ent_counts_lp_nb,
              training_args=training_args_scibert, 
              adapter_name="scibert_S20Rel_LP_NB", 
              model_name="allenai/scibert_scivocab_uncased")

In [None]:
train_dataset(save_path="/vol/data/adapters/BioBERT/S20Rel_EP_NB", 
              datasets=group_datasets_ep_no_border_biobert, 
              num_classes=part_ent_counts_ep_nb,
              training_args=training_args_biobert, 
              adapter_name="biobert_S20Rel_EP_NB", 
              model_name="dmis-lab/biobert-v1.1")

In [None]:
train_dataset(save_path="/vol/data/adapters/SciBERT/S20Rel_EP_NB", 
              datasets=group_datasets_ep_no_border_scibert, 
              num_classes=part_ent_counts_ep_nb,
              training_args=training_args_scibert, 
              adapter_name="scibert_S20Rel_EP_NB", 
              model_name="allenai/scibert_scivocab_uncased")