## Task 1

In [58]:
from dataclasses import dataclass
import inspect
import math

import torch
import torch.nn as nn
from torch.nn import functional as F

from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
device

device(type='cuda')

### GPT-2 model (125M) implementation

In [59]:
class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

Multi-head Self-attention

Attention(Q, K, V ) = softmax(Q @ K.T / √d_k) @ V (taken from "Attention is All you Need" paper)

In [60]:
class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        # check for Rotary Positional Embedding
        if (config.emb_type_RoPE == True):
            # print("Using RoPE", end=' ')
            self.rotary_emb = RotaryEmbedding(dim = config.n_embd // config.n_head)

        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        # self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        self.flash = False
        if not self.flash:
            # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        if (config.emb_type_RoPE == True):
            # RoPE here
            q, k = self.rotary_emb.rotate_queries_and_keys(q, k)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

In [61]:
class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)

        if (config.attn_type_GQA == True):
            self.attn = GroupedQueryAttention(config)

        else:
            if (config.attn_type_SWA == True):
                print("Sliding Window Attention not implemented yet. Using Self Attention by default !!!")
            self.attn = CausalSelfAttention(config)

        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

GPT-2 design of using both token and positional embeddings

> (wte = nn.Embedding(config.vocab_size, config.n_embd) is the token embedding applied on Input to the network idx, the context vector U ==> W_e is self.transformer.wte(idx))

> (wpe = nn.Embedding(config.block, config.n_embd) is the positional embedding applied on length of the Input to the network, pos = torch.arange(0, t, dtype=torch.long, device=device) ==> W_p is self.transformer.wpe(pos))

> Both are added and passed on to the transformer block

> h0 = U*W_e + W_p

> h_l = transformer_block(h_{l−1})∀i ∈ [1, n]

Transformer layers (n_layer=12) with multi-head (n_head=12) self-attention and position-wise feedforward network

In [62]:
class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config

        if (config.emb_type_RoPE == True):
            print("Using RoPE\n")

        if (config.attn_type_GQA == True):
            print("Using GQA\n")

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # with weight tying when using torch.compile() some warnings get generated:
        # "UserWarning: functional_call was passed multiple values for tied weights.
        # This behavior is deprecated and will be an error in future versions"
        # not 100% sure what this is, so far seems to be harmless. TODO investigate
        self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        # report number of parameters
        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

        # forward the GPT model itself, using GPT-2 design of using both token and positional embeddings here
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            loss = None

        return logits, loss

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f"using fused AdamW: {use_fused}")

        return optimizer

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            # forward the model to get the logits for the index in the sequence
            logits, _ = self(idx_cond)
            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

### Configuration for GPT-2 model

In [63]:
@dataclass
class GPTConfig:
    block_size: int = 1024
    # vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    vocab_size: int = 50257
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.0
    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
    emb_type_RoPE: bool = False # for Rotatory Position Embedding
    attn_type_GQA: bool = False # for Grouped Query attention
    attn_type_SWA: bool = False # for Sliding Window attention
    gqa_groups: int = 6 # number of groups for GQA
    heads_per_group: int = 2 # heads/queries per group for GQA

### Compile the model

In [64]:
config = GPTConfig()
model = GPT(config)
model = model.to(device)
optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type)

number of parameters: 123.65M
num decayed parameter tensors: 50, with 124,318,464 parameters
num non-decayed parameter tensors: 98, with 121,344 parameters
using fused AdamW: True


### Load the original GPT-2 125M model checkpoints

In [65]:
from transformers import GPT2LMHeadModel
model_hf = GPT2LMHeadModel.from_pretrained('gpt2')
sd_hf = model_hf.state_dict()
print(sd_hf.keys())

