In [1]:
import random
import string
from tqdm import tqdm

# Define constants
MAX_LENGTH = 200
OOD_MAX_LENGTH = 400
VALID_CHARACTERS = ["s", "a", "b", "e", "p"]
MAIN_CHARACTERS = ["a", "b"]
START_TOKEN = "s"
END_TOKEN = "e"
PADDING_TOKEN = "p"
VALID_RATIO = 0.5  # Half of the dataset should be valid a*b* strings


# Function to generate valid a*b* strings
def generate_valid_string(min_length = 0, max_length = MAX_LENGTH):
    num_a = random.randint(0, max_length - 2)
    num_b = random.randint(min(0, min_length - num_a), max_length - 2 - num_a)
    valid_str = "a" * num_a + "b" * num_b
    num_p = random.randint(0, max_length - len(valid_str) - 2)
    return (
        START_TOKEN
        + PADDING_TOKEN * num_p
        + valid_str
        + END_TOKEN
        + PADDING_TOKEN * ((max_length - len(valid_str) - 2) - num_p)
    )

# Function to generate invalid strings
def generate_invalid_string(min_length = 1, max_length = MAX_LENGTH):
    length = random.randint(min_length, max_length - 2)
    if length == 1:
        num_p = random.randint(0, max_length - 2)
        return START_TOKEN + PADDING_TOKEN * num_p + "ba" + END_TOKEN + PADDING_TOKEN * (max_length - 2 - num_p)
    while True:
        # Random string of a's and b's which isn't a valid a*b* string
        invalid_str = "".join(random.choices(MAIN_CHARACTERS, k=length))
        if "ba" in invalid_str:
            break
    num_p = random.randint(0, max_length - len(invalid_str) - 2)
    return (
        START_TOKEN
        + PADDING_TOKEN * num_p
        + invalid_str
        + END_TOKEN
        + PADDING_TOKEN * (max_length - len(invalid_str) - 2)
    )


# Generate dataset
dataset = []
num_samples = 100000  # Total number of samples

for _ in range(num_samples):
    while True:
        x = generate_valid_string()
        if not (len(x) == MAX_LENGTH):
            continue
        x += PADDING_TOKEN * (OOD_MAX_LENGTH - MAX_LENGTH)
        if not (len(x) == OOD_MAX_LENGTH):
            continue
        dataset.append((x, 1))
        break

# remove duplicates
valid_dataset = list(set(dataset))
dataset = []

for _ in range(num_samples):
    while True:
        x = generate_invalid_string()
        if not (len(x) == MAX_LENGTH):
            continue
        x += PADDING_TOKEN * (OOD_MAX_LENGTH - MAX_LENGTH)
        if not (len(x) == OOD_MAX_LENGTH):
            continue
        dataset.append((x, 0))
        break

# Remove all duplicates
invalid_dataset = list(set(dataset))[:len(valid_dataset)]
print(len(valid_dataset))
print(len(invalid_dataset))

split = len(valid_dataset) * 4 // 5;
train_dataset = valid_dataset[:split] + invalid_dataset[:split]
test_dataset = valid_dataset[split:] + invalid_dataset[split:]

num_ood_samples = 10000
dataset = []
for _ in range(num_ood_samples):
    while True:
        x = generate_valid_string(min_length=MAX_LENGTH + 2, max_length=OOD_MAX_LENGTH)
        if not (len(x) == OOD_MAX_LENGTH):
            continue
        dataset.append((x, 1))
        break

# remove duplicates
ood_valid_dataset = list(set(dataset))
dataset = []

for _ in range(num_ood_samples):
    while True:
        x = generate_invalid_string(min_length=MAX_LENGTH + 2, max_length=OOD_MAX_LENGTH)
        if not (len(x) == OOD_MAX_LENGTH):
            continue
        dataset.append((x, 0))
        break

# Remove all duplicates
ood_invalid_dataset = list(set(dataset))[:len(ood_valid_dataset)]
print(len(ood_valid_dataset), len(ood_invalid_dataset))

ood_dataset = ood_valid_dataset + ood_invalid_dataset
print(len(train_dataset))
print(len(test_dataset))
print(len(ood_dataset))


# Write to file
with open("train_dataset.txt", "w") as f:
    for data, label in train_dataset:
        f.write(f"{data} {label}\n")

with open("test_dataset.txt", "w") as f:
    for data, label in test_dataset:
        f.write(f"{data} {label}\n")

with open("ood_dataset.txt", "w") as f:
    for data, label in ood_dataset:
        f.write(f"{data} {label}\n")

59440
59440
9438 9438
95104
23776
18876


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Define constants for model
VOCAB_SIZE = len(VALID_CHARACTERS)
EMBEDDING_DIM = 6
NUM_HEADS = 2
NUM_LAYERS = 1
HIDDEN_DIM = 2
BATCH_SIZE = 512
EPOCHS = 5

# Mapping characters to indices
char_to_index = {ch: idx for idx, ch in enumerate(VALID_CHARACTERS)}

