In [1]:
import random
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel, EvalPrediction
from models.modeling_moebert import MoEBertForSentenceSimilarity, BertForSentenceSimilarity
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, hamming_loss, confusion_matrix
from models.load_model import MoEBertLoadWeights
from data_zoo import *
from utils import get_yaml, log_metrics, load_model
from trainer import HF_trainer

In [2]:
from models.losses import clip_loss

In [3]:
losses = []
for _ in range(1000):
    a = torch.rand(8, 768)
    b = torch.rand(8, 768)
    losses.append(clip_loss(a, b).item())

sum(losses) / len(losses)

6.842775687217713

In [4]:
import torch
from torch.utils.data import Dataset as TorchDataset
from datasets import load_dataset


def data_collator(features):
    batch = {key: torch.stack([f[key] for f in features]) for key in features[0]}
    return batch


class TextDataset(TorchDataset):
    def __init__(self, a, b, c_labels, r_labels, tokenizer, domains, add_tokens, max_length=512):
        self.a = a
        self.b = b
        self.c_labels = c_labels
        self.r_labels = r_labels
        self.tokenizer = tokenizer
        self.domains = domains
        self.max_length = max_length
        self.add_tokens = add_tokens

    def __len__(self):
        return len(self.a)

    def __getitem__(self, idx): # Maybe need a version for non MOE
        r_label = torch.tensor(self.r_labels[idx], dtype=torch.long)
        c_label = torch.tensor(self.c_labels[idx], dtype=torch.float)
        tokenized_a = self.tokenizer(self.a[idx],
                                     return_tensors='pt',
                                     padding='max_length',
                                     truncation=True,
                                     max_length=self.max_length)
        tokenized_b = self.tokenizer(self.b[idx],
                                     return_tensors='pt',
                                     padding='max_length',
                                     truncation=True,
                                     max_length=self.max_length)
        if self.add_tokens:
            domain_token = self.tokenizer(self.domains[int(r_label.item())],
                                          add_special_tokens=False).input_ids[0]  # get the domain token
            tokenized_a['input_ids'][0][0] = domain_token  # replace the cls token with the domain token
            tokenized_b['input_ids'][0][0] = domain_token  # replace the cls token with the domain token
        return {
            'input_ids_a': tokenized_a['input_ids'].squeeze(),
            'attention_mask_a': tokenized_a['attention_mask'].squeeze(),
            'input_ids_b': tokenized_b['input_ids'].squeeze(),
            'attention_mask_b': tokenized_b['attention_mask'].squeeze(),
            'labels': c_label,
            'r_labels': r_label
        }


def get_datasets_train(args, tokenizer):
    data_paths = args['data_paths']
    domains = args['domains']
    add_tokens = args['new_special_tokens']
    max_length = args['max_length']
    a_col = args['a_col']
    b_col = args['b_col']
    label_col = args['label_col']

    train_a, train_b, train_c_label, train_r_label = [], [], [], []
    valid_a, valid_b, valid_c_label, valid_r_label = [], [], [], []
    test_a, test_b, test_c_label, test_r_label = [], [], [], []
    for i, data_path in enumerate(data_paths):
        dataset = load_dataset(data_path)
        train = dataset['train']
        valid = dataset['valid']
        test = dataset['test']
        train_a.extend(train[a_col])
        train_b.extend(train[b_col])
        train_c_label.extend(train[label_col])
        train_r_label.extend([i] * len(train[label_col]))
        valid_a.extend(valid[a_col])
        valid_b.extend(valid[b_col])
        valid_c_label.extend(valid[label_col])
        valid_r_label.extend([i] * len(valid[label_col]))
        test_a.extend(test[a_col])
        test_b.extend(test[b_col])
        test_c_label.extend(test[label_col])
        test_r_label.extend([i] * len(test[label_col]))
    train_dataset = TextDataset(train_a, train_b, train_c_label, train_r_label,
                                tokenizer, domains, add_tokens, max_length)
    valid_dataset = TextDataset(valid_a, valid_b, valid_c_label, valid_r_label,
                                tokenizer, domains, add_tokens,  max_length)
    test_dataset = TextDataset(test_a, test_b, test_c_label, test_r_label,
                               tokenizer, domains, add_tokens,  max_length)
    return train_dataset, valid_dataset, test_dataset

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

yargs = get_yaml('yamls/SE/copd.yaml')
args = yargs['general_args']

In [7]:
model, tokenizer = load_model(args)
model = model.to(device).half()

BertForSentenceSimilarity(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(31090, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.05, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, el

In [8]:
train_dataset, valid_dataset, test_dataset = get_datasets_train(args, tokenizer)

In [13]:
for k, v in train_dataset[0].items():
    if 'input' in k:
        print(tokenizer.decode(v))

[CLS] chronic obstructive pulmonary disease ( copd ) is a systemic disease., several long non - coding rnas ( lncrnas ) have been identified to be aberrantly expressed in copd patients., this study investigated the role of lncrna cancer susceptibility candidate 2 ( casc2 ) in copd, as well as its potential mechanism., fifty smokers with copd and another 50 smokers without copd were recruited., receiver operating characteristic curve was constructed to assess the diagnostic value of casc2 in copd patients., 16hbe cells were treated with cigarette smoke extract ( cse ) to establish a cell model. qrt - pcr was used for the measurement of mrna levels., the cell viability and apoptosis were detected by using cell counting kit - 8 and flow cytometry assay., enzyme - linked immunosorbent assay was performed to detect the levels of proinflammatory cytokines., luciferase reporter assay was performed for the target gene analysis., serum casc2 was dramatically decreased in copd patients compared 

In [9]:
batch = [{k:v.to(device) for k,v in train_dataset[i].items()} for i in range(4)]
batch = data_collator(batch)

In [13]:
batch['attention_mask_a'].shape

torch.Size([4, 512])

In [20]:
out = model(**batch)

In [21]:
out.loss

tensor(7.6094, device='cuda:0', dtype=torch.float16, grad_fn=<DivBackward0>)