odict_keys(['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.ln_1.weight', 'transformer.h.0.ln_1.bias', 'transformer.h.0.attn.c_attn.weight', 'transformer.h.0.attn.c_attn.bias', 'transformer.h.0.attn.c_proj.weight', 'transformer.h.0.attn.c_proj.bias', 'transformer.h.0.ln_2.weight', 'transformer.h.0.ln_2.bias', 'transformer.h.0.mlp.c_fc.weight', 'transformer.h.0.mlp.c_fc.bias', 'transformer.h.0.mlp.c_proj.weight', 'transformer.h.0.mlp.c_proj.bias', 'transformer.h.1.ln_1.weight', 'transformer.h.1.ln_1.bias', 'transformer.h.1.attn.c_attn.weight', 'transformer.h.1.attn.c_attn.bias', 'transformer.h.1.attn.c_proj.weight', 'transformer.h.1.attn.c_proj.bias', 'transformer.h.1.ln_2.weight', 'transformer.h.1.ln_2.bias', 'transformer.h.1.mlp.c_fc.weight', 'transformer.h.1.mlp.c_fc.bias', 'transformer.h.1.mlp.c_proj.weight', 'transformer.h.1.mlp.c_proj.bias', 'transformer.h.2.ln_1.weight', 'transformer.h.2.ln_1.bias', 'transformer.h.2.attn.c_attn.weight', 'transformer.h.2.attn.

### Copy the loaded weights in our implementation of GPT-2

In [66]:
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
for i in sd_hf:
    if any(i.endswith(w) for w in transposed):
        print(True, end='')
        sd_hf[i] = sd_hf[i].t()

TrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrueTrue

In [67]:
model_state = model.state_dict()
for i in model_state.keys():
    if not i.endswith('.attn.bias'):
        model_state[i] = sd_hf[i]

In [68]:
model.load_state_dict(model_state)

<All keys matched successfully>

### Run a sample prediction with the original weights combined with our implementation to validate the correct working of created model

In [69]:
model.eval()

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=True)
          (c_proj): Linear(in_features=768, out_features=768, bias=True)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [70]:
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
text = '''What is the best'''
encoded_input = tokenizer(text, return_tensors='pt')
# print(encoded_input)
outputs = model.generate(idx = encoded_input['input_ids'].to(device), max_new_tokens = 20, temperature = 0.8, top_k = 200)
print(tokenizer.decode(outputs[0]))

What is the best way to get to your position without being forced by your opponents to fight the same fights around you in


### References

1. Attention research paper - https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf

2. GPT research paper - https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf

3. GPT2 research paper - https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf

4. Github used - https://github.com/karpathy/nanoGPT/blob/master/model.py#L6

5. Huggingface resources used - https://huggingface.co/gpt2
https://huggingface.co/transformers/v3.0.2/_modules/transformers/tokenization_gpt2.html#GPT2Tokenizer
https://discuss.huggingface.co/t/how-to-decode-gpt2/16160

## Task 3 - requires user to give input (sGPU, DDP, FSDP)

### Prepare dataset for training (using tinyshakespeare/input.txt)

Reference link for dataset -- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [45]:
pip install tiktoken



In [46]:
import numpy as np
import requests
import tiktoken

In [47]:
data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
input_file_path = './data.txt'
with open(input_file_path, 'w') as f:
    f.write(requests.get(data_url).text)
with open(input_file_path, 'r') as f:
    data = f.read()
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

enc = tiktoken.get_encoding("gpt2")
train_ids = enc.encode_ordinary(train_data)
print(f"train has {len(train_ids):,} tokens")
val_ids = enc.encode_ordinary(val_data)
print(f"train has {len(val_ids):,} tokens")

train_file_path = './train.bin'
val_file_path = './val.bin'
# export to bin files
train_ids = np.array(train_ids, dtype=np.int16)
val_ids = np.array(val_ids, dtype=np.int16)
train_ids.tofile(train_file_path)
val_ids.tofile(val_file_path)

train has 301,966 tokens
train has 36,059 tokens


In [48]:
print((len(train_data)))
print(len(train_ids))
print(train_data[0], train_ids[0])

1003854
301966
F 5962


We have saved the training and validation files in "train.bin" and "val.bin" in the current directory

### Training function

In [49]:
import os
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

In [50]:
def get_batch(data, device_type, batch_size = 2):
    block_size = config.block_size
    ix = torch.randint(len(data) - config.block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

In [51]:
@torch.no_grad()
def estimate_loss(model, train_data, val_data, device_type, batch_size = 2):
    out = {}
    eval_iters = 20
    model.eval()
    split = 'train'
    for data in [train_data, val_data]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(data, device_type, batch_size = 2)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean() / batch_size
        split = 'val'
    model.train()
    return out

In [52]:
def train_loop(rank, world_size, partial_epochs, iters, train_data, device_type, optimizer, model, val_data, mode='sGPU', batch_size = 2):
    if (mode == 'sGPU'):
        pass
    elif (mode == 'DDP'):
        setup(rank, world_size)
        model = model.to(rank)
        ddp_model = DDP(model, device_ids=[rank], find_unused_parameters=True)
        optimizer = torch.optim.AdamW(ddp_model.parameters(), weight_decay=1e-2, lr=1e-4, betas=(0.9, 0.95))
    elif (mode == 'FSDP'):
        setup(rank, world_size)
        model = model.to(rank)
        sampler = DistributedSampler(train_data, rank=rank, num_replicas=world_size, shuffle=True)
        train_kwargs = {'batch_size': batch_size, 'sampler': sampler, 'num_workers': 2,
                    'pin_memory': True,
                    'shuffle': False}
        train_loader = torch.utils.data.DataLoader(train_data,**train_kwargs)
        torch.cuda.set_device(rank)
        fsdp_model = FSDP(model)
        optimizer = torch.optim.AdamW(fsdp_model.parameters(), weight_decay=1e-2, lr=1e-4, betas=(0.9, 0.95))

    # common for all
    for epoch in range(1, partial_epochs+1):
        losses = 0
        total = 0

        if (mode == 'FSDP'):
            ddp_loss = torch.zeros(2).to(rank)
            sampler.set_epoch(epoch)

        for iter in tqdm(range(iters)):
            X, Y = get_batch(train_data, device_type, batch_size = 2)

            if (mode == 'sGPU'):
                X, Y = X.to(device), Y.to(device)
            else:
                X, Y = X.to(rank), Y.to(rank)

            optimizer.zero_grad()

            if (mode == 'sGPU'):
                logits, loss = model(X, Y)
            elif (mode == 'DDP'):
                logits, loss = ddp_model(X, Y)
            elif (mode == 'FSDP'):
                logits, loss = fsdp_model(X, Y)

            loss.backward()
            optimizer.step()

            if (mode != 'FSDP'):
                  losses += loss.item()
                  total += len(X)
            else:
                ddp_loss[0] += loss.item()
                ddp_loss[1] += len(X)

        if (mode == 'FSDP'):
            dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)

        print(f"Epoch:{epoch}, Loss:{losses/total}" if (mode != 'FSDP') else f"Epoch:{epoch}, Loss:{ddp_loss[0] / ddp_loss[1]}")
        print(estimate_loss(model, train_data, val_data, device_type, batch_size = 2))
        print("------------------------\n")

    if (mode != 'sGPU'):
        cleanup()

In [53]:
def train(train_data, val_data, device_type, mode='sGPU'):
    '''
      compatible value of mode = sGPU, DDP, and FSDP
    '''
    assert mode == 'sGPU' or mode == 'DDP' or mode == 'FSDP', "compatible value of mode = sGPU, DDP, and FSDP"

    partial_epochs = 5
    iters = 20
    batch_size = 2
    config = GPTConfig()
    model = GPT(config)
    n_gpus = torch.cuda.device_count()
    world_size = n_gpus
    rank = 0

    if (mode == 'sGPU'):
        model = model.to(device)
        optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type)
        train_loop(rank, world_size, partial_epochs, iters, train_data, device_type, optimizer, model, val_data, mode, batch_size)

    else: # (mode == 'DDP' or mode == 'FSDP'):
        if n_gpus < 2:
            print(f"Requires at least 2 GPUs to run Multiprocessing using DDP or FSDP, but got {n_gpus}")
            train_loop(rank, world_size, partial_epochs, iters, train_data, device_type, None, model, val_data, mode, batch_size)
        else:
            print(f"Utilizing {n_gpus} GPUs\n")
            mp.spawn(train_loop,
                args=(world_size, partial_epochs, iters, train_data, device_type, None, model, val_data, mode, batch_size,),
                nprocs=world_size,
                join=True)

    return model

