In [1]:
import pandas as pd
import torch
import transformers
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertModel, DistilBertTokenizer
from torch import cuda

# Verify if cuda is enabled
device = 'cuda' if cuda.is_available() else 'cpu'
print(f'Runs on {device}')

Runs on cpu


In [2]:
MAX_LEN = 512 # Max tokens
TRAIN_BATCH_SIZE = 8
VALID_BATCH_SIZE = 20
EPOCHS = 2
LEARNING_RATE = 1e-05
THRESHOLD = 0.5
TRAIN_BACKBONE = True # Specify if we train the backbone (DistilBERT) and the head or only the head

In [3]:
# Init wandb

import wandb

wandb.login()
wandb.init(
    # set the wandb project where this run will be logged
    project="wat-distilbert",
    
    # track hyperparameters and run metadata
    config={
        "learning_rate": LEARNING_RATE,
        "architecture": "distilbert",
        "epochs": EPOCHS,
        "train_distilbert": TRAIN_BACKBONE,
    }
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: maherwshwshny (mlfortm). Use `wandb login --relogin` to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

In [5]:
id2label = {0: "Irrelevant", 1: "Relevant"}
label2id = {"Irrelevant": 0, "Relevant": 1}

df = pd.read_csv('C:\\Users\\DELL 5540\\Desktop\\few-shot-learning-LLM-for-irrelevance-classification-main\\data\\WaTA_dataset.csv', encoding = "ISO-8859-1")
df['Class'] = df['Class'].apply(label2id.get)
df.head(10)

Unnamed: 0,Sentence,Class
0,The party sends a warrant possession request a...,1
1,The Client Service Back Office as part of the ...,1
2,Then the SCT Warrant Possession is forwarded t...,1
3,The SCT physical file is stored by the Back Of...,1
4,When the report is received the respective SCT...,1
5,Then Back Office attaches the new SCT document...,1
6,After that some other MC internal staff receiv...,1
7,As a basic principle ACME AG receives invoices...,0
8,These are received by the Secretariat in the c...,1
9,In ACME Financial Accounting a software specia...,1


In [6]:
# Check dataset balance between relevant and irrelevant sentences

def verify_data_balance(df):
    count = df['Class'].value_counts()

    print(f'Number of irrelevant sentences: {count[0]}')
    print(f'Number of relevant sentences: {count[1]}')
    print(f'Percentage of irrelevant: {count[0] / (count[0] + count[1])}')
    print(f'Percentage of relevant: {count[1] / (count[0] + count[1])}')
    
verify_data_balance(df)

Number of irrelevant sentences: 6071
Number of relevant sentences: 19040
Percentage of irrelevant: 0.24176655648918802
Percentage of relevant: 0.758233443510812


In [7]:
# Define the torch dataset

def tokenize(sentence, tokenizer):
    inputs = tokenizer.encode_plus(
        sentence,
        None,
        add_special_tokens=True,
        max_length=MAX_LEN,
        padding='max_length',
        truncation=True,
        return_token_type_ids=True
    )
    ids = inputs['input_ids']
    mask = inputs['attention_mask']
    return torch.tensor(ids, dtype=torch.long), torch.tensor(mask, dtype=torch.long)

class RelevanceDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.len = len(data)
        self.data = data
        self.tokenizer = tokenizer
        
    def __getitem__(self, index):
        sentence = self.data.Sentence[index]
        
        ids, mask = tokenize(sentence, self.tokenizer)

        return {
            'ids': ids,
            'mask': mask,
            'sentence': sentence,
            'targets': torch.tensor(self.data.Class[index], dtype=torch.float)
        } 
    
    def __len__(self):
        return self.len

In [8]:
# Split dataset into training and test set and instantiate datasets for torch

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')

train_size = 0.9
train_dataset = df.sample(frac=train_size,random_state=200)
test_dataset = df.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)

