In [169]:
import sentencepiece as sp
import random

In [170]:
random.seed(1)

In [171]:
tokenizer  = sp.SentencePieceProcessor()
tokenizer.Load('D:/Workspace/PYTHON/NLP/Project2/tokenizer/unigram/unigram_tokenizer.model')
tokenizer

<sentencepiece.SentencePieceProcessor; proxy of <Swig Object of type 'sentencepiece::SentencePieceProcessor *' at 0x000001C6A479D980> >

In [173]:
with open("D:/Workspace/PYTHON/NLP/Project2/Data/sentences_cleaned.txt", "r", encoding="utf-8") as f:
    sentences = [line.strip() for line in f if line.strip()]
print(len(sentences))
print(sentences[0])

10247
میرا اس سے کوئی لینا دینا نہیں


In [174]:
PAD_ID = tokenizer.piece_to_id('<pad>')
UNK_ID = tokenizer.piece_to_id('<unk>')
SOS_ID = tokenizer.piece_to_id('<s>')
EOS_ID = tokenizer.piece_to_id('</s>')
MASK_ID = tokenizer.piece_to_id('<mask>')
print(PAD_ID,UNK_ID,SOS_ID,EOS_ID,MASK_ID)

0 1 2 3 4


In [175]:
# Tokenize all sentences
tokenized_data = []
for sent in sentences:
    tokens = tokenizer.encode(sent)
    tokenized_data.append(tokens)
print(f"Sample tokenized: {tokenized_data[0]}")
print(f"Decoded: {tokenizer.decode(tokenized_data[0])}")


Sample tokenized: [437, 19, 16, 101, 135, 59, 288, 6, 34]
Decoded: میرا اس سے کوئی لینا دینا نہیں


In [181]:
def create_span_corruption(tokens, mask_ratio=0.15, mean_span_length=1):

    # Calculate number of tokens to mask
    num_to_mask = max(1, int(len(tokens) * mask_ratio))
    
    # Calculate number of spans
    num_spans = max(1, int(num_to_mask / mean_span_length))
    
    # Randomly select span start positions
    possible_starts = list(range(len(tokens)))
    random.shuffle(possible_starts)
    
    masked_positions = set()
    spans = []
    
    for start_pos in possible_starts:
        if start_pos not in masked_positions:
            # Randomly determine span length 
            span_length = min(
                max(1, int(random.expovariate(1.0 / mean_span_length))), #(Poisson-like distribution)
                len(tokens) - start_pos
            )
            
            # Add positions to masked set
            span_positions = list(range(start_pos, start_pos + span_length))
            if not any(pos in masked_positions for pos in span_positions):
                masked_positions.update(span_positions)
                spans.append((start_pos, start_pos + span_length))
                
                if len(masked_positions) >= num_to_mask:
                    break
    
    # Sort spans by position
    spans.sort()
    
    # Create input sequence (with <mask> tokens)
    input_ids = [SOS_ID]
    target_ids = [SOS_ID]
    
    last_idx = 0
    for start, end in spans:
        # Add non-masked tokens before this span
        input_ids.extend(tokens[last_idx:start])
        
        # Add single <mask> token for the entire span
        input_ids.append(MASK_ID)
        
        # Add masked tokens to target
        
        target_ids.extend(tokens[start:end]) 
        target_ids.append(MASK_ID)
        last_idx = end
    
    # Add remaining non-masked tokens
    input_ids.extend(tokens[last_idx:])
    input_ids.append(EOS_ID)
    target_ids.append(EOS_ID)
    
    return input_ids, target_ids


In [182]:
dataset = []
for tokens in tokenized_data:
    input_ids, target_ids = create_span_corruption(tokens)
    dataset.append({
        'input_ids': input_ids,
        'target_ids': target_ids,
    })

In [183]:
len(dataset)

10247

In [184]:

print(f"Input IDs:  {dataset[0]['input_ids']}")
print(f"Target IDs: {dataset[0]['target_ids']}")
print(f"Input:  {tokenizer.decode(dataset[0]['input_ids'])}")
print(f"Target: {tokenizer.decode(dataset[0]['target_ids'])}")



Input IDs:  [2, 437, 19, 16, 101, 135, 59, 4, 6, 34, 3]
Target IDs: [2, 288, 4, 3]
Input:  میرا اس سے کوئی لینا<mask>ا نہیں
Target: دین<mask>


In [185]:
import torch
torch.save({
    'input_ids': [d['input_ids'] for d in dataset],
    'target_ids': [d['target_ids'] for d in dataset],
}, 'span_corruption_dataset.pt')


In [186]:

# Statistics
input_lengths = [len(d['input_ids']) for d in dataset]
target_lengths = [len(d['target_ids']) for d in dataset]

print(f"\nDataset Statistics:")
print(f"Total examples: {len(dataset)}")
print(f"Avg input length: {sum(input_lengths)/len(input_lengths):.1f}")
print(f"Avg target length: {sum(target_lengths)/len(target_lengths):.1f}")
print(f"Max input length: {max(input_lengths)}")
print(f"Max target length: {max(target_lengths)}")


Dataset Statistics:
Total examples: 10247
Avg input length: 15.1
Avg target length: 5.4
Max input length: 61
Max target length: 18
