In [1]:
from transformers import BertForSequenceClassification, BertTokenizer, get_linear_schedule_with_warmup
from torch.optim import AdamW
from src.data.make_dataset import DataProcessor
from src.models.train_model import train_model, evaluate_model

##### Load the data

In [2]:
# Path to the dataset
file_path = '../data/raw/filtered.tsv'

# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

##### Preprocess the data and create data loaders

In [3]:
# Set the size of the dataset
dataset_size = 500000

# Create the data loaders
data_processor = DataProcessor(file_path, tokenizer, dataset_size)
train_loader, val_loader = data_processor.process()

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\sokos\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\sokos\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


##### Train the model

In [4]:
# Initialize the model and optimizer
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
optimizer = AdamW(model.parameters(), lr=5e-5)

# Initialize the learning rate scheduler
epochs = 10
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(train_loader)*epochs)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
# Train the model
train_model(model, train_loader, optimizer, scheduler, epochs)

  0%|          | 0/1407 [00:00<?, ?it/s]

  0%|          | 0/1407 [00:00<?, ?it/s]

##### Evaluate the model

In [6]:
evaluate_model(model, val_loader)

Validation Accuracy: 0.8492
Validation Recall: 0.8687815428983418
Validation Precision: 0.8607142857142858
Validation F1: 0.8647290993900253


(0.8492, 0.8687815428983418, 0.8607142857142858, 0.8647290993900253)