NOTE: FILE IS CREATED IN KAGGLE, TO RUN THE CELLS PREFER USING KAGGLE (accelerator GPU - T4x2 , add dataset - all-the-news)

In [None]:
!pip install einops
!pip install fancy_einsum

Collecting einops
  Obtaining dependency information for einops from https://files.pythonhosted.org/packages/29/0b/2d1c0ebfd092e25935b86509a9a817159212d82aa43d7fb07eca4eeff2c2/einops-0.7.0-py3-none-any.whl.metadata
  Downloading einops-0.7.0-py3-none-any.whl.metadata (13 kB)
Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0
Collecting fancy_einsum
  Downloading fancy_einsum-0.0.3-py3-none-any.whl (6.2 kB)
Installing collected packages: fancy_einsum
Successfully installed fancy_einsum-0.0.3


In [None]:

%%writefile GPT2.py

import torch
import torch.nn as nn
from dataclasses import dataclass
import einops
import numpy as np
import math
import torch.nn.functional as F
from fancy_einsum import einsum
from transformers import GPT2Model
from torch.utils.data import DataLoader, Dataset
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AdamW
import pandas as pd
import tqdm
import os
import functools
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    CPUOffload,
    BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
)
@dataclass
class configure_model():
  dim:int = 768
  eps:float = 1e-5
  vocab:int = 50257
  init_stddev:float = 0.02
  max_word:int = 1024
  n_head:int = 12
  d_head:int = 64
  d_mlp:int = 3072
  n_layers:int = 12



configure = configure_model()


class Layer_Norm(nn.Module):
  def __init__(self,configure):
    super().__init__()
    self.configure = configure
    self.w = nn.Parameter(torch.ones(configure.dim))
    self.b = nn.Parameter(torch.zeros(configure.dim))

  def forward(self,in_tensor):

    #taking mean along the dimension
    # batch_size --> number of sentences taken, pos --> number of words in sequence , dim --> size of the embedding layer
    #z=(x-mu)
    z= in_tensor - einops.reduce(in_tensor,"batch_size seq_len dim -> batch_size seq_len 1","mean")

    #variance+eps
    var = (einops.reduce(z**2,"batch_size seq_len dim -> batch_size seq_len 1","mean")+configure.eps).sqrt()

    #normalize
    l_norm = self.w*(z/var)+self.b

    return l_norm

class embed_input(nn.Module):
  def __init__(self,configure):
    super().__init__()

    self.embed_w = nn.Parameter(torch.empty((configure.vocab,configure.dim)))

    #changing the standard deviation to 0.02(mentioned in GPT2 original paper)
    nn.init.normal_(self.embed_w, std=configure.init_stddev)

  def forward(self,in_tokens):
    #retrieve the rows from corresponding tokens
    return self.embed_w[in_tokens,:]


class pos_embed(nn.Module):
  def __init__(self,configure):
    super().__init__()
    self.pos_w = nn.Parameter(torch.empty((configure.max_word,configure.dim)))
    nn.init.normal_(self.pos_w, std=configure.init_stddev)

  def forward(self,in_tokens):
    #take pos_embeddings upto number od words in sequence
    pos_embed = self.pos_w[:in_tokens.size(1),:]
    pos_embed = einops.repeat(pos_embed, "position dim -> batch position dim", batch=in_tokens.size(0))

    return pos_embed



