# Install Requirements

# Test

In [56]:
import torch
import torch.nn as nn
from transformers import AutoConfig, ElectraForMaskedLM, AutoTokenizer


torch.manual_seed(42)

generator_tokenizer = AutoTokenizer.from_pretrained('google/electra-base-generator')
generator_config = AutoConfig.from_pretrained('google/electra-base-generator')


discriminator_tokenizer = AutoTokenizer.from_pretrained('google/electra-base-discriminator')
discriminator_config = AutoConfig.from_pretrained('google/electra-base-discriminator')
text = "Hello, this is [MASK] speaking, how can I help you?"

inputs = generator_tokenizer(text, return_tensors='pt', padding='max_length', max_length=512, truncation=True)
random_input_ids = inputs['input_ids']
random_attention_mask = inputs['attention_mask']
random_token_type_ids = inputs['token_type_ids']



In [57]:

import torch
import torch.nn.functional as F


# reference: efficient-kan by @Blealtan
# CODE: https://github.com/Blealtan/efficient-kan


class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features)
        )
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.base_weight, gain=self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(
                        self.grid_size + 1, self.in_features, self.out_features
                    )
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (
                    self.scale_spline
                    if not self.enable_standalone_scale_spline
                    else 1.0
                )
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                torch.nn.init.constant_(self.spline_scaler, self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape \
                (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid  # type: ignore
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given
        points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape \
                (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape \
                (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        return base_output + spline_output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(
            splines, orig_coeff
        )  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0,
                batch - 1,
                self.grid_size + 1,
                dtype=torch.int64,
                device=x.device,
            )
        ]

        uniform_step = (
            x_sorted[-1] - x_sorted[0] + 2 * margin
        ) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = (
            self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        )
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(
                    self.spline_order, 0, -1, device=x.device
                ).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(
                    1, self.spline_order + 1, device=x.device
                ).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)  # type: ignore
        self.spline_weight.data.copy_(
            self.curve2coeff(x, unreduced_spline_output)
        )

    def regularization_loss(
        self, regularize_activation=1.0, regularize_entropy=1.0
    ):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as
        stated in the paper, since the original one requires computing
        absolutes and entropy from the expanded
        (batch, in_features, out_features) intermediate tensor, which is
        hidden behind the F.linear function if we want an memory
        efficient implementation.

        The L1 regularization is now computed as mean absolute value of the
        spline weights. The authors implementation also includes this term
        in addition to the sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


class KAN(torch.nn.Module):
    def __init__(
        self,
        width,
        grid=3,
        k=3,
        noise_scale=0.1,
        noise_scale_base=1.0,
        scale_spline=1.0,
        base_fun=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
        bias_trainable=True,
    ):
        super(KAN, self).__init__()
        self.grid_size = grid
        self.spline_order = k
        self.bias_trainable = bias_trainable  # TODO

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(width, width[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid,
                    spline_order=grid,
                    scale_noise=noise_scale,
                    scale_base=noise_scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_fun,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False):
        B, C, T = x.shape

        x = x.view(-1, T)

        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)

        U = x.shape[1]

        x = x.view(B, C, U)

        return x

    def regularization_loss(
        self, regularize_activation=1.0, regularize_entropy=1.0
    ):
        return sum(
            layer.regularization_loss(
                regularize_activation, regularize_entropy
            )
            for layer in self.layers
        )

In [58]:
import math
from typing import *
import torch
from torch import (
    nn, 
    Tensor, 
    FloatTensor, 
    LongTensor
)
import torch.nn.functional as F


class ElectraGenerator(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int,
        vocab_type_size: int,
        embedding_dropout_p: float,
        hidden_dim: int,
        num_heads: int,
        ff_dim: int,
        num_layers: int,
        max_pos_embedding: int
    ):
        super().__init__()
        self.embedding = InputEmbedding(
            vocab_size,
            embedding_dim,
            vocab_type_size,
            embedding_dropout_p,
            max_pos_embedding
        )
        self.encoder = ElectraEncoder(
            hidden_dim,
            num_heads,
            num_layers,
            0.1,
            ff_dim
        )
        self.generator = GeneratorOutput(hidden_dim, vocab_size)
        
    def forward(
        self, 
        input_ids: LongTensor, 
        attention_mask: LongTensor,
        token_type_ids: LongTensor,
    ) -> Tensor:
        embeddings = self.embedding(input_ids, token_type_ids)
        seq_out = self.encoder(embeddings, attention_mask)
        dropouted_seq_output = F.dropout(seq_out, p=0.1)
        return self.generator(dropouted_seq_output)
    

class GeneratorOutput(nn.Module):
    def __init__(self, hidden, vocab_size) :
        super().__init__()
        self.linear = nn.Linear(hidden, vocab_size)
        self.softmax = nn.LogSoftmax(dim = -1)
        
    def forward(self, x) :
        return self.softmax(self.linear(x))
    
    
class ElectraDiscriminator(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int,
        vocab_type_size: int,
        embedding_dropout_p: float,
        hidden_dim: int,
        num_heads: int,
        ff_dim: int,
        num_layers: int,
        max_pos_embedding: int,
        num_labels: int
    ):
        super().__init__()
        self.embedding = InputEmbedding(
            vocab_size,
            embedding_dim,
            vocab_type_size,
            embedding_dropout_p,
            max_pos_embedding
        )
        self.encoder = ElectraEncoder(
            hidden_dim,
            num_heads,
            num_layers,
            0.1,
            ff_dim
        )
        self.classifier = KAN(width=[hidden_dim, num_labels])
        
    def forward(
        self, 
        input_ids: LongTensor, 
        attention_mask: LongTensor,
        token_type_ids: LongTensor,
    ) -> Tensor:
        embeddings = self.embedding(input_ids, token_type_ids)
        seq_out = self.encoder(embeddings, attention_mask)
        dropouted_seq_output = F.dropout(seq_out, p=0.1)
        return self.classifier(dropouted_seq_output)
    

class ElectraEncoder(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        num_layers: int,
        dropout_p: float = 0.1,
        hidden_dim: Optional[int] = None,
    ):
        super().__init__()
        if not hidden_dim:
            hidden_dim = dim * 4 # default hidden_dim on paper
        self.layers = nn.ModuleList([
            EncoderLayer(dim, num_heads, hidden_dim, dropout_p) for i in range(num_layers)
        ])
        
    def forward(
        self,
        hidden_states: Tensor,
        mask: Tensor
    ) -> Tensor:
        for layer in self.layers:
            hidden_states = layer(hidden_states, mask)
        return hidden_states

    
class InputEmbedding(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int,
        vocab_type_size: int,
        embedding_dropout_p: float,
        max_pos_embedding: int
        ):
       super().__init__()
       self.embedding = nn.Embedding(vocab_size, embedding_dim)
       self.positional_embedding = nn.Embedding(max_pos_embedding, embedding_dim)
       self.token_type_embedding = nn.Embedding(vocab_type_size, embedding_dim)
       self.dropout = nn.Dropout(embedding_dropout_p)
   
    def forward(
        self, 
        input_ids: LongTensor, 
        token_type_ids: LongTensor,
    ) -> Tensor:
        seq_length = input_ids.shape[1]
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        
        embeddings = (
            self.embedding(input_ids) +
            self.positional_embedding(position_ids) +
            self.token_type_embedding(token_type_ids)
        )
        return self.dropout(embeddings)


class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout_p: float):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(
        self, 
        query: Tensor, 
        key: Tensor,
        value: Tensor,
        attention_mask: LongTensor
    ) -> Tensor:
        multiplied_kv = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(key.shape[-1])
        masked_attention = multiplied_kv.masked_fill(attention_mask == 0, -1e9)
        attention = self.softmax(masked_attention)
        return torch.matmul(attention, value)
        
        
