In [2]:
import torch
import numpy as np
from transformers import BertTokenizer
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm
Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [28]:
#Read the synthetically create data into a pandas dataframe
df = pd.read_csv("./clare2.csv")

In [29]:
df

Unnamed: 0,question,label
0,How do you cope with feelings of anxiety or pa...,1
1,What strategies do you use to manage symptoms ...,1
2,Can you share some techniques for improving se...,1
3,How do you recognize and address triggers that...,1
4,What are some effective ways to communicate wi...,1
...,...,...
463,What's your favorite way to celebrate a person...,5
464,"If you could have any exotic animal as a pet, ...",5
465,Share a moment when you felt completely at pea...,5
466,What's the most memorable gift you've ever giv...,5


In [5]:
#Create a BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
labels = {1:0,
          2:1,
          3:2,
          4:3,
          5:4
          }
#Create the Dataset class
class Dataset(torch.utils.data.Dataset):

    def __init__(self, df):

        self.labels = [labels[label] for label in df['label']]
        self.texts = [tokenizer(text, 
                               padding='max_length', max_length = 512, truncation=True,
                                return_tensors="pt") for text in df['question']]

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)

    def get_batch_labels(self, idx):
        # Fetch a batch of labels
        return np.array(self.labels[idx])

    def get_batch_texts(self, idx):
        # Fetch a batch of inputs
        return self.texts[idx]

    def __getitem__(self, idx):

        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)

        return batch_texts, batch_y

In [6]:
#split the data into train,val and test.
np.random.seed(112)
df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=42), 
                                     [int(.8*len(df)), int(.9*len(df))])

print(len(df_train),len(df_val), len(df_test))

374 47 47


  return bound(*args, **kwds)


In [13]:
#Create a BERT classifier with a bert-base-cased, and a linear layer actiting as a classification head. A dropout layer is added to avoid overfitting.
from torch import nn
from transformers import BertModel

class BertClassifier(nn.Module):

    def __init__(self, dropout=0.5):

        super(BertClassifier, self).__init__()

        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 5)
        self.relu = nn.ReLU()

    def forward(self, input_id, mask):

        _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        final_layer = self.relu(linear_output)

        return final_layer

In [18]:
from torch.optim import Adam
from tqdm import tqdm
#Function to fine-tune and validate the model performance
def train(model, train_data, val_data, learning_rate, epochs):

    train, val = Dataset(train_data), Dataset(val_data)

    train_dataloader = torch.utils.data.DataLoader(train, batch_size=2, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val, batch_size=2)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr= learning_rate)

    if use_cuda:

            model = model.cuda()
            criterion = criterion.cuda()

    for epoch_num in range(epochs):

            total_acc_train = 0
            total_loss_train = 0

            for train_input, train_label in tqdm(train_dataloader):

                train_label = train_label.to(device)
                mask = train_input['attention_mask'].to(device)
                input_id = train_input['input_ids'].squeeze(1).to(device)

                output = model(input_id, mask)
                
                batch_loss = criterion(output, train_label.long())
                total_loss_train += batch_loss.item()
                
                acc = (output.argmax(dim=1) == train_label).sum().item()
                total_acc_train += acc

                model.zero_grad()
                batch_loss.backward()
                optimizer.step()
            
            total_acc_val = 0
            total_loss_val = 0

            with torch.no_grad():

                for val_input, val_label in val_dataloader:

                    val_label = val_label.to(device)
                    mask = val_input['attention_mask'].to(device)
                    input_id = val_input['input_ids'].squeeze(1).to(device)

                    output = model(input_id, mask)

                    batch_loss = criterion(output, val_label.long())
                    total_loss_val += batch_loss.item()
                    
                    acc = (output.argmax(dim=1) == val_label).sum().item()
                    total_acc_val += acc
            
            print(
                f'Epochs: {epoch_num + 1} | Train Loss: {total_loss_train / len(train_data): .3f} \
                | Train Accuracy: {total_acc_train / len(train_data): .3f} \
                | Val Loss: {total_loss_val / len(val_data): .3f} \
                | Val Accuracy: {total_acc_val / len(val_data): .3f}')

In [19]:
#I trained the Model on colab, using a A100 to save time.
EPOCHS = 10
model = BertClassifier()
LR = 1e-6
              
train(model, df_train, df_val, LR, EPOCHS)

100%|██████████| 187/187 [11:09<00:00,  3.58s/it]


Epochs: 1 | Train Loss:  0.812                 | Train Accuracy:  0.193                 | Val Loss:  0.777                 | Val Accuracy:  0.404


In [9]:
#Load the trained model
def load_checkpoint(load_path, model):
    
    if load_path==None:
        return
    
    state_dict = torch.load(load_path, map_location="cpu")
    print(f'Model loaded from <== {load_path}')
    
    model.load_state_dict(state_dict['model_state_dict'])
    return model

In [15]:
model2 = load_checkpoint("model.pt",model)

Model loaded from <== model.pt


In [16]:
#function to evaluate the model performance.
def evaluate(model, test_data):
    
    test = Dataset(test_data)

    test_dataloader = torch.utils.data.DataLoader(test, batch_size=1)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    outlier_list = []
    y_pred = []

    if use_cuda:
        model = model.cuda()

    total_acc_test = 0
    with torch.no_grad():
        id = 0
        for test_input, test_label in test_dataloader:

              test_label = test_label.to(device)
              mask = test_input['attention_mask'].to(device)
              input_id = test_input['input_ids'].squeeze(1).to(device)
              
              output = model(input_id, mask)
              if (output.argmax(dim=1)!=test_label):
                outlier_list.append(id)
              y_pred.append(output.argmax(dim=1))
                         

              acc = (output.argmax(dim=1) == test_label).sum().item()
              total_acc_test += acc
              id = id+1
    
    print(f'Test Accuracy: {total_acc_test / len(test_data): .3f}')
    return outlier_list,y_pred

In [17]:
outlier_list,y_pred = evaluate(model2, df_test)

Test Accuracy:  1.000


In [26]:
df_test.iloc[2]['question']

'How do I know if my suicidal thoughts are a result of a mental health condition?'

In [27]:
df_test.iloc[2]

question    How do I know if my suicidal thoughts are a re...
label                                                       4
Name: 319, dtype: object