In [1]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset, random_split, ConcatDataset, Dataset
from tqdm import tqdm

from torch.nn.utils.rnn import pad_sequence

from transformers import BertTokenizer
from datasets import load_dataset
from transformers import AutoTokenizer

In [2]:
ds = load_dataset("openai/gsm8k", "main")

README.md:   0%|          | 0.00/7.94k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

In [3]:
ds

DatasetDict({
    train: Dataset({
        features: ['question', 'answer'],
        num_rows: 7473
    })
    test: Dataset({
        features: ['question', 'answer'],
        num_rows: 1319
    })
})

In [4]:
ds['train'][0]

{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
 'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'}

In [5]:
if torch.cuda.is_available(): 
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
class RoPE(nn.Module):
    def __init__(self, emb_dim, max_len):
        super().__init__()
        self.dim = emb_dim
        self.max_len = max_len
        
        # Create a matrix of positional encodings
        self.position = torch.arange(max_len, dtype=torch.float).unsqueeze(1)  # Shape: [max_len, 1]
        self.div_term = torch.exp(torch.arange(0, emb_dim, 2).float() * -(math.log(10000.0) / emb_dim))  # Shape: [dim/2]
        
    def forward(self, x):
        # x shape: (batch_size, seq_len, dim)
        seq_len = x.size(1)
        
        # Limit to the maximum length
        if seq_len > self.max_len:
            raise ValueError("seq_len exceeds max_len")
        
        # Compute the rotary positional encoding
        pos_enc = self.position[:seq_len].matmul(self.div_term.unsqueeze(0))  # Shape: [seq_len, dim/2]
        pos_enc = pos_enc.unsqueeze(0)  # Shape: [1, seq_len, dim/2]
        
        # Create the rotary encodings
        rotary_encodings = torch.zeros_like(x)
        rotary_encodings[:, :, 0::2] = torch.cos(pos_enc)  # Even indices
        rotary_encodings[:, :, 1::2] = torch.sin(pos_enc)  # Odd indices
        
        return x + rotary_encodings
        