class Attention(nn.Module):
    def __init__(self, configure):
        super().__init__()
        self.configure =configure

        # Initializing separate linear layers for queries, keys, and values
        self.linear_qkv = nn.Parameter(torch.empty((configure.dim, 3 * configure.n_head * configure.d_head)))
        nn.init.normal_(self.linear_qkv.data, std=configure.init_stddev)
        self.bias_qkv = nn.Parameter(torch.zeros(configure.n_head*3*configure.d_head))  # Adjust dimensions

        #initializing output weights
        self.w_o = nn.Parameter(torch.empty((configure.n_head*configure.d_head,configure.dim)))
        nn.init.normal_(self.w_o, std=configure.init_stddev)
        self.b_o = nn.Parameter(torch.zeros(configure.dim))

        self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32, device="cuda"),persistent=False)

    def apply_causal_mask(self, attn_scores):

        mask = torch.triu(torch.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device), diagonal=1).bool()
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores

    def forward(self, norm_input):



        reshaped_layer = self.linear_qkv.view(3, self.configure.n_head,self.configure.dim,self.configure.d_head)
        reshaped_bias = self.bias_qkv.view(3, self.configure.n_head, self.configure.d_head)

       # Separate linear layers for queries, keys, and values
        q_layer, k_layer, v_layer =   reshaped_layer
        q_bias, k_bias, v_bias = reshaped_bias

        q = einsum("batch query_pos dim, n_head dim d_head -> batch query_pos n_head d_head", norm_input, q_layer) + q_bias
        k = einsum("batch key_pos dim, n_head dim d_head -> batch key_pos n_head d_head", norm_input, k_layer) + k_bias
        v = einsum("batch key_pos dim, n_head dim d_head -> batch key_pos n_head d_head", norm_input, v_layer) + v_bias

        attn = einsum("batch query_pos n_head d_head, batch key_pos n_head d_head -> batch n_head query_pos key_pos", q, k)
        attn = attn / (math.sqrt(self.configure.d_head))
        attn = self.apply_causal_mask(attn)

        attn_prob = F.softmax(attn, dim=-1)

        att = einsum("batch n_head query_pos key_pos, batch key_pos n_head d_head -> batch query_pos n_head d_head", attn_prob, v)

        reshaped_o_layer = self.w_o.view(self.configure.n_head,self.configure.d_head,self.configure.dim)
        out = einsum("batch query_pos n_head d_head, n_head d_head dim -> batch query_pos dim", att, reshaped_o_layer)+self.b_o

        return out



class mlp(nn.Module):
  def __init__(self,configure):
    super().__init__()
    self.w1 = nn.Parameter(torch.empty((configure.dim,configure.d_mlp)))
    nn.init.normal_(self.w1, std=configure.init_stddev)
    self.b1 = nn.Parameter(torch.zeros(configure.d_mlp))

    self.w2 = nn.Parameter(torch.empty((configure.d_mlp,configure.dim)))
    nn.init.normal_(self.w2, std=configure.init_stddev)
    self.b2 = nn.Parameter(torch.zeros(configure.dim))

    self.gelu = nn.GELU()

  def forward(self, x):

     o1 = self.gelu(einsum("batch pos dim, dim d_mlp -> batch pos d_mlp", x, self.w1)+self.b1)
     o2 = einsum("batch pos d_mlp, d_mlp dim -> batch pos dim", o1,self.w2)+self.b2
     return o2

class GPT2_block(nn.Module):
  def __init__(self,configure):
    super().__init__()

    self.l1 = Layer_Norm(configure)
    self.self_att = Attention(configure)
    self.l2 = Layer_Norm(configure)
    self.mlp_out = mlp(configure)

  def forward(self,embeddings):
    norm_embeddings = self.l1(embeddings)
    self_att_out = self.self_att(norm_embeddings)
    out1 = embeddings+self_att_out

    out_embeddings_norm = self.l2(out1)
    mlp_out = self.mlp_out(out_embeddings_norm)
    out2 = out1 + mlp_out

    return out2

class logits_layer(nn.Module):
  def __init__(self,configure):
    super().__init__()
    self.w1 = nn.Parameter(torch.empty((configure.dim,configure.vocab)))
    nn.init.normal_(self.w1, std=configure.init_stddev)
    self.b1 = nn.Parameter(torch.zeros((configure.vocab), requires_grad=False))

  def forward(self, gpt_out):

    logits = einsum("batch position dim, dim vocab -> batch position vocab", gpt_out, self.w1) + self.b1
    return logits