### Training on the prepared data - give input by expanding this block

In [54]:
train_data = np.memmap('train.bin', dtype=np.uint16, mode='r')
val_data = np.memmap('val.bin', dtype=np.uint16, mode='r')
while (True):
    mode = input("Enter the mode in which you want to train the file (compatible value of mode = sGPU, DDP, and FSDP)")
    if (mode == 'sGPU' or mode == 'DDP' or mode == 'FSDP'):
        break;
    else:
        print("compatible value of mode = sGPU, DDP, and FSDP")

# call training function with mode specified
model = train(train_data, val_data, device_type, mode=mode)

Enter the mode in which you want to train the file (compatible value of mode = sGPU, DDP, and FSDP)FSDP
number of parameters: 123.65M
Requires at least 2 GPUs to run Multiprocessing using DDP or FSDP, but got 1


100%|██████████| 20/20 [00:12<00:00,  1.54it/s]


Epoch:1, Loss:4.34592342376709
{'train': tensor(3.8753), 'val': tensor(3.8435)}
------------------------



100%|██████████| 20/20 [00:13<00:00,  1.49it/s]


Epoch:2, Loss:3.586996078491211
{'train': tensor(3.3317), 'val': tensor(3.3227)}
------------------------



100%|██████████| 20/20 [00:13<00:00,  1.54it/s]


