In [8]:
%%writefile fused_encoder_train.py

import gc
import os
import sys
import time
import random
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator
from transformers import AutoModel, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
from transformers import AutoConfig
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import warnings
warnings.simplefilter('ignore')
from torch.amp import custom_fwd, custom_bwd
from accelerate import Accelerator, DistributedDataParallelKwargs
torch._dynamo.config.capture_dynamic_output_shape_ops = True
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

from dataclasses import dataclass


#to do
#transform this into 2d matrix multiplication  like the rest of the layers
@torch.compile(fullgraph=True)
def linear_fwd(input_features, weight, bias):
    output = torch.einsum("...ik, ...jk -> ...ij", input_features, weight)  # x*w^T
       
    if bias is not None:
        output += bias.unsqueeze(0).expand_as(output)
    return output

@torch.compile(fullgraph=True)
def linear_bwd(grad_output,input_features, weight, bias):
    grad_input = grad_weight = grad_bias = None

    grad_input = torch.einsum("...ij, ...jk -> ...ik", grad_output, weight)
    
    grad_weight = torch.einsum(
        "...ji, ...jk -> ...ik", grad_output, input_features
    )
    
    grad_bias = grad_output.sum(0)

    return grad_input, grad_weight, grad_bias
    
class LinearFunction(torch.autograd.Function):

    @staticmethod
    @custom_fwd( device_type='cuda')
    def forward(ctx, input_features, weight, bias=None):
        ctx.save_for_backward(input_features, weight, bias)
        return linear_fwd(input_features, weight, bias)
        
    @staticmethod
    @custom_bwd(device_type='cuda')
    def backward(ctx, grad_output):
     
        input_features, weight, bias = ctx.saved_tensors
        return linear_bwd(grad_output,input_features, weight, bias)

