In [37]:
import random
import string
from tqdm import tqdm

# Define constants
MAX_LENGTH = 250
OOD_MAX_LENGTH = 500
VALID_CHARACTERS = ["s", "a", "b", "e", "p"]
MAIN_CHARACTERS = ["a", "b"]
START_TOKEN = "s"
END_TOKEN = "e"
PADDING_TOKEN = "p"
VALID_RATIO = 0.5  # Half of the dataset should be valid a*b* strings
num_p = 0

# Function to generate valid a*b* strings
def generate_valid_string(min_length = 0, max_length = MAX_LENGTH):
    num_a = random.randint(0, max_length - 2)
    num_b = random.randint(min(0, min_length - num_a), max_length - 2 - num_a)
    valid_str = "a" * num_a + "b" * num_b
    # num_p = random.randint(0, max_length - len(valid_str) - 2)
    return (
        START_TOKEN
        + PADDING_TOKEN * num_p
        + valid_str
        + END_TOKEN
        + PADDING_TOKEN * ((max_length - len(valid_str) - 2) - num_p)
    )

# Function to generate invalid strings
def generate_invalid_string(min_length = 1, max_length = MAX_LENGTH):
    length = random.randint(min_length, max_length - 2)
    if length == 1:
        # num_p = random.randint(0, max_length - 2)
        return START_TOKEN + PADDING_TOKEN * num_p + "ba" + END_TOKEN + PADDING_TOKEN * (max_length - 2 - num_p)
    while True:
        # Random string of a's and b's which isn't a valid a*b* string
        invalid_str = "".join(random.choices(MAIN_CHARACTERS, k=length))
        if "ba" in invalid_str:
            break
    # num_p = random.randint(0, max_length - len(invalid_str) - 2)
    return (
        START_TOKEN
        + PADDING_TOKEN * num_p
        + invalid_str
        + END_TOKEN
        + PADDING_TOKEN * (max_length - len(invalid_str) - 2)
    )


# Generate dataset
dataset = []
num_samples = 200000  # Total number of samples

for _ in range(num_samples):
    while True:
        x = generate_valid_string()
        if not (len(x) == MAX_LENGTH):
            continue
        x += PADDING_TOKEN * (OOD_MAX_LENGTH - MAX_LENGTH)
        if not (len(x) == OOD_MAX_LENGTH):
            continue
        dataset.append((x, 1))
        break

# remove duplicates
valid_dataset = list(set(dataset))
dataset = []

for _ in range(num_samples):
    while True:
        x = generate_invalid_string()
        if not (len(x) == MAX_LENGTH):
            continue
        x += PADDING_TOKEN * (OOD_MAX_LENGTH - MAX_LENGTH)
        if not (len(x) == OOD_MAX_LENGTH):
            continue
        dataset.append((x, 0))
        break

# Remove all duplicates
invalid_dataset = list(set(dataset))[:len(valid_dataset)]
print(len(valid_dataset))
print(len(invalid_dataset))

split = len(valid_dataset) * 4 // 5;
train_dataset = valid_dataset[:split] + invalid_dataset[:split]
test_dataset = valid_dataset[split:] + invalid_dataset[split:]

num_ood_samples = 10000
dataset = []
for _ in range(num_ood_samples):
    while True:
        x = generate_valid_string(min_length=MAX_LENGTH + 2, max_length=OOD_MAX_LENGTH)
        if not (len(x) == OOD_MAX_LENGTH):
            continue
        dataset.append((x, 1))
        break

# remove duplicates
ood_valid_dataset = list(set(dataset))
dataset = []

for _ in range(num_ood_samples):
    while True:
        x = generate_invalid_string(min_length=MAX_LENGTH + 2, max_length=OOD_MAX_LENGTH)
        if not (len(x) == OOD_MAX_LENGTH):
            continue
        dataset.append((x, 0))
        break

# Remove all duplicates
ood_invalid_dataset = list(set(dataset))[:len(ood_valid_dataset)]
print(len(ood_valid_dataset), len(ood_invalid_dataset))

ood_dataset = ood_valid_dataset + ood_invalid_dataset
print(len(train_dataset))
print(len(test_dataset))
print(len(ood_dataset))