Epoch:3, Loss:3.2246177196502686
{'train': tensor(3.0996), 'val': tensor(3.1252)}
------------------------



100%|██████████| 20/20 [00:12<00:00,  1.56it/s]


Epoch:4, Loss:3.0579822063446045
{'train': tensor(3.0534), 'val': tensor(3.0596)}
------------------------



100%|██████████| 20/20 [00:13<00:00,  1.53it/s]


Epoch:5, Loss:3.0134801864624023
{'train': tensor(2.9528), 'val': tensor(3.0299)}
------------------------



If you encounter "RuntimeError: trying to initialize the default process group twice!", run this cell by uncommenting

In [55]:
# cleanup()

Generate a prediction after training on the dataset

In [56]:
enc = tiktoken.get_encoding("gpt2")
text = '''What is the role'''
outputs = model.generate(idx = torch.tensor([enc.encode_ordinary(text)]).to(device), max_new_tokens = 20, temperature = 0.8, top_k = 200)
print(enc.decode(outputs[0].tolist()))

What is the role


GLES:

 of hear'll?

As her off in a like


### References

DDP - https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

FSDP - https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html#how-to-use-fsdp

Unused code

In [57]:
# from tqdm import tqdm
# def train(model, optimizer, train_data, val_data, device_type):
#     partial_epochs = 1
#     iters = 50
#     train_kwargs = {'batch_size': batch_size, 'num_workers': 2,
#                     'pin_memory': True,
#                     'shuffle': False}
#     train_loader = torch.utils.data.DataLoader(train_data,**train_kwargs)

#     for epoch in range(1, partial_epochs+1):
#         progress = tqdm(enumerate(train_loader), desc="Epoch: {}".format(epoch), total=len(train_loader))
#         loss = 0
#         total = 0
#         for iter, (X, Y) in progress:
#             X, Y = X.to(device), Y.to(device)
#             print(X)
#             optimizer.zero_grad()
#             logits, loss = model(X, Y)
#             loss.backward()
#             optimizer.step()
#             progress.update(1)
#             loss += loss.item()
#             total += len(X)
#         print(f"Epoch{epoch}, Loss:{loss/total}")

