In [1]:
import os
import re
import torch
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import get_scheduler
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import bibtexparser
from bibtexparser.bparser import BibTexParser

In [2]:
saved_model_dir = "./SavedModels/"

In [3]:
train_val_dataset = pd.read_csv("./data/title_abstract_train_val_dataset.csv")
eval_dataset = pd.read_csv("./data/title_abstract_eval_dataset.csv")

In [4]:
train_val_dataset.tail(10)

Unnamed: 0,Title_and_Abstract,Accepted_for_Full_Text
990,Guest Editorial: An End-to-End Machine Learnin...,0
991,Guest Editorial Special Issue on Privacy and S...,0
992,Guest Editorial Special Issue on Emerging Tren...,0
993,Guest Editorial Special Issue on Advanced Cogn...,0
994,Enhancing Smart Agriculture Scenarios with Low...,1
995,High Voltage Discharge Exhibits Severe Effect ...,0
996,Heterogeneous GNN-RL-Based Task Offloading for...,1
997,Optimized Data Fusion With Scheduled Rest Peri...,0
998,FarmEdge: A Unified Edge Computing Framework E...,1
999,"5G Network: Architecture, Protocols, Challenge...",0


In [5]:
eval_dataset.head(5)

Unnamed: 0,Title_and_Abstract,Accepted_for_Full_Text
0,Influence of artificial intelligence (AI) on f...,0
1,Software engineering approaches for tinyml bas...,0
2,Artificial intelligence in practice: how 50 su...,0
3,Industry 4.0: Industrial internet of things (I...,0
4,Artificial intelligence and biological misuse:...,0


In [6]:
train_val_dataset.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 2 columns):
 #   Column                  Non-Null Count  Dtype 
---  ------                  --------------  ----- 
 0   Title_and_Abstract      996 non-null    object
 1   Accepted_for_Full_Text  1000 non-null   int64 
dtypes: int64(1), object(1)
memory usage: 15.8+ KB


In [7]:
eval_dataset.drop("Accepted_for_Full_Text", axis = 1, inplace=True)

In [8]:
eval_dataset

Unnamed: 0,Title_and_Abstract
0,Influence of artificial intelligence (AI) on f...
1,Software engineering approaches for tinyml bas...
2,Artificial intelligence in practice: how 50 su...
3,Industry 4.0: Industrial internet of things (I...
4,Artificial intelligence and biological misuse:...
...,...
10498,Wireless Sensor Network Based Greenhouse Monit...
10499,agroString 2.0: A Distributed-Ledger based Sma...
10500,Performance of Routing Protocol for Low-Power ...
10501,Churn-Tolerant Leader Election Protocols


In [9]:
train_data, val_data = train_test_split(train_val_dataset, test_size=0.2, random_state=42)
train_data.shape, val_data.shape

((800, 2), (200, 2))

In [10]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

In [11]:
train_encodings = tokenizer(train_data['Title_and_Abstract'].astype(str).tolist(), truncation=True, padding=True, max_length=512)
val_encodings = tokenizer(val_data['Title_and_Abstract'].astype(str).tolist(), truncation=True, padding=True, max_length=512)
test_encodings = tokenizer(eval_dataset['Title_and_Abstract'].astype(str).tolist(), truncation=True, padding=True, max_length=512)

In [12]:
class TrainValDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)



class TestDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings['input_ids'])


In [13]:
# Create the PyTorch datasets
train_dataset = TrainValDataset(train_encodings, train_data['Accepted_for_Full_Text'].tolist())
val_dataset = TrainValDataset(val_encodings, val_data['Accepted_for_Full_Text'].tolist())

test_dataset = TestDataset(test_encodings)

In [14]:
# Create DataLoaders for train, validation, and test
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, generator=torch.Generator().manual_seed(42))
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [15]:
def set_seed(seed=42):
    """Set all random seeds to a fixed value and possibly disable nondeterministic algorithms."""
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [16]:
set_seed()

model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)

# Choose device available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

