In [147]:
import numpy as np
import pandas as pd
import torch
import transformers

from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup

In [148]:
RANDOM_SEED = 11
PRE_TRAINED_MODEL_NAME = 'bert-base-cased'

In [149]:
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

# Similar book dataset and dataloader

In [150]:
similar_book_pairs = pd.read_csv('data/similar_book_pairs.csv')

In [151]:
class SimilarBooksDataset(Dataset):

  def __init__(self, similar_book_pairs):
    self.similar_book_pairs = similar_book_pairs
  
  def __len__(self):
    return len(self.similar_book_pairs)
  
  def __getitem__(self, item):
    pair = self.similar_book_pairs.iloc[item]
    
    return {
        'book1_sequence': f'"{pair["book1_title"]}" - {pair["book1_description"]}',
        'book2_sequence': f'"{pair["book2_title"]}" - {pair["book2_description"]}',
        'target_class': 1
    }

In [152]:
class SequenceCollate:
    """
    Collate to tokenize and apply the padding to the sequences with dataloader
    """
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
    
    def __call__(self, batch):
        book1_sequences, book2_sequences, target_classes = [], [], []
        for pair in batch:
            book1_sequences.append(pair['book1_sequence'])
            book2_sequences.append(pair['book2_sequence'])
            target_classes.append(pair['target_class'])

        encoded_sequences = self.tokenizer(book1_sequences, book2_sequences, padding='longest', truncation='longest_first', return_tensors='pt')

        return {
            'sequences': encoded_sequences['input_ids'],
            'attention_masks': encoded_sequences['attention_mask'],
            'target_classes': torch.as_tensor(target_classes, dtype=torch.long)
        }

In [153]:
ds = SimilarBooksDataset(similar_book_pairs)

In [154]:
dl = DataLoader(ds, batch_size=4, collate_fn=SequenceCollate())

In [157]:
count = 0
for item in dl:
    print(item['sequences'].shape)
    print(item['attention_masks'].shape)
    print(item['target_classes'].shape)
    count += 1
    if count >= 30:
        break

torch.Size([4, 488])
torch.Size([4, 488])
torch.Size([4])
torch.Size([4, 512])
torch.Size([4, 512])
torch.Size([4])
torch.Size([4, 336])
torch.Size([4, 336])
torch.Size([4])
torch.Size([4, 445])
torch.Size([4, 445])
torch.Size([4])
torch.Size([4, 512])
torch.Size([4, 512])
torch.Size([4])
torch.Size([4, 250])
torch.Size([4, 250])
torch.Size([4])
torch.Size([4, 512])
torch.Size([4, 512])
torch.Size([4])
torch.Size([4, 512])
torch.Size([4, 512])
torch.Size([4])
torch.Size([4, 476])
torch.Size([4, 476])
torch.Size([4])
torch.Size([4, 504])
torch.Size([4, 504])
torch.Size([4])
torch.Size([4, 512])
torch.Size([4, 512])
torch.Size([4])
torch.Size([4, 512])
torch.Size([4, 512])
torch.Size([4])
torch.Size([4, 472])
torch.Size([4, 472])
torch.Size([4])
torch.Size([4, 512])
torch.Size([4, 512])
torch.Size([4])
torch.Size([4, 512])
torch.Size([4, 512])
torch.Size([4])
torch.Size([4, 449])
torch.Size([4, 449])
torch.Size([4])
torch.Size([4, 512])
torch.Size([4, 512])
torch.Size([4])
torch.Size([4,