# train_data = np.memmap('train.bin', dtype=np.uint16, mode='r')
# val_data = np.memmap('val.bin', dtype=np.uint16, mode='r')
# train(model, optimizer, train_data, val_data, device_type)

## Task 2

### RoPE (Rotary Positional Embedding)

In [27]:
pip install einops

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0


#### Defining helper functions and Class to apply Rotary Positional Embedding

In [28]:
from math import pi, log

import torch
from torch.cuda.amp import autocast
from torch import einsum, broadcast_tensors

from einops import rearrange, repeat

In [29]:
# helper functions

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

# rotary embedding helper functions

def rotate_half(x):
    x = rearrange(x, '... (d r) -> ... d r', r = 2)
    x1, x2 = x.unbind(dim = -1)
    x = torch.stack((-x2, x1), dim = -1)
    return rearrange(x, '... d r -> ... (d r)')

@autocast(enabled = False)
def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2):
    rot_dim, seq_len = freqs.shape[-1], t.shape[seq_dim]
    freqs = freqs[-seq_len:].to(t)

    end_index = start_index + rot_dim
    assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
    t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
    t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
    return torch.cat((t_left, t, t_right), dim = -1)

# learned rotation helpers

def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None):
    if exists(freq_ranges):
        rotations = einsum('..., f -> ... f', rotations, freq_ranges)
        rotations = rearrange(rotations, '... r f -> ... (r f)')

    rotations = repeat(rotations, '... n -> ... (n r)', r = 2)
    return apply_rotary_emb(rotations, t, start_index = start_index)

In [30]:
# classes

class RotaryEmbedding(nn.Module):
    def __init__(
        self,
        dim,
        theta = 10000,
        max_freq = 10,
        num_freqs = 1,
        learned_freq = False,
        use_xpos = True,
        xpos_scale_base = 512,
        interpolate_factor = 1.,
        theta_rescale_factor = 1.,
        seq_before_head_dim = False
    ):
        super().__init__()
        # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
        # has some connection to NTK literature
        # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/

        # theta *= theta_rescale_factor ** (dim / (dim - 2))

        # θi = 10000**(-2i/d)
        freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))

        self.cache = dict()
        self.cache_scale = dict()
        self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)

        self.learned_freq = learned_freq

        # default sequence dimension

        self.seq_before_head_dim = seq_before_head_dim
        self.default_seq_dim = -3 if seq_before_head_dim else -2

        # interpolation factors

        assert interpolate_factor >= 1.
        self.interpolate_factor = interpolate_factor

        # xpos

        self.use_xpos = use_xpos
        if not use_xpos:
            self.register_buffer('scale', None)
            return

        scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
        self.scale_base = xpos_scale_base
        self.register_buffer('scale', scale)

    def get_seq_pos(self, seq_len, device, dtype, offset = 0):
        return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor

    def rotate_queries_and_keys(self, q, k, seq_dim = None):
        seq_dim = default(seq_dim, self.default_seq_dim)

        assert self.use_xpos
        device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]

        seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)
        freqs = self.forward(lambda: seq, cache_key = f'freqs:{seq_len}')
        scale = self.get_scale(lambda: seq, cache_key = f'scale:{seq_len}').to(dtype)

        if seq_dim == -3:
            freqs = rearrange(freqs, 'n d -> n 1 d')
            scale = rearrange(scale, 'n d -> n 1 d')

        rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim)
        rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim)

        rotated_q = rotated_q.type(q.dtype)
        rotated_k = rotated_k.type(k.dtype)

        return rotated_q, rotated_k

    def get_scale(self, t, cache_key = None):
        assert self.use_xpos

        if exists(cache_key) and cache_key in self.cache:
            return self.cache[cache_key]

        if callable(t):
            t = t()

        scale = 1.
        if self.use_xpos:
            power = (t - len(t) // 2) / self.scale_base
            scale = self.scale ** rearrange(power, 'n -> n 1')
            scale = torch.cat((scale, scale), dim = -1)

        if exists(cache_key):
            self.cache[cache_key] = scale

        return scale

    @autocast(enabled = False)
    def forward(self, t, cache_key = None):
        should_cache = not self.learned_freq and exists(cache_key)

        if should_cache and cache_key in self.cache:
            return self.cache[cache_key]

        if callable(t):
            t = t()

        freqs = self.freqs

        freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
        freqs = repeat(freqs, '... n -> ... (n r)', r = 2)

        if should_cache:
            self.cache[cache_key] = freqs

        return freqs

#### Compiling the model and set emb_type_RoPE = True for applying RoPE

In [31]:
config = GPTConfig()
config.emb_type_RoPE = True
model_RoPE = GPT(config)
model_RoPE = model_RoPE.to(device)
optimizer_RoPE = model_RoPE.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type)
model_RoPE.eval()

