In [1]:
from dataset import BERTDataset
from tokenizer.tokenizers import ContrastiveTokenizer

In [2]:
tk = ContrastiveTokenizer(
    tokenizer_path="dictionary/bn_en",
    vocab_size=16000,
    max_token_length=9,
    pad_token="<pad>",
    unk_token="<unk>",
    start_token="<ben>",
    end_token="</ben>",
)
tk.load_from_disk("tokenizer/dictionary/bn_en.model")

In [3]:
dataset = SimBERTDataset(
    corpus_path="/mnt/JaHiD/Zahid/RnD/ContrastiveBERT/bn_translator_data/dataset/BPCC-combined/combined/eng_Latn-ben_Beng/combined",
    tokenizer=tk,
    seq_len=128,
    padding=False,
    encoding="utf-8",
    corpus_lines=0,
    on_memory=True,
)

Loading Dataset: 33036862it [00:38, 866164.00it/s] 


In [4]:
x = None
for d in dataset:
    x = d
    break

In [5]:
x

[{'bert_input': tensor([    2, 14401,   124, 15882,   488, 14401,   124,  5867, 11855,   676,
            146,  8383,   319,  8400, 15925,  4285, 15925, 15891, 10394,     4,
           3109, 13545,  3898,  6075, 15898, 15884,     3]),
  'bert_label': tensor([   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0, 8718,    0,    0,    0,    0,
             0,    0,    0]),
  'segment_label': tensor(0.)},
 {'bert_input': tensor([    6,  8769, 13692, 15946,  1270,  6663,  1655, 15892,  8769, 13692,
          15936, 15861,  5610, 13638,     1, 14735,  1441,   888, 14252,  3967,
           1524,  3020,   144,  9392, 15879,     7]),
  'bert_label': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0]),
  'segment_label': tensor(1.)}]

In [6]:
tk.tokenizer.decode([    2, 14401,   124, 15882,   488, 14401,   124,  5867, 11855,   676,
            146,  8383,   319,  8400, 15925,  4285, 15925, 15891, 10394,  8718,
           3109, 13545,     4,  6075, 15898, 15884,     3])

'তামিলনাড়ু - তামিলনাড়ুতে ৬০ দিন পর বৃহস্পতিবার কোভিড-১৯-এ দৈনিক সংক্রমণ ১০ হাজারের<mask> নেমেছ।'

In [7]:
tk.tokenizer.decode([    6,  8769, 13692, 15946,  1270,  6663,  1655, 15892,  8769, 13692,
          15936, 15861,  5610,     4,     4,     4,  1441,   888, 14252,  3967,
           1524,  3020,   144,  9392, 15879,     7])

'<en> Tamil Nadu: After 60 days, Tamil Nadu’s daily<mask><mask><mask> case count dropped below 10,000 on Thursday.</en>'

In [8]:
tk.pad_token_id

0

In [15]:
from torch.utils.data import DataLoader
import torch


def collate_fn(batch):

    bn_bert_input, en_bert_input = [], []
    bn_bert_label, en_bert_label = [], []
    bn_segment_label, en_segment_label = [], []

    for items in batch:
        bn_bert_input.append(items[0]['bert_input'])
        en_bert_input.append(items[1]['bert_input'])


        bn_bert_label.append(items[0]['bert_label'])
        en_bert_label.append(items[1]['bert_label'])

        bn_segment_label.append(items[0]['segment_label'])
        en_segment_label.append(items[1]['segment_label'])


    bn_bert_input += en_bert_input
    bn_bert_label += en_bert_label
    bn_segment_label += en_segment_label

    max_length = max(len(seq) for seq in bn_bert_input)

    # Pad sequences in the batch
    padded_bert_inputs = [
        torch.nn.functional.pad(seq, pad=(0, max_length - len(seq)), value=0)
        for seq in bn_bert_input
    ]
    padded_bert_labels = [
        torch.nn.functional.pad(seq, pad=(0, max_length - len(seq)), value=0)
        for seq in bn_bert_label
    ]

    # Stack the padded sequences along the batch dimension
    padded_bert_inputs = torch.stack(padded_bert_inputs)
    padded_bert_labels = torch.stack(padded_bert_labels)
    segment_labels = torch.stack(bn_segment_label)

    attention_mask = (padded_bert_inputs != 0).float()

    return {
        "bert_input": padded_bert_inputs,
        "bert_label": padded_bert_labels,
        "segment_label": segment_labels,
        "attention_mask": attention_mask
    }

train_loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)#, prefetch_factor=2)
for data in train_loader:
    break

In [16]:
data

{'bert_input': tensor([[    2, 14401,   124, 15882,   488, 14401,   124,  5867, 11855,   676,
            146,     4,     4,     4,     4,     4,     4,     4, 10394,  8718,
           3109, 13545,  3898,  6075, 15898, 15884,     3,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0],
         [    2,   396,  2587,   532, 14459,   662,   850,  1833,  2810, 15862,
           1307,  8985,  3461,     4,   823,  4851,  4170,  5456,   560,  1584,
          12912, 15877,   211, 15967,     4,  5873, 15898,  1066, 15889,   560,
            475, 15892,  2060,   227,  6297,   579,   860,     4,   181,   204,
            788,  6166,   105, 15889, 15884,     3,     0,     0,     0],
         [    6,  8769, 13692, 15946,  1270,  6663,  1655, 15892,  8769, 13692,
          15936, 15861,  5610, 13638,     1, 14735,  1441,   888, 14252,  3967,
           1524,  3020,   144,  9392, 

In [None]:
import torch
labels = torch.cat([torch.arange(32) for i in range(2)], dim=0)

In [None]:
labels

In [None]:
images = torch.cat(images, dim=0)