print('Train data balance:')
verify_data_balance(train_dataset)
print('Test data balance:')
verify_data_balance(test_dataset)

print("FULL Dataset: {}".format(df.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))

training_set = RelevanceDataset(train_dataset, tokenizer)
testing_set = RelevanceDataset(test_dataset, tokenizer)



Train data balance:
Number of irrelevant sentences: 5462
Number of relevant sentences: 17138
Percentage of irrelevant: 0.24168141592920353
Percentage of relevant: 0.7583185840707964
Test data balance:
Number of irrelevant sentences: 609
Number of relevant sentences: 1902
Percentage of irrelevant: 0.24253285543608125
Percentage of relevant: 0.7574671445639187
FULL Dataset: (25111, 2)
TRAIN Dataset: (22600, 2)
TEST Dataset: (2511, 2)


In [9]:
# Create data loaders, one for training and another for testing

train_params = {
    'batch_size': TRAIN_BATCH_SIZE,
    'shuffle': True,
    'num_workers': 0
}

test_params = {
    'batch_size': VALID_BATCH_SIZE,
    'shuffle': True,
    'num_workers': 0
}

training_loader = DataLoader(training_set, **train_params)
testing_loader = DataLoader(testing_set, **test_params)

In [10]:
# Define the binary classification model to discriminate between relevant or irrelevant sentences.
# The model uses DistilBERT as a backbone and a binary classification head

class BinaryClassifier(torch.nn.Module):
    def __init__(self):
        super(BinaryClassifier, self).__init__()
        self.backbone = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.head = torch.nn.Sequential(
            torch.nn.Linear(768, 768),
            torch.nn.ReLU(),
            torch.nn.Linear(768, 1),
            torch.nn.Sigmoid()
        )

    def forward(self, input_ids, attention_mask):
        backbone_out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = backbone_out[0]
        return self.head(hidden_state[:, 0])
    
    def unfreeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = True
            
    def freeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = False

In [11]:
# Instantiate classification model and push it to the GPU

model = BinaryClassifier()
model.to(device)



BinaryClassifier(
  (backbone): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Li

In [None]:
# Run this cell if you want to use pretrained classification model
# saved_model = wandb.restore('classifier.bin', run_path="wat-distilbert/ru3wl9xi")8p404q8d
saved_model = wandb.restore('classifier.bin', run_path="wat-distilbert/a1ibmen6") # Run comic-surf-18
#saved_tokenizer = wandb.restore('tokenizer.bin', run_path="wat-distilbert/ru3wl9xi")
model = torch.load(saved_model.name)
model.to(device)
#tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
#loaded_tokenizer = loaded_tokenizer.load_vocabulary(saved_tokenizer.name)

In [13]:
# Define loss and optimizer
loss_function = torch.nn.BCELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)

In [14]:
# Define metrics

def compute_accuracy(tp, tn, fp, fn):
    return (tp + tn) / (tp + tn + fp + fn)

def compute_precision(tp, fp):
    return tp / (tp + fp)

def compute_recall(tp, fn):
    return tp / (tp + fn)

def compute_f1(tp, fn, fp):
    return tp / (tp + (fn + fp) / 2)

def pred_to_class(pred, threshold=0.5):
    return (pred >= THRESHOLD).float() 

In [15]:
# Define method to visualize the results

def test_example(model, tokenizer, nb_relevant=20, nb_irrelevant=20):
    for i in range(len(testing_set)):
        test_data = testing_set[i]
        ids, mask, sentence, target = test_data['ids'], test_data['mask'], test_data['sentence'], test_data["targets"]
        if target.item() == 1:
            if nb_relevant > 0:
                nb_relevant -= 1
            else:
                continue
        elif target.item() == 0:
            if nb_irrelevant > 0:
                nb_irrelevant -= 1
            else:
                continue
        ids = ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        outputs = model(ids, mask)
        pred_class = pred_to_class(outputs, THRESHOLD)
        print(sentence)
        print(f"pred: {id2label[pred_class.item()]}, target: {id2label[target.item()]}")
    