class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        dropout_p: float
    ):
        super().__init__()
        assert dim % num_heads == 0
        self.attention = ScaledDotProductAttention(dropout_p)
        self.dropout = nn.Dropout(dropout_p)
        self.fc_q = KAN(width=[dim, dim])
        self.fc_k = KAN(width=[dim, dim])
        self.fc_v = KAN(width=[dim, dim])
        self.fc_out = KAN(width=[dim, dim])
        self.num_heads = num_heads 
        self.dim = dim
               
    def forward(
        self, 
        query: Tensor, 
        key: Tensor,
        value: Tensor,
        attention_mask: LongTensor
    ) -> Tensor:
        batch_size = query.size(0)
        query = self.fc_q(query).view(batch_size, -1, self.num_heads, query.size(-1) // self.num_heads).transpose(1, 2)
        key = self.fc_k(key).view(batch_size, -1, self.num_heads, key.size(-1) // self.num_heads).transpose(1, 2)
        value = self.fc_v(value).view(batch_size, -1, self.num_heads, value.size(-1) // self.num_heads).transpose(1, 2)
        attention_output = self.attention(query, key, value, attention_mask)
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * (self.dim // self.num_heads))
        output = self.fc_out(attention_output)
        return self.dropout(output)
 

class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        ff_dim: int,
        dropout_p: float
    ):
        super().__init__()
        self.fc1 = nn.Linear(dim, ff_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(ff_dim, dim)
        self.dropout = nn.Dropout(dropout_p)
        
    def forward(
        self, 
        x: Tensor
    ) -> Tensor:
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x
    

class EncoderLayer(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        hidden_dim: int,
        dropout_p: float
    ):
        super().__init__()
        self.attn = MultiHeadAttention(dim, num_heads, dropout_p)
        self.ff = FeedForward(dim, hidden_dim, dropout_p)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout_p)
        
    def forward(
        self, 
        x: Tensor, 
        attention_mask: LongTensor
    ) -> Tensor:
        attention_output = self.attn(x, x, x, attention_mask)
        add_norm = self.norm1(x + attention_output)
        output = self.ff(attention_output)
        ff_add_norm = self.norm2(add_norm + output)
        return self.dropout(ff_add_norm)

In [59]:
model = ElectraGenerator(
    vocab_size=len(generator_tokenizer.vocab),
    embedding_dim=768,
    vocab_type_size=2,
    embedding_dropout_p=0.1,
    hidden_dim=768,
    num_heads=12,
    ff_dim=3072,
    num_layers=12,
    max_pos_embedding=512
)

In [60]:
output = model(
    random_input_ids,
    random_attention_mask,
    random_token_type_ids
)
output = torch.argmax(output, axis=2).squeeze().tolist()

In [61]:
output

[5973,
 21371,
 17665,
 2757,
 2209,
 19749,
 7959,
 17917,
 12997,
 4322,
 29952,
 21913,
 22973,
 30240,
 2771,
 16439,
 5478,
 21305,
 21371,
 987,
 22077,
 27747,
 9538,
 17917,
 26847,
 17917,
 13317,
 9410,
 29989,
 22268,
 15036,
 11362,
 21371,
 2548,
 21371,
 7068,
 10746,
 26194,
 9489,
 17919,
 22566,
 2757,
 25107,
 6098,
 21371,
 2757,
 26566,
 21371,
 2757,
 2222,
 2757,
 12319,
 29551,
 12118,
 13117,
 16683,
 27747,
 22547,
 21371,
 29643,
 5344,
 2757,
 26401,
 15036,
 7418,
 18631,
 8118,
 18454,
 29643,
 17917,
 17917,
 3785,
 20191,
 2757,
 16490,
 22056,
 21371,
 7788,
 26315,
 23795,
 2499,
 13394,
 23429,
 2929,
 21371,
 7418,
 7760,
 23149,
 2757,
 2757,
 21371,
 16852,
 4497,
 2757,
 23514,
 2757,
 29230,
 29230,
 12118,
 8714,
 7965,
 23821,
 14089,
 12802,
 23643,
 5149,
 21371,
 20403,
 24995,
 22095,
 5310,
 2222,
 15036,
 17256,
 11355,
 21371,
 17917,
 29230,
 5051,
 4451,
 29230,
 2209,
 17917,
 17421,
 10746,
 20132,
 2757,
 17421,
 4036,
 4497,
 7418,


In [67]:
import torch.nn as nn
import torch.nn.init as init

def initialize_weights(module):
    if isinstance(module, nn.Linear):
        init.xavier_uniform_(module.weight)
        if module.bias is not None:
            init.zeros_(module.bias)
    elif isinstance(module, nn.Conv2d):
        init.kaiming_uniform_(module.weight, nonlinearity='relu')
        if module.bias is not None:
            init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        init.normal_(module.weight, mean=0, std=0.01)
    elif isinstance(module, nn.LayerNorm):
        init.ones_(module.weight)
        init.zeros_(module.bias)


In [69]:
from transformers import AutoModelForMaskedLM

model = AutoModelForMaskedLM.from_pretrained('google/electra-base-generator')
model.apply(initialize_weights)
output_orig = model(input_ids=random_input_ids, attention_mask=random_attention_mask, token_type_ids=random_token_type_ids)

In [70]:
output_orig.logits.shape

torch.Size([1, 512, 30522])

In [71]:
torch.argmax(output_orig.logits, dim=2)

tensor([[ 4367,  4367,  4367,  4367,  4367,  4367,  4367,  4367,  4367,  4367,
          4367,  4367,  4367,  4367,  4367, 14611, 14611, 14611,  4367, 14611,
         14611,  4367,  4367, 14611,  4367, 14611,  4367,  4367,  4367,  4367,
          4367, 14611,  4367, 14611,  4367,  4367,  4367,  4367, 14611,  4367,
          4367, 14611,  4367,  4367,  4367,  4367, 14611,  4367,  4367,  4367,
          4367,  4367,  4367, 14611, 14611, 14611,  4367,  4367, 14611,  4367,
          4367, 14611,  4367, 14611,  4367, 14611,  4367,  4367,  4367, 14611,
          4367,  4367,  4367,  4367,  4367, 14611, 14611, 14611,  4367,  4367,
         14611,  4367,  4367,  4367,  4367, 14611, 14611,  4367, 14611, 14611,
          4367,  4367, 14611, 14611,  4367,  4367,  4367,  4367,  4367, 14611,
         14611, 14611,  4367, 14611,  4367,  4367, 14611, 14611,  4367,  4367,
         14611, 14611,  4367, 14611,  4367,  4367,  4367,  4367, 14611,  4367,
         14611,  4367,  4367,  4367,  4367, 14611, 1

In [5]:
output = model(random_input_ids, random_attention_mask, random_token_type_ids)

In [73]:
generator_tokenizer.convert_ids_to_tokens(torch.argmax(output, axis=2).squeeze().tolist())

['motion',
 'motion',
 'motion',
 'motion',
 'motion',
 'motion',
 'motion',
 'motion',
 'motion',
 'motion',
 'motion',
 'motion',
 'motion',
 'motion',
 'motion',
 'infrared',
 'infrared',
 'infrared',
 'motion',
 'infrared',
 'infrared',
 'motion',
 'motion',
 'infrared',
 'motion',
 'infrared',
 'motion',
 'motion',
 'motion',
 'motion',
 'motion',
 'infrared',
 'motion',
 'infrared',
 'motion',
 'motion',
 'motion',
 'motion',
 'infrared',
 'motion',
 'motion',
 'infrared',
 'motion',
 'motion',
 'motion',
 'motion',
 'infrared',
 'motion',
 'motion',
 'motion',
 'motion',
 'motion',
 'motion',
 'infrared',
 'infrared',
 'infrared',
 'motion',
 'motion',
 'infrared',
 'motion',
 'motion',
 'infrared',
 'motion',
 'infrared',
 'motion',
 'infrared',
 'motion',
 'motion',
 'motion',
 'infrared',
 'motion',
 'motion',
 'motion',
 'motion',
 'motion',
 'infrared',
 'infrared',
 'infrared',
 'motion',
 'motion',
 'infrared',
 'motion',
 'motion',
 'motion',
 'motion',
 'infrared',
 'in

In [74]:
generator_tokenizer.convert_ids_to_tokens(torch.argmax(output, axis=2).squeeze().tolist())

TypeError: argmax(): argument 'input' (position 1) must be Tensor, not list