In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd

# ----------------------------
# 1) Датасет для задачи копирования
# ----------------------------
class CopyDataset(Dataset):
    def __init__(self, seq_len, vocab_size=100, num_samples=2000):
        super().__init__()
        self.seq_len = seq_len
        self.vocab_size = vocab_size
        # Генерируем случайные последовательности из [1, vocab_size)
        self.data = torch.randint(1, vocab_size, (num_samples, seq_len))
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        x = self.data[idx]
        return x, x  # таргет = тот же самый тензор

class NoisyCopyDataset(CopyDataset):
    def __getitem__(self, idx):
        x, _ = super().__getitem__(idx)
        noise = torch.zeros(10, dtype=x.dtype)  # токен 0 — шум
        x_noisy = torch.cat([noise, x, noise], dim=0)
        return x_noisy[:self.seq_len], x  # input truncated, target — исход
# ----------------------------
# 2) Модели
# ----------------------------
class LSTMCopy(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    def forward(self, x):
        emb = self.embed(x)  # (B, L, E)
        out, _ = self.lstm(emb)  # (B, L, H)
        logits = self.fc(out)    # (B, L, V)
        return logits

class TransformerCopy(nn.Module):
    def __init__(self, vocab_size, embed_dim, nhead, num_layers):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        encoder_layer = nn.TransformerEncoderLayer(embed_dim, nhead, dim_feedforward=embed_dim*4)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc = nn.Linear(embed_dim, vocab_size)
    def forward(self, x):
        emb = self.embed(x)            # (B, L, E)
        emb = emb.permute(1,0,2)       # (L, B, E) для Transformer
        out = self.encoder(emb)        # (L, B, E)
        out = out.permute(1,0,2)       # (B, L, E)
        logits = self.fc(out)          # (B, L, V)
        return logits

class MambaPlusPlus(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.linear_a = nn.Linear(embed_dim, hidden_dim)
        self.linear_b = nn.Linear(embed_dim, hidden_dim)
        self.linear_c = nn.Linear(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    def forward(self, x):
        emb = self.embed(x)  # (B, L, E)
        B, L, _ = emb.shape
        h = torch.zeros(B, self.linear_c.in_features, device=x.device)
        outputs = []
        for t in range(L):
            a_t = torch.sigmoid(self.linear_a(emb[:,t]))  # гейты
            b_t = self.linear_b(emb[:,t])                  # входной сигнал
            h = a_t * h + b_t                              # SSM-update
            y_t = self.linear_c(h)                         # output projection
            outputs.append(y_t.unsqueeze(1))
        out = torch.cat(outputs, dim=1)                    # (B, L, H)
        logits = self.fc(out)                              # (B, L, V)
        return logits

# ----------------------------
# 3) Функция обучения и оценки
# ----------------------------
def train_and_eval(model, train_loader, test_seq_lens, device, epochs=5):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    # Тренировка
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)  # (B, L, V)
            loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss={total_loss/len(train_loader):.4f}")
    # Оценка точности на разных длинах
    model.eval()
    results = {}
    with torch.no_grad():
        for L in test_seq_lens:
            # Формируем батч фиксированного размера
            x_test = torch.randint(1, train_loader.dataset.vocab_size, (64, L), device=device)
            y_test = x_test
            logits = model(x_test)
            preds = logits.argmax(dim=-1)
            acc = (preds == y_test).float().mean().item()
            results[L] = acc
    return results

# ----------------------------
# 4) Запуск всех моделей
# ----------------------------
def run_experiment():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vocab_size = 100
    embed_dim = 32
    hidden_dim = 64

    # Датасет длины 50
    train_ds = NoisyCopyDataset(seq_len=50, vocab_size=vocab_size, num_samples=2000)
    train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)

    test_seq_lens = [50, 100, 200, 300, 500, 1000]

    models = {
        "LSTM": LSTMCopy(vocab_size, embed_dim, hidden_dim),
        "Transformer": TransformerCopy(vocab_size, embed_dim, nhead=4, num_layers=2),
        "Mamba++": MambaPlusPlus(vocab_size, embed_dim, hidden_dim)
    }

    all_results = {}
    for name, model in models.items():
        print(f"\n=== Training {name} ===")
        results = train_and_eval(model, train_loader, test_seq_lens, device, epochs=5)
        all_results[name] = results

    # Вывод таблицы
    df = pd.DataFrame(all_results).T
    df.columns = [f"len={L}" for L in df.columns]
    print("\n=== Final Accuracy ===")
    print(df)
    df.to_csv("copy_task_results.csv")
    print("Saved results to copy_task_results.csv")

if __name__ == "__main__":
    run_experiment()

In [None]:
# import time
# start = time.time()
# for _ in range(100):
#     _ = model(x_test)
# print("Throughput:", 100 * x_test.numel() / (time.time() - start), "tokens/sec")