test_example(model, tokenizer)

The SCT physical file is stored by the Back Office awaiting a report to be sent by the Police
pred: Irrelevant, target: Relevant
Then Back Office attaches the new SCT document and stores the expanded SCT physical file
pred: Irrelevant, target: Relevant
Based on the statements of the cost center managers she will proceed with the clarification with the vendor but if necessary she consults the cost center managers by telephone or e mail again
pred: Irrelevant, target: Relevant
When all inconsistencies are resolved the copy of the invoice is sent to the cost center managers again and the process continues
pred: Irrelevant, target: Relevant
Once the requirement is registered the request is received by the immediate supervisor of the employee requesting the vacation
pred: Irrelevant, target: Relevant
If the request is asked to make a change then it is returned to the petitioner employee who can review the comments for the change request
pred: Irrelevant, target: Relevant
Here the informatio

In [16]:
# EVALUATION

def evaluate():
    print("=========EVAL=========")
    model.eval()
    tp, fp, fn, tn = 0, 0, 0, 0
    total_loss = 0
    nb_steps = 0
    with torch.no_grad():
        for i, data in enumerate(testing_loader):
            ids = data['ids'].to(device, dtype = torch.long)
            mask = data['mask'].to(device, dtype = torch.long)
            targets = data['targets'].to(device, dtype = torch.float).reshape(-1, 1)

            outputs = model(ids, mask)
            loss = loss_function(outputs, targets)
            total_loss += loss.item()

            pred_class = pred_to_class(outputs, THRESHOLD)
            t_preds, t_targets = pred_class == 1, targets == 1
            f_preds, f_targets = pred_class == 0, targets == 0
            tp += (t_preds & t_targets).sum().item()
            fp += (t_preds & f_targets).sum().item()
            fn += (f_preds & t_targets).sum().item()
            tn += (f_preds & f_targets).sum().item()

            nb_steps += 1

            if i % 100 == 0:
                print(f"Validation Loss per 100 steps: {loss.item()}")
                print(f"Validation Accuracy per 100 steps: {compute_accuracy(tp, tn, fp, fn)}")
                print(f"Validation Precision per 100 steps: {compute_precision(tp, fp)}")
                print(f"Validation Recall per 100 steps: {compute_recall(tp, fn)}")
                print(f"Validation f1 per 500 steps: {compute_f1(tp, fn, fp)}")

    avg_loss = total_loss / nb_steps
    accuracy = compute_accuracy(tp, tn, fp, fn)
    precision = compute_precision(tp, fp)
    recall = compute_recall(tp, fn)
    f1 = compute_f1(tp, fn, fp)
    wandb.log({"eval_loss": avg_loss, "eval_accuracy": accuracy, "eval_precision": precision, "eval_recall": recall, "eval_f1": f1})
    print(f"Avg loss: {avg_loss}")
    print(f"Accuracy: {accuracy}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"f1: {f1}")
    
    print("=========END EVAL=========")

In [17]:
# TRAINING

if TRAIN_BACKBONE:
    model.unfreeze_backbone()

for i in range(EPOCHS):
    total_loss = 0
    nb_steps = 0
    tp, fp, fn, tn = 0, 0, 0, 0 
    model.train()
    for j,data in enumerate(training_loader):
        ids = data['ids'].to(device, dtype = torch.long)
        mask = data['mask'].to(device, dtype = torch.long)
        targets = data['targets'].to(device, dtype = torch.float).reshape(-1, 1)

        outputs = model(ids, mask)
        loss = loss_function(outputs, targets)
        total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        pred_class = pred_to_class(outputs, THRESHOLD)
        t_preds, t_targets = pred_class == 1, targets == 1
        f_preds, f_targets = pred_class == 0, targets == 0
        tp += (t_preds & t_targets).sum().item()
        fp += (t_preds & f_targets).sum().item()
        fn += (f_preds & t_targets).sum().item()
        tn += (f_preds & f_targets).sum().item()
        
        nb_steps += 1
        
        if j > 0 and j % 500 == 0:
            print(f"Training Loss per 500 steps: {loss.item()}")
            print(f"Training Accuracy per 500 steps: {compute_accuracy(tp, tn, fp, fn)}")
            print(f"Training Precision per 500 steps: {compute_precision(tp, fp)}")
            print(f"Training Recall per 500 steps: {compute_recall(tp, fn)}")
            print(f"Training f1 per 500 steps: {compute_f1(tp, fn, fp)}")
    
    avg_loss = total_loss / nb_steps
    accuracy = compute_accuracy(tp, tn, fp, fn)
    precision = compute_precision(tp, fp)
    recall = compute_recall(tp, fn)
    f1 = compute_f1(tp, fn, fp)
    wandb.log({"loss": avg_loss, "accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1})
    print(f"Epoch {i}, avg loss: {avg_loss}")
    print(f"Epoch {i}, accuracy: {accuracy}")
    print(f"Epoch {i}, precision: {precision}")
    print(f"Epoch {i}, recall: {recall}")
    print(f"Epoch {i}, f1: {f1}")
    evaluate()
    
    
    test_example(model, tokenizer)

Training Loss per 500 steps: 0.15832293033599854
Training Accuracy per 500 steps: 0.7592315369261478
Training Precision per 500 steps: 0.7655154505323293
Training Recall per 500 steps: 0.9794019933554817
Training f1 per 500 steps: 0.8593499489870281
Training Loss per 500 steps: 0.37487879395484924
Training Accuracy per 500 steps: 0.7667332667332667
Training Precision per 500 steps: 0.7832100794302931
Training Recall per 500 steps: 0.9525316455696202
Training f1 per 500 steps: 0.8596122050202916
Training Loss per 500 steps: 0.6147935390472412
Training Accuracy per 500 steps: 0.780896069287142
Training Precision per 500 steps: 0.8003163673583326
Training Recall per 500 steps: 0.9466211754347348
Training f1 per 500 steps: 0.8673423082740886
Training Loss per 500 steps: 0.6158429980278015
Training Accuracy per 500 steps: 0.7905422288855573
Training Precision per 500 steps: 0.8124510292755894
Training Recall per 500 steps: 0.9406234537357744
Training f1 per 500 steps: 0.8718517103000191
Tra

In [18]:
# Evaluate current model
evaluate()

Validation Loss per 100 steps: 0.48011723160743713
Validation Accuracy per 100 steps: 0.8
Validation Precision per 100 steps: 0.8
Validation Recall per 100 steps: 0.9230769230769231
Validation f1 per 500 steps: 0.8571428571428571
Validation Loss per 100 steps: 0.1415138840675354
Validation Accuracy per 100 steps: 0.8524752475247525
Validation Precision per 100 steps: 0.9096306068601583
Validation Recall per 100 steps: 0.8954545454545455
Validation f1 per 500 steps: 0.9024869109947644
Avg loss: 0.3235193288752011
Accuracy: 0.8554360812425329
Precision: 0.9060686015831134
Recall: 0.9027339642481599
f1: 0.9043982091124572


In [19]:
# Save model to wandb

import os

out_model = os.path.join(wandb.run.dir, "classifier.bin")
out_tokenizer = os.path.join(wandb.run.dir, "tokenizer.bin")

torch.save(model, out_model)
tokenizer.save_vocabulary(out_tokenizer)

wandb.save('classifier.bin')
wandb.save('tokenizer.bin')

out_model_pt = os.path.join(wandb.run.dir, "classifier.pt")
torch.save(model.state_dict(), out_model_pt)

wandb.save('classifier.pt')

['c:\\Users\\DELL 5540\\Desktop\\few-shot-learning-LLM-for-irrelevance-classification-main\\wandb\\run-20240516_145006-hdbgjret\\files\\classifier.pt']

In [20]:
test_example(model, tokenizer, 8, 8)

The SCT physical file is stored by the Back Office awaiting a report to be sent by the Police
pred: Relevant, target: Relevant
Then Back Office attaches the new SCT document and stores the expanded SCT physical file
pred: Relevant, target: Relevant
Based on the statements of the cost center managers she will proceed with the clarification with the vendor but if necessary she consults the cost center managers by telephone or e mail again
pred: Relevant, target: Relevant
When all inconsistencies are resolved the copy of the invoice is sent to the cost center managers again and the process continues
pred: Relevant, target: Relevant
Once the requirement is registered the request is received by the immediate supervisor of the employee requesting the vacation
pred: Relevant, target: Relevant
If the request is asked to make a change then it is returned to the petitioner employee who can review the comments for the change request
pred: Relevant, target: Relevant
Here the information on the for

In [21]:
wandb.finish()

VBox(children=(Label(value='3.857 MB of 511.150 MB uploaded\r'), FloatProgress(value=0.007546268960314534, max…

0,1
accuracy,▁█
eval_accuracy,▁██
eval_f1,▁██
eval_loss,█▁▁
eval_precision,▁██
eval_recall,█▁▁
f1,▁█
loss,█▁
precision,▁█
recall,█▁

0,1
accuracy,0.85898
eval_accuracy,0.85544
eval_f1,0.9044
eval_loss,0.32352
eval_precision,0.90607
eval_recall,0.90273
f1,0.90891
loss,0.31651
precision,0.89081
recall,0.92776


# Evaluation with different thresholds

| THRESHOLD | RUN   | Avg Loss               | Accuracy               | Precision             | Recall                | F1                    |
|-----------|-------|------------------------|------------------------|-----------------------|-----------------------|-----------------------|
| 0.5       | run-1 | 0.7038330654244102     | 0.8570290720828355     | 0.8728854519091348    | 0.9495268138801262    | 0.9095945605640896    |
| 0.5       | run-2 | 0.831273452559438      | 0.8410991636798089     | 0.9219539584503088    | 0.8633017875920084    | 0.8916644040184633    |
| 0.5       | run-3 | 0.5872668912708168     | 0.8614097968936678     | 0.8940162271805274    | 0.926919032597266     | 0.9101703665462054    |
| 0.5       | run-4 | 0.9387986349253287     | 0.8566308243727598     | 0.8933673469387755    | 0.9206098843322819    | 0.9067840497151735    |
| 0.5       | run-6 | 0.3630746433008758     | 0.8546395858223815     | 0.881011403073872     | 0.9342797055730809    | 0.9068639959173258    |
| 0.5       | run-7 | 0.3518915092129083     | 0.8630027877339705     | 0.8841222879684418    | 0.9426919032597266    | 0.9124681933842239    |
| 0.6       | run-1 | 0.7061429985850636     | 0.8554360812425329     | 0.8759159745969711    | 0.9426919032597266    | 0.9080779944289693    |
| 0.7       | run-1 | 0.7037493148331239     | 0.8566308243727598     | 0.8801775147928994    | 0.9384858044164038    | 0.9083969465648855    |
| 0.8       | run-1 | 0.7040119320611536     | 0.8566308243727598     | 0.8847305389221557    | 0.9321766561514195    | 0.9078341013824884    |
| 0.9       | run-1 | 0.7037559155104978     | 0.8610115491835922     | 0.8955680081507896    | 0.9242902208201893    | 0.9097024579560156    |
| 0.9       | run-7 | 0.35272171691296594    | 0.8351254480286738     | 0.9355971896955504    | 0.8401682439537329    | 0.8853185595567867    |
| 0.95      | run-1 | 0.7070154592197108     | 0.8582238152130626     | 0.9017671517671517    | 0.9121976866456362    | 0.9069524307370622    |