In [7]:
class SelfAttention(nn.Module):
    def __init__(self, emb_dim, num_heads=8):
        super().__init__()
        self.embed_size = emb_dim
        self.num_heads = num_heads
        self.head_dim = emb_dim // num_heads

        assert self.head_dim * num_heads == emb_dim, "Embedding size must be divisible by number of heads"

        self.fc_out = nn.Linear(emb_dim, emb_dim)

        self.cached_keys = None
        self.cached_values = None

    def forward(self, queries, keys, values, cache=True):
        N, seq_length, _ = queries.shape

        # Assuming queries, keys, and values are already normalized
        queries = queries.view(N, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        keys = keys.view(N, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        values = values.view(N, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        if self.cached_keys is None or self.cached_values is None:
            if cache:
                self.cached_keys = keys
                self.cached_values = values
        else:
            keys = self.cached_keys
            values = self.cached_values

        # Calculate attention scores
        energy = torch.einsum("nhqd,nhkd->nhqk", [queries, keys])
        attention = F.softmax(energy / (self.head_dim ** 0.5), dim=-1)

        out = torch.einsum("nhql,nhld->nhqd", [attention, values]).reshape(N, seq_length, self.embed_size)

        return self.fc_out(out)

    def reset_cache(self):
        """Reset the cached keys and values."""
        self.cached_keys = None
        self.cached_values = None


In [8]:
class Encoder(nn.Module):
    def __init__(self, emb_dim, heads, max_len):
        super().__init__()
        
        self.rmsnorm1 = nn.RMSNorm(emb_dim)
        self.rope = RoPE(emb_dim, max_len)
        self.selfattention = SelfAttention(emb_dim, heads)
        self.rmsnorm2 = nn.RMSNorm(emb_dim)
        self.ff1 = nn.Linear(emb_dim, 3*emb_dim)
        self.ff2 = nn.Linear(3*emb_dim, emb_dim)
        
    def forward(self, x):
        xnorm = self.rmsnorm1(x)
        q = self.rope(xnorm)
        k = self.rope(xnorm)
        v = xnorm
        x = x + self.selfattention(q, k, v)
        xnorm = self.rmsnorm2(x)
        xnorm = F.relu(self.ff1(xnorm))
        xnorm = self.ff2(xnorm)
        x = x+xnorm
        
        return x

In [9]:
class LLaMA(nn.Module):
    def __init__(self, vocab_size, emb_dim, heads, layers, max_len):
        super().__init__()
        
        self.embeddings = nn.Embedding(vocab_size, emb_dim)
        self.encoder = nn.ModuleList(Encoder(emb_dim, heads, max_len)for _ in range(layers))
        self.rmsnorm = nn.RMSNorm(emb_dim)
        self.linear = nn.Linear(emb_dim, vocab_size)
        
    def forward(self, x):
        x = self.embeddings(x)
        for layer in self.encoder:
            x = layer(x)
        x = self.rmsnorm(x)
        x = F.softmax(self.linear(x), dim=-1)
        
        return x

In [10]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [11]:
# Preprocess data
def preprocess_function(examples):
    # Combine question and answer for input
    inputs = [f"question: {q} answer:" for q in examples['question']]
    targets = examples['answer']
    
    model_inputs = tokenizer(inputs, truncation=True, padding='max_length', max_length=1024)
    labels = tokenizer(targets, truncation=True, padding='max_length', max_length=1024)
    
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

# Split dataset
train_data = ds['train'].map(preprocess_function, batched=True) 
test_data = ds['test'].map(preprocess_function, batched=True)


class QNADataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {
            'input_ids': torch.tensor(self.data[idx]['input_ids']),
            'labels': torch.tensor(self.data[idx]['labels'])
        }

train_dataset = QNADataset(train_data)
test_dataset = QNADataset(test_data)

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, pin_memory=True, num_workers=4)
validation_dataloader = DataLoader(test_dataset, batch_size=8, pin_memory=True, num_workers=4)


model = LLaMA(tokenizer.vocab_size, emb_dim=768, heads=8, layers=4, max_len=1024)
criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = optim.Adam(model.parameters(), lr=3e-5)

model

Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

Map:   0%|          | 0/1319 [00:00<?, ? examples/s]

LLaMA(
  (embeddings): Embedding(30522, 768)
  (encoder): ModuleList(
    (0-3): 4 x Encoder(
      (rmsnorm1): RMSNorm((768,), eps=None, elementwise_affine=True)
      (rope): RoPE()
      (selfattention): SelfAttention(
        (fc_out): Linear(in_features=768, out_features=768, bias=True)
      )
      (rmsnorm2): RMSNorm((768,), eps=None, elementwise_affine=True)
      (ff1): Linear(in_features=768, out_features=2304, bias=True)
      (ff2): Linear(in_features=2304, out_features=768, bias=True)
    )
  )
  (rmsnorm): RMSNorm((768,), eps=None, elementwise_affine=True)
  (linear): Linear(in_features=768, out_features=30522, bias=True)
)

In [12]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)
model = model.to('cuda')
model.to(device)

num_epochs=1
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for batch in tqdm(train_dataloader):
        optimizer.zero_grad()

        # Move tensors to the specified device
        input_ids = batch['input_ids'].to(device)
#         attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Forward pass

        logits = model(input_ids)
        loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))

        # Backward pass
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(train_dataloader)
    print(f"training : Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}")
    
    model.eval()
    epoch_loss = 0
    for batch in tqdm(validation_dataloader):

        # Move tensors to the specified device
        input_ids = batch['input_ids'].to(device)
#         attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Forward pass
        logits = model(input_ids)
        loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(validation_dataloader)
    print(f"validation : Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}")

Let's use 2 GPUs!


  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
100%|██████████| 935/935 [10:47<00:00,  1.44it/s]


training : Epoch 1/1, Loss: 9.4439


100%|██████████| 165/165 [00:56<00:00,  2.94it/s]

validation : Epoch 1/1, Loss: 9.4302