# Write to file
with open("train_dataset_nopadded.txt", "w") as f:
    for data, label in train_dataset:
        f.write(f"{data} {label}\n")

with open("test_dataset_nopadded.txt", "w") as f:
    for data, label in test_dataset:
        f.write(f"{data} {label}\n")

with open("ood_dataset_nopadded.txt", "w") as f:
    for data, label in ood_dataset:
        f.write(f"{data} {label}\n")

29855
29855
7471 7471
47768
11942
14942


In [38]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Define constants for model
VALID_CHARACTERS = ["s", "a", "b", "e", "p"]
MAIN_CHARACTERS = ["a", "b"]
START_TOKEN = "s"
END_TOKEN = "e"
PADDING_TOKEN = "p"
VALID_RATIO = 0.5  # Half of the dataset should be valid a*b* strings
VOCAB_SIZE = len(VALID_CHARACTERS)
EMBEDDING_DIM = 6
NUM_HEADS = 2
NUM_LAYERS = 1
HIDDEN_DIM = 1
BATCH_SIZE = 512
EPOCHS = 5

# Mapping characters to indices
char_to_index = {ch: idx for idx, ch in enumerate(VALID_CHARACTERS)}

# Custom dataset class
class StringDataset(Dataset):
    def __init__(self, file_path):
        self.data = []
        self.labels = []
        with open(file_path, "r") as f:
            for line in f:
                parts = line.strip().split(" ")
                self.data.append(parts[0])
                self.labels.append(int(parts[1]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        string = self.data[idx]
        label = self.labels[idx]
        encoded = self.encode_string(string)
        return torch.tensor(encoded, dtype=torch.long), torch.tensor(
            label, dtype=torch.float32
        )

    def encode_string(self, string):
        return [char_to_index[char] for char in string]


# Transformer model
class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_heads, hidden_dim, num_layers):
        super(TransformerClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.pos_encoder = nn.Parameter(torch.zeros(1, OOD_MAX_LENGTH, embedding_dim))
        encoder_layers = nn.TransformerEncoderLayer(
            embedding_dim, num_heads, hidden_dim
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        self.fc = nn.Linear(OOD_MAX_LENGTH * embedding_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.embedding(x) + self.pos_encoder
        x = self.transformer_encoder(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return self.sigmoid(x)

# Prepare dataset and dataloader
dataset = StringDataset("train_dataset_nopadded.txt")
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Initialize model, loss function, and optimizer
model = TransformerClassifier(
    VOCAB_SIZE, EMBEDDING_DIM, NUM_HEADS, HIDDEN_DIM, NUM_LAYERS
)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cpu" and torch.backends.mps.is_available():
    device = torch.device("mps")
model.to(device)

# Training loop
for epoch in range(EPOCHS):
    model.train()
    for inputs, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(inputs.to(device))
        loss = criterion(outputs.squeeze(), labels.to(device))
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item()}")
    # Evaluate the model
    model.eval()
    correct = 0
    total = 0
    for inputs, labels in dataloader:
        outputs = model(inputs.to(device))
        predicted = torch.round(outputs)
        total += labels.size(0)
        correct += (predicted.squeeze().to(device) == labels.to(device)).sum().item()
    print(f"Accuracy: {correct/total}")
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# Training loop
for epoch in range(EPOCHS):
    model.train()
    for inputs, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(inputs.to(device))
        loss = criterion(outputs.squeeze(), labels.to(device))
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item()}")
    # Evaluate the model
    model.eval()
    correct = 0
    total = 0
    for inputs, labels in dataloader:
        outputs = model(inputs.to(device))
        predicted = torch.round(outputs)
        total += labels.size(0)
        correct += (predicted.squeeze().to(device) == labels.to(device)).sum().item()
    print(f"Accuracy: {correct/total}")

# Save the trained model
torch.save(model.state_dict(), "1head_1layer_embed6batch512hidden1_100max200ood0pad_total_transformer_model.pth")

test_dataset = StringDataset("test_dataset_nopadded.txt")
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
# Print accuracy of the model
model.eval()
correct = 0
total = 0
for inputs, labels in test_dataloader:
    outputs = model(inputs.to(device))
    predicted = torch.round(outputs)
    total += labels.size(0)
    correct += (predicted.squeeze().to(device) == labels.to(device)).sum().item()
print(f"Accuracy: {correct/total}")

Epoch 1/5, Loss: 0.38084810972213745
Accuracy: 0.9324652487020599
Epoch 2/5, Loss: 0.2495858073234558
Accuracy: 0.9465960475632222
Epoch 3/5, Loss: 0.18593522906303406
Accuracy: 0.955870038519511
Epoch 4/5, Loss: 0.2120962142944336
Accuracy: 0.9533369619829174
Epoch 5/5, Loss: 0.27028077840805054
Accuracy: 0.9480405292245855
Epoch 1/5, Loss: 0.15211986005306244
Accuracy: 0.9499665047730699
Epoch 2/5, Loss: 0.13904444873332977
Accuracy: 0.9487104337631888
Epoch 3/5, Loss: 0.08156761527061462
Accuracy: 0.9491291240998158
Epoch 4/5, Loss: 0.18180182576179504
Accuracy: 0.9518087422542287
Epoch 5/5, Loss: 0.22958149015903473
Accuracy: 0.950217718975046
Accuracy: 0.9450678278345336


In [39]:
ood_dataset = StringDataset("ood_dataset_nopadded.txt")
ood_dataloader = DataLoader(ood_dataset, batch_size=BATCH_SIZE, shuffle=True)
# Print accuracy of the model
model.eval()
correct = 0
total = 0
for inputs, labels in ood_dataloader:
    outputs = model(inputs.to(device))
    predicted = torch.round(outputs)
    total += labels.size(0)
    correct += (predicted.squeeze().to(device) == labels.to(device)).sum().item()
print(f"Accuracy: {correct/total}")

Accuracy: 0.6351893990095034


In [None]:
# Print 20 examples which are wrongly classified
print("20 examples which are wrongly classified")
count = 0
for inputs, labels in dataloader:
    outputs = model(inputs.to(device))
    predicted = torch.round(outputs)
    for i in range(len(predicted)):
        if count == 20:
            break
        if predicted[i] != labels[i]:
            # Convert back to string of a's and b's
            string = "".join([VALID_CHARACTERS[int(idx)] for idx in inputs[i]])
            print(string, labels[i].item(), predicted[i].item())
            count += 1

20 examples which are wrongly classified
saabbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbepppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppppp

In [64]:
import numpy as np
model_2 = TransformerClassifier(
    VOCAB_SIZE, EMBEDDING_DIM, NUM_HEADS, HIDDEN_DIM, NUM_LAYERS
)

model_2.load_state_dict(torch.load('1head_1layer_embed6batch512hidden1_100max500ood_total_transformer_model.pth'))
model_2.to(device)

# Print 20 examples which are wrongly classified
# print("20 examples which are wrongly classified")
count = 0
matrix = np.zeros((2, 2))
num_pad_zeros = np.zeros((500))
for inputs, labels in dataloader:
    outputs = model_2(inputs.to(device))
    predicted = torch.round(outputs)
    for i in range(len(predicted)):
        matrix[int(predicted[i]), int(labels[i])] += 1
        if predicted[i] != labels[i]:
            # Convert back to string of a's and b's
            num_start_pad = 0
            for j in range(len(inputs[i])):
                if VALID_CHARACTERS[int(inputs[i][j])] != 'p':
                    break
                num_start_pad += 1
            num_pad_zeros[num_start_pad] += 1
print(matrix)
print(num_pad_zeros)

[[2.3878e+04 1.2457e+04]
 [6.0000e+00 1.1427e+04]]
[12463.     0.     0.     0.     0.     0.     0.     0.     0.     0.
     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
     0.     0.     0.     

In [59]:
ood_dataset = StringDataset("ood_dataset_gappadded.txt")
ood_dataloader = DataLoader(ood_dataset, batch_size=BATCH_SIZE, shuffle=True)
# Print accuracy of the model
model_2.eval()
correct = 0
total = 0
for inputs, labels in ood_dataloader:
    outputs = model_2(inputs.to(device))
    predicted = torch.round(outputs)
    total += labels.size(0)
    correct += (predicted.squeeze().to(device) == labels.to(device)).sum().item()
print(f"Accuracy: {correct/total}")

Accuracy: 0.9748629848229342


In [50]:
print(model_2.transformer_encoder.layers[0].self_attn.in_proj_weight)

Parameter containing:
tensor([[-0.1173, -0.4868, -0.2787, -0.3836,  0.4869,  0.4462],
        [-0.3077,  0.1080, -0.0703,  0.2785,  0.0020,  0.3271],
        [-0.3017, -0.1836, -0.0864, -0.2321, -0.2330, -0.4434],
        [-0.0861,  0.6321, -0.4590, -0.3281, -0.5029, -0.4216],
        [ 0.4099, -0.1211, -0.1174, -0.1267, -0.3700, -0.1985],
        [ 0.3000,  0.4385, -0.3269,  0.3167,  0.4170, -0.3065],
        [ 0.1610, -0.0720,  0.3529,  0.4052,  0.0425,  0.2735],
        [-0.3191, -0.0640,  0.2417, -0.0950,  0.0050, -0.3506],
        [-0.3651, -0.1423, -0.1854, -0.2698,  0.1397,  0.0741],
        [ 0.4333,  0.2350, -0.0254,  0.4441,  0.2188,  0.1552],
        [ 0.1111, -0.1373, -0.0707,  0.4105,  0.1355,  0.3309],
        [-0.2359, -0.2420, -0.0904,  0.2508,  0.3944, -0.0200],
        [-0.2139,  0.3799, -0.1082, -0.3886,  0.0216, -0.0505],
        [-0.2125,  0.3423, -0.3185,  0.3012, -0.3671,  0.2281],
        [-0.4219,  0.1814,  0.2821,  0.0127,  0.1133,  0.3603],
        [-0.2622, 

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

VALID_CHARACTERS = ["s", "a", "b", "e", "p"]
MAIN_CHARACTERS = ["a", "b"]
START_TOKEN = "s"
END_TOKEN = "e"
PADDING_TOKEN = "p"
VALID_RATIO = 0.5  # Half of the dataset should be valid a*b* strings
VOCAB_SIZE = len(VALID_CHARACTERS)
EMBEDDING_DIM = 6
NUM_HEADS = 2
NUM_LAYERS = 1
HIDDEN_DIM = 1
BATCH_SIZE = 512
EPOCHS = 5
MAX_LENGTH = 200
OOD_MAX_LENGTH = 400

class TransformerEncoderLayerWithAttention(nn.TransformerEncoderLayer):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='relu'):
        super(TransformerEncoderLayerWithAttention, self).__init__(d_model, nhead, dim_feedforward, dropout, activation)

    def forward(
            self,
            src,
            src_mask = None,
            src_key_padding_mask = None,
            is_causal: bool = False):
        src_key_padding_mask = F._canonical_mask(
            mask=src_key_padding_mask,
            mask_name="src_key_padding_mask",
            other_type=F._none_or_dtype(src_mask),
            other_name="src_mask",
            target_type=src.dtype
        )

        src_mask = F._canonical_mask(
            mask=src_mask,
            mask_name="src_mask",
            other_type=None,
            other_name="",
            target_type=src.dtype,
            check_other=False,
        )

        is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()

        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
        why_not_sparsity_fast_path = ''
        if not is_fastpath_enabled:
            why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
        elif not src.dim() == 3:
            why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
        elif self.training:
            why_not_sparsity_fast_path = "training is enabled"
        elif not self.self_attn.batch_first:
            why_not_sparsity_fast_path = "self_attn.batch_first was not True"
        elif self.self_attn.in_proj_bias is None:
            why_not_sparsity_fast_path = "self_attn was passed bias=False"
        elif not self.self_attn._qkv_same_embed_dim:
            why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
        elif not self.activation_relu_or_gelu:
            why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
        elif not (self.norm1.eps == self.norm2.eps):
            why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
        elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None):
            why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
        elif self.self_attn.num_heads % 2 == 1:
            why_not_sparsity_fast_path = "num_head is odd"
        elif torch.is_autocast_enabled():
            why_not_sparsity_fast_path = "autocast is enabled"
        if not why_not_sparsity_fast_path:
            tensor_args = (
                src,
                self.self_attn.in_proj_weight,
                self.self_attn.in_proj_bias,
                self.self_attn.out_proj.weight,
                self.self_attn.out_proj.bias,
                self.norm1.weight,
                self.norm1.bias,
                self.norm2.weight,
                self.norm2.bias,
                self.linear1.weight,
                self.linear1.bias,
                self.linear2.weight,
                self.linear2.bias,
            )

            # We have to use list comprehensions below because TorchScript does not support
            # generator expressions.
            _supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
            if torch.overrides.has_torch_function(tensor_args):
                why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
            elif not all((x.device.type in _supported_device_type) for x in tensor_args):
                why_not_sparsity_fast_path = ("some Tensor argument's device is neither one of "
                                              f"{_supported_device_type}")
            elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
                why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
                                              "input/output projection weights or biases requires_grad")

            if not why_not_sparsity_fast_path:
                merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src)
                return torch._transformer_encoder_layer_fwd(
                    src,
                    self.self_attn.embed_dim,
                    self.self_attn.num_heads,
                    self.self_attn.in_proj_weight,
                    self.self_attn.in_proj_bias,
                    self.self_attn.out_proj.weight,
                    self.self_attn.out_proj.bias,
                    self.activation_relu_or_gelu == 2,
                    self.norm_first,
                    self.norm1.eps,
                    self.norm1.weight,
                    self.norm1.bias,
                    self.norm2.weight,
                    self.norm2.bias,
                    self.linear1.weight,
                    self.linear1.bias,
                    self.linear2.weight,
                    self.linear2.bias,
                    merged_mask,
                    mask_type,
                )


        x = src
        if self.norm_first:
            x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal)
            x = x + self._ff_block(self.norm2(x))
        else:
            x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
            x = self.norm2(x + self._ff_block(x))

        return x

class TransformerEncoderWithAttention(nn.TransformerEncoder):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoderWithAttention, self).__init__(encoder_layer, num_layers, norm)

    def forward(self, src, mask=None, src_key_padding_mask=None):
        output = src

        attentions = []  # List to store attention scores from each layer

        for mod in self.layers:
            output, attention_scores = mod(output)
            attentions.append(attention_scores)

        if self.norm is not None:
            output = self.norm(output)

        return output, attentions

# Transformer model
class TransformerDebugClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_heads, hidden_dim, num_layers):
        super(TransformerDebugClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.pos_encoder = nn.Parameter(torch.zeros(1, OOD_MAX_LENGTH, embedding_dim))
        encoder_layer = TransformerEncoderLayerWithAttention(
            embedding_dim, num_heads, hidden_dim
        )
        self.transformer_encoder = TransformerEncoderWithAttention(encoder_layer, num_layers)
        self.fc = nn.Linear(OOD_MAX_LENGTH * embedding_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        print(x.shape)
        x = self.embedding(x) + self.pos_encoder
        print(x.shape)
        x, attn = self.transformer_encoder(x)
        print(x.shape)
        print(attn)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return self.sigmoid(x)

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Mapping characters to indices
char_to_index = {ch: idx for idx, ch in enumerate(VALID_CHARACTERS)}

# Custom dataset class
class StringDataset(Dataset):
    def __init__(self, file_path):
        self.data = []
        self.labels = []
        with open(file_path, "r") as f:
            for line in f:
                parts = line.strip().split(" ")
                self.data.append(parts[0])
                self.labels.append(int(parts[1]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        string = self.data[idx]
        label = self.labels[idx]
        encoded = self.encode_string(string)
        return torch.tensor(encoded, dtype=torch.long), torch.tensor(
            label, dtype=torch.float32
        )

    def encode_string(self, string):
        return [char_to_index[char] for char in string]

model_3 = TransformerDebugClassifier(
    VOCAB_SIZE, EMBEDDING_DIM, NUM_HEADS, HIDDEN_DIM, NUM_LAYERS
)

model_3.load_state_dict(torch.load('1head_1layer_embed6batch512hidden1_200max400ood_total_transformer_model.pth', map_location=torch.device('cpu')))

ood_dataset = StringDataset("ood_dataset_padded.txt")
ood_dataloader = DataLoader(ood_dataset, batch_size=BATCH_SIZE, shuffle=True)
for inputs, labels in ood_dataloader:
    outputs = model_3(inputs)
    break


torch.Size([512, 400])
torch.Size([512, 400, 6])


AttributeError: 'MultiheadAttention' object has no attribute 'data'