In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import numpy as np

# Define a simple Dataset class
class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Transformer model definition
class TransformerModel(nn.Module):
    def __init__(self, input_dim, model_dim, num_heads, num_layers, num_classes):
        super(TransformerModel, self).__init__()

        self.embedding = nn.Linear(input_dim, model_dim)
        self.positional_encoding = nn.Parameter(torch.zeros(1, 100, model_dim))

        self.encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)

        self.fc = nn.Linear(model_dim, num_classes)

    def forward(self, x):
        embedded = self.embedding(x) + self.positional_encoding[:, :x.size(1), :]
        encoded = self.transformer_encoder(embedded)
        output = self.fc(encoded.mean(dim=1))  # Global average pooling
        return output

# Example usage
if __name__ == "__main__":
    # Dummy data
    input_dim = 10
    model_dim = 64
    num_heads = 4
    num_layers = 2
    num_classes = 3

    data = np.random.rand(100, 10).astype(np.float32)  # 100 samples, 10 features each
    labels = np.random.randint(0, num_classes, size=(100,))

    dataset = CustomDataset(data, labels)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

    model = TransformerModel(input_dim, model_dim, num_heads, num_layers, num_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    for epoch in range(10):
        for batch_data, batch_labels in dataloader:
            batch_data = batch_data.permute(1, 0, 2)  # Transformer expects seq_len x batch_size x features
            batch_labels = batch_labels

            optimizer.zero_grad()
            outputs = model(batch_data)
            loss = criterion(outputs, batch_labels)
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}, Loss: {loss.item()}")