In [5]:
#imports
import numpy as np
import pandas as pd

#torch imports
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from early_stopping_pytorch import EarlyStopping


#transformer imports
from transformers import DistilBertModel, DistilBertTokenizer
from transformers import DataCollatorWithPadding
from transformers import DistilBertModel
from transformers import pipeline

#sklearn imports
from sklearn.preprocessing import OneHotEncoder

#plotting
import matplotlib.pyplot as plt

#misc. imports
import sys
import os
from tqdm import tqdm
import time
import logging

#set up logging config
logging.basicConfig(
    level=logging.DEBUG,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)

In [11]:
# Add project root to sys.path
script_dir = os.path.dirname(os.path.abspath("__main__"))
project_root = os.path.dirname(script_dir)
sys.path.append(project_root)

#data loader imports
from Dataset.data_loaders import create_dataloaders

2025-02-25 18:10:35,307 - INFO - PyTorch version 2.1.0.dev20230610 available.


In [12]:
#generate dataloaders
train_dataloader, valid_dataloader, test_dataloader = create_dataloaders(batch_size=64)

2025-02-25 18:12:30,075 - INFO - Loading tokenizer...

2025-02-25 18:12:30,080 - DEBUG - Starting new HTTPS connection (1): huggingface.co:443
2025-02-25 18:12:30,351 - DEBUG - https://huggingface.co:443 "HEAD /distilbert-base-uncased/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
2025-02-25 18:12:30,424 - INFO - Successfully loaded tokenizer.

2025-02-25 18:12:30,425 - INFO - Loading Datasets...

2025-02-25 18:12:30,426 - INFO - Checking Split Choice: Train...
2025-02-25 18:12:30,427 - INFO - Selected valid split.

2025-02-25 18:12:30,430 - INFO - Attempting to load data

2025-02-25 18:12:30,454 - INFO - Successfully Loaded Dataset.

2025-02-25 18:12:30,455 - INFO - Checking Split Choice: Valid...
2025-02-25 18:12:30,456 - INFO - Selected valid split.

2025-02-25 18:12:30,457 - INFO - Attempting to load data

2025-02-25 18:12:30,491 - INFO - Successfully Loaded Dataset.

2025-02-25 18:12:30,492 - INFO - Checking Split Choice: Test...
2025-02-25 18:12:30,493 - INFO - Selected valid

In [13]:
class SimpleClassifierModel(nn.Module):
    
    def __init__(self, simplified_model, output_dim=3, dropout_rate=0.3):
        super(SimpleClassifierModel, self).__init__()
        
        self.pretrained = simplified_model
        self.activation = nn.LeakyReLU()
        self.dropout_layer = nn.Dropout(dropout_rate)
        self.linear_1 = nn.Linear(self.pretrained.config.hidden_size, 768)
        self.linear_2 = nn.Linear(768, 256)
        self.linear_3 = nn.Linear(256, output_dim)
        self.layer_norm_1 = nn.BatchNorm1d(self.pretrained.config.hidden_size)
        self.layer_norm_2 = nn.BatchNorm1d(256)
        
        
    def forward(self, x):
        #pretrained model out put, extracting the last hidden layer
        initial_out = self.pretrained(input_ids = x["input_ids"], attention_mask = x["attention_mask"]).last_hidden_state[:,0,:]
        
        #first fully connected block
        x = self.linear_1(initial_out)
        x = self.layer_norm_1(x)
        x = self.activation(x)
        x = self.dropout_layer(x)

        #second block
        x = self.linear_2(x)
        x = self.layer_norm_2(x)
        x = self.activation(x)
        x = self.dropout_layer(x)

        #third block
        output = self.linear_3(x)
        return output