In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

# Твой класс MambaPlusPlus (скопируй сюда из твоего определения)
class MambaPlusPlus(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_heads):
        super().__init__()
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        
        self.W_a = nn.ModuleList([nn.Linear(embed_dim, self.head_dim) for _ in range(num_heads)])
        self.W_b = nn.ModuleList([nn.Linear(embed_dim, self.head_dim) for _ in range(num_heads)])
        self.W_out = nn.ModuleList([nn.Linear(embed_dim, self.head_dim) for _ in range(num_heads)])
        self.C   = nn.ModuleList([nn.Linear(self.head_dim, self.head_dim) for _ in range(num_heads)])
        
        self.proj = nn.Linear(hidden_dim, hidden_dim)
        self.norm = nn.LayerNorm(hidden_dim)
        self.ffn1 = nn.Linear(hidden_dim, hidden_dim * 4)
        self.ffn2 = nn.Linear(hidden_dim * 4, hidden_dim)
        self.act = nn.GELU()
        self.output_fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        emb = self.embed(x)  # (B, L, E)
        B, L, _ = emb.shape
        h = [torch.zeros(B, self.head_dim, device=x.device) for _ in range(self.num_heads)]
        head_outputs = []
        
        for t in range(L):
            head_outs_t = []
            for i in range(self.num_heads):
                a_t = torch.tanh(self.W_a[i](emb[:, t]))
                b_t = self.W_b[i](emb[:, t])
                h[i] = a_t * h[i] + b_t
                head_out = self.C[i](h[i]) * self.W_out[i](emb[:, t])
                head_outs_t.append(head_out)
            concat = torch.cat(head_outs_t, dim=-1)
            head_outputs.append(concat.unsqueeze(1))
        
        z = torch.cat(head_outputs, dim=1)  # (B, L, hidden_dim)
        u = z + self.norm(z)
        ffn_out = self.act(self.ffn1(u))
        o = self.ffn2(ffn_out)
        h_out = u + o
        logits = self.output_fc(self.proj(h_out))
        return logits

# Генератор данных
def generate_synthetic_data(batch_size, seq_len, vocab_size, top_k=1):
    x = torch.randint(1, vocab_size, (batch_size, seq_len))
    values = x.float()
    if top_k == 1:
        max_vals, max_indices = values.max(dim=1)
        labels = torch.zeros_like(x)
        labels.scatter_(1, max_indices.unsqueeze(1), 1)
    else:
        topk_vals, topk_indices = torch.topk(values, top_k, dim=1)
        labels = torch.zeros_like(x)
        labels.scatter_(1, topk_indices, 1)
    return x, labels

# Гиперпараметры
vocab_size = 100
embed_dim = 32
hidden_dim = 64
num_heads = 4
batch_size = 64
seq_len = 20
top_k = 1  # выберем 1 максимум

model = MambaPlusPlus(vocab_size, embed_dim, hidden_dim, num_heads)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Один шаг обучения
model.train()
x, labels = generate_synthetic_data(batch_size, seq_len, vocab_size, top_k)
optimizer.zero_grad()
logits = model(x)  # (B, L, vocab_size)
logits_max, _ = logits.max(dim=2)  # (B, L)
loss = criterion(logits_max, labels.float())
loss.backward()
optimizer.step()

print(f"Loss: {loss.item():.4f}")


Loss: 1.3082
