In [1]:
import numpy as np
import pandas as pd
from simpletransformers.classification import MultiLabelClassificationModel, MultiLabelClassificationArgs

In [2]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score

In [3]:
train_data = pd.read_csv('BC7-LitCovid-Train.csv')

In [4]:
train_data.head()

Unnamed: 0,pmid,journal,title,abstract,keywords,pub_type,authors,doi,label
0,32519164,J Thromb Thrombolysis,Potential role for tissue factor in the pathog...,"In December 2019, a new and highly contagious ...",covid-19;il-6;sars-cov-2;tnf-alpha;thrombosis;...,Journal Article;Review,"Bautista-Vargas, Mario;Bonilla-Abadia, Fabio;C...",10.1007/s11239-020-02172-x,Treatment;Mechanism
1,32691006,J Tradit Complement Med,Dietary therapy and herbal medicine for COVID-...,"A novel coronavirus disease (COVID-19), transm...",covid-19;coronavirus;dietary therapy;herbal me...,Journal Article;Review,"Panyod, Suraphan;Ho, Chi-Tang;Sheen, Lee-Yan",10.1016/j.jtcme.2020.05.004,Treatment;Prevention
2,32858315,J Affect Disord,First report of manic-like symptoms in a COVID...,"BACKGROUND: In December 2019, the novel corona...",cerebrospinal fluid;igg;manic-like symptoms;sa...,Case Reports;Journal Article,"Lu, Shaojia;Wei, Ning;Jiang, Jiajun;Wu, Lingli...",10.1016/j.jad.2020.08.031,Case Report
3,32985329,J Dent Res,Epidemiological Investigation of OHCWs with CO...,During the coronavirus disease 2019 (COVID-19)...,dental education;dental public health;infectio...,"Journal Article;Research Support, Non-U.S. Gov't","Meng, L;Ma, B;Cheng, Y;Bian, Z",10.1177/0022034520962087,Prevention
4,32812051,J Antimicrob Chemother,The impact of sofosbuvir/daclatasvir or ribavi...,OBJECTIVES: Sofosbuvir and daclatasvir are dir...,,Journal Article;Randomized Controlled Trial;Re...,"Eslami, Gholamali;Mousaviasl, Sajedeh;Radmanes...",10.1093/jac/dkaa331,Treatment


In [5]:
train_data.shape

(24960, 9)

In [24]:
processed_train_data = pd.DataFrame()
def split_semicolon(string):
    return str(string).split(';')

processed_train_data['labels']=train_data['label'].apply(split_semicolon)
# processed_train_data['text']=train_data['abstract']

In [25]:
for i in range(len(train_data)):
    s = '[CLS]'+train_data.loc[i,'title']+'[SEP]'+train_data.loc[i,'abstract']
    processed_train_data.loc[i,'text'] = s
    processed_train_data.loc[i,'len'] = len(s.split())

In [27]:
np.count_nonzero(processed_train_data['len'] >= 512)

163

In [28]:
label_mlb = MultiLabelBinarizer()
label_mle = label_mlb.fit_transform(processed_train_data['labels'])
print(label_mle.shape)
print(label_mlb.classes_)

(24960, 7)
['Case Report' 'Diagnosis' 'Epidemic Forecasting' 'Mechanism' 'Prevention'
 'Transmission' 'Treatment']


In [29]:
label_mle.sum(axis=0)

array([ 2063,  6193,   645,  4438, 11102,  1088,  8717])

In [30]:
processed_train_data['labels'] = label_mle.tolist()
processed_train_data.head()

Unnamed: 0,labels,text,len
0,"[0, 0, 0, 1, 0, 0, 1]",[CLS]Potential role for tissue factor in the p...,169.0
1,"[0, 0, 0, 0, 1, 0, 1]",[CLS]Dietary therapy and herbal medicine for C...,198.0
2,"[1, 0, 0, 0, 0, 0, 0]",[CLS]First report of manic-like symptoms in a ...,243.0
3,"[0, 0, 0, 0, 1, 0, 0]",[CLS]Epidemiological Investigation of OHCWs wi...,305.0
4,"[0, 0, 0, 0, 0, 0, 1]",[CLS]The impact of sofosbuvir/daclatasvir or r...,256.0


In [31]:
X_train, X_test, Y_train, Y_test = train_test_split(processed_train_data['text'], processed_train_data['labels'],test_size=1000)

In [32]:
train = pd.DataFrame()
train['text'] = X_train
train['labels'] = Y_train

test = pd.DataFrame()
test['text'] = X_test
test['labels'] = Y_test

train.head()

Unnamed: 0,text,labels
14703,[CLS]Zoonotic and reverse zoonotic events of S...,"[0, 0, 0, 0, 1, 1, 0]"
7355,[CLS]Challenges of SARS-CoV-2 and lessons lear...,"[0, 0, 0, 0, 1, 0, 0]"
23159,[CLS]Risk of SARS-CoV-2 infection among contac...,"[0, 0, 0, 0, 1, 1, 0]"
23180,[CLS]Operational Considerations for Physical T...,"[0, 0, 0, 0, 1, 0, 0]"
1674,[CLS]Acidic electrolyzed water potently inacti...,"[0, 0, 0, 0, 1, 0, 0]"


In [14]:
# train.to_csv('train_data.csv')
# test.to_csv('test_data.csv')

In [33]:
from sklearn.metrics import accuracy_score

In [34]:
def weighted_f1(labels, preds, threshold=0.5):
  """ Converts probabilities to labels using the [threshold] and calculates metrics. 
  Parameters ---------- labels preds threshold 
  Returns ------- """ 
  preds[preds > threshold] = 1
  preds[preds <= threshold] = 0 

  scores = f1_score(labels, preds, average='weighted') 

  #print("Scores: " ,scores)
  return scores

In [35]:
model_args = MultiLabelClassificationArgs(
    num_train_epochs=2,
    evaluate_each_epoch=True,
    overwrite_output_dir= True,
    # evaluate_during_training=True,
    # save_model_every_epoch = True
)

In [36]:
model = MultiLabelClassificationModel(
    "bert",
    "bert-base-uncased",
    num_labels=7,
    use_cuda=False,
    args=model_args
)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMultiLabelSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForMultiLabelSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMultiLabelSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMultiLabelSequenceClassification were not 

In [37]:
# temp_train = train.iloc[:10,:]
# temp_test = test.iloc[:10,:]

In [38]:
model.train_model(train, eval_df = test)

 10%|█         | 1/10 [00:06<00:55,  6.20s/it]
Epochs 0/2. Running Loss:    0.6668: 100%|██████████| 2/2 [00:07<00:00,  3.53s/it]
Epochs 1/2. Running Loss:    0.6229: 100%|██████████| 2/2 [00:07<00:00,  3.59s/it]
Epoch 2 of 2: 100%|██████████| 2/2 [00:42<00:00, 21.35s/it]


(4, 0.649782195687294)

In [39]:
result, model_outputs, wrong_predictions = model.eval_model(test,f1=weighted_f1)

 10%|█         | 1/10 [00:05<00:46,  5.21s/it]
Running Evaluation: 100%|██████████| 2/2 [00:01<00:00,  1.11it/s]
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