optimizer = AdamW(model.parameters(), lr=5e-5)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:
def train_and_validate(model, train_dataloader, val_dataloader, optimizer, device, epochs=3, saved_models_path=saved_model_dir):
    """Train and validate the model.
    
    Args:
        model (torch.nn.Module): The model to train and validate.
        train_dataloader (DataLoader): DataLoader for training data.
        val_dataloader (DataLoader): DataLoader for validation data.
        optimizer (torch.optim.Optimizer): Optimizer for the model.
        device (torch.device): Device to run the model computation.
        epochs (int): Number of epochs to train the model.
        saved_models_path (str): Directory path where the model will be saved.
    """
    model = model.to(device)
    best_val_accuracy = 0

    # Get the model name from its class
    model_name = model.__class__.__name__

    total_steps = len(train_dataloader) * epochs
    scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps
    )

    for epoch in range(epochs):
        # Training Phase
        model.train()
        total_loss = 0
        for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{epochs} - Training"):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_dataloader)
        print(f"Average training loss: {avg_train_loss:.4f}")

        # Validation Phase
        model.eval()
        total_val_loss = 0
        total_correct = 0
        total_examples = 0

        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc=f"Epoch {epoch + 1}/{epochs} - Validation"):
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(**batch)
                loss = outputs.loss
                total_val_loss += loss.item()

                logits = outputs.logits
                predictions = torch.argmax(logits, dim=1)
                total_correct += (predictions == batch['labels']).sum().item()
                total_examples += batch['labels'].size(0)

        avg_val_loss = total_val_loss / len(val_dataloader)
        val_accuracy = total_correct / total_examples
        print(f"Average validation loss: {avg_val_loss:.4f}")
        print(f"Validation accuracy: {val_accuracy:.4f}")

        # Save the best model
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            # Construct the filename for saving the model
            save_path = os.path.join(saved_models_path, f"{model_name}_best_model.pth")
            torch.save(model.state_dict(), save_path)
            print(f"Saved improved model at {save_path}")

In [18]:
train_and_validate(
    model=model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    optimizer=optimizer,
    device=device,
    epochs=6
)

Epoch 1/6 - Training: 100%|██████████| 50/50 [00:24<00:00,  2.06it/s]


Average training loss: 0.5278


Epoch 1/6 - Validation: 100%|██████████| 13/13 [00:02<00:00,  6.26it/s]


Average validation loss: 0.4467
Validation accuracy: 0.8050
Saved improved model at ./SavedModels/DistilBertForSequenceClassification_best_model.pth


Epoch 2/6 - Training: 100%|██████████| 50/50 [00:24<00:00,  2.06it/s]


Average training loss: 0.3249


Epoch 2/6 - Validation: 100%|██████████| 13/13 [00:02<00:00,  6.20it/s]


Average validation loss: 0.3343
Validation accuracy: 0.8550
Saved improved model at ./SavedModels/DistilBertForSequenceClassification_best_model.pth


Epoch 3/6 - Training: 100%|██████████| 50/50 [00:24<00:00,  2.04it/s]


Average training loss: 0.1727


Epoch 3/6 - Validation: 100%|██████████| 13/13 [00:02<00:00,  6.14it/s]


Average validation loss: 0.4581
Validation accuracy: 0.8450


Epoch 4/6 - Training: 100%|██████████| 50/50 [00:24<00:00,  2.04it/s]


Average training loss: 0.0929


Epoch 4/6 - Validation: 100%|██████████| 13/13 [00:02<00:00,  6.13it/s]


Average validation loss: 0.4306
Validation accuracy: 0.8550


Epoch 5/6 - Training: 100%|██████████| 50/50 [00:24<00:00,  2.03it/s]


Average training loss: 0.0379


Epoch 5/6 - Validation: 100%|██████████| 13/13 [00:02<00:00,  6.13it/s]


Average validation loss: 0.5010
Validation accuracy: 0.8550


Epoch 6/6 - Training: 100%|██████████| 50/50 [00:24<00:00,  2.04it/s]


Average training loss: 0.0203


Epoch 6/6 - Validation: 100%|██████████| 13/13 [00:02<00:00,  6.13it/s]

Average validation loss: 0.5064
Validation accuracy: 0.8500



