In [None]:
# Load GPT2 model and add layer at each block for each layer.

import torch
from transformers import GPT2Model

class TransformerLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.expand_A = nn.Linear(config.n_embd, 2048, bias=False)
        self.attn_weight_transform = nn.Linear(2304, 4096, bias=False)
        self.attn_bias = nn.Parameter(torch.zeros(4096))
        self.proj_weight_transform = nn.Linear(768, 2048, bias=False)
        self.proj_bias = nn.Parameter(torch.zeros(2048))
        self.mlp_fc_transform = nn.Linear(3072, 2048, bias=False)
        self.mlp_proj_transform = nn.Linear(2048, 2048, bias=False)
        self.ln1_weight = nn.Parameter(torch.ones(2048))
        self.ln1_bias = nn.Parameter(torch.zeros(2048))
        self.ln2_weight = nn.Parameter(torch.ones(2048))
        self.ln2_bias = nn.Parameter(torch.zeros(2048))

    def forward(self, x):
        # Simulate some computations
        return x

class CustomGPT2(nn.Module):
    def __init__(self, model_name='gpt2'):
        super().__init__()
        self.base_model = GPT2Model.from_pretrained(model_name)
        for param in self.base_model.parameters():
            param.requires_grad = False  # Freeze all original GPT-2 parameters
        self.custom_layers = nn.ModuleList([TransformerLayer(self.base_model.config) for _ in self.base_model.h])

    def forward(self, input_ids, attention_mask=None):
        last_hidden_state = self.base_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        for layer in self.custom_layers:
            last_hidden_state = layer(last_hidden_state)
        return last_hidden_state

# Initialize the custom model
model = CustomGPT2()

# Print trainable parameters
print("Trainable Parameters:")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"{name}: {param.size()}")