# Custom dataset class
class StringDataset(Dataset):
    def __init__(self, file_path):
        self.data = []
        self.labels = []
        with open(file_path, "r") as f:
            for line in f:
                parts = line.strip().split(" ")
                self.data.append(parts[0])
                self.labels.append(int(parts[1]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        string = self.data[idx]
        label = self.labels[idx]
        encoded = self.encode_string(string)
        return torch.tensor(encoded, dtype=torch.long), torch.tensor(
            label, dtype=torch.float32
        )

    def encode_string(self, string):
        return [char_to_index[char] for char in string]


# Transformer model
class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_heads, hidden_dim, num_layers):
        super(TransformerClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.pos_encoder = nn.Parameter(torch.zeros(1, OOD_MAX_LENGTH, embedding_dim))
        encoder_layers = nn.TransformerEncoderLayer(
            embedding_dim, num_heads, hidden_dim
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        self.fc = nn.Linear(OOD_MAX_LENGTH * embedding_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.embedding(x) + self.pos_encoder
        x = self.transformer_encoder(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return self.sigmoid(x)

# Prepare dataset and dataloader
dataset = StringDataset("train_dataset.txt")
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Initialize model, loss function, and optimizer
model = TransformerClassifier(
    VOCAB_SIZE, EMBEDDING_DIM, NUM_HEADS, HIDDEN_DIM, NUM_LAYERS
)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cpu" and torch.backends.mps.is_available():
    device = torch.device("mps")
model.to(device)

# Training loop
for epoch in range(EPOCHS):
    model.train()
    for inputs, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(inputs.to(device))
        loss = criterion(outputs.squeeze(), labels.to(device))
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item()}")
    # Evaluate the model
    model.eval()
    correct = 0
    total = 0
    for inputs, labels in dataloader:
        outputs = model(inputs.to(device))
        predicted = torch.round(outputs)
        total += labels.size(0)
        correct += (predicted.squeeze().to(device) == labels.to(device)).sum().item()
    print(f"Accuracy: {correct/total}")
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# Training loop
for epoch in range(EPOCHS):
    model.train()
    for inputs, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(inputs.to(device))
        loss = criterion(outputs.squeeze(), labels.to(device))
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item()}")
    # Evaluate the model
    model.eval()
    correct = 0
    total = 0
    for inputs, labels in dataloader:
        outputs = model(inputs.to(device))
        predicted = torch.round(outputs)
        total += labels.size(0)
        correct += (predicted.squeeze().to(device) == labels.to(device)).sum().item()
    print(f"Accuracy: {correct/total}")

# Save the trained model
torch.save(model.state_dict(), "shortrandloc200_400total_transformer_model.pth")

test_dataset = StringDataset("test_dataset.txt")
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
# Print accuracy of the model
model.eval()
correct = 0
total = 0
for inputs, labels in test_dataloader:
    outputs = model(inputs.to(device))
    predicted = torch.round(outputs)
    total += labels.size(0)
    correct += (predicted.squeeze().to(device) == labels.to(device)).sum().item()
print(f"Accuracy: {correct/total}")



Epoch 1/5, Loss: 0.07164396345615387
Accuracy: 0.9886124663526245
Epoch 2/5, Loss: 0.03961213678121567
Accuracy: 0.9935228802153432
Epoch 3/5, Loss: 0.01588316820561886
Accuracy: 0.9962882738896366
Epoch 4/5, Loss: 0.0081285759806633
Accuracy: 0.9956468707940781
Epoch 5/5, Loss: 0.014825839549303055
Accuracy: 0.996687836473755
Epoch 1/5, Loss: 0.0060550919733941555
Accuracy: 0.9971925471063257
Epoch 2/5, Loss: 0.012896250933408737
Accuracy: 0.9967614401076716
Epoch 3/5, Loss: 0.014983288012444973
Accuracy: 0.9972135767160162
Epoch 4/5, Loss: 0.027989424765110016
Accuracy: 0.9968140141318977
Epoch 5/5, Loss: 0.007835699245333672
Accuracy: 0.9973607839838493
Accuracy: 0.9968455585464334


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [4]:
ood_dataset = StringDataset("ood_dataset.txt")
ood_dataloader = DataLoader(ood_dataset, batch_size=BATCH_SIZE, shuffle=True)
# Print accuracy of the model
model.eval()
correct = 0
total = 0
for inputs, labels in ood_dataloader:
    outputs = model(inputs.to(device))
    predicted = torch.round(outputs)
    total += labels.size(0)
    correct += (predicted.squeeze().to(device) == labels.to(device)).sum().item()
print(f"Accuracy: {correct/total}")

Accuracy: 0.9812460267005721


In [5]:
# Print 20 examples which are wrongly classified
print("20 examples which are wrongly classified")
count = 0
for inputs, labels in dataloader:
    outputs = model(inputs.to(device))
    predicted = torch.round(outputs)
    for i in range(len(predicted)):
        if count == 20:
            break
        if predicted[i] != labels[i]:
            # Convert back to string of a's and b's
            string = "".join([VALID_CHARACTERS[int(idx)] for idx in inputs[i]])
            print(string, labels[i].item(), predicted[i].item())
            count += 1

20 examples which are wrongly classified
sbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbeppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp 1.0 0.0
saaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaeppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp 1.0 0.0
sppbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb