In [1]:
# !wget "https://www.dropbox.com/scl/fi/e6oqpx6iuos7kn9m139z7/wikitext-103-raw-v1.zip?rlkey=81evwbaqfkxtckj8zhks7yied&st=6ept2pdm&dl=0"
# !unzip -q "wikitext-103-raw-v1.zip?rlkey=81evwbaqfkxtckj8zhks7yied&st=6ept2pdm&dl=0"
# !rm -rf "wikitext-103-raw-v1.zip?rlkey=81evwbaqfkxtckj8zhks7yied&st=6ept2pdm&dl=0"

In [17]:
from typing import Optional, Any

import torch
import time
import random

from collections import defaultdict
from torch.utils.data.dataset import Dataset
from torch.utils.data import Sampler, DataLoader
from transformers import AutoTokenizer
from tqdm import tqdm

from transformer import TransformerModel, generate_square_subsequent_mask

For each of the implemented methods (and all variations of the third and fourth methods), mock one training epoch and measure minimum, maximum, mean and median batch processing times.
To mock a training epoch, you need to construct a small GPT-2-like model: use `nn.Embedding` layer, `PositionalEncoding` class from `transformer.py` file and a single `nn.TransformerDecoder` layer with a hidden size of 1024 and 8 heads.
For tokenization, use the `.tokenize()` method of `AutoTokenizer.from_pretrained("bert-base-uncased")`.
Run one epoch **without a backward pass** to measure the iteration time.
Make sure you've [warmed up](https://forums.developer.nvidia.com/t/why-warm-up/48565) the GPU before computing the statistics and do not forget about asynchronous CUDA kernel execution.

Keep in mind that all padding in this task must be **implemented by you**: unlike the seminar, PyTorch’s default collation padding is not allowed.
For all subproblems, drop all sequences exceeding 640 tokens.
Feel free to modify the keyword arguments of functions.

In [3]:
def get_gpt2_model() -> torch.nn.Module:
    return TransformerModel(
        ntoken=30523,
        d_model=256,
        nhead=8,
        d_hid=1024,
        nlayers=1,
        dropout=0.3,
    )

In [None]:
def run_epoch(dataloader):
    model = get_gpt2_model()
    device = "cuda" if torch.cuda.is_available() else 1/0

    model.to(device)
    model.eval()

    times = []
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            if i >= 500:
                break

            batch = batch.to(device)
            mask = generate_square_subsequent_mask(batch.shape[0]).to(device)
            model(batch, mask)

        torch.cuda.synchronize()

        for batch in tqdm(dataloader):
            batch = batch.to(device)
            mask = generate_square_subsequent_mask(batch.shape[0]).to(device)

            start = time.perf_counter()
            model(batch, mask)

            torch.cuda.synchronize()
            end = time.perf_counter()

            times.append(end - start)

    times = torch.tensor(times)
    stats = {
        "min": float(times.min()),
        "max": float(times.max()),
        "mean": float(times.mean()),
        "median": float(times.median()),
    }

    return stats

In [4]:
data_path = "wikitext-103-raw-v1"
MAX_LENGTH = 640
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [5]:
def read_data(data_path: str):
    with open(data_path + "/train-00000-of-00002.txt", "r") as f:
        data = f.readlines()

    with open(data_path + "/train-00001-of-00002.txt", "r") as f:
        data += f.readlines()

    data = [el.strip(" \n=") for el in data]
    
    return data

In [6]:
class WikiTextDataset(Dataset):
    def __init__(
        self, 
        data_path: str, 
        max_length: int = MAX_LENGTH,
        tokenizer: Any = tokenizer,
    ):
        self.max_length = max_length
        self.tokenizer = tokenizer

        self.data = read_data(data_path)
        
    def __getitem__(self, idx: int) -> str:
        return self.data[idx]
    
    def __len__(self) -> int:
        return len(self.data)

In [7]:
wikitext_dataset = WikiTextDataset(data_path)

In [None]:
def brain_collate_fn(
    batch: list[str], 
    max_length: Optional[int] = MAX_LENGTH,
) -> tuple[torch.Tensor, torch.Tensor]:
    text_list = []
    for _text in batch:
        processed_text = tokenizer(
            _text, 
            return_tensors="pt", 
            truncation=True,
            max_length=max_length,
        )["input_ids"].squeeze(0)
        pad_size = max_length - len(processed_text)
        processed_text = torch.cat(
            [
                processed_text, 
                torch.full((pad_size, ), tokenizer.pad_token_id)
            ],
            dim=0,
        )
        text_list.append(processed_text)
    
    text_list = torch.stack(text_list, dim=0)

    return text_list

In [10]:
brain_dataloader = DataLoader(
    wikitext_dataset,
    collate_fn=brain_collate_fn,
    batch_size=64,
)
run_epoch(brain_dataloader)

100%|██████████| 18204/18204 [19:00<00:00, 15.96it/s]