class GPT2(nn.Module):
  def __init__(self):
    super().__init__()
    configure = configure_model()
    self.embed_token = embed_input(configure)
    self.embed_pos = pos_embed(configure)
    self.blocks = nn.ModuleList([GPT2_block(configure) for _ in range(configure.n_layers)])
    self.l = Layer_Norm(configure)


  def forward(self,tokens):
    embed1 = self.embed_token(tokens)
    embed2 = self.embed_pos(tokens)

    out = embed1+embed2

    for block in self.blocks:
      out = block(out)

    norm_out = self.l(out)


    return norm_out

class crossentropyloss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, logits, tokens):
        log_probs = logits.log_softmax(dim=-1)
        pred_log_probs = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
        return -pred_log_probs.mean()

class custom_dataset(Dataset):
    def __init__(self, tokenizer, max_length=1024,num_rows=50):
        super().__init__()

        self.dataset_path = '/kaggle/input/all-the-news/articles1.csv'
        self.dataframe = pd.read_csv(self.dataset_path,nrows=num_rows)
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Tokenize each entry and add it to the list
        self.content_list = [self.tokenize(content) for content in self.dataframe['content']]

    def tokenize(self, text):
        # Encode the text into tokens with truncation if necessary
        tokens = self.tokenizer.encode(text, max_length=self.max_length, truncation=True)
        return tokens

    def __len__(self):
        return len(self.content_list)

    def __getitem__(self, idx):
        return torch.tensor(self.content_list[idx])

def get_dataloader():
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    cus_dataset = custom_dataset(tokenizer)
    dataloader2 = DataLoader(
        cus_dataset,
        batch_size=1,
        pin_memory=True,
        shuffle=False,
        sampler=DistributedSampler(cus_dataset)
    )
    return dataloader2



def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '5554'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def training(model, rank, world_size, train_loader, optimizer, epoch):
    model.train()
    get_logits = logits_layer(configure)
    criterion = crossentropyloss()
    ddp_loss = torch.zeros(2).to(rank)
    for batch_idx, tokens in enumerate(train_loader):
        tokens = tokens.to(rank)
        output = model(tokens)
        get_logits = get_logits.to(output.device)
        logits = get_logits(output)
        loss = criterion(logits, tokens)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ddp_loss[0] += loss.item()
        ddp_loss[1] += len(tokens)
    dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
    if rank == 0:
        print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1]))

def fsdp_model(rank, world_size,dataset_path):

    setup(rank, world_size)
    data_loader = get_dataloader()
    lr = 1e-3
    weight_decay = 1e-2
    epochs = 2


    my_auto_wrap_policy = functools.partial(
       transformer_auto_wrap_policy,
       transformer_layer_class={
           GPT2,
       },
    )
    torch.cuda.set_device(rank)
    init_start_event = torch.cuda.Event(enable_timing=True)
    init_end_event = torch.cuda.Event(enable_timing=True)

    model = GPT2().to(rank)
    model = FSDP(model)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    init_start_event.record()
    for epoch in range(1, epochs + 1):
        training(model, rank, world_size, data_loader, optimizer, epoch)
    init_end_event.record()
    cleanup()

def run_demo(demo_fn, world_size, dataset_path):
    mp.spawn(demo_fn,
             args=(world_size,dataset_path),
             nprocs=world_size,
             join=True)

if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    print(f"total GPUs: {n_gpus}")
    assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
    dataset_path = '/kaggle/input/all-the-news/articles1.csv'
    world_size = n_gpus
    run_demo(fsdp_model, world_size, dataset_path)


Overwriting GPT2.py


In [None]:
!python GPT2.py

total GPUs: 2
[W socket.cpp:601] [c10d] The client socket has failed to connect to [localhost]:5554 (errno: 99 - Cannot assign requested address).
Train Epoch: 1 	Loss: 10.012598
Train Epoch: 2 	Loss: 10.283640
