In [None]:
import random
import string
from tqdm import tqdm

# Define constants
MAX_LENGTH = 500
OOD_MAX_LENGTH = 1500
VALID_CHARACTERS = ["s", "a", "b", "e", "p"]
MAIN_CHARACTERS = ["a", "b"]
START_TOKEN = "s"
END_TOKEN = "e"
PADDING_TOKEN = "p"
VALID_RATIO = 0.5  # Half of the dataset should be valid a*b* strings


# Function to generate valid a*b* strings
def generate_valid_string(min_length = 0, max_length = MAX_LENGTH):
    num_a = random.randint(0, max_length - 2)
    num_b = random.randint(min(0, min_length - num_a), max_length - 2 - num_a)
    valid_str = "a" * num_a + "b" * num_b
    return (
        START_TOKEN
        + valid_str
        + END_TOKEN
        + PADDING_TOKEN * (max_length - len(valid_str) - 2)
    )

# Function to generate invalid strings
def generate_invalid_string(min_length = 1, max_length = MAX_LENGTH):
    length = random.randint(min_length, max_length - 2)
    if length == 1:
        return START_TOKEN + "ba" + END_TOKEN + PADDING_TOKEN * (max_length - 2)
    while True:
        # Random string of a's and b's which isn't a valid a*b* string
        invalid_str = "".join(random.choices(MAIN_CHARACTERS, k=length))
        if "ba" in invalid_str:
            break
    return (
        START_TOKEN
        + invalid_str
        + END_TOKEN
        + PADDING_TOKEN * (max_length - len(invalid_str) - 2)
    )


# Generate dataset
dataset = []
num_samples = 2000000  # Total number of samples

for _ in range(num_samples):
    while True:
        x = generate_valid_string()
        if not (len(x) == MAX_LENGTH):
            continue
        x += PADDING_TOKEN * (OOD_MAX_LENGTH - MAX_LENGTH)
        if not (len(x) == OOD_MAX_LENGTH):
            continue
        dataset.append((x, 1))
        break

# remove duplicates
valid_dataset = list(set(dataset))
dataset = []

for _ in range(num_samples):
    while True:
        x = generate_invalid_string()
        if not (len(x) == MAX_LENGTH):
            continue
        x += PADDING_TOKEN * (OOD_MAX_LENGTH - MAX_LENGTH)
        if not (len(x) == OOD_MAX_LENGTH):
            continue
        dataset.append((x, 0))
        break

# Remove all duplicates
invalid_dataset = list(set(dataset))[:len(valid_dataset)]
print(len(valid_dataset))
print(len(invalid_dataset))

split = len(valid_dataset) * 4 // 5;
train_dataset = valid_dataset[:split] + invalid_dataset[:split]
test_dataset = valid_dataset[split:] + invalid_dataset[split:]

num_ood_samples = 50000
dataset = []
for _ in range(num_ood_samples):
    while True:
        x = generate_valid_string(min_length=MAX_LENGTH + 2, max_length=OOD_MAX_LENGTH)
        if not (len(x) == OOD_MAX_LENGTH):
            continue
        dataset.append((x, 1))
        break

# remove duplicates
ood_valid_dataset = list(set(dataset))
dataset = []

for _ in range(num_ood_samples):
    while True:
        x = generate_invalid_string(min_length=MAX_LENGTH + 2, max_length=OOD_MAX_LENGTH)
        if not (len(x) == OOD_MAX_LENGTH):
            continue
        dataset.append((x, 0))
        break

# Remove all duplicates
ood_invalid_dataset = list(set(dataset))[:len(ood_valid_dataset)]
print(len(ood_valid_dataset), len(ood_invalid_dataset))

ood_dataset = ood_valid_dataset + ood_invalid_dataset
print(len(train_dataset))
print(len(test_dataset))
print(len(ood_dataset))


# Write to file
with open("train_dataset.txt", "w") as f:
    for data, label in train_dataset:
        f.write(f"{data} {label}\n")

with open("test_dataset.txt", "w") as f:
    for data, label in test_dataset:
        f.write(f"{data} {label}\n")

with open("ood_dataset.txt", "w") as f:
    for data, label in ood_dataset:
        f.write(f"{data} {label}\n")

124710
124710
33821 33821
199536
49884
67642


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Define constants for model
VOCAB_SIZE = len(VALID_CHARACTERS)
EMBEDDING_DIM = 6
NUM_HEADS = 2
NUM_LAYERS = 1
HIDDEN_DIM = 2
BATCH_SIZE = 512
EPOCHS = 15

# Mapping characters to indices
char_to_index = {ch: idx for idx, ch in enumerate(VALID_CHARACTERS)}