class MyLinear(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        super(MyLinear, self).__init__()
        self.input_features = input_features
        self.output_features = output_features

        self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
        torch.nn.init.xavier_uniform_(self.weight)
        
        if bias:
            self.bias = nn.Parameter(torch.Tensor(output_features))
            torch.nn.init.zeros_(self.bias)
        else:
            # You should always register all possible parameters, but the
            # optional ones can be None if you want.
            self.register_parameter("bias", None)

      
        self.weight = nn.init.xavier_normal_(self.weight)
        if self.bias is not None:
            self.bias =   nn.init.zeros_(self.bias)
    def forward(self, x):
        # See the autograd section for explanation of what happens here.
        return LinearFunction.apply(x, self.weight, self.bias)

    def extra_repr(self):
        # Set the extra information about this module. You can test
        # it by printing an object of this class.
        return "input_features={}, output_features={}, bias={}".format(
            self.input_features, self.output_features, self.bias is not None
        )

class LayerNormFn(torch.autograd.Function):
    @staticmethod
    @custom_fwd(device_type='cuda')
    def forward(ctx, x, weight, bias, eps):
        var, mean = torch.var_mean(x, dim=-1, keepdim=True)
        
        # mean = torch.mean(x, dim=-1, keepdim=True)
        # var = torch.var(x, dim=-1, keepdim=True, unbiased=False)
        
        x_normalized = (x - mean) / torch.sqrt(var + eps)
        y = x_normalized * weight + bias
        ctx.save_for_backward(x, x_normalized, weight, mean, var)
        ctx.eps = eps
        return y

    @staticmethod
    @custom_bwd(device_type='cuda')
    def backward(ctx, grad_output):
        x, x_normalized, weight, mean, var = ctx.saved_tensors
        eps = ctx.eps

        grad_weight = torch.sum(grad_output * x_normalized, dim=(0))
        grad_bias = torch.sum(grad_output, dim=(0))

        grad_x_normalized = grad_output * weight
        grad_var = torch.sum(
            grad_x_normalized * (x - mean) * (-0.5) * torch.pow(var + eps, -1.5),
            dim=(-1),
            keepdim=True,
        )
        grad_mean = torch.sum(
            grad_x_normalized * (-1 / torch.sqrt(var + eps)), dim=(-1), keepdim=True
        ) + grad_var * torch.mean(-2 * (x - mean), dim=(-1), keepdim=True)
        grad_x = (
            grad_x_normalized / torch.sqrt(var + eps)
            + grad_var * 2 * (x - mean) / x.shape[-1]
            + grad_mean / x.shape[-1]
        )

        return grad_x, grad_weight, grad_bias, None


class LayerNorm(torch.nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        super().__init__()
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = tuple(normalized_shape)
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            self.weight = torch.nn.Parameter(torch.ones(normalized_shape))
            self.bias = torch.nn.Parameter(torch.zeros(normalized_shape))
            nn.init.ones_(self.weight)
            nn.init.zeros_(self.bias)
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)

    def forward(self, x):
        if self.elementwise_affine:
            return LayerNormFn.apply(x, self.weight, self.bias, self.eps)
        else:
            return LayerNormFn.apply(
                x,
                torch.ones_like(x.shape[-1:]),
                torch.zeros_like(x.shape[-1:]),
                self.eps,
            )



@torch.compile(fullgraph=True)
def rms_forward(x, weight, eps):
    # Calculate the RMS
    rms = x.float().pow(2).mean(-1, keepdim=True).add(eps).sqrt()
    # Normalize
    x_norm = x / rms
    # Apply the gain (weight)
    output = weight * x_norm.type_as(weight)
    return output

@torch.compile(fullgraph=True)
def rms_backward(grad_output, x, weight,eps):
    
    rms = x.float().pow(2).mean(-1, keepdim=True).add(eps).sqrt()
    # Normalize
    x_norm = x / rms
    # Gradients calculations
    grad_weight = torch.sum(grad_output * x_norm, dim=(0))

    # Compute grad_x: we need to backpropagate through normalization and RMS
    grad_x_norm = grad_output * weight  # Gradients w.r.t normalized input
    grad_rms = (grad_x_norm * (-x_norm)).sum(
        -1, keepdim=True
    ) / rms  # Gradients w.r.t RMS

    # Gradient w.r.t input x: the gradient of x involves the gradient of the normalization (grad_x_norm)
    grad_x = grad_x_norm / rms + grad_rms * x_norm / x_norm.size(
        -1
    )  # RMS backpropagation

    return grad_x, grad_weight
    
class RMSNormFn(torch.autograd.Function):
    @staticmethod
    @custom_fwd(device_type='cuda')
    def forward(ctx, x, weight, eps):
        output = rms_forward(x, weight, eps)
        ctx.save_for_backward(x, weight)
        ctx.eps = eps
        return output
        
    @staticmethod
    @custom_bwd(device_type='cuda')
    def backward(ctx, grad_output):
        x, weight = ctx.saved_tensors
        eps = ctx.eps
        grad_x, grad_weight = rms_backward(grad_output,x, weight,eps)

        return grad_x, grad_weight, None


class RMSNorm(torch.nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x):
        return RMSNormFn.apply(x, self.weight, self.eps)


@torch.compile(fullgraph=True)
def linear_fwd_2d(x,weight, bias):
    output = x.mm(weight.t())
    if bias is not None:
        output += bias.unsqueeze(0).expand_as(output)
    return output


@torch.compile(fullgraph=True)
def linear_bwd_2d(grad_rms,weight,x,bias):
    grad_input = grad_rms.mm(weight)
    grad_prev = grad_rms

    grad_weight = grad_rms.t().mm(x)
    grad_bias = None
    if bias is not None:
        grad_bias = grad_rms.sum(0)
        
    return grad_input, grad_prev, grad_weight, grad_bias
    
class LinearRms(torch.autograd.Function):
    @staticmethod
    # @custom_fwd
    def forward(ctx, x, prev, weight, bias=None, weight_rms=None, eps=None):
        output = linear_fwd_2d(x,weight, bias)
        output = rms_forward(output + prev, weight_rms, eps)
        ctx.save_for_backward(
            x, weight, bias, output + prev, weight_rms
        )
        ctx.eps = eps
        return output

    @staticmethod
    # @custom_bwd
    def backward(ctx, grad_output):
        x, weight, bias, rms_x, weight_rms = ctx.saved_tensors
        eps = ctx.eps
        grad_weight = grad_bias = None

        grad_rms, grad_weight_rms = rms_backward(
            grad_output, rms_x, weight_rms,eps
        )  

        grad_input, grad_prev, grad_weight, grad_bias = linear_bwd_2d(grad_rms,weight,x,bias)
        return grad_input, grad_prev, grad_weight, grad_bias, grad_weight_rms, None

class GELUFunction(torch.autograd.Function):
    @staticmethod
    @custom_fwd(device_type='cuda')
    def forward(ctx, x):
        ctx.save_for_backward(x)
        cdf = 0.5 * (1.0 + torch.erf(x / 2**0.5))
        ctx.save_for_backward(x)
        return x * cdf #0.5 * (1.0 + torch.erf(x / 2**0.5))

    @staticmethod
    @custom_bwd(device_type='cuda')
    def backward(ctx, grad_output):
        x = ctx.saved_tensors
        cdf = 0.5 * (1.0 + torch.erf(x / 2**0.5))
        pdf = torch.exp(-0.5 * x**2) / (2 * 3.14159265359) ** 0.5
        return grad_output * (cdf + x * pdf)


class GELU(torch.nn.Module):
    def forward(self, x):
        return GELUFunction.apply(x)


import torch
from torch.autograd import Function

@torch.compile(fullgraph=True)
def gelu_back_ward(x):
    cdf = 0.5 * (1.0 + torch.erf(x / 2**0.5))
    pdf = torch.exp(-0.5 * x**2) / (2 * 3.14159265359) ** 0.5
    return (cdf + x * pdf)
    
class LinearGeLU(Function):
    @staticmethod
    @custom_fwd(device_type='cuda')
    def forward(ctx, input,prev, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        ctx.save_for_backward(input, weight, bias,output+prev)
        return F.gelu(output+prev)  # output.clamp(min=0)  # ReLU activation
        

    @staticmethod
    @custom_bwd(device_type='cuda')
    def backward(ctx, grad_output):
        input, weight, bias,output_prev = ctx.saved_tensors
        grad_weight = grad_bias = None

        grad_gelu = grad_output*gelu_back_ward(output_prev)

        grad_input = grad_gelu.mm(weight)
        grad_prev = grad_gelu

        grad_weight = grad_gelu.t().mm(input)
        grad_bias = None
        if bias is not None:
            grad_bias = grad_gelu.sum(0)

        return grad_input,grad_prev , grad_weight, grad_bias


class LinearGeLUFused(torch.nn.Module):
    def __init__(self, in_features, out_features):
        super(LinearGeLUFused, self).__init__()
        self.weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features))
        self.bias = torch.nn.Parameter(torch.Tensor(out_features))
        torch.nn.init.xavier_uniform_(self.weight)
        torch.nn.init.zeros_(self.bias)

    def forward(self, input,prev):
        return LinearGeLU.apply(input,prev, self.weight, self.bias)