{'min': 0.027524013072252274,
 'max': 0.08858683705329895,
 'mean': 0.0510968379676342,
 'median': 0.049413640052080154}

In [11]:
def big_brain_collate_fn(
    batch: list[str], 
    max_length: Optional[int] = MAX_LENGTH,
) -> tuple[torch.Tensor, torch.Tensor]:
    processed_text_list = [
        tokenizer(
            _text, 
            return_tensors="pt", 
            truncation=True,
            max_length=max_length,
        )["input_ids"].squeeze(0)
        for _text in batch
    ]
    max_length = max(len(_text) for _text in processed_text_list)
    text_list = []
    for processed_text in processed_text_list:
        pad_size = max_length - len(processed_text)
        processed_text = torch.cat(
            [
                processed_text, 
                torch.full((pad_size, ), tokenizer.pad_token_id)
            ],
            dim=0,
        )
        text_list.append(processed_text)
    
    text_list = torch.stack(text_list, dim=0)

    return text_list

In [14]:
big_brain_dataloader = DataLoader(
    wikitext_dataset,
    collate_fn=big_brain_collate_fn,
    batch_size=64,
)
run_epoch(big_brain_dataloader)

100%|██████████| 18204/18204 [12:00<00:00, 25.28it/s]


{'min': 0.0015366340521723032,
 'max': 0.060972291976213455,
 'mean': 0.02851749025285244,
 'median': 0.027600599452853203}

In [69]:
class WikiTextDatasetTokenized(Dataset):
    def __init__(
        self, 
        data_path: str, 
        max_length: int = MAX_LENGTH,
        tokenizer: Any = tokenizer,
    ):
        self.max_length = max_length
        self.tokenizer = tokenizer

        self.data = [
            self.tokenizer(
                sentence, 
                return_tensors="pt", 
                truncation=True, 
                max_length=self.max_length,
            )["input_ids"].squeeze(0)
            for sentence in read_data(data_path)
        ]
        
    def __getitem__(self, idx: int) -> str:
        return self.data[idx]
    
    def __len__(self) -> int:
        return len(self.data)

In [75]:
class UltraBigBrainBatchSampler(torch.utils.data.BatchSampler):
    def __init__(self, dataset, batch_size, k):
        self.batch_size = batch_size
        self.k = k

        self.length_to_indices = defaultdict(list)
        for i, item in enumerate(dataset):
            self.length_to_indices[len(item)].append(i)

        self.lengths = list(self.length_to_indices.keys())
        self.num_batches = len(dataset) // batch_size

    def __iter__(self):
        for _ in range(self.num_batches):
            min_len = random.choice(self.lengths)
            
            candidates = []
            for l in range(min_len, min_len + self.k + 1):
                candidates += self.length_to_indices[l]

            if not candidates:
                continue

            batch = random.sample(
                candidates,
                k=min(self.batch_size, len(candidates))
            )
            yield batch

    def __len__(self):
        return self.num_batches

In [76]:
def ultra_big_brain_collate_fn(
    batch: list[str], 
    max_length: Optional[int] = MAX_LENGTH,
) -> tuple[torch.Tensor, torch.Tensor]:
    max_length = max(len(_text) for _text in batch)
    text_list = []
    for processed_text in batch:
        pad_size = max_length - len(processed_text)
        processed_text = torch.cat(
            [
                processed_text, 
                torch.full((pad_size, ), tokenizer.pad_token_id)
            ],
            dim=0,
        )
        text_list.append(processed_text)
    
    text_list = torch.stack(text_list, dim=0)

    return text_list

In [73]:
tokenized_dataset = WikiTextDatasetTokenized(
    data_path,
    MAX_LENGTH,
    tokenizer,
)

In [95]:
ultra_big_brain_batch_sampler = UltraBigBrainBatchSampler(
    dataset=tokenized_dataset,
    batch_size=64,
    k=5,
)
ultra_big_brain_dataloader = torch.utils.data.DataLoader(
    tokenized_dataset,
    collate_fn=ultra_big_brain_collate_fn,
    batch_sampler=ultra_big_brain_batch_sampler,
)

In [96]:
(next(iter(ultra_big_brain_dataloader))==0).sum(dim=1)

tensor([1, 2, 5, 3, 0, 4, 2, 4, 3, 1, 5, 1, 1, 2, 1, 4, 4, 4, 4, 1, 0, 2, 2, 1,
        4, 1, 1, 3, 1, 1, 3, 3, 5, 4, 1, 0, 0, 0, 4, 3, 1, 1, 3, 4, 3, 4, 0, 4,
        1, 5, 5, 4, 0, 2, 2, 3, 4, 0, 1, 2, 4, 2, 4, 0])

In [97]:
run_epoch(ultra_big_brain_dataloader)

100%|██████████| 18203/18203 [07:39<00:00, 39.57it/s]


{'min': 0.0008770970162004232,
 'max': 0.0751737430691719,
 'mean': 0.02448800764977932,
 'median': 0.024018418043851852}