# Install Requirements

In [1]:
!pip install \
    torch==2.3.0 \
    transformers \
    ruamel.yaml \
    lightning

!pip install git+https://github.com/KindXiaoming/pykan.git

Collecting git+https://github.com/KindXiaoming/pykan.git
  Cloning https://github.com/KindXiaoming/pykan.git to /tmp/pip-req-build-e6kgh5p3
  Running command git clone --filter=blob:none --quiet https://github.com/KindXiaoming/pykan.git /tmp/pip-req-build-e6kgh5p3
  Resolved https://github.com/KindXiaoming/pykan.git to commit e6078bc88018d7c3289e32d010df04155ffb84be
  Preparing metadata (setup.py) ... [?25ldone
[?25h

# Modeling

In [1]:
from typing import *
from kan import KAN, KANLayer
from transformers import (
    ElectraModel, 
    AutoTokenizer, 
    PretrainedConfig,
    ElectraForMaskedLM,
    ElectraForPreTraining
)
from transformers.models.electra.modeling_electra import (
    ElectraSelfAttention, 
    ElectraIntermediate, 
    ElectraOutput, 
    ElectraAttention,
    ElectraLayer,
    ElectraGeneratorPredictions,
    ElectraDiscriminatorPredictions,
    ElectraEncoder,
    ElectraSelfOutput
)


class ElectraModelWithKANs(ElectraModel):
    def __init__(self, config: PretrainedConfig) -> None:
        super().__init__(config)
        if config.embedding_size != config.hidden_size:
            self.embedding_project = KANLayer(config.embedding_size, config.hidden_size)
        self.encoder = ElectraKANEncoder(config)
        self.post_init()


class ElectraKANEncoder(ElectraEncoder):
    def __init__(self, config: PretrainedConfig):
        super().__init__(config)
        self.layer = nn.ModuleList([ElectraKANLayer(config) for _ in range(config.num_hidden_layers)])


class ElectraKANSelfAttention(ElectraSelfAttention):
    def __init__(self, config: PretrainedConfig, position_embedding_type: Optional[str] = None) -> None:
        super().__init__(config, position_embedding_type)
        self.query = KANLayer(config.hidden_size, self.all_head_size)
        self.key = KANLayer(config.hidden_size, self.all_head_size)
        self.value = KANLayer(config.hidden_size, self.all_head_size)


class ElectraKANSelfOutput(ElectraSelfOutput):
    def __init__(self, config: PretrainedConfig):
        super().__init__(config)
        self.dense = KANLayer(config.hidden_size, config.hidden_size)


class ElectraKANIntermediate(ElectraIntermediate):
    def __init__(self, config: PretrainedConfig) -> None:
        super().__init__(config)
        self.dense = KANLayer(config.hidden_size, config.intermediate_size)


class ElectraKANOutput(ElectraOutput):
    def __init__(self, config: PretrainedConfig) -> None:
        super().__init__(config)
        self.dense = KANLayer(config.hidden_size, config.hidden_size)


class ElectraKANAttention(ElectraAttention):
    def __init__(self, config: PretrainedConfig, position_embedding_type: Optional[str] = None) -> None:
        super().__init__(config)
        self.self = ElectraKANSelfAttention(config, position_embedding_type=position_embedding_type)
        self.output = ElectraKANSelfOutput(config)
        self.pruned_heads = set()


class ElectraKANLayer(ElectraLayer):
    def __init__(self, config: PretrainedConfig) -> None:
        super().__init__(config)
        self.attention = ElectraKANAttention(config)
        if self.add_cross_attention:
            if not self.is_decoder:
                raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
            self.crossattention = ElectraKANAttention(config, position_embedding_type="absolute")
        self.intermediate = ElectraKANIntermediate(config)
        self.output = ElectraKANOutput(config)


class ElectraKANGeneratorPrediction(ElectraGeneratorPredictions):
    def __init__(self, config: PretrainedConfig):
        super().__init__(config)
        self.dense = KANLayer(config.hidden_size, config.embedding_size)
        

class ElectraKANDiscriminatorPrediction(ElectraDiscriminatorPredictions):
    def __init__(self, config: PretrainedConfig) -> None:
        super().__init__(config)
        self.dense = KANLayer(config.hidden_size, config.hidden_size)
        self.dense_precition = KANLayer(config.hidden_size, 1)
        
    

class ElectraGenerator(ElectraForMaskedLM):
    def __init__(self, config: PretrainedConfig):
        super().__init__(config)
        self.electra = ElectraModelWithKANs(config)
        self.generator_predictions = ElectraKANGeneratorPrediction(config)
        self.generator_lm_head = KANLayer(config.embedding_size, config.vocab_size)


class ElectraDiscriminator(ElectraForPreTraining):
    def __init__(self, config: PretrainedConfig) -> None:
        super().__init__(config)
        self.electra = ElectraModelWithKANs(config)
        self.discriminator_predictions = ElectraKANDiscriminatorPrediction(config)
        self.post_init()
        

## Modeling test

In [None]:
import torch
import torch.nn as nn
from transformers import AutoConfig


torch.manual_seed(42)

generator_tokenizer = AutoTokenizer.from_pretrained('google/electra-base-generator')
generator_config = AutoConfig.from_pretrained('google/electra-base-generator')


discriminator_tokenizer = AutoTokenizer.from_pretrained('google/electra-base-discriminator')
discriminator_config = AutoConfig.from_pretrained('google/electra-base-discriminator')


random_input_ids = torch.randint(0, len(generator_tokenizer), (1, 512))
random_attention_mask = torch.randint(0, 1, (1, 512))
random_token_type_ids = torch.randint(0, 1, (1, 512))

orig_generator = ElectraForMaskedLM(generator_config)
generator = ElectraGenerator(generator_config) 


print("initialzation complete.")

with torch.no_grad():
    orig_generator.eval()
    generator.eval()

    orig_output = orig_generator(input_ids=random_input_ids, attention_mask=random_attention_mask, token_type_ids=random_token_type_ids)
    output = generator(input_ids=random_input_ids, attention_mask=random_attention_mask, token_type_ids=random_token_type_ids)



In [None]:
torch.allclose(orig_output, output)

# Train models