class LinearRMSFused(torch.nn.Module):

    def __init__(self, in_features, out_features, eps=1e-6):
        super(LinearRMSFused, self).__init__()
        self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias = torch.nn.Parameter(torch.Tensor(out_features))
        torch.nn.init.xavier_uniform_(self.weight)
        torch.nn.init.zeros_(self.bias)
        self.weight_rms = torch.nn.Parameter(torch.ones(out_features))
        self.eps = eps

    def forward(self, input, prev):
        return LinearRms.apply(
            input, prev, self.weight, self.bias, self.weight_rms, self.eps
        )

import torch
from torch.autograd import Function


@torch.compile(fullgraph=True)
def gelu_bwd(x):
    cdf = 0.5 * (1.0 + torch.erf(x / 2**0.5))
    pdf = torch.exp(-0.5 * x**2) / (2 * 3.14159265359) ** 0.5
    return (cdf + x * pdf)

@torch.compile(fullgraph=True)
def gelu_fwd(x):
    cdf = 0.5 * (1.0 + torch.erf(x / 2**0.5))
    return x * cdf
    
@torch.compile(fullgraph=True)
def ffn_gelu_fwd(x,weight1, bias1, weight2, bias2=None):

    output1 = x.mm(weight1.t())

    if bias1 is not None:
        output1 += bias1.unsqueeze(0).expand_as(output1)

    gelu_output = gelu_fwd(output1)  # output.clamp(min=0)  # ReLU activation
    output2 = gelu_output.mm(weight2.t())

    if bias2 is not None:
        output2 += bias2.unsqueeze(0).expand_as(output2)

    return output2,output1
    
@torch.compile(fullgraph=True)
def ffn_gelu_bwd(grad_output,x,weight1, bias1, weight2, bias2,gelu_input):
    grad_bias1 = grad_bias2 = None

    grad_gelu = gelu_bwd(gelu_input)

    grad_x2 = grad_output.mm(weight2) * grad_gelu
    grad_weight2 = grad_output.T.mm(gelu_input)

    if bias2 is not None:
        grad_bias2 = grad_output.sum(0)

    grad_x1 = grad_x2.mm(weight1)
    grad_weight1 = grad_x2.T.mm(x)

    if bias1 is not None:
        grad_bias1 = grad_x2.sum(0)

    return grad_x1, grad_weight1, grad_bias1, grad_weight2, grad_bias2
    
class FFNGeLU(Function):
    @staticmethod
    @custom_fwd(device_type='cuda')
    def forward(ctx, x, weight1, bias1, weight2, bias2=None):
        output2,output1 = ffn_gelu_fwd(x,weight1, bias1, weight2, bias2)
        ctx.save_for_backward(x, weight1, bias1, weight2, bias2, output1)
        return output2

    @staticmethod
    @custom_bwd(device_type='cuda')
    def backward(ctx, grad_output):
        x, weight1, bias1, weight2, bias2, gelu_input = ctx.saved_tensors
        grad_x1, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = ffn_gelu_bwd(grad_output,x,weight1, bias1, weight2, bias2,gelu_input)

        return grad_x1, grad_weight1, grad_bias1, grad_weight2, grad_bias2


class FFNGeluModule(torch.nn.Module):
    def __init__(self, in_features, mid_feature, out_features):
        super(FFNGeluModule, self).__init__()
        self.weight1 = torch.nn.Parameter(
            torch.Tensor(mid_feature, in_features)
        )
        self.bias1 = torch.nn.Parameter(torch.Tensor(mid_feature))

        self.weight2 = torch.nn.Parameter(
            torch.Tensor(out_features, mid_feature)
        )
        self.bias2 = torch.nn.Parameter(torch.Tensor(out_features))

        torch.nn.init.xavier_uniform_(self.weight1)
        torch.nn.init.zeros_(self.bias1)
        torch.nn.init.xavier_uniform_(self.weight2)
        torch.nn.init.zeros_(self.bias2)

    def forward(self, input):
        return FFNGeLU.apply(input, self.weight1, self.bias1, self.weight2, self.bias2)


import torch
from torch.autograd import Function, gradcheck


