In [1]:
from preprocessor import Preprocessor as sp

In [2]:
import re
def clean_sentence(sentence):
  pattern = r'[^A-Za-z#.\'!,\-:;\"? ]'
  return re.sub(pattern, '', sentence)

import numpy as np

def one_hot_encode(text):
    # Define the vocabulary
    vocab = list('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ#.\'!,\-:;\"?')
    vocab_size = len(vocab)

    # Create a mapping from character to index
    char_to_index = {char: idx for idx, char in enumerate(vocab)}

    # Initialize the one-hot encoded array
    one_hot_encoded = np.zeros((len(text), vocab_size), dtype=int)

    # Convert each character to one-hot encoded vector
    for i, char in enumerate(text):
        if char in char_to_index:  # Ensure character is in the vocabulary
            one_hot_encoded[i, char_to_index[char]] = 1
        else:
            raise ValueError(f"Character '{char}' not in vocabulary")

    return one_hot_encoded

  vocab = list('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ#.\'!,\-:;\"?')


In [3]:
import pandas as pd
num_sentences = 50_000
file_path = '/Users/delmedigo/Dev/SpaceGen/SpaceGen/train.parquet'
sentence_df = pd.read_parquet(file_path)
sentence_df = sentence_df[sentence_df.sentence.apply(lambda bytes_wrong: len(bytes_wrong) <= 500 and len(bytes_wrong) >= 5)]
sentence_df = sentence_df.sample(num_sentences)
sentence_df.drop_duplicates(inplace=True)
sentence_df['sentence'] = sentence_df['sentence'].apply(lambda sentence: clean_sentence(sentence))
text_lists = sentence_df['sentence'].tolist()
sentence_df.shape

(50000, 1)

In [4]:
import numpy as np


data = pd.DataFrame(text_lists, columns=["correct_sentence"])
data['wrong_sentence'] = data['correct_sentence'].apply(lambda text: text.replace(' ',''))
data['bytes_correct'] = data['correct_sentence'].apply(lambda text: sp.to_bytes_list(text))
data['bytes_wrong'] = data['wrong_sentence'].apply(lambda text: sp.to_bytes_list(text))
data['decision'] = data[['bytes_wrong','bytes_correct']].apply(lambda row: sp.create_decision_vector(row['bytes_wrong'], row['bytes_correct']), axis=1)
dec_dict = {'K': 0, 'I': 1}
data['decision'] = data['decision'].apply(lambda dec: [dec_dict[d] for d in dec])
data = data[data.bytes_wrong.apply(lambda bytes_wrong: len(bytes_wrong) <= 500)]
lngths = [len(bytes_wrong) for bytes_wrong in data.bytes_wrong.tolist()]
max_len = max(lngths)
data['bytes_wrong_padded'] = data['bytes_wrong'].apply(lambda bytes_wrong: bytes_wrong + [0]*(max_len-len(bytes_wrong)))
data['decision_padded'] = data['decision'].apply(lambda decision: decision + [0]*(max_len-len(decision)))
data['bytes_wrong_padded'] = data['bytes_wrong_padded'].apply(lambda bytes_wrong: np.array(bytes_wrong))
data['decision_padded'] = data['decision_padded'].apply(lambda decision: np.array(decision))
data['wrong_sentence_padded'] = data['wrong_sentence'].apply(lambda wrong_sentence: wrong_sentence + '#'*(max_len-len(wrong_sentence)))
data['bytes_wrong_one_hot'] = data['wrong_sentence_padded'].apply(one_hot_encode)
data['bytes_wrong_one_hot'] = data['bytes_wrong_one_hot'].apply(lambda bytes_wrong: np.array(bytes_wrong))

In [26]:
import torch
import tensorflow as tf
from torch import device
import torch.nn as nn
X = np.stack(data.bytes_wrong_one_hot)
y = np.stack(data.decision_padded)
X = torch.tensor(X, dtype=torch.short)
y = torch.tensor(y, dtype=torch.short)

num_classes = 2

#y = tf.keras.utils.to_categorical(y, num_classes=num_classes)

print(f'X shape: {X.shape}')
print(f'y shape: {y.shape}')

X shape: torch.Size([50000, 421, 63])
y shape: torch.Size([50000, 421])


In [31]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [30]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

# Create Dataset and DataLoader
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        seq_len = x.size(1)
        encoding = self.encoding[:, :seq_len, :].to(x.device)
        return x + encoding

class TransformerEncoder(nn.Module):
    def __init__(self, input_dim, d_model, nhead, num_layers, dim_feedforward, max_len=500, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        self.embedding = nn.Linear(input_dim, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_len)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        self.fc = nn.Linear(d_model, 1)

    def forward(self, src):
        src = self.embedding(src.float())  # (batch_size, seq_len, d_model)
        src = self.positional_encoding(src)  # (batch_size, seq_len, d_model)
        src = self.transformer_encoder(src)  # (batch_size, seq_len, d_model)
        output = torch.sigmoid(self.fc(src))  # (batch_size, seq_len, 1)
        return output.squeeze(-1)  # (batch_size, seq_len)

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.04, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, outputs, targets):
        # Ensure outputs and targets are float
        outputs = outputs.float()
        targets = targets.float()
        
        # Compute Binary Cross Entropy loss
        bce_loss = nn.functional.binary_cross_entropy(outputs, targets, reduction='none')
        
        # Compute the focal loss
        pt = torch.where(targets == 1, outputs, 1 - outputs)
        focal_loss = self.alpha * ((1 - pt) ** self.gamma) * bce_loss
        
        # Reduce the loss based on the reduction method
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Model parameters
input_dim = 63
d_model = 64
nhead = 8
num_layers = 2
dim_feedforward = 256
max_len = 421
dropout = 0.1

model = TransformerEncoder(input_dim, d_model, nhead, num_layers, dim_feedforward, max_len, dropout)

# Compile and Train the Model
criterion = FocalLoss(alpha=0.04, gamma=2.0)
optimizer = optim.Adam(model.parameters(), lr=0.0003)

num_epochs = 10
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
model.to(device)
print(f"Using device: {device}")

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_predictions = 0
    for i, (inputs, targets) in enumerate(dataloader):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        outputs = outputs.view(-1)  # Flatten the output for binary classification
        targets = targets.view(-1)  # Flatten the targets
        
        loss = criterion(outputs, targets)  # Focal Loss
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Calculate accuracy
        predictions = (outputs >= 0.5).long()  # Binarize predictions
        correct_predictions += (predictions == targets).sum().item()
        total_predictions += targets.numel()  # Total number of elements

    epoch_loss = running_loss / len(dataloader)
    epoch_accuracy = correct_predictions / total_predictions  # Accuracy
    print(f"Epoch {epoch+1}/{num_epochs} Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}")

Using device: mps
Epoch 1/10 Loss: 0.0010, Accuracy: 0.9632
Epoch 2/10 Loss: 0.0009, Accuracy: 0.9655
Epoch 3/10 Loss: 0.0009, Accuracy: 0.9656
Epoch 4/10 Loss: 0.0009, Accuracy: 0.9656
Epoch 5/10 Loss: 0.0009, Accuracy: 0.9657
Epoch 6/10 Loss: 0.0009, Accuracy: 0.9657


KeyboardInterrupt: 

In [11]:
torch.save(model.state_dict(), 'model.pth')