In [1]:
import random
import string
from tqdm import tqdm

# Define constants
MAX_LENGTH = 50
VALID_CHARACTERS = ["s", "a", "b", "e", "p"]
MAIN_CHARACTERS = ["a", "b"]
START_TOKEN = "s"
END_TOKEN = "e"
PADDING_TOKEN = "p"
VALID_RATIO = 0.5  # Half of the dataset should be valid a*b* strings


# Function to generate valid a*b* strings
def generate_valid_string():
    num_a = random.randint(0, MAX_LENGTH - 2)
    num_b = random.randint(0, MAX_LENGTH - 2 - num_a)
    valid_str = "a" * num_a + "b" * num_b
    return (
        START_TOKEN
        + valid_str
        + END_TOKEN
        + PADDING_TOKEN * (MAX_LENGTH - len(valid_str) - 2)
    )


# Function to generate invalid strings
def generate_invalid_string():
    length = random.randint(1, MAX_LENGTH - 2)
    if length == 1:
        return START_TOKEN + "ba" + END_TOKEN + PADDING_TOKEN * (MAX_LENGTH - 2)
    while True:
        # Random string of a's and b's which isn't a valid a*b* string
        invalid_str = "".join(random.choices(MAIN_CHARACTERS, k=length))
        if "ba" in invalid_str:
            break
    return (
        START_TOKEN
        + invalid_str
        + END_TOKEN
        + PADDING_TOKEN * (MAX_LENGTH - len(invalid_str) - 2)
    )


# Generate dataset
dataset = []
num_samples = 10000  # Total number of samples
num_valid_samples = 15000  # int(VALID_RATIO * num_samples)
num_invalid_samples = 1500  # num_samples - num_valid_samples

for _ in range(num_valid_samples):
    while True:
        x = generate_valid_string()
        if not (len(x) == MAX_LENGTH):
            continue
        dataset.append((x, 1))
        break

for _ in range(num_invalid_samples):
    while True:
        x = generate_invalid_string()
        if not (len(x) == MAX_LENGTH):
            continue
        dataset.append((x, 0))
        break

# Remove all duplicates
dataset = list(set(dataset))

# Write to file
with open("dataset.txt", "w") as f:
    for data, label in dataset:
        f.write(f"{data} {label}\n")

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Define constants for model
VOCAB_SIZE = len(VALID_CHARACTERS)
EMBEDDING_DIM = 6
NUM_HEADS = 2
NUM_LAYERS = 1
HIDDEN_DIM = 2
BATCH_SIZE = 512
EPOCHS = 30

# Mapping characters to indices
char_to_index = {ch: idx for idx, ch in enumerate(VALID_CHARACTERS)}