Using RoPE

number of parameters: 123.65M
num decayed parameter tensors: 50, with 124,318,464 parameters
num non-decayed parameter tensors: 98, with 121,344 parameters
using fused AdamW: True


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (rotary_emb): RotaryEmbedding()
          (c_attn): Linear(in_features=768, out_features=2304, bias=True)
          (c_proj): Linear(in_features=768, out_features=768, bias=True)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

Run a sample prediction (please note that training is not done for this model)

In [32]:
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
text = '''What is the best'''
encoded_input = tokenizer(text, return_tensors='pt')
outputs = model_RoPE.generate(idx = encoded_input['input_ids'].to(device), max_new_tokens = 20, temperature = 0.8, top_k = 200)
print(tokenizer.decode(outputs[0]))

What is the bestanes horr unexilatedoit editionsaccompanied w 1943 alas differe Speed yeastends potential puppies among Dusk orbitbrids


In [33]:
train_data = np.memmap('train.bin', dtype=np.uint16, mode='r')
val_data = np.memmap('val.bin', dtype=np.uint16, mode='r')

train_loop(0, torch.cuda.device_count(), 5, 20, train_data, device_type, optimizer_RoPE, model_RoPE, val_data, mode='sGPU', batch_size = 2)

100%|██████████| 20/20 [00:14<00:00,  1.37it/s]


Epoch:1, Loss:4.355218839645386
{'train': tensor(3.8761), 'val': tensor(3.8937)}
------------------------



100%|██████████| 20/20 [00:15<00:00,  1.30it/s]


Epoch:2, Loss:3.561240720748901
{'train': tensor(3.3037), 'val': tensor(3.3676)}
------------------------



100%|██████████| 20/20 [00:14<00:00,  1.36it/s]


Epoch:3, Loss:3.2067344427108764
{'train': tensor(3.1453), 'val': tensor(3.1720)}
------------------------



100%|██████████| 20/20 [00:14<00:00,  1.37it/s]


Epoch:4, Loss:3.0743162035942078
{'train': tensor(3.0499), 'val': tensor(3.0921)}
------------------------



100%|██████████| 20/20 [00:14<00:00,  1.35it/s]


Epoch:5, Loss:2.9654118061065673
{'train': tensor(2.9142), 'val': tensor(2.9718)}
------------------------



After training

In [34]:
text = '''What is the best'''
encoded_input = tokenizer(text, return_tensors='pt')
outputs = model_RoPE.generate(idx = encoded_input['input_ids'].to(device), max_new_tokens = 20, temperature = 0.8, top_k = 200)
print(tokenizer.decode(outputs[0]))

What is the best
And and's of thy in so,
InUS:
S the the to me,


#### References

Github - https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py

### GQA (Group Query Attention)

Grouped query attention implementation with number of query heads = 12 with 2 heads grouped, hence new number of query heads become 6

6 groups and 2 query heads per group

> For converting a multi-head checkpoint to a GQA checkpoint, we
construct each group key and value head by meanpooling all the original heads within that group

> It implements a form of mean pooling where the averaging is done implicitly through the concatenation and projection operations

>> torch.cat([ ... for query in self.querys ], dim=2) line concatenates the attention outputs from all heads along the third dimension, and