# Custom dataset class
class StringDataset(Dataset):
    def __init__(self, file_path):
        self.data = []
        self.labels = []
        with open(file_path, "r") as f:
            for line in f:
                parts = line.strip().split(" ")
                self.data.append(parts[0])
                self.labels.append(int(parts[1]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        string = self.data[idx]
        label = self.labels[idx]
        encoded = self.encode_string(string)
        return torch.tensor(encoded, dtype=torch.long), torch.tensor(
            label, dtype=torch.float32
        )

    def encode_string(self, string):
        return [char_to_index[char] for char in string]


# Transformer model
class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_heads, hidden_dim, num_layers):
        super(TransformerClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.pos_encoder = nn.Parameter(torch.zeros(1, OOD_MAX_LENGTH, embedding_dim))
        encoder_layers = nn.TransformerEncoderLayer(
            embedding_dim, num_heads, hidden_dim
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        self.fc = nn.Linear(OOD_MAX_LENGTH * embedding_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.embedding(x) + self.pos_encoder
        x = self.transformer_encoder(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return self.sigmoid(x)

# Prepare dataset and dataloader
dataset = StringDataset("train_dataset.txt")
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Initialize model, loss function, and optimizer
model = TransformerClassifier(
    VOCAB_SIZE, EMBEDDING_DIM, NUM_HEADS, HIDDEN_DIM, NUM_LAYERS
)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cpu" and torch.backends.mps.is_available():
    device = torch.device("mps")
model.to(device)

# Training loop
for epoch in range(EPOCHS):
    model.train()
    for inputs, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(inputs.to(device))
        loss = criterion(outputs.squeeze(), labels.to(device))
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item()}")
    # Evaluate the model
    model.eval()
    correct = 0
    total = 0
    for inputs, labels in dataloader:
        outputs = model(inputs.to(device))
        predicted = torch.round(outputs)
        total += labels.size(0)
        correct += (predicted.squeeze().to(device) == labels.to(device)).sum().item()
    print(f"Accuracy: {correct/total}")
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# Training loop
for epoch in range(EPOCHS):
    model.train()
    for inputs, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(inputs.to(device))
        loss = criterion(outputs.squeeze(), labels.to(device))
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item()}")
    # Evaluate the model
    model.eval()
    correct = 0
    total = 0
    for inputs, labels in dataloader:
        outputs = model(inputs.to(device))
        predicted = torch.round(outputs)
        total += labels.size(0)
        correct += (predicted.squeeze().to(device) == labels.to(device)).sum().item()
    print(f"Accuracy: {correct/total}")

# Save the trained model
torch.save(model.state_dict(), "short70_200total_transformer_model.pth")

test_dataset = StringDataset("test_dataset.txt")
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
# Print accuracy of the model
model.eval()
correct = 0
total = 0
for inputs, labels in test_dataloader:
    outputs = model(inputs.to(device))
    predicted = torch.round(outputs)
    total += labels.size(0)
    correct += (predicted.squeeze().to(device) == labels.to(device)).sum().item()
print(f"Accuracy: {correct/total}")



Epoch 1/15, Loss: 0.11898428946733475
Accuracy: 0.9723608772351856
Epoch 2/15, Loss: 0.12191472202539444
Accuracy: 0.9675296688316896
Epoch 3/15, Loss: 0.12137658149003983
Accuracy: 0.9611949723358191
Epoch 4/15, Loss: 0.14996561408042908
Accuracy: 0.9738944350894074
Epoch 5/15, Loss: 0.20408190786838531
Accuracy: 0.9731827840590169
Epoch 6/15, Loss: 0.15650595724582672
Accuracy: 0.9643673322107289
Epoch 7/15, Loss: 0.1393161565065384
Accuracy: 0.9728720631865929
Epoch 8/15, Loss: 0.12338759005069733
Accuracy: 0.9652343436773314
Epoch 9/15, Loss: 0.1447402685880661
Accuracy: 0.9661264132788069
Epoch 10/15, Loss: 0.2224826067686081
Accuracy: 0.9658507737952049
Epoch 11/15, Loss: 0.082332082092762
Accuracy: 0.9660362039932644
Epoch 12/15, Loss: 0.1062888503074646
Accuracy: 0.9678905059738594
Epoch 13/15, Loss: 0.14177195727825165
Accuracy: 0.9686422500200466
Epoch 14/15, Loss: 0.10272542387247086
Accuracy: 0.9703963194611499
Epoch 15/15, Loss: 0.09556952863931656
Accuracy: 0.972315772592

In [None]:
ood_dataset = StringDataset("ood_dataset.txt")
ood_dataloader = DataLoader(ood_dataset, batch_size=BATCH_SIZE, shuffle=True)
# Print accuracy of the model
model.eval()
correct = 0
total = 0
for inputs, labels in ood_dataloader:
    outputs = model(inputs.to(device))
    predicted = torch.round(outputs)
    total += labels.size(0)
    correct += (predicted.squeeze().to(device) == labels.to(device)).sum().item()
print(f"Accuracy: {correct/total}")

Accuracy: 0.6400756926170131


In [None]:
# Print 20 examples which are wrongly classified
print("20 examples which are wrongly classified")
count = 0
for inputs, labels in dataloader:
    outputs = model(inputs.to(device))
    predicted = torch.round(outputs)
    for i in range(len(predicted)):
        if count == 20:
            break
        if predicted[i] != labels[i]:
            # Convert back to string of a's and b's
            string = "".join([VALID_CHARACTERS[int(idx)] for idx in inputs[i]])
            print(string, labels[i].item(), predicted[i].item())
            count += 1

20 examples which are wrongly classified
saabbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbepppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp