In [1]:
import pandas as pd
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [2]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
df = pd.read_csv('purposes_added_metadata_v2.csv')

In [4]:
# Label mapping
label_names = [
    "Methodologies or Experimental Designs",
    "Theoretical Frameworks or Models",
    "Research Gaps",
    "Quantitative Data and Analysis",
    "Literature Review",
    "Comparative Works",
    "Exploring Interdisciplinary Connection"
]
num_labels = len(label_names)

In [5]:
# Convert labels to multi-hot encoding
def labels_to_multihot(labels):
    if labels == "[]":  # Unlabelled
        return None
    labels = eval(labels)  # Convert string list to actual list
    multihot = [1 if label_names[i] in labels else 0 for i in range(num_labels)]
    return multihot

In [7]:
df["multihot_labels"] = df["tags"].apply(labels_to_multihot)

In [8]:
labelled_df = df[df["TagCount"] > 0].copy()
unlabelled_df = df[df["TagCount"] == 0].copy()

In [10]:
train_df, val_df = train_test_split(labelled_df, test_size=0.2, random_state=42)
print(f"Train size: {len(train_df)}, Validation size: {len(val_df)}, Test (unlabelled) size: {len(unlabelled_df)}")

Train size: 5508, Validation size: 1377, Test (unlabelled) size: 2433


In [11]:
# Custom Dataset
class ArticleDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len=128):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        abstract = str(self.data.iloc[idx]["abstract"])
        encoding = self.tokenizer(
            abstract,
            add_special_tokens=True,
            max_length=self.max_len,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
        
        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()
        
        if self.data.iloc[idx]["TagCount"] > 0:
            labels = torch.tensor(self.data.iloc[idx]["multihot_labels"], dtype=torch.float)
        else:
            labels = torch.zeros(num_labels)  # Placeholder for unlabelled
        
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

In [12]:
# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [13]:
# Create datasets
train_dataset = ArticleDataset(train_df, tokenizer)
val_dataset = ArticleDataset(val_df, tokenizer)
test_dataset = ArticleDataset(unlabelled_df, tokenizer)

In [14]:
# Create dataloaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [30]:
class MultiLabelBERT(nn.Module):
    def __init__(self, num_labels):
        super(MultiLabelBERT, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(768, num_labels)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]  # [CLS] token
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

In [40]:
# Initialise model
model = MultiLabelBERT(num_labels=num_labels).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.BCEWithLogitsLoss()  # Binary cross-entropy for multi-label

In [41]:
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    correct_preds = 0
    total_preds = 0
    
    for batch in tqdm(loader, desc="Training"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        
        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        preds = torch.sigmoid(logits) > 0.5
        correct_preds += torch.sum(preds == labels.bool()).item()
        total_preds += labels.numel()
        
    avg_loss = total_loss / len(loader)
    accuracy = correct_preds / total_preds
    return avg_loss, accuracy

def eval_epoch(model, loader, criterion):
    model.eval()
    total_loss = 0
    correct_preds = 0
    total_preds = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validation"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            
            total_loss += loss.item()
            
            preds = torch.sigmoid(logits) > 0.5
            correct_preds += torch.sum(preds == labels.bool()).item()
            total_preds += labels.numel()
    
    avg_loss = total_loss / len(loader)
    accuracy = correct_preds / total_preds
    return avg_loss, accuracy

In [42]:
epochs = 20
for epoch in range(epochs):
    print(f"\nEpoch {epoch + 1}/{epochs}")
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc = eval_epoch(model, val_loader, criterion)
    
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")


Epoch 1/20


Training: 100%|██████████| 345/345 [00:48<00:00,  7.11it/s]
Validation: 100%|██████████| 87/87 [00:07<00:00, 11.83it/s]


Train Loss: 0.4385, Train Acc: 0.8018
Val Loss: 0.3988, Val Acc: 0.8283

Epoch 2/20


Training: 100%|██████████| 345/345 [00:48<00:00,  7.09it/s]
Validation: 100%|██████████| 87/87 [00:07<00:00, 11.88it/s]


Train Loss: 0.3746, Train Acc: 0.8379
Val Loss: 0.3769, Val Acc: 0.8370

Epoch 3/20


Training: 100%|██████████| 345/345 [00:48<00:00,  7.08it/s]
Validation: 100%|██████████| 87/87 [00:07<00:00, 11.90it/s]


Train Loss: 0.3308, Train Acc: 0.8625
Val Loss: 0.3670, Val Acc: 0.8431

Epoch 4/20


Training: 100%|██████████| 345/345 [00:48<00:00,  7.08it/s]
Validation: 100%|██████████| 87/87 [00:07<00:00, 11.96it/s]


Train Loss: 0.2762, Train Acc: 0.8906
Val Loss: 0.3794, Val Acc: 0.8450

Epoch 5/20


Training: 100%|██████████| 345/345 [00:48<00:00,  7.09it/s]
Validation: 100%|██████████| 87/87 [00:07<00:00, 11.92it/s]


Train Loss: 0.2154, Train Acc: 0.9216
Val Loss: 0.4054, Val Acc: 0.8393

Epoch 6/20


Training: 100%|██████████| 345/345 [00:48<00:00,  7.10it/s]
Validation: 100%|██████████| 87/87 [00:07<00:00, 11.92it/s]


Train Loss: 0.1591, Train Acc: 0.9464
Val Loss: 0.4474, Val Acc: 0.8456

Epoch 7/20


Training: 100%|██████████| 345/345 [00:48<00:00,  7.09it/s]
Validation: 100%|██████████| 87/87 [00:07<00:00, 11.88it/s]


Train Loss: 0.1181, Train Acc: 0.9632
Val Loss: 0.4842, Val Acc: 0.8441

Epoch 8/20


Training: 100%|██████████| 345/345 [00:48<00:00,  7.09it/s]
Validation: 100%|██████████| 87/87 [00:07<00:00, 11.95it/s]


Train Loss: 0.0889, Train Acc: 0.9730
Val Loss: 0.5350, Val Acc: 0.8394

Epoch 9/20


Training: 100%|██████████| 345/345 [00:48<00:00,  7.08it/s]
Validation: 100%|██████████| 87/87 [00:07<00:00, 11.98it/s]


Train Loss: 0.0620, Train Acc: 0.9835
Val Loss: 0.5680, Val Acc: 0.8383

Epoch 10/20


Training: 100%|██████████| 345/345 [00:48<00:00,  7.09it/s]
Validation: 100%|██████████| 87/87 [00:07<00:00, 11.97it/s]


Train Loss: 0.0473, Train Acc: 0.9881
Val Loss: 0.5875, Val Acc: 0.8424

Epoch 11/20


Training: 100%|██████████| 345/345 [00:48<00:00,  7.09it/s]
Validation: 100%|██████████| 87/87 [00:07<00:00, 11.94it/s]


Train Loss: 0.0359, Train Acc: 0.9918
Val Loss: 0.6241, Val Acc: 0.8429

Epoch 12/20


Training: 100%|██████████| 345/345 [00:48<00:00,  7.09it/s]
Validation: 100%|██████████| 87/87 [00:07<00:00, 11.99it/s]


Train Loss: 0.0286, Train Acc: 0.9938
Val Loss: 0.6467, Val Acc: 0.8417

Epoch 13/20


Training: 100%|██████████| 345/345 [00:48<00:00,  7.11it/s]
Validation: 100%|██████████| 87/87 [00:07<00:00, 11.63it/s]


Train Loss: 0.0261, Train Acc: 0.9940
Val Loss: 0.6736, Val Acc: 0.8400

Epoch 14/20


Training: 100%|██████████| 345/345 [00:48<00:00,  7.05it/s]
Validation: 100%|██████████| 87/87 [00:07<00:00, 11.61it/s]


Train Loss: 0.0220, Train Acc: 0.9949
Val Loss: 0.6945, Val Acc: 0.8414

Epoch 15/20


Training: 100%|██████████| 345/345 [00:48<00:00,  7.08it/s]
Validation: 100%|██████████| 87/87 [00:07<00:00, 11.92it/s]


Train Loss: 0.0184, Train Acc: 0.9958
Val Loss: 0.7059, Val Acc: 0.8404

Epoch 16/20


Training: 100%|██████████| 345/345 [00:59<00:00,  5.79it/s]
Validation: 100%|██████████| 87/87 [00:11<00:00,  7.77it/s]


Train Loss: 0.0162, Train Acc: 0.9961
Val Loss: 0.7363, Val Acc: 0.8393

Epoch 17/20


Training: 100%|██████████| 345/345 [01:24<00:00,  4.08it/s]
Validation: 100%|██████████| 87/87 [00:11<00:00,  7.90it/s]


Train Loss: 0.0146, Train Acc: 0.9964
Val Loss: 0.7606, Val Acc: 0.8402

Epoch 18/20


Training: 100%|██████████| 345/345 [01:24<00:00,  4.10it/s]
Validation: 100%|██████████| 87/87 [00:11<00:00,  7.86it/s]


Train Loss: 0.0156, Train Acc: 0.9962
Val Loss: 0.7806, Val Acc: 0.8412

Epoch 19/20


Training: 100%|██████████| 345/345 [01:24<00:00,  4.10it/s]
Validation: 100%|██████████| 87/87 [00:11<00:00,  7.86it/s]


Train Loss: 0.0174, Train Acc: 0.9949
Val Loss: 0.8016, Val Acc: 0.8403

Epoch 20/20


Training: 100%|██████████| 345/345 [01:23<00:00,  4.11it/s]
Validation: 100%|██████████| 87/87 [00:10<00:00,  7.91it/s]

Train Loss: 0.0156, Train Acc: 0.9956
Val Loss: 0.8310, Val Acc: 0.8375





In [43]:
def predict(model, loader):
    model.eval()
    predictions = []
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Predicting"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            
            logits = model(input_ids, attention_mask)
            preds = torch.sigmoid(logits) > 0.5
            predictions.extend(preds.cpu().numpy())
    
    return predictions

In [44]:
# Get predictions for unlabelled data
test_preds = predict(model, test_loader)

# Convert predictions to label names
predicted_labels = []
for pred in test_preds:
    labels = [label_names[i] for i in range(num_labels) if pred[i]]
    predicted_labels.append(labels)

# Add predictions to unlabelled_df
unlabelled_df["predicted_labels"] = predicted_labels
print(unlabelled_df[["abstract", "predicted_labels"]].head())

Predicting: 100%|██████████| 153/153 [00:16<00:00,  9.36it/s]

                                               abstract  \
1596  Background: Patients with pre-existing cirrhos...   
1600  The AAD COVID-19 Registry: Crowdsourcing Derma...   
1605  Abstract We present a case of coronavirus dise...   
1607  Background Coronavirus infectious disease 2019...   
1635           Echocardiography in the Time of COVID-19   

                                       predicted_labels  
1596                   [Quantitative Data and Analysis]  
1600                   [Quantitative Data and Analysis]  
1605                                                 []  
1607                   [Quantitative Data and Analysis]  
1635  [Theoretical Frameworks or Models, Comparative...  





In [45]:
unlabelled_df

Unnamed: 0,cord_uid,sha,source_x,title,doi,pmcid,pubmed_id,license,abstract,publish_time,...,pmc_json_files,url,s2_id,referenced_by_count,JournalName_DOI,tags,TagCount,multihot_labels,predicted_labels,predicted_count
1596,4i02f1ji,a08d4b749b4d5b73acfa15344195a097a9699c96,MedRxiv,Clinical course and risk factors for mortality...,10.1101/2020.04.24.20072611,,,medrxiv,Background: Patients with pre-existing cirrhos...,2020-04-28,...,,https://doi.org/10.1101/2020.04.24.20072611,216653286.0,0,,[],0,,[Quantitative Data and Analysis],1
1600,3q6idrmz,b2aef9bcd1a7b70237ed02ef4004c94f86017151,Elsevier; Medline; PMC,The AAD COVID-19 Registry: Crowdsourcing Derma...,10.1016/j.jaad.2020.04.045,PMC7162762,32305438.0,els-covid,The AAD COVID-19 Registry: Crowdsourcing Derma...,2020-04-17,...,,https://www.sciencedirect.com/science/article/...,215790286.0,54,Journal of the American Academy of Dermatology,[],0,,[Quantitative Data and Analysis],1
1605,ys7z7j8j,ff2e6962ae80b7f17be32dd59513bdeb3a2d6e92,Elsevier; Medline; PMC,Keratoconjunctivitis as the initial medical pr...,10.1016/j.jcjo.2020.03.003,PMC7124283,32284146.0,els-covid,Abstract We present a case of coronavirus dise...,2020-04-02,...,document_parses/pmc_json/PMC7124283.xml.json,https://api.elsevier.com/content/article/pii/S...,214758418.0,219,Canadian Journal of Ophthalmology,[],0,,[],0
1607,tgpqnoq2,129a24f0574121f32e5bb5bb3495d4ab684e79b6,MedRxiv,Heparin-induced thrombocytopenia is associated...,10.1101/2020.04.23.20076851,,,medrxiv,Background Coronavirus infectious disease 2019...,2020-04-28,...,,http://medrxiv.org/cgi/content/short/2020.04.2...,216588708.0,42,,[],0,,[Quantitative Data and Analysis],1
1635,bvdyhq6v,e1000a407a8d3581502550fff91c69ffed0f760e,Elsevier; Medline; PMC,Echocardiography in the Time of COVID-19,10.1016/j.echo.2020.04.011,PMC7146691,32503705.0,els-covid,Echocardiography in the Time of COVID-19,2020-04-10,...,,https://www.sciencedirect.com/science/article/...,215550941.0,33,Journal of the American Society of Echocardiog...,[],0,,"[Theoretical Frameworks or Models, Comparative...",1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9303,kb8dz8hd,,Elsevier; Medline; PMC; WHO,Is it super-spreading,10.1016/s0262-4079(20)30375-4,PMC7130545,32287797.0,els-covid,If the covid-19 virus is transmitted largely b...,2020-02-22,...,,https://api.elsevier.com/content/article/pii/S...,213341263.0,4,New Scientist,[],0,,[Theoretical Frameworks or Models],0
9304,4ikzvnmv,f85a9161da699b741dea23835c60dfbff51ce959,Medline; PMC,Direct Diagnostic Tests for Lyme Disease,10.1093/cid/ciy614,PMC6399434,30307486.0,green-oa,Borrelia burgdorferi was discovered to be the ...,2018-10-11,...,document_parses/pmc_json/PMC6399434.xml.json,https://doi.org/10.1093/cid/ciy614; https://ww...,52957133.0,68,Clinical Infectious Diseases,[],0,,[],2
9306,4zg3ms5b,d69fafa5d266102bd2666aac853ef94472ba0a8e,Elsevier; Medline; PMC,Outbreak of a new coronavirus: what anaestheti...,10.1016/j.bja.2020.02.008,PMC7124191,32115186.0,els-covid,Outbreak of a new coronavirus: what anaestheti...,2020-02-27,...,document_parses/pmc_json/PMC7124191.xml.json,https://doi.org/10.1016/j.bja.2020.02.008; htt...,211726841.0,238,British Journal of Anaesthesia,[],0,,[],0
9308,zsfylxav,f093aa0cf1ddef303ec4c049ab0be105587dd92c,Medline; PMC,Mapping Climate Change Vulnerabilities to Infe...,10.1289/ehp.1103805,PMC3295348,22113877.0,cc0,"Background: The incidence, outbreak frequency,...",2011-11-23,...,document_parses/pmc_json/PMC3295348.xml.json,https://www.ncbi.nlm.nih.gov/pubmed/22113877/;...,15853141.0,100,Environmental Health Perspectives,[],0,,"[Methodologies or Experimental Designs, Quanti...",1


In [46]:
unlabelled_df['predicted_count'] = unlabelled_df['predicted_labels'].apply(len)

In [47]:
frequency = unlabelled_df['predicted_count'].value_counts()
frequency

predicted_count
1    1484
0     538
2     365
3      46
Name: count, dtype: int64

In [48]:
from collections import Counter

In [49]:
all_tags = [tag for tags_list in unlabelled_df['predicted_labels'] for tag in tags_list]

tag_frequency = Counter(all_tags)

In [50]:
label_frequency_df = pd.DataFrame(tag_frequency.items(), columns=['Predicted_Labels', 'Frequency']).sort_values(by='Frequency', ascending=False)
label_frequency_df

Unnamed: 0,Predicted_Labels,Frequency
0,Quantitative Data and Analysis,863
4,Methodologies or Experimental Designs,700
1,Theoretical Frameworks or Models,303
3,Research Gaps,179
2,Comparative Works,157
5,Literature Review,137
6,Exploring Interdisciplinary Connection,13


In [51]:
unlabelled_df.to_csv("predicted_purposes.csv", index=False)