In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

class CustomMultiheadAttention(nn.MultiheadAttention):
    def __init__(self, *args, **kwargs):
        super(CustomMultiheadAttention, self).__init__(*args, **kwargs)
        self.attention_weights = None

    def forward(self, query, key, value, key_padding_mask=None, need_weights=True):
        self.attention_weights = None  # Reset the attention weights before each forward pass
        output, attention_weights = super(CustomMultiheadAttention, self).forward(
            query, key, value, key_padding_mask=key_padding_mask, need_weights=need_weights
        )
        self.attention_weights = attention_weights
        return output, attention_weights

class TransformerModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_heads):
        super(TransformerModel, self).__init__()

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.custom_multihead_attention = CustomMultiheadAttention(embed_dim=hidden_size, num_heads=num_heads)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, src, tgt):
        src = self.embedding(src)
        tgt = self.embedding(tgt)

        # Assume self-attention for simplicity
        output, attention_weights = self.custom_multihead_attention(tgt, tgt, tgt)

        output = F.relu(output.mean(dim=1))  # Aggregate over sequence length
        output = self.fc(output)
        return output, attention_weights

# Dummy dataset for illustration purposes
def generate_dummy_data(num_samples, seq_length, vocab_size):
    src = torch.randint(0, vocab_size, (num_samples, seq_length))
    tgt = torch.randint(0, vocab_size, (num_samples, seq_length))
    labels = torch.randint(0, 2, (num_samples,))
    return src, tgt, labels

# Set hyperparameters
input_size = 100  # Vocabulary size
hidden_size = 256
num_heads = 4
output_size = 2
batch_size = 32
learning_rate = 0.001
num_epochs = 10

# Create the model, loss function, and optimizer
model = TransformerModel(input_size, hidden_size, num_heads)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Generate dummy data
num_samples = 1000
seq_length = 20
vocab_size = input_size
src, tgt, labels = generate_dummy_data(num_samples, seq_length, vocab_size)

# Create DataLoader
dataset = TensorDataset(src, tgt, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop
for epoch in range(num_epochs):
    for batch_src, batch_tgt, batch_labels in dataloader:
        optimizer.zero_grad()
        outputs, attention_weights = model(batch_src, batch_tgt)
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()

        # Print or log the development of attention heads during training
        print("Epoch {epoch + 1} Loss", epoch)
        print(loss.item())
        print("one attention pattern developping during training")
        print(attention_weights[0])

# After training, you can use the trained model for prediction on new data
# and analyze attention heads as needed.


Epoch {epoch + 1} Loss 0
0.6909310817718506
one attention pattern developping during training
tensor([[0.0437, 0.0215, 0.0197,  ..., 0.0377, 0.0504, 0.0384],
        [0.0295, 0.0433, 0.0296,  ..., 0.0281, 0.0301, 0.0375],
        [0.0338, 0.0215, 0.0280,  ..., 0.0326, 0.0216, 0.0337],
        ...,
        [0.0175, 0.0524, 0.0405,  ..., 0.0259, 0.0401, 0.0385],
        [0.0327, 0.0349, 0.0289,  ..., 0.0227, 0.0386, 0.0402],
        [0.0297, 0.0388, 0.0364,  ..., 0.0246, 0.0405, 0.0298]],
       grad_fn=<SelectBackward0>)
Epoch {epoch + 1} Loss 0
0.6813976764678955
one attention pattern developping during training
tensor([[0.0359, 0.0359, 0.0320,  ..., 0.0183, 0.0183, 0.0286],
        [0.0322, 0.0281, 0.0254,  ..., 0.0392, 0.0392, 0.0348],
        [0.0395, 0.0303, 0.0341,  ..., 0.0494, 0.0494, 0.0302],
        ...,
        [0.0375, 0.0412, 0.0288,  ..., 0.0202, 0.0202, 0.0224],
        [0.0375, 0.0412, 0.0288,  ..., 0.0202, 0.0202, 0.0224],
        [0.0210, 0.0283, 0.0329,  ..., 0.0390, 