class MyEmbeddingFunction(Function):
    @staticmethod
    @custom_fwd(device_type='cuda')
    def forward(ctx, weight, indices,pad_index):
        result = weight[indices]
        ctx.save_for_backward(indices, weight)
        ctx.pad_index = pad_index
        return result

    @staticmethod
    @custom_bwd(device_type='cuda')
    def backward(ctx, grad_output):
        indices, weight = ctx.saved_tensors
        pad_index = ctx.pad_index 
        grad_weight = torch.zeros_like(weight)
        #index_add_(dim, index, source)
        
        index = indices.flatten()
        ignore_mask = index!=pad_index
        index = index[ignore_mask]
        grad_output = grad_output.view(-1, grad_output.size(-1))
        grad_output = grad_output[ignore_mask]
        
        grad_weight.index_add_(0, index,grad_output)
        grad_indices = None
        return grad_weight, grad_indices, None

class MyEmbedding(torch.nn.Module):
    def __init__(self, num_embeddings, embedding_dim,pad_index):
        super(MyEmbedding, self).__init__()
        self.weight = torch.nn.Parameter(torch.randn(num_embeddings, embedding_dim))
        self.pad_index = pad_index

    def forward(self, indices):
        
        return MyEmbeddingFunction.apply(self.weight, indices,self.pad_index)

@torch.compile(fullgraph=True)
def attention_forward(Q, K, V, mask=None):
    d_k = Q.shape[-1]
    scale = 1.0 / torch.sqrt(torch.tensor(d_k, dtype=Q.dtype, device=Q.device))  # avoid NumPy
    EPSILON = 1e-10

    Q_scaled = Q * scale
    S = torch.einsum("... i d, ... j d -> ... i j", Q_scaled, K)

    if mask is not None:
        S = S + mask

    softmax = F.softmax(S, dim=-1)
    P_V = torch.einsum("... i j, ... j d -> ... i d", softmax, V)

    return P_V, softmax
    
@torch.compile(fullgraph=True)
def attention_backward(Q, K, V, O, dO, softmax):
    scale = 1.0 / torch.sqrt(torch.tensor(Q.shape[-1], dtype=Q.dtype, device=Q.device))  # avoid NumPy


    P = softmax  # (1 / l) * torch.exp(S - m)
   

    dV = torch.einsum("... r c, ... r d -> ... c d", P, dO)
    dP = torch.einsum("... r d, ... c d -> ... r c", dO, V)

    D = torch.sum(dO * O, dim=-1, keepdim=True)
    dS = P * (dP - D)

    dQ = scale * torch.einsum("... r c, ... c d -> ... r d", dS, K)
    dK = scale * torch.einsum("... r c, ... r d -> ... c d", dS, Q)
    return dQ, dK, dV


class ScaledDotProductAttention(torch.autograd.Function):
    @staticmethod
    @custom_fwd(device_type='cuda')
    def forward(ctx, q, k, v, mask=None):
        out, attn_weights = attention_forward(q,k,v, mask)
        ctx.save_for_backward(q, k, v, mask,out, attn_weights)
        return out

    @staticmethod
    @custom_bwd(device_type='cuda')
    def backward(ctx, do):
        Q, K, V, mask, O, softmax = ctx.saved_tensors
        dq, dk, dv = attention_backward(Q, K, V,O, do, softmax)
        return dq, dk, dv, None


@torch.compile(fullgraph=True)
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

@torch.compile(fullgraph=True)
def rope_fwd(q, k, freqs):
        # q, k: (batch_size, num_heads, seq_len, head_dim)
        # cos, sin: (seq_len, head_dim) or (1, seq_len, head_dim) or (batch_size, num_heads, seq_len, head_dim)
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos()
        sin = emb.sin()
        cos = cos.unsqueeze(1)
        sin = sin.unsqueeze(1)
        # forward Method
        #qleft,qright = q1,q2
        #qleft out = q1*cos-q2*sin
        #qright out = q2*cos+q1*sin
        #final out1 = concat(qleft out, qright out)
        #final out2 = concat(q1,q2)*cos+concat(-q2,q1)*sin
        #final out1 and final out2 both are same
        q_rotated = (q * cos) + (rotate_half(q) * sin)
        k_rotated = (k * cos) + (rotate_half(k) * sin)

        return q_rotated, k_rotated, cos, sin

@torch.compile(fullgraph=True)
def rope_bwd(grad_q_rotated, grad_k_rotated, cos, sin):
        # backward Method
        #qleft,qright = q1,q2
        # y=(x1*cos+ x2*cos)+(−x2*sin + x1*sin)
        # y = x1*(cos+sin)+x2(cos-sin)
        # dy/dx1 = cos+sin
        # dy/dx1 = cos-sin
        #do/dq qleft out = out_grad*(cos+sin)
        #do/dq right out = out_grad*(cos-sin)
        #do/dq final out1 = concat(do/dq qleft, do/dq right out)
        #do/dq final out2 = concat(q1,q2)*cos-concat(-q2,q1)*sin
        #final out1 and final out2 both are same

        grad_q = (grad_q_rotated * cos) - (
            rotate_half(grad_q_rotated) * sin
        )
        grad_k = (grad_k_rotated * cos) - (
           rotate_half(grad_k_rotated) * sin
        )

        #if freq is a parameter  we will need to calculate its grad as well

        # grad_cos_q = torch.sum(grad_q_rotated * q, dim=(0, 1, 2), keepdim=True)
        # grad_sin_q = torch.sum(
        #     grad_q_rotated * RotaryEmbeddingFunction.rotate_half(q),
        #     dim=(0, 1, 2),
        #     keepdim=True,
        # )

        # grad_cos_k = torch.sum(grad_k_rotated * k, dim=(0, 1, 2), keepdim=True)
        # grad_sin_k = torch.sum(
        #     grad_k_rotated * RotaryEmbeddingFunction.rotate_half(k),
        #     dim=(0, 1, 2),
        #     keepdim=True,
        # )

        # grad_cos = grad_cos_q + grad_cos_k
        # grad_sin = grad_sin_q + grad_sin_k

        return grad_q, grad_k
    
