In [16]:
import torch
import tqdm

from tqdm.notebook import tqdm
from x_transformers import TransformerWrapper, Decoder
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import GPT2TokenizerFast
from x_transformers import AutoregressiveWrapper

In [3]:
# constants
vocab_size = 50257 # vocab size corresponding to gpt2 tokenizer

In [15]:
# Rotary embeddings
model_rotary = TransformerWrapper(
    num_tokens = vocab_size,
    max_seq_len = 512,
    attn_layers = Decoder(
        dim = 256,
        depth = 3,
        heads = 4,
        rotary_pos_emb = True  # Enable rotary embeddings
    )
)

# ALiBi embeddings
model_alibi = TransformerWrapper(
    num_tokens = vocab_size,
    max_seq_len = 512,
    attn_layers = Decoder(
        dim = 256,
        depth = 3,
        heads = 4,
        alibi_pos_bias = True  # Enable ALiBi
    )
)

# No positional embeddings
model_no_pos = TransformerWrapper(
    num_tokens = vocab_size,
    max_seq_len = 512,
    use_abs_pos_emb = False,  # Disable absolute positional embeddings
    attn_layers = Decoder(
        dim = 256,
        depth = 3,
        heads = 4,
        disable_abs_pos_emb = True  # Disable absolute positional embeddings in attention layers
    )
)

In [11]:
# loading wikitext
wikitext_dataset = load_dataset("wikitext", "wikitext-2-v1")

# remove empty samples
filtered_dataset = wikitext_dataset.filter(lambda sample: len(sample['text'].strip()) > 0)

# load gpt2 tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

# tokenizing function
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length=1024)

# tokenized version of wikitext
tokenized_wikitext = filtered_dataset.map(tokenize_function, batched=True, remove_columns=["text"])

Map:   0%|          | 0/23767 [00:00<?, ? examples/s]

In [20]:
class WikiTextDataset(torch.utils.data.Dataset):
    def __init__(self, tokenized_dataset, split='train', max_length=512):
        self.tokenized_dataset = tokenized_dataset[split]
        self.max_length = max_length

    def __len__(self):
        return len(self.tokenized_dataset)

    def __getitem__(self, idx):
        item = self.tokenized_dataset[idx]
        input_ids = item['input_ids'][:self.max_length]  # Truncate if too long
        input_ids = input_ids + [0] * (self.max_length - len(input_ids))  # Pad if too short
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long)
        }
# Create the dataset
wikitext_dataset = WikiTextDataset(tokenized_wikitext, split='train')

In [21]:
def train_model(model, dataset, num_epochs, batch_size, learning_rate):
    model = AutoregressiveWrapper(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    for epoch in tqdm(range(num_epochs)):
        total_loss = 0
        for batch in tqdm(dataloader):
            input_ids = batch['input_ids'].to(device)
            
            optimizer.zero_grad()
            loss = model(input_ids)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

In [22]:
# Train each model on WikiText-103
for model in [(model_rotary, "rotary"), (model_alibi, "alibi"), (model_no_pos, "nopos")]:
    print(f"Training {model[1]} on WikiText-103")
    train_model(model[0], wikitext_dataset, num_epochs=5, batch_size=32, learning_rate=3e-4)

Training TransformerWrapper on WikiText-103


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/743 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 MiB. GPU 0 has a total capacity of 23.46 GiB of which 32.81 MiB is free. Including non-PyTorch memory, this process has 23.43 GiB memory in use. Of the allocated memory 23.17 GiB is allocated by PyTorch, and 71.81 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [19]:
def evaluate_model(model, dataset):
    model.eval()
    total_loss = 0
    dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
    device = next(model.parameters()).device
    
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            loss = model(input_ids)
            total_loss += loss.item()
    return total_loss / len(dataloader)

# Evaluate each model on WikiText-103
for model in [model_rotary, model_alibi, model_no_pos]:
    model = AutoregressiveWrapper(model)
    wikitext_loss = evaluate_model(model, WikiTextDataset(tokenized_wikitext, split='test'))
    print(f"Model: {model.__class__.__name__}")
    print(f"WikiText-103 Loss: {wikitext_loss}")

OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB. GPU 0 has a total capacity of 23.46 GiB of which 416.81 MiB is free. Including non-PyTorch memory, this process has 23.05 GiB memory in use. Of the allocated memory 22.39 GiB is allocated by PyTorch, and 487.34 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)