# Custom dataset class
class StringDataset(Dataset):
    def __init__(self, file_path):
        self.data = []
        self.labels = []
        with open(file_path, "r") as f:
            for line in f:
                parts = line.strip().split(" ")
                self.data.append(parts[0])
                self.labels.append(int(parts[1]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        string = self.data[idx]
        label = self.labels[idx]
        encoded = self.encode_string(string)
        return torch.tensor(encoded, dtype=torch.long), torch.tensor(
            label, dtype=torch.float32
        )

    def encode_string(self, string):
        return [char_to_index[char] for char in string]


# Transformer model
class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_heads, hidden_dim, num_layers):
        super(TransformerClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.pos_encoder = nn.Parameter(torch.zeros(1, MAX_LENGTH, embedding_dim))
        encoder_layers = nn.TransformerEncoderLayer(
            embedding_dim, num_heads, hidden_dim
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        self.fc = nn.Linear(MAX_LENGTH * embedding_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.embedding(x) + self.pos_encoder
        x = self.transformer_encoder(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return self.sigmoid(x)


# Prepare dataset and dataloader
dataset = StringDataset("dataset.txt")
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Initialize model, loss function, and optimizer
model = TransformerClassifier(
    VOCAB_SIZE, EMBEDDING_DIM, NUM_HEADS, HIDDEN_DIM, NUM_LAYERS
)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cpu" and torch.backends.mps.is_available():
    device = torch.device("mps")
model.to(device)

# Training loop
for epoch in range(EPOCHS):
    model.train()
    for inputs, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(inputs.to(device))
        loss = criterion(outputs.squeeze(), labels.to(device))
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item()}")
    # Evaluate the model
    model.eval()
    correct = 0
    total = 0
    for inputs, labels in dataloader:
        outputs = model(inputs.to(device))
        predicted = torch.round(outputs)
        total += labels.size(0)
        correct += (predicted.squeeze().to(device) == labels.to(device)).sum().item()
    print(f"Accuracy: {correct/total}")
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# Training loop
for epoch in range(EPOCHS):
    model.train()
    for inputs, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(inputs.to(device))
        loss = criterion(outputs.squeeze(), labels.to(device))
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item()}")
    # Evaluate the model
    model.eval()
    correct = 0
    total = 0
    for inputs, labels in dataloader:
        outputs = model(inputs.to(device))
        predicted = torch.round(outputs)
        total += labels.size(0)
        correct += (predicted.squeeze().to(device) == labels.to(device)).sum().item()
    print(f"Accuracy: {correct/total}")

# Save the trained model
torch.save(model.state_dict(), "transformer_model.pth")

# Print accuracy of the model
model.eval()
correct = 0
total = 0
for inputs, labels in dataloader:
    outputs = model(inputs.to(device))
    predicted = torch.round(outputs)
    total += labels.size(0)
    correct += (predicted.squeeze().to(device) == labels.to(device)).sum().item()
print(f"Accuracy: {correct/total}")

Epoch 1/30, Loss: 0.645404577255249
Accuracy: 0.6271896420411271
Epoch 2/30, Loss: 0.5931392312049866
Accuracy: 0.7010662604722011
Epoch 3/30, Loss: 0.61623615026474
Accuracy: 0.7098248286367098
Epoch 4/30, Loss: 0.4925661087036133
Accuracy: 0.7402894135567403
Epoch 5/30, Loss: 0.4830002784729004
Accuracy: 0.7635186595582635
Epoch 6/30, Loss: 0.5012676119804382
Accuracy: 0.7924600152322925
Epoch 7/30, Loss: 0.45602673292160034
Accuracy: 0.801980198019802
Epoch 8/30, Loss: 0.521774172782898
Accuracy: 0.8111195734958111
Epoch 9/30, Loss: 0.3749580681324005
Accuracy: 0.8164508758568164
Epoch 10/30, Loss: 0.4207557439804077
Accuracy: 0.8229246001523229
Epoch 11/30, Loss: 0.39253461360931396
Accuracy: 0.8282559025133283
Epoch 12/30, Loss: 0.35285770893096924
Accuracy: 0.8305407463823306
Epoch 13/30, Loss: 0.3151388466358185
Accuracy: 0.8358720487433359
Epoch 14/30, Loss: 0.4389675557613373
Accuracy: 0.8408225437928408
Epoch 15/30, Loss: 0.3898017406463623
Accuracy: 0.8465346534653465
Epoch 

In [3]:
# Print 20 examples which are wrongly classified
print("20 examples which are wrongly classified")
count = 0
for inputs, labels in dataloader:
    outputs = model(inputs.to(device))
    predicted = torch.round(outputs)
    for i in range(len(predicted)):
        if count == 20:
            break
        if predicted[i] != labels[i]:
            # Convert back to string of a's and b's
            string = "".join([VALID_CHARACTERS[int(idx)] for idx in inputs[i]])
            print(string, labels[i].item(), predicted[i].item())
            count += 1

20 examples which are wrongly classified
saaaabbbbepppppppppppppppppppppppppppppppppppppppp 1.0 0.0
sabbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbepppppppppppppp 1.0 0.0
sbbbbbbbbbbbbbbbbbbbbepppppppppppppppppppppppppppp 1.0 0.0
saabbaabbaabbbbbbabbbbbabaaaabbbbbaabepppppppppppp 0.0 1.0
saaaabababbabbaaeppppppppppppppppppppppppppppppppp 0.0 1.0
saaabbabaabbaaaabbbaababbabbeppppppppppppppppppppp 0.0 1.0
saaabbbbbbbbeppppppppppppppppppppppppppppppppppppp 1.0 0.0
saabbbbbbbbbbbbbbbbbeppppppppppppppppppppppppppppp 1.0 0.0
sababaabbaabbababbbbaabbbabaaabbbbbbeppppppppppppp 0.0 1.0
saaaaaaabbbbbabaaabbaabaaabaeppppppppppppppppppppp 0.0 1.0
saaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbepppp 1.0 0.0
saaaabbbbabbabbbabababbbababaaaaaepppppppppppppppp 0.0 1.0
sbaaaababaaaabbababbaababbbbbabbbbbabaabbaaabbbbep 0.0 1.0
saaaaabbaaaabbabbbabaaaabaeppppppppppppppppppppppp 0.0 1.0
sabbbbbbbbbbbbbbbbbbbbbbeppppppppppppppppppppppppp 1.0 0.0
saaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaeppppppppppp 1.0 0.0
saaabbbbbbbeppp