>>proj(Z_s) applies the projection layer to the concatenated attention output

#### Defining Class to implement MQA and GQA

> MQA gives one key, value output for an input of any number querys

> We apply MQA on a group of 2 query heads

> And for GQA we create 6 blocks of MQA (6*2 = 12 heads)

In [35]:
class  MultiQueryAttention(nn.Module):
    r"""
    https://arxiv.org/pdf/1911.02150.pdf
    """
    def __init__(self, word_size, embed_dim, n_query): # 768, 64, 2
        super().__init__()
        self.n_query = n_query
        self.querys = nn.ModuleList([
            nn.Linear(in_features=word_size, out_features=embed_dim, bias=True)
            for _ in range(n_query)
        ])
        self.key = nn.Linear(in_features=word_size, out_features=embed_dim, bias=True)
        self.value = nn.Linear(in_features=word_size, out_features=embed_dim, bias=True)
        self.proj = nn.Linear(in_features=embed_dim*n_query,
                              out_features=embed_dim, bias=True)

    def forward(self, x):
        # x.shape --> [1, 4, 768]
        K = self.key(x)
        V = self.value(x)
        # print("1", F.scaled_dot_product_attention(self.querys[0](x), K, V).size())
        Z_s = torch.cat([
            F.scaled_dot_product_attention(query(x), K, V) for query in self.querys
        ], dim=2)
        # print("2", Z_s.size()) --> 1, 4, 128
        # Z_s = torch.mean(Z_s, dim = 0)
        Z = self.proj(Z_s) # --> 1, 4, 64
        return Z

class  GroupedQueryAttention(nn.Module):
    r"""
    https://arxiv.org/pdf/2305.13245.pdf
    """
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        word_size = config.n_embd # 768
        embed_dim = config.n_embd // config.n_head # 768/12 = 64
        n_query_each_group = config.heads_per_group # 2
        n_grouped = config.gqa_groups # 6
        self.grouped = nn.ModuleList([MultiQueryAttention(word_size, embed_dim, n_query=n_query_each_group) for _ in range(n_grouped)])
        self.proj = nn.Linear(in_features=embed_dim*n_grouped, out_features=config.n_embd, bias=True)

    def forward(self, x, mask=None):
        # x.shape --> [1, 4, 768]
        Z_s = torch.cat([head(x) for head in self.grouped], dim=2)
        # print(Z_s.size()) --> [1, 4, 384]
        # Z_s = torch.mean(Z_s, dim = 0)
        Z = self.proj(Z_s) # --> [1, 4, 768]
        return Z

#### Compiling the model and set attn_type_GQA = True for applying GQA

In [36]:
config = GPTConfig()
config.attn_type_GQA = True
model_GQA = GPT(config)
model_GQA = model_GQA.to(device)
optimizer_GQA = model_GQA.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type)
model_GQA.eval()

Using GQA

number of parameters: 113.62M
num decayed parameter tensors: 398, with 114,291,456 parameters
num non-decayed parameter tensors: 446, with 116,736 parameters
using fused AdamW: True


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm()
        (attn): GroupedQueryAttention(
          (grouped): ModuleList(
            (0-5): 6 x MultiQueryAttention(
              (querys): ModuleList(
                (0-1): 2 x Linear(in_features=768, out_features=64, bias=True)
              )
              (key): Linear(in_features=768, out_features=64, bias=True)
              (value): Linear(in_features=768, out_features=64, bias=True)
              (proj): Linear(in_features=128, out_features=64, bias=True)
            )
          )
          (proj): Linear(in_features=384, out_features=768, bias=True)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=3072, 

Observe reduce number of parameters above for same number of heads

Run a sample prediction (please note that training is not done for this model)

In [37]:
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
text = '''What is the best'''
encoded_input = tokenizer(text, return_tensors='pt')
outputs = model_GQA.generate(idx = encoded_input['input_ids'].to(device), max_new_tokens = 20, temperature = 0.8, top_k = 200)
print(tokenizer.decode(outputs[0]))

What is the bestcluded smartphone coy hairst carbs remainderannel Army Venezuela Previously outer Newsletter benchmark tides Opposition Gender%" confid sequencing Vic


In [38]:
train_data = np.memmap('train.bin', dtype=np.uint16, mode='r')
val_data = np.memmap('val.bin', dtype=np.uint16, mode='r')

train_loop(0, torch.cuda.device_count(), 5, 20, train_data, device_type, optimizer_GQA, model_GQA, val_data, mode='sGPU', batch_size = 2)

100%|██████████| 20/20 [00:13<00:00,  1.44it/s]


Epoch:1, Loss:4.2742954134941105
{'train': tensor(3.8272), 'val': tensor(3.8118)}
------------------------



100%|██████████| 20/20 [00:14<00:00,  1.42it/s]


Epoch:2, Loss:3.5456133127212524
{'train': tensor(3.3337), 'val': tensor(3.3423)}
------------------------



100%|██████████| 20/20 [00:13<00:00,  1.44it/s]


Epoch:3, Loss:3.1780059695243836
{'train': tensor(3.0757), 'val': tensor(3.1482)}
------------------------



100%|██████████| 20/20 [00:13<00:00,  1.43it/s]


Epoch:4, Loss:3.051790547370911
{'train': tensor(2.9899), 'val': tensor(3.0486)}
------------------------



100%|██████████| 20/20 [00:13<00:00,  1.43it/s]


Epoch:5, Loss:2.96373770236969
{'train': tensor(2.9068), 'val': tensor(2.9822)}
------------------------



After training

In [39]:
text = '''What is the best'''
encoded_input = tokenizer(text, return_tensors='pt')
outputs = model_GQA.generate(idx = encoded_input['input_ids'].to(device), max_new_tokens = 20, temperature = 0.8, top_k = 200)
print(tokenizer.decode(outputs[0]))

What is the best, and, to,


And, the the be him,
 this prince




#### References

1. GQA - https://arxiv.org/pdf/2305.13245v2.pdf

2. Github - https://github.com/knotgrass/attention/blob/main/attn/attention.py

Unused code GQA

In [40]:
class GroupedQueryAttention2(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn_q = nn.Linear(config.n_embd, config.heads_per_group * config.n_embd, bias=config.bias)
        self.c_attn_kv = nn.Linear(config.n_embd, 2 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        # self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        self.flash = False
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q = self.c_attn_q(x)
        k, v = self.c_attn_kv(x).split(self.n_embd, dim=2)
        # print(k.shape, v.shape, q.shape)

        # Group queries into 6 groups of 2 queries each
        num_groups = config.gqa_groups
        queries_per_group = config.heads_per_group
        # hs = C // num_groups
        # grouped_queries = []
        # for i in range(num_groups):
        #     group_start = i * hs
        #     group_end = (i+1)*hs
        #     group_queries = q[:, :, group_start:group_end]
        #     print(group_queries.shape)
        #     group_queries = group_queries.view(B, T, queries_per_group, hs//queries_per_group)
        #     print(group_queries.shape)
        #     mean_query = torch.mean(group_queries, dim=2)
        #     print(mean_query.shape)

        #     grouped_queries.append(mean_query)

        grouped_queries = q.view(B, T, queries_per_group, C)
        # print(grouped_queries.shape)
        q = torch.mean(grouped_queries, dim=2)

        # Reshape grouped queries to match the expected shape
        # print(grouped_queries.shape)
        # grouped_queries = torch.stack(grouped_queries, dim=0)
        q = q.view(B, T, num_groups, C // (num_groups)).transpose(1, 2)
        # print(grouped_queries.shape)
        k = k.view(B, T, num_groups, C // (num_groups)).transpose(1, 2)
        # q = grouped_queries.view(B, T, num_groups, C // num_groups)
        v = v.view(B, T, num_groups, C // (num_groups)).transpose(1, 2)
        # print(k.shape, q.shape, v.shape)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y