## Encrypted Transformers 

This notebook documents the steps involved in training and using an encrypted instance of transformers. The model training is done in PyTorch, and the weights are then encrypted using VENumpy. Synthetic medical training data is generated with ChatGPT.

In [36]:
# Import Dependencies
import sys
sys.path.append('..')

from transformers import DataCollatorForLanguageModeling, AutoTokenizer, DataCollatorWithPadding, BertTokenizerFast

import time
import numpy as np
import os
import pickle
from tqdm.notebook import tqdm
import venumpy
import torch
import pandas as pd
# import math

from venumMLlib.deep_learning.transformer.transformer import *
from venumMLlib.venum_tools import *
from venumMLlib.approx_functions import *

In [37]:
# Import the transformer data folder
data_folder = "../demos/transformer_demo/data/"
data = pd.read_csv(data_folder + "chatgpt_medical_reports_rare_diseases.csv",encoding='utf-8',index_col=False)

data.columns = ['idx','condition', 'text']
label_mapping = dict(zip(data.condition.unique(),range(data.condition.nunique())))
data['label'] = data['condition'].map(label_mapping)

In [38]:
# Reverse the dictionary
reversed_label_mapping = {value: key for key, value in label_mapping.items()}

## VENumpy Instance
Set security level and precision

In [39]:
ctx = venumpy.SecretContext.new_with_security(128)
ctx.precision= 6

## Load weights from pre-trained model

In this step, we load the weights from the pre-trained model and then encrypt them using VENumpy and `venum_tools` `encrypt_array` method.

In [42]:
model_path = "../demos/transformer_demo/model/"
state_dict = torch.load(model_path + 'medical_2heads.pth')

In [43]:
encrypted_state_dict = {}

In [44]:
num_heads = 2

In [45]:
for k in tqdm(state_dict.keys()):
    weight = state_dict[k].T.numpy()
    encrypted_state_dict[k] = encrypt_array(weight,ctx)

  0%|          | 0/21 [00:00<?, ?it/s]

In [46]:
embedding_weights = state_dict['embeddings.weight'].numpy()

## Model hyper parameters
Next, we retrieve and print the model hyperparameters to understand the configuration.

In [13]:
max_seq_len = 20
d_model = 8
num_heads = 2
d_ff = 32

## Tokenizer
Load the tokenizer associated with the pre-trained model. In this case, using BERT. Prepare your synthetic medical data for tokenization based on the model's requirements. We use a tokenizer to process the synthetic medical data. This tokenizer converts text inputs into token IDs that can be fed into the model.

In [47]:
# Load a tokenizer (example using BERT tokenizer)
tokenizer = BertTokenizerFast.from_pretrained(  "bert-base-uncased")

In [48]:
# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_seq_len)

In [16]:
def tokenize_input(sentence, tokenizer, max_seq_len):
    inputs = tokenizer.encode_plus(
        sentence,
        add_special_tokens=True,
        max_length=max_seq_len,
        padding='max_length',
        return_tensors='pt',
        truncation=True
    )
    # Convert the tensor to a numpy array
    input_ids_numpy = inputs['input_ids'].squeeze().numpy()
    return input_ids_numpy


In [17]:
vocab_size = tokenizer.vocab_size
class_size = len(reversed_label_mapping)
batch_size = 1000

## Softmax 
We pass the tokenized input through the model, and then apply softmax to get the probabilities of each class.

In [18]:
def log_softmax(x):
    # Subtract the max for numerical stability
    x_max = np.max(x, axis=-1, keepdims=True)
    
    # Log of the sum of exponentials of the input elements
    log_sum_exp = np.log(np.sum(np.exp(x - x_max), axis=-1, keepdims=True))
    
    return x - x_max - log_sum_exp

In [19]:
def texts_to_batch_indices(texts, max_seq_len):
    """
    Converts a batch of texts to a list of lists of indices based on a vocabulary dictionary.

    Args:
    - texts (list of str): The input texts.
    - vocab_dict (dict): The vocabulary dictionary mapping words to indices.

    Returns:
    - list of list: List of lists of indices representing the texts.
    """
    batch_indices = []
    for text in texts:
        indices = tokenize_input(text,tokenizer, max_seq_len)
        batch_indices.append(indices)
    return batch_indices


## Inference Text
Create the prompt listing of symptoms used to perform the final diagnosis, then encrypt it.

In [20]:
texts = ["Liver Function Tests (LFTs): Results indicating liver dysfunction (AST, ALT, Bilirubin elevated)"]
print(texts[0])

Liver Function Tests (LFTs): Results indicating liver dysfunction (AST, ALT, Bilirubin elevated)


In [21]:
batch_indices = texts_to_batch_indices(texts, max_seq_len)

In [22]:
batch_size = len(texts)

In [23]:
embeddings = Embeddings(embedding_weights)
embedding_output = embeddings.forward(batch_indices,batch_size,max_seq_len)
embedding_output= encrypt_array(embedding_output,ctx)

In [24]:
print(embedding_output[0][0])

[<venumpy.Ciphertext object at 0x33dc13990>
 <venumpy.Ciphertext object at 0x37bd54890>
 <venumpy.Ciphertext object at 0x33edac990>
 <venumpy.Ciphertext object at 0x33dc0bb10>
 <venumpy.Ciphertext object at 0x34202cbd0>
 <venumpy.Ciphertext object at 0x33a5c7a90>
 <venumpy.Ciphertext object at 0x342006790>
 <venumpy.Ciphertext object at 0x342004090>]


## Encrypted Transformer Class
The encrypted weights (`encrypted_state_dict`) are passed as an argument to the Encrypted Transformer Class (`TransformerModule`) along with the hyperparameters listed previously.

In [25]:
transformer = TransformerModule(encrypted_state_dict, max_seq_len=max_seq_len, d_model=8, num_heads=num_heads, d_ff=32, vocab_size=vocab_size)

In [26]:
output_linear = transformer.forward(embedding_output, ctx, batch_size)

In [27]:
output = log_softmax([[i.decrypt() for i in batch] for batch in output_linear])

## Decryption
Decrypt the output of the encrypted transformer

In [28]:
output = (decrypt_array(output_linear))

## Predicted Disease in Plaintext
Finally, we print out the predicted class of the supplied list of symptoms based on the highest probability from the softmax output.

In [29]:
probs = log_softmax(output)
predicted_class_idx = np.argmax(probs)
predicted_class = reversed_label_mapping[predicted_class_idx]# for i in predicted_class_idx]
print(predicted_class)

Wilson's Disease
