#BERT NLP

##1 Tokenisation

In [None]:
import torch
import matplotlib.pyplot as plt
from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertForQuestionAnswering

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
#Create bert tokeniser
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
tokenizer.vocab_size

In [None]:
#Sample texts
text_1 = "I am good at both chess and boxing"
text_2 = "What am I good at?"

In [None]:
# Tokenization with special tokens: [CLS] at beginning and [SEP] at end
indexed_tokens = tokenizer.encode(text_1, text_2, add_special_tokens=True)
indexed_tokens

In [None]:
tokens = tokenizer.convert_ids_to_tokens(indexed_tokens)
tokens

In [None]:
decoded_text = tokenizer.decode(indexed_tokens)

##Segmentation of the text

In [None]:
cls_token_id = tokenizer.cls_token_id  # [CLS] token
sep_token_id = tokenizer.sep_token_id  # [SEP] token
cls_token_id
sep_token_id

In [None]:
def get_segment_ids(indexed_tokens, sep_token_id=102):
    """
    Create segment IDs to distinguish between different text segments.
    
    Args:
        indexed_tokens: List of token IDs
        sep_token_id: ID of the separator token
        
    Returns:
        segment_ids_tensor: Tensor of segment IDs
        tokens_tensor: Tensor of token IDs
    """
    segment_ids = []
    segment_id = 0
    
    for token in indexed_tokens:
        if token == sep_token_id:
            segment_id += 1
        segment_ids.append(segment_id)
    
    # Last [SEP] token is not considered
    segment_ids[-1] = segment_ids[-1] - 1
    
    # Convert to tensors
    segment_ids_tensor = torch.tensor([segment_ids]).to(device)
    tokens_tensor = torch.tensor([indexed_tokens]).to(device)
    
    return segment_ids_tensor, tokens_tensor


In [None]:
segments_tensors, tokens_tensor = get_segment_ids(indexed_tokens, sep_token_id)

In [None]:
plt.figure(figsize=(15, 3))
plt.bar(range(len(tokens)), segments_tensors[0].cpu().numpy(), align='center')
plt.xticks(range(len(tokens)), tokens, rotation=45)
plt.ylabel('Segment ID')
plt.title('Token Segmentation')
plt.tight_layout()
plt.savefig('token_segmentation.png')  # Save the figure if needed
plt.show()