class RotaryEmbeddingFunction(torch.autograd.Function):
    @staticmethod
    @custom_fwd(device_type='cuda',cast_inputs=torch.float32)
    def forward(ctx, q, k, freqs):
        q_rotated, k_rotated, cos, sin = rope_fwd(q, k, freqs)

        ctx.save_for_backward(cos, sin)
        return q_rotated, k_rotated

    @staticmethod
    @custom_bwd(device_type='cuda')
    def backward(ctx, grad_q_rotated, grad_k_rotated):
        cos, sin = ctx.saved_tensors
        grad_q, grad_k = rope_bwd(grad_q_rotated, grad_k_rotated, cos, sin)

        return grad_q, grad_k, None


class RotaryEmbedding(nn.Module):
    """
    RotaryEmbedding is a PyTorch module that implements rotary positional embeddings for attention mechanisms.
    Args:
        config (object): Configuration object containing the following attributes:
            hidden_size (int): The hidden size of the model.
            num_attention_heads (int): The number of attention heads.
    Attributes:
        inv_freq (torch.Tensor): A tensor containing the inverse frequencies for the rotary embeddings.
    Methods:
        forward(seq_len):
            Computes the rotary positional embeddings for a given sequence length.
            Args:
                seq_len (int): The length of the input sequence.
            Returns:
                torch.Tensor: A tensor containing the rotary positional embeddings with shape (1, seq_len, dim).
    """

    def __init__(self,config):
        super().__init__()
        dim = int(config.hidden_size // config.num_attention_heads)
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, seq_len):
        t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
        freqs = torch.einsum("i, j -> i j", t, self.inv_freq)

        return freqs[None, :, :]

import torch
import torch.nn as nn
from einops import rearrange, reduce, repeat
from typing import Optional, Tuple

        
class EncoderAttention(nn.Module):
    def __init__(self, config, layer_idx: int) -> None:
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )
        self.head_size = int(config.hidden_size // config.num_attention_heads)
        self.attention_bias = getattr(config, "attention_bias", True)
        self.layer_idx = layer_idx
        # self.qkv = nn.Linear(config.hidden_size,3*config.hidden_size)
        self.q = MyLinear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.k = MyLinear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.v = MyLinear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.out =MyLinear(
            config.hidden_size, config.hidden_size, bias=self.attention_bias
        )
        self.out =  LinearRMSFused(config.hidden_size,config.hidden_size) 
        self.num_attention_heads = config.num_attention_heads
        self.dot_attn = ScaledDotProductAttention.apply
        self.apply_rotary_pos_emb =  RotaryEmbeddingFunction.apply

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: torch.Tensor,
        freqs: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        q = self.q(hidden_state)
        k = self.k(hidden_state)
        v = self.v(hidden_state)
        # q,k,v = self.qkv(hidden_state).chunk(3, dim = -1) #b X l X d dim =-1 or 2
        # place holder for RoPe operation
        q = rearrange(q, "b l (h d) -> b h l d", h=self.num_attention_heads)
        k = rearrange(k, "b l (h d) -> b h l d", h=self.num_attention_heads)
        v = rearrange(v, "b l (h d) -> b h l d", h=self.num_attention_heads)
        if freqs is not None:
            q, k = self.apply_rotary_pos_emb(q, k, freqs)

        # out = torch.nn.functional.scaled_dot_product_attention(
        #     query=q, key=k, value=v, attn_mask=attention_mask, is_causal=False
        # )
        out = self.dot_attn(q,k,v,attention_mask)
        out = rearrange(out, "b h l d -> b l (h d)")
        
        b,l,d = out.size()
        
        out = out.view(-1,d).contiguous()
        hidden_state = hidden_state.view(-1,d).contiguous()
        
        out = self.out(out,hidden_state)
        
        return out.view(b,l,d).contiguous()

import torch
import torch.nn as nn
from typing import Optional, Tuple

from dataclasses import dataclass



@dataclass
class EncoderOutput(object):
    logits: torch.Tensor


@dataclass
class MLMOutput(object):
    hidden_state: torch.Tensor
    logits: torch.Tensor


class EncoderLayer(nn.Module):
    def __init__(self, config, layer_idx: int, attention_type: str = None) -> None:
        super().__init__()
        self.attention = (
         
            EncoderAttention(config, layer_idx=layer_idx)
        )
        # FeedForward(config) #F
        # self.feed_forward = DoubleLinearModule(config.hidden_size, 4*config.hidden_size, config.hidden_size) #FeedForward(config)
        self.feed_forward = FFNGeluModule(config.hidden_size, 4*config.hidden_size, config.hidden_size)
        self.layer_idx = layer_idx
        self.layernorm = RMSNorm(
            config.hidden_size, eps=getattr(config, "layer_norm_eps", 1e-6)
        )

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: torch.Tensor,
        freqs: torch.Tensor = None,
    ) -> torch.Tensor:
        out = self.attention(
            hidden_state=hidden_state, attention_mask=attention_mask, freqs=freqs
        )
        b,l,d = out.size()
        out1 = out.view(-1,d).contiguous()
        out1 = self.feed_forward(out1)
        out1 = out1.view(b,l,d).contiguous() #out # 
        return self.layernorm(out1 + out)


