In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import time
import torch
import psutil
from transformers import BertModel, BertTokenizer, BertConfig
from transformers.models.bert.modeling_bert import BertSelfAttention
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions

## Custom Classes

In [3]:
# Custom BertSelfAttention to simulate Grouped Query Attention
class CustomBertSelfAttention(BertSelfAttention):
    def __init__(self, config):
        super().__init__(config)
        # Override the number of attention heads for GQA simulation
        self.num_attention_heads = 8 # Drastically reduce heads to 2 for clear difference. You can still use 8 for noticeable speed difference
        
        self.attention_head_size = int(self.all_head_size / self.num_attention_heads)
        
        self.query = torch.nn.Linear(config.hidden_size, self.all_head_size)
        self.key = torch.nn.Linear(config.hidden_size, self.all_head_size)
        self.value = torch.nn.Linear(config.hidden_size, self.all_head_size)
        
        self.dropout = torch.nn.Dropout(config.attention_probs_dropout_prob)


In [4]:
# Custom BertLayer to replace the self-attention with CustomBertSelfAttention
class CustomBertLayer(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = CustomBertSelfAttention(config)
        self.intermediate = torch.nn.Linear(config.hidden_size, config.intermediate_size)
        self.output = torch.nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self, 
        hidden_states, 
        attention_mask=None, 
        head_mask=None, 
        encoder_hidden_states=None, 
        encoder_attention_mask=None, 
        past_key_value=None, 
        output_attentions=False):
            
        attention_output = self.attention(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)
        
        intermediate_output = self.intermediate(attention_output[0])
        
        layer_output = self.output(intermediate_output)
        layer_output = self.dropout(layer_output)
        layer_output = self.LayerNorm(layer_output + attention_output[0])
        
        return (layer_output,) + attention_output[1:]

In [5]:
# Custom BertEncoder to use the CustomBertLayer
class CustomBertEncoder(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer = torch.nn.ModuleList([CustomBertLayer(config) for _ in range(config.num_hidden_layers)])

    def forward(
            self, 
            hidden_states, 
            attention_mask=None, 
            head_mask=None, 
            encoder_hidden_states=None, 
            encoder_attention_mask=None, 
            past_key_values=None, 
            use_cache=None, 
            output_attentions=False, 
            output_hidden_states=False, 
            return_dict=True):
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        next_decoder_cache = () if use_cache else None

        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_outputs = layer_module(
                hidden_states,
                attention_mask,
                head_mask[i] if head_mask is not None else None,
                encoder_hidden_states,
                encoder_attention_mask,
                past_key_values[i] if past_key_values is not None else None,
                output_attentions,
            )
            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache += (layer_outputs[1],)

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, next_decoder_cache, all_hidden_states, all_attentions]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
        )

In [6]:
# Custom BERT model using the CustomBertEncoder
class CustomBertModel(BertModel):
    def __init__(self, config):
        super().__init__(config)
        self.encoder = CustomBertEncoder(config)

## Utility Functions

In [7]:
# Define a function to measure inference speed
def measure_inference_speed(model, tokenizer, input_text, num_runs=10):
    inputs = tokenizer(input_text, return_tensors='pt')
    model.eval()
    with torch.no_grad():
        start_time = time.time()
        for _ in range(num_runs):
            outputs = model(**inputs)
        end_time = time.time()
    avg_time = (end_time - start_time) / num_runs
    return avg_time

In [8]:
# Compare memory usage
def measure_memory_usage(model, tokenizer, input_text):
    # Tokenize the input text
    inputs = tokenizer(input_text, return_tensors='pt')
    
    # Move inputs to the same device as the model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    inputs = {k: v.to(device) for k, v in inputs.items()}
    model.to(device)
    
    # Run the model to ensure it's loaded into memory
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
    
    if torch.cuda.is_available():
        memory_usage = torch.cuda.memory_allocated()
    else:
        process = psutil.Process()
        memory_usage = process.memory_info().rss  # in bytes

    return memory_usage

## Main Execution Code

In [9]:
# Load the pre-trained BERT model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

In [10]:
# Instantiate the custom BERT model with simulated GQA
config = BertConfig.from_pretrained(model_name)
custom_model = CustomBertModel(config)

In [11]:
# Example input text
input_text = "This is a sample input text to measure the inference speed of the model."

In [12]:
# Measure the inference speed of all models
base_model_speed = measure_inference_speed(model, tokenizer, input_text)
gqa_model_speed = measure_inference_speed(custom_model, tokenizer, input_text)

print(f"Base BERT model inference speed: {base_model_speed:.6f} seconds")
print(f"GQA BERT model inference speed: {gqa_model_speed:.6f} seconds")
print(f"Speedup: {base_model_speed / gqa_model_speed:.2f}x")

Base BERT model inference speed: 0.081608 seconds
GQA BERT model inference speed: 0.022804 seconds
Speedup: 3.58x


In [13]:
# Measure memory usage for both models
base_model_memory = measure_memory_usage(model, tokenizer, input_text)
gqa_model_memory = measure_memory_usage(custom_model, tokenizer, input_text)
memory_diff = base_model_memory - gqa_model_memory

# Print the results
print(f"Base BERT Model Memory Usage: {base_model_memory} bytes")
print(f"GQA BERT Model Memory Usage: {gqa_model_memory} bytes")
print(f"% difference in memory usage: {(memory_diff / base_model_memory) * 100:.2f}%")

Base BERT Model Memory Usage: 1497899008 bytes
GQA BERT Model Memory Usage: 1497903104 bytes
% difference in memory usage: -0.00%


### Notes
- The differences in memory usage may be negligible when running on a CPU. 
    - The memory savings and performance improvements of Grouped Query Attention (GQA) are more noticeable in GPU environments due to the higher parallelism and larger memory capacities involved.
- Larger Model: Switching to a larger model (bert-large-uncased) will make the difference in computational requirements more pronounced.
- Longer Input Text: Using a longer input text will highlight differences in execution time more clearly.
- Inference Speed Focus: Measuring inference speed will likely show a more noticeable difference on the CPU compared to memory usage.