In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from collections import Counter

# -------------------------------
# Set random seeds for reproducibility
# -------------------------------
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# -------------------------------
# Define syndrome mappings for a 3-qubit repetition (bit-flip) code.
# -------------------------------
ideal_syndromes = {
    0: (0, 0),  # no error
    1: (1, 0),  # error on qubit 1
    2: (1, 1),  # error on qubit 2
    3: (0, 1)   # error on qubit 3
}
token_map = {
    (0, 0): 0,
    (1, 0): 1,
    (1, 1): 2,
    (0, 1): 3
}

# -------------------------------
# Data generation function.
#
# Two noise modes:
#   "independent": each syndrome bit is flipped with probability noise_prob.
#   "biased": first bit is flipped with noise_prob1, second with noise_prob2.
#
# Parameter enhanced_input:
#   - If False, returns a single token per measurement (traditional).
#   - If True, returns the two syndrome bits as a list [bit1, bit2].
# -------------------------------
def generate_sample(sequence_length=20, noise_mode="independent", noise_prob=0.2,
                    noise_prob1=0.5, noise_prob2=0.1, enhanced_input=False):
    true_class = random.choice([0, 1, 2, 3])
    ideal = ideal_syndromes[true_class]
    sequence = []
    for _ in range(sequence_length):
        measured = []
        if noise_mode == "independent":
            for bit in ideal:
                if random.random() < noise_prob:
                    measured.append(1 - bit)
                else:
                    measured.append(bit)
        elif noise_mode == "biased":
            if random.random() < noise_prob1:
                measured.append(1 - ideal[0])
            else:
                measured.append(ideal[0])
            if random.random() < noise_prob2:
                measured.append(1 - ideal[1])
            else:
                measured.append(ideal[1])
        else:
            raise ValueError("Unknown noise_mode")
        measured = tuple(measured)
        if enhanced_input:
            # Return the raw two bits as integers
            sequence.append(list(measured))
        else:
            token = token_map[measured]
            sequence.append(token)
    return sequence, true_class

# -------------------------------
# PyTorch Dataset for syndrome data.
# -------------------------------
class SyndromeDataset(Dataset):
    def __init__(self, num_samples=5000, sequence_length=20, noise_mode="independent",
                 noise_prob=0.2, noise_prob1=0.5, noise_prob2=0.1,
                 enhanced_input=False, separate_bit_embeddings=False):
        self.samples = []
        self.labels = []
        for _ in range(num_samples):
            seq, label = generate_sample(sequence_length, noise_mode, noise_prob, noise_prob1, noise_prob2, enhanced_input)
            self.samples.append(seq)
            self.labels.append(label)
        if enhanced_input and separate_bit_embeddings:
            # Each measurement is a list of two ints; shape: (num_samples, seq_len, 2)
            self.samples = torch.tensor(self.samples, dtype=torch.long)
        elif enhanced_input:
            self.samples = torch.tensor(self.samples, dtype=torch.float32)
        else:
            self.samples = torch.tensor(self.samples, dtype=torch.long)
        self.labels = torch.tensor(self.labels, dtype=torch.long)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx], self.labels[idx]

# -------------------------------
# Positional Encoding Module.
# -------------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=100):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 1:
            pe[:, 1::2] = torch.cos(position * div_term[:-1])
        else:
            pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # shape: (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# -------------------------------
# Separate Bit Embedding Module.
# -------------------------------
class SyndromeBitEmbedding(nn.Module):
    def __init__(self, d_model, bit_vocab_size=2, combine_method="concat"):
        super(SyndromeBitEmbedding, self).__init__()
        self.embedding1 = nn.Embedding(bit_vocab_size, d_model)
        self.embedding2 = nn.Embedding(bit_vocab_size, d_model)
        self.combine_method = combine_method
        if combine_method == "concat":
            self.proj = nn.Linear(2 * d_model, d_model)
        elif combine_method != "sum":
            raise ValueError("Unknown combine_method")

    def forward(self, x):
        # x shape: (batch, seq_len, 2); each value should be int (0 or 1)
        bit1 = x[:, :, 0].long()  # (batch, seq_len)
        bit2 = x[:, :, 1].long()  # (batch, seq_len)
        emb1 = self.embedding1(bit1)  # (batch, seq_len, d_model)
        emb2 = self.embedding2(bit2)  # (batch, seq_len, d_model)
        if self.combine_method == "sum":
            out = emb1 + emb2
        else:  # "concat"
            out = torch.cat([emb1, emb2], dim=-1)  # (batch, seq_len, 2*d_model)
            out = self.proj(out)  # (batch, seq_len, d_model)
        return out

# -------------------------------
# Transformer-based Sequence Classifier.
# -------------------------------
class SyndromeTransformer(nn.Module):
    def __init__(self, enhanced_input=False, use_separate_embeddings=False, vocab_size=4,
                 d_model=128, nhead=8, num_layers=4, dim_feedforward=256, num_classes=4, dropout=0.1):
        super(SyndromeTransformer, self).__init__()
        self.enhanced_input = enhanced_input
        self.use_separate_embeddings = use_separate_embeddings
        if self.use_separate_embeddings:
            # Input shape: (batch, seq_len, 2) with int values.
            self.bit_embedding = SyndromeBitEmbedding(d_model, combine_method="concat")
        else:
            if enhanced_input:
                # Input shape: (batch, seq_len, 2) as float vectors.
                self.input_projection = nn.Linear(2, d_model)
            else:
                # Traditional token input.
                self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        # Extra classification head.
        self.fc1 = nn.Linear(d_model, d_model)
        self.fc2 = nn.Linear(d_model, num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        if self.use_separate_embeddings:
            emb = self.bit_embedding(x)  # (batch, seq_len, d_model)
        elif self.enhanced_input:
            emb = self.input_projection(x)  # (batch, seq_len, d_model)
        else:
            emb = self.embedding(x)  # (batch, seq_len, d_model)
        emb = self.pos_encoder(emb)
        emb = emb.transpose(0, 1)  # (seq_len, batch, d_model)
        out = self.transformer_encoder(emb)  # (seq_len, batch, d_model)
        out = out.transpose(0, 1)  # (batch, seq_len, d_model)
        pooled = out.mean(dim=1)   # Mean pooling
        pooled = self.dropout(pooled)
        hidden = torch.relu(self.fc1(pooled))
        hidden = self.dropout(hidden)
        return self.fc2(hidden)

# -------------------------------
# Classical decoder: Majority Vote.
# -------------------------------
def classical_decoder(sample, enhanced_input=False, use_separate_embeddings=False):
    if use_separate_embeddings:
        tokens = [token_map[tuple(x)] for x in sample.cpu().numpy()]
    elif enhanced_input:
        bits = [(1 if x[0] >= 0.5 else 0, 1 if x[1] >= 0.5 else 0) for x in sample.cpu().numpy()]
        tokens = [token_map[tuple(b)] for b in bits]
    else:
        tokens = sample.cpu().numpy().tolist()
    return Counter(tokens).most_common(1)[0][0]

# -------------------------------
# Training and evaluation routines.
# -------------------------------
def train_model(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    return running_loss / len(dataloader.dataset)

def evaluate_model(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total

def evaluate_classical(dataset, enhanced_input=False, use_separate_embeddings=False):
    correct = 0
    total = len(dataset)
    for sample, label in dataset:
        if classical_decoder(sample, enhanced_input, use_separate_embeddings) == label.item():
            correct += 1
    return correct / total

# -------------------------------
# Run an experiment with early stopping and weight decay.
# -------------------------------
def run_experiment(noise_mode, enhanced_input, use_separate_embeddings, noise_prob=0.2, noise_prob1=0.5, noise_prob2=0.1,
                   num_train=5000, num_val=1000, sequence_length=20, num_epochs=20,
                   d_model=128, nhead=8, num_layers=4, dim_feedforward=256, dropout=0.1, weight_decay=1e-5, patience=5):
    mode_str = "Separate Bit Embeddings" if use_separate_embeddings else ("Enhanced" if enhanced_input else "Traditional")
    print(f"=== Experiment ({mode_str}): noise_mode = {noise_mode} ===")

    train_dataset = SyndromeDataset(num_samples=num_train, sequence_length=sequence_length,
                                    noise_mode=noise_mode, noise_prob=noise_prob,
                                    noise_prob1=noise_prob1, noise_prob2=noise_prob2,
                                    enhanced_input=enhanced_input, separate_bit_embeddings=use_separate_embeddings)
    val_dataset = SyndromeDataset(num_samples=num_val, sequence_length=sequence_length,
                                  noise_mode=noise_mode, noise_prob=noise_prob,
                                  noise_prob1=noise_prob1, noise_prob2=noise_prob2,
                                  enhanced_input=enhanced_input, separate_bit_embeddings=use_separate_embeddings)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = SyndromeTransformer(enhanced_input=enhanced_input, use_separate_embeddings=use_separate_embeddings,
                                vocab_size=4, d_model=d_model, nhead=nhead, num_layers=num_layers,
                                dim_feedforward=dim_feedforward, num_classes=4, dropout=dropout).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=weight_decay)

    best_val_acc = 0.0
    best_epoch = 0
    epochs_no_improve = 0
    best_model_state = None

    for epoch in range(num_epochs):
        train_loss = train_model(model, train_loader, criterion, optimizer, device)
        val_acc = evaluate_model(model, val_loader, device)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Transformer Val Accuracy: {val_acc*100:.2f}%")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch
            epochs_no_improve = 0
            best_model_state = model.state_dict()
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"Early stopping triggered at epoch {epoch+1} (best epoch: {best_epoch+1}, best accuracy: {best_val_acc*100:.2f}%)")
            model.load_state_dict(best_model_state)
            break

    transformer_acc = evaluate_model(model, val_loader, device)
    classical_acc = evaluate_classical(val_dataset, enhanced_input, use_separate_embeddings)
    print(f"\nFinal Transformer Val Accuracy: {transformer_acc*100:.2f}%")
    print(f"Classical Decoder (Majority Vote) Val Accuracy: {classical_acc*100:.2f}%")

    print("\nTesting examples:")
    for i in range(5):
        sample, true_label = val_dataset[i]
        sample_tensor = sample.unsqueeze(0).to(device)
        with torch.no_grad():
            output = model(sample_tensor)
            _, pred_transformer = torch.max(output, 1)
        pred_classical = classical_decoder(sample, enhanced_input, use_separate_embeddings)
        if use_separate_embeddings:
            decoded_seq = ["".join(str(x) for x in vec.tolist()) for vec in sample]
        elif enhanced_input:
            decoded_seq = ["".join(str(int(b)) for b in vec.tolist()) for vec in sample]
        else:
            inv_token_map = {v: k for k, v in token_map.items()}
            decoded_seq = ["".join(str(b) for b in inv_token_map[int(token)]) for token in sample.tolist()]
        print(f"Sample {i+1}:")
        print(" Syndrome Sequence:", decoded_seq)
        print(" True Correction:", true_label.item(), " (", ideal_syndromes[true_label.item()],")")
        print(" Transformer Predicted Correction:", pred_transformer.item(), " (", ideal_syndromes[pred_transformer.item()],")")
        print(" Classical Predicted Correction:", pred_classical, " (", ideal_syndromes[pred_classical],")")
        print("------")
    print("\n")
    return transformer_acc, classical_acc

# -------------------------------
# Main routine: Run experiments.
# -------------------------------
def main():
    # Use enhanced input with separate bit embeddings.
    enhanced_input = True
    use_separate_embeddings = True

    # Experiment 1: Independent noise
    transformer_acc_ind, classical_acc_ind = run_experiment(
        noise_mode="independent",
        enhanced_input=enhanced_input,
        use_separate_embeddings=use_separate_embeddings,
        noise_prob=0.4,
        num_train=5000,
        num_val=1000,
        sequence_length=20,
        num_epochs=20,
        d_model=128, nhead=8, num_layers=4, dim_feedforward=256, dropout=0.1,
        weight_decay=1e-5,
        patience=5
    )

    # Experiment 2: Biased noise
    transformer_acc_bias, classical_acc_bias = run_experiment(
        noise_mode="biased",
        enhanced_input=enhanced_input,
        use_separate_embeddings=use_separate_embeddings,
        noise_prob=0.2,      # not used in biased mode
        noise_prob1=0.3,     # high flip probability for first bit
        noise_prob2=0.1,     # low flip probability for second bit
        num_train=5000,
        num_val=1000,
        sequence_length=20,
        num_epochs=20,
        d_model=128, nhead=8, num_layers=4, dim_feedforward=256, dropout=0.1,
        weight_decay=1e-5,
        patience=5
    )

    print("Summary:")
    print(f"Independent Noise -> Transformer: {transformer_acc_ind*100:.2f}%, Classical: {classical_acc_ind*100:.2f}%")
    print(f"Biased Noise      -> Transformer: {transformer_acc_bias*100:.2f}%, Classical: {classical_acc_bias*100:.2f}%")

if __name__ == "__main__":
    main()


=== Experiment (Separate Bit Embeddings): noise_mode = independent ===
Epoch 1/20, Loss: 0.9463, Transformer Val Accuracy: 62.30%
Epoch 2/20, Loss: 0.8473, Transformer Val Accuracy: 62.50%
Epoch 3/20, Loss: 0.8275, Transformer Val Accuracy: 65.00%
Epoch 4/20, Loss: 0.8327, Transformer Val Accuracy: 68.00%
Epoch 5/20, Loss: 0.8337, Transformer Val Accuracy: 65.10%
Epoch 6/20, Loss: 0.8453, Transformer Val Accuracy: 67.40%
Epoch 7/20, Loss: 0.8142, Transformer Val Accuracy: 67.20%
Epoch 8/20, Loss: 0.8127, Transformer Val Accuracy: 66.90%
Epoch 9/20, Loss: 0.8108, Transformer Val Accuracy: 68.00%
Early stopping triggered at epoch 9 (best epoch: 4, best accuracy: 68.00%)

Final Transformer Val Accuracy: 68.00%
Classical Decoder (Majority Vote) Val Accuracy: 60.40%

Testing examples:
Sample 1:
 Syndrome Sequence: ['01', '11', '11', '01', '01', '11', '10', '11', '01', '01', '00', '00', '00', '00', '11', '11', '00', '01', '01', '00']
 True Correction: 3  ( (0, 1) )
 Transformer Predicted Cor