In [1]:
import torch
from model import *
import torch.nn as nn
from torchinfo import summary

In [3]:
class BuildTransformer(nn.Module):
    def __init__(self, *, num_layers, d_model, num_heads, input_vocab_size, dropout_rate=0.1):
        super().__init__()
        self.token_embedding = TokenEmbedding(vocab_size=input_vocab_size, d_model=d_model)
        self.decoder = Decoder(num_layers=num_layers, d_model=d_model, num_heads=num_heads,
                               input_vocab_size=input_vocab_size, dropout_rate=dropout_rate)

    def forward(self, inputs):
        x = inputs
        # Create padding mask (True for padding tokens, False otherwise)
        pad_mask = (x == 0)  # Assuming padding token is 0, as in TokenEmbedding's padding_idx
        x = self.token_embedding(x)
        x = self.decoder(x, pad_mask=pad_mask)
        return x

# Example instantiation
num_layers = 2
d_model = 256
num_heads = 8  # 32 dim per head
dropout_rate = 0.1
VOCAB_SIZE = 4000  # Example value, replace with actual VOCAB_SIZE
MAX_LEN = 180      # Example value, replace with actual MAX_LEN

model = BuildTransformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    input_vocab_size=VOCAB_SIZE,
    dropout_rate=dropout_rate
)

# Build weights by running a forward pass
dummy = torch.zeros((1, MAX_LEN), dtype=torch.int64)
_ = model(dummy)

# Model summary
summary(model, input_size=(1, MAX_LEN), dtypes=[torch.int64])

Layer (type:depth-idx)                             Output Shape              Param #
BuildTransformer                                   [1, 180, 4000]            --
├─TokenEmbedding: 1-1                              [1, 180, 256]             --
│    └─Embedding: 2-1                              [1, 180, 256]             1,024,000
├─Decoder: 1-2                                     [1, 180, 4000]            --
│    └─Dropout: 2-2                                [1, 180, 256]             --
│    └─ModuleList: 2-3                             --                        --
│    │    └─DecoderLayer: 3-1                      [1, 180, 256]             674,304
│    │    └─DecoderLayer: 3-2                      [1, 180, 256]             674,304
│    └─RMSNorm: 2-4                                [1, 180, 256]             256
│    └─Linear: 2-5                                 [1, 180, 4000]            1,028,000
Total params: 3,400,864
Trainable params: 3,400,864
Non-trainable params: 0
Total mult-add

In [6]:
print("Trainable Parameters:")
print("-" * 60)
total_params = 0
for name, param in model.named_parameters():
    if param.requires_grad:
        param_count = param.numel()
        print(f"Parameter: {name:<50} Shape: {str(param.shape):<20} Parameters: {param_count}")
        total_params += param_count
print("-" * 60)
print(f"Total Trainable Parameters: {total_params:,}")

Trainable Parameters:
------------------------------------------------------------
Parameter: token_embedding.embedding.weight                   Shape: torch.Size([4000, 256]) Parameters: 1024000
Parameter: decoder.dec_layers.0.causal_self_attention.mha.wq.weight Shape: torch.Size([256, 256]) Parameters: 65536
Parameter: decoder.dec_layers.0.causal_self_attention.mha.wk.weight Shape: torch.Size([256, 256]) Parameters: 65536
Parameter: decoder.dec_layers.0.causal_self_attention.mha.wv.weight Shape: torch.Size([256, 256]) Parameters: 65536
Parameter: decoder.dec_layers.0.causal_self_attention.mha.wo.weight Shape: torch.Size([256, 256]) Parameters: 65536
Parameter: decoder.dec_layers.0.causal_self_attention.mha.lora_wq.A.weight Shape: torch.Size([8, 256]) Parameters: 2048
Parameter: decoder.dec_layers.0.causal_self_attention.mha.lora_wq.B.weight Shape: torch.Size([256, 8]) Parameters: 2048
Parameter: decoder.dec_layers.0.causal_self_attention.mha.lora_wv.A.weight Shape: torch.Size([8, 256

In [8]:
# Freeze all parameters except those in LoraLayer
for name, param in model.named_parameters():
    if 'lora' not in name.lower():
        param.requires_grad = False

print("Trainable Parameters:")
print("-" * 60)
total_params = 0
for name, param in model.named_parameters():
    if param.requires_grad:
        param_count = param.numel()
        print(f"Parameter: {name:<50} Shape: {str(param.shape):<20} Parameters: {param_count}")
        total_params += param_count
print("-" * 60)
print(f"Total Trainable Parameters: {total_params:,}")

Trainable Parameters:
------------------------------------------------------------
Parameter: decoder.dec_layers.0.causal_self_attention.mha.lora_wq.A.weight Shape: torch.Size([8, 256]) Parameters: 2048
Parameter: decoder.dec_layers.0.causal_self_attention.mha.lora_wq.B.weight Shape: torch.Size([256, 8]) Parameters: 2048
Parameter: decoder.dec_layers.0.causal_self_attention.mha.lora_wv.A.weight Shape: torch.Size([8, 256]) Parameters: 2048
Parameter: decoder.dec_layers.0.causal_self_attention.mha.lora_wv.B.weight Shape: torch.Size([256, 8]) Parameters: 2048
Parameter: decoder.dec_layers.0.ffn.swiglu.lora_lin1.A.weight Shape: torch.Size([8, 256]) Parameters: 2048
Parameter: decoder.dec_layers.0.ffn.swiglu.lora_lin1.B.weight Shape: torch.Size([1024, 8]) Parameters: 8192
Parameter: decoder.dec_layers.1.causal_self_attention.mha.lora_wq.A.weight Shape: torch.Size([8, 256]) Parameters: 2048
Parameter: decoder.dec_layers.1.causal_self_attention.mha.lora_wq.B.weight Shape: torch.Size([256, 8])

In [10]:
# Create a dummy input and target for gradient computation
dummy_input = torch.zeros((1, MAX_LEN), dtype=torch.int64)
dummy_target = torch.zeros((1, MAX_LEN), dtype=torch.int64)  # Dummy target for loss

# Ensure model is in training mode
model.train()

# Forward pass
output = model(dummy_input)  # Shape: [1, 512, 10000]

# Compute a dummy loss (e.g., cross-entropy loss)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(output.view(-1, VOCAB_SIZE), dummy_target.view(-1))

# Backward pass to compute gradients
loss.backward()

# Check gradients for all parameters
print("Parameter Gradients (LoRA should have gradients, others should not):")
print("-" * 80)

total_trainable_params = 0
for name, param in model.named_parameters():
    grad_status = "Has Gradient" if param.grad is not None else "No Gradient"
    param_count = param.numel()
    if param.requires_grad:
        total_trainable_params += param_count
    print(f"Parameter: {name:<60} Shape: {str(param.shape):<20} Trainable: {param.requires_grad:<10} {grad_status}")
print("-" * 80)
print(f"Total Trainable Parameters: {total_trainable_params:,}")

Parameter Gradients (LoRA should have gradients, others should not):
--------------------------------------------------------------------------------
Parameter: token_embedding.embedding.weight                             Shape: torch.Size([4000, 256]) Trainable: 0          No Gradient
Parameter: decoder.dec_layers.0.causal_self_attention.mha.wq.weight     Shape: torch.Size([256, 256]) Trainable: 0          No Gradient
Parameter: decoder.dec_layers.0.causal_self_attention.mha.wk.weight     Shape: torch.Size([256, 256]) Trainable: 0          No Gradient
Parameter: decoder.dec_layers.0.causal_self_attention.mha.wv.weight     Shape: torch.Size([256, 256]) Trainable: 0          No Gradient
Parameter: decoder.dec_layers.0.causal_self_attention.mha.wo.weight     Shape: torch.Size([256, 256]) Trainable: 0          No Gradient
Parameter: decoder.dec_layers.0.causal_self_attention.mha.lora_wq.A.weight Shape: torch.Size([8, 256]) Trainable: 1          Has Gradient
Parameter: decoder.dec_layers.0