In [7]:
import torch
import torch.nn as nn


# Subclassing TransformerEncoderLayer to capture attention weights
class CustomTransformerEncoderLayer(nn.TransformerEncoderLayer):
    def __init__(self, *args, **kwargs):
        super(CustomTransformerEncoderLayer, self).__init__(*args, **kwargs)
        self.attn_weights = None

    def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal=False):
        src2, self.attn_weights = self.self_attn(
            src,
            src,
            src,
            attn_mask=src_mask,
            key_padding_mask=src_key_padding_mask,
            need_weights=True,
            is_causal=is_causal,
        )
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src


class TransformerClassifier(nn.Module):
    def __init__(
        self, vocab_size, embedding_dim, num_heads, hidden_dim, num_layers, max_length
    ):
        super(TransformerClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.pos_encoder = nn.Parameter(torch.zeros(1, max_length, embedding_dim))
        encoder_layers = CustomTransformerEncoderLayer(
            embedding_dim, num_heads, hidden_dim
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        self.fc = nn.Linear(max_length * embedding_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.embedding(x) + self.pos_encoder
        x = self.transformer_encoder(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return self.sigmoid(x)

    def get_attention_weights(self, x):
        with torch.no_grad():
            self.forward(x)
            # Retrieve attention weights from the first encoder layer
            return self.transformer_encoder.layers[0].attn_weights


# Instantiate and test the model
vocab_size = 100  # Example vocab size
embedding_dim = 6
num_heads = 2
num_layers = 1
hidden_dim = 2
max_length = 1500

model = TransformerClassifier(
    vocab_size, embedding_dim, num_heads, hidden_dim, num_layers, max_length
)

# Create input string
input_str = "sabe" + "p" * (1500 - 4)
input_indices = torch.tensor([ord(c) % vocab_size for c in input_str]).unsqueeze(0)

# Print the Q, K, V matrices
Q = model.transformer_encoder.layers[0].self_attn.q_proj(input_indices)
K = model.transformer_encoder.layers[0].self_attn.k_proj(input_indices)
V = model.transformer_encoder.layers[0].self_attn.v_proj(input_indices)

AttributeError: 'MultiheadAttention' object has no attribute 'q_proj'