class EncoderModel(nn.Module):

    def __init__(
        self,
        config,
        pos_embedding_type: Optional[str] = "rope",
        attention_type: str = None,
    ) -> None:
        super().__init__()
        self.word_embeddings = MyEmbedding(config.vocab_size, config.hidden_size,config.pad_token_id)
     
        self.emb_freq = RotaryEmbedding(config)(config.max_position_embeddings)
        
        self.all_layer = nn.ModuleList(
            [
                EncoderLayer(config, layer_idx, attention_type)
                for layer_idx in range(config.num_hidden_layers)
            ]
        )

    def forward(
        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
    ) -> torch.Tensor:
        bsz, seqlen = input_ids.shape
        hidden_state = self.word_embeddings(input_ids)
        freqs = self.emb_freq[:, :seqlen].to(input_ids.device)

        attention_mask = attention_mask.unsqueeze(1).unsqueeze(2).type_as(hidden_state)
        attention_mask = (1.0 - attention_mask) * torch.finfo(hidden_state.dtype).min

        for layer in self.all_layer:
            hidden_state = layer(hidden_state, attention_mask, freqs)
        return EncoderOutput(hidden_state)

    @classmethod
    def from_config(
        cls,
        config,
        pos_embedding_type: Optional[str] = "absolute",
        attention_type: str = None,
    ) -> nn.Module:
        return cls(config, pos_embedding_type, attention_type)


def chunked_cross_entropy(logits,targets,chunk_size = 32,ignore_index=-100):
    targets = targets.view(-1)
    
    logit_chunks = logits.split(chunk_size)
    target_chunks = targets.split(chunk_size)
    loss_chunks = [
        torch.nn.functional.cross_entropy(
            logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none"
        )
        for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
    ]
    non_masked_elems = (targets != ignore_index).sum()
    loss = torch.cat(loss_chunks).sum() / non_masked_elems.maximum(
        torch.ones_like(non_masked_elems)
    )
    return loss
    
@torch.compile
def linear_cross_entropy_fwd(inputs, weight, bias, targets, ignore_index=-100):
    logits = F.linear(inputs, weight, bias)
    shape = logits.size()
    logits_flat = logits.view(-1, shape[-1])
    targets_flat = targets.view(-1)
    valid_mask = targets_flat != ignore_index
    valid_logits = logits_flat[valid_mask]
    valid_targets = targets_flat[valid_mask]

    softmax = F.softmax(valid_logits, dim=-1)
    log_probs = torch.log(softmax + 1e-12)  # Numerical stability
    target_log_probs = log_probs[torch.arange(valid_logits.size(0)), valid_targets]
    loss = -target_log_probs.mean()
    return loss, softmax, valid_targets, valid_mask, logits_flat,shape

@torch.compile
def linear_cross_entropy_bwd(grad_outputs,inputs, weight, bias, softmax, valid_targets, valid_mask, logits_flat,shape):
    grad_logits = torch.zeros_like(logits_flat)
    valid_grad_logits = softmax.clone()
    valid_grad_logits[torch.arange(valid_grad_logits.size(0)), valid_targets] -= 1
    valid_grad_logits /= valid_grad_logits.size(0)  # Normalize by batch size

    grad_logits[valid_mask] = valid_grad_logits
    grad_logits = grad_logits.view(*shape)
    grad_input = grad_weight = grad_bias = None

    grad_loss = grad_logits * grad_outputs

    grad_input = grad_loss.matmul(weight)
    grad_weight = grad_loss.transpose(-2, -1).matmul(inputs)
    grad_bias = grad_loss.sum(dim=0)

    # Return all grads corresponding to forward inputs: inputs, weight, bias, targets, ignore_index
    return grad_input, grad_weight, grad_bias
    
class LinearCrossEntropyIgnoreIndex(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs, weight, bias, targets, ignore_index=-100):
        loss, softmax, valid_targets, valid_mask, logits_flat,shape = linear_cross_entropy_fwd(inputs, weight, bias, targets, ignore_index=-100)

        ctx.save_for_backward(inputs, weight, bias, softmax, valid_targets, valid_mask, logits_flat)
        ctx.shape = shape
        ctx.ignore_index = ignore_index
        return loss

    @staticmethod
    def backward(ctx, grad_outputs):
        inputs, weight, bias, softmax, valid_targets, valid_mask, logits_flat = ctx.saved_tensors
        grad_input, grad_weight, grad_bias = linear_cross_entropy_bwd(grad_outputs,inputs, weight, bias, softmax, valid_targets, valid_mask, logits_flat, ctx.shape)
        # Return all grads corresponding to forward inputs: inputs, weight, bias, targets, ignore_index
        return grad_input, grad_weight, grad_bias, None, None


class MyLinearCrossEntropy(torch.nn.Module):
    def __init__(self, in_features, out_features, ignore_index=-100):
        super(MyLinearCrossEntropy, self).__init__()
        self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias = torch.nn.Parameter(torch.Tensor(out_features))
        self.ignore_index = ignore_index

        torch.nn.init.xavier_uniform_(self.weight)
        torch.nn.init.zeros_(self.bias)

    def forward(self, x, target=None):
        if target is not None:
            return LinearCrossEntropyIgnoreIndex.apply(x, self.weight, self.bias, target, self.ignore_index)
        else:
            if x.dim() != 2:
                B, l, d = x.shape
                x = x.view(-1, d).contiguous()
            output = x.mm(self.weight.t())
            # if bias is not None:
            output += self.bias.unsqueeze(0).expand_as(output)
            
            return output

class callback:
    def __init__(self):
        self.loss = list()
        self.model = list()
    
    def put(self, model, loss):
        self.loss.append(loss)
        self.model.append(model)

    def get_model(self):
        ind = np.argmin(self.loss)
        return self.model[ind]

class ClinicModel(nn.Module):
    def __init__(self, config,n_classes):
        super(ClinicModel, self).__init__()
        model = EncoderModel(config,pos_embedding_type='rope')
        self.model = model
        self.output = MyLinearCrossEntropy(768, n_classes)

    def forward(self, ids, mask,labels=None):
        sequence_output = self.model(ids, mask).logits[:, 0, :]
#         sequence_output = sequence_output[:, 0, :]
        logits = self.output(sequence_output,labels)
        return logits


class ClinicDataset(Dataset):
    def __init__(self, data,is_test=False):
        self.X = data['Text'].values
        self.Y = data['Target'].values
        self.is_test = is_test
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.MAX_LEN = 128
        
    def __getitem__(self, idx):
        inputs = self.tokenizer.encode_plus(self.X[idx],
            add_special_tokens=True,
            truncation=True,
            max_length=self.MAX_LEN
        )['input_ids'] 

        if not self.is_test:
            target_value = self.Y[idx]
      
        mask = [1]*len(inputs) + [0] * (self.MAX_LEN - len(inputs)) 
        mask = torch.tensor(mask, dtype=torch.long)
        
        if len(inputs) != self.MAX_LEN:
            inputs = inputs + [self.tokenizer.pad_token_id] * (self.MAX_LEN - len(inputs)) 
        ids = torch.tensor(inputs, dtype=torch.long)
        
        
        
        
        if self.is_test:
            return {
                'ids': ids,
                'mask': mask,
            }
        
        else:
            targets = torch.FloatTensor(target_value)
            return {
                'ids': ids,
                'mask': mask,
                'targets': targets
            }
        
    def __len__(self):
        return len(self.Y)



class ClinicDatasetV2(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return {'ids':item.get('input_ids'),'mask':item.get('attention_mask'),'labels':item.get('labels')}

    def __len__(self):
        return len(self.labels)



def valid_func(model, val_loader, accelerator):
    model.eval()
    loss_fn = torch.nn.CrossEntropyLoss()
    PROB = []
    TARGETS = []
    losses = []
    PREDS = []

    for batch_idx, data in enumerate(val_loader):
        input_ids = data["ids"]
        input_masks = data["mask"]
        targets = data["labels"].long().view(-1)
        with torch.no_grad():
            logits = model(input_ids, input_masks)
            
        logits, targets = accelerator.gather_for_metrics((logits, targets))

        PREDS += [torch.argmax(logits, 1).detach().cpu()]
        TARGETS += [targets.detach().cpu()]

        loss = loss_fn(logits, targets)
        losses.append(loss.item())

    PREDS = torch.cat(PREDS).cpu().numpy()
    TARGETS = torch.cat(TARGETS).cpu().numpy()
    accuracy = (PREDS == TARGETS).mean()

    loss_valid = np.mean(losses)
    return loss_valid, accuracy
@dataclass
class Config:
    hidden_size: int = 768
    num_attention_heads: int = 12
    num_key_value_heads: int = 4
    max_position_embeddings: int = 514
    num_hidden_layers: int = 4
    vocab_size: int = 50265
    hidden_dropout_prob: float = 0.1
    initializer_range: float = 0.02
    intermediate_size: int = 3072
    layer_norm_eps: float = 1e-05
    pad_token_id: int = 1
    hidden_act: str = "gelu"
    
def main():
    ddp_kwargs = DistributedDataParallelKwargs() #find_unused_parameters=True
    accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
    
    model_ckpt = "roberta-base"
    tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
    
    device = accelerator.device
    is_main_process = accelerator.is_main_process
    
    if is_main_process:
        print(f"Using device: {device}")
        print(f"Number of processes: {accelerator.num_processes}")
        print(f"Distributed type: {accelerator.distributed_type}")
    
    epochs = 3
    learning_rate = 1e-4
    weight_decay = 0.0
    BATCH_SIZE = 32
    
    base_dir = "path"
    save_dir = os.path.join(base_dir, "checkpoints")
    samples_dir = os.path.join(base_dir, "samples")
    
    # main process is from accelerator
    if is_main_process:
        os.makedirs(save_dir, exist_ok=True)
        os.makedirs(samples_dir, exist_ok=True)
        print(f"Saving checkpoints to: {save_dir}")
        print(f"Saving samples to: {samples_dir}")

    data_path = "../input/data-for-distilation"
    train = pd.read_csv("../input/data-for-distilation/Clinc_Train.csv")
    valid = pd.read_csv("../input/data-for-distilation/Clinc_valid.csv")
    n_classes = np.unique(train.Target).shape[0]
    
    if is_main_process:
        print("Initializing Encoder Model...")
    config = Config()
    config.num_hidden_layers = 12
    model = ClinicModel(config,n_classes)
    
    if is_main_process:
        print("Loading data...")

    train_texts = train['Text'].values.tolist()
    val_texts = valid['Text'].values.tolist()
    train_labels = train['Target'].values.tolist()
    val_labels = valid['Target'].values.tolist()
    
    train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=128)
    val_encodings = tokenizer(val_texts, truncation=True, padding=True, max_length=128)
    
    train_loader = torch.utils.data.DataLoader(ClinicDatasetV2(train_encodings, train_labels),batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = torch.utils.data.DataLoader(ClinicDatasetV2(val_encodings, val_labels),batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate,weight_decay=weight_decay)
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs * len(train_loader))
    
    model, optimizer, train_loader,val_loader, scheduler = accelerator.prepare(model, optimizer, train_loader,val_loader, scheduler)
    
    
    if is_main_process:
        print("Starting training...")
    global_step = 0
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", disable=not is_main_process)
        for data in progress_bar:
            input_ids = data['ids']
            input_masks = data['mask']
            targets = data['labels'].long().view(-1)
          
            
            with accelerator.accumulate(model):
                loss = model(input_ids,input_masks,targets)
                accelerator.backward(loss)
                
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
                
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            
            loss_value = loss.detach().float().item()
            epoch_loss += loss_value
            
            if is_main_process:
                progress_bar.set_postfix({"loss": loss_value, "lr": optimizer.param_groups[0]['lr']})
            
            global_step += 1

        gathered_epoch_loss = accelerator.gather(torch.tensor(epoch_loss, device=device)).sum().item()
        avg_loss = gathered_epoch_loss / (len(train_loader) * accelerator.num_processes)
        vloss, vaccuracy = valid_func(model, val_loader,accelerator)
        
        
        if is_main_process:
            accelerator.print(f"Epoch {epoch+1} : loss = {avg_loss}, accuracy =  {vaccuracy}")
        
        if is_main_process and ((epoch + 1) % 5 == 0 or epoch == epochs - 1):
            unwrapped_model = accelerator.unwrap_model(model)
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': unwrapped_model.state_dict(),
                'loss': avg_loss,
            }, 
            f"{save_dir}/stable_dit_model_epoch_{epoch+1}.pt")
            accelerator.print(f"Checkpoint saved at epoch {epoch+1}")
        
        accelerator.wait_for_everyone()
    
    if is_main_process:
        accelerator.end_training()
        print("Training completed!")

if __name__ == "__main__":
    main()

Overwriting fused_encoder_train.py


In [9]:
! accelerate launch --num_processes=2 ../working/fused_encoder_train.py

Using device: cuda:0
Number of processes: 2
Distributed type: MULTI_GPU
Saving checkpoints to: path/checkpoints
Saving samples to: path/samples
Initializing Encoder Model...
Loading data...
Starting training...
Epoch 1/3:   0%|                                        | 0/239 [00:00<?, ?it/s][rank1]:W0920 15:05:36.368000 1336 torch/fx/experimental/symbolic_shapes.py:5124] [6/0] failed during evaluate_expr(Eq(151*u0, 0), hint=None, size_oblivious=False, forcing_spec=False
[rank1]:E0920 15:05:36.369000 1336 torch/fx/experimental/recording.py:298] [6/0] failed while running evaluate_expr(*(Eq(151*u0, 0), None), **{'fx_node': False})
[rank0]:W0920 15:05:36.375000 1335 torch/fx/experimental/symbolic_shapes.py:5124] [6/0] failed during evaluate_expr(Eq(151*u0, 0), hint=None, size_oblivious=False, forcing_spec=False
[rank0]:E0920 15:05:36.375000 1335 torch/fx/experimental/recording.py:298] [6/0] failed while running evaluate_expr(*(Eq(151*u0, 0), None), **{'fx_node': False})
Epoch 1/3: 100%|███

In [1]:
 !pip install -q -U triton --no-index --find-links ../input/triton-3-0-0/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl