In [11]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import string

# Load Dataset
with open('cipher_dataset.json', 'r') as f:
    dataset = json.load(f)

# Character-level Tokenizer
def char_tokenizer(text):
    vocab = {ch: idx for idx, ch in enumerate(string.ascii_lowercase + ' ')}
    return [vocab[ch] for ch in text]

class CipherDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.vocab_size = 27  # 26 letters + space
        self.max_length = max(len(item['ciphertext']) for item in data)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        ciphertext = char_tokenizer(item['ciphertext'])
        plaintext = char_tokenizer(item['plaintext'])
        cipher_type = item['cipher_type']
        
        # Pad sequences
        ciphertext += [0] * (self.max_length - len(ciphertext))
        plaintext += [0] * (self.max_length - len(plaintext))
        
        # Convert to tensors
        ciphertext = torch.tensor(ciphertext, dtype=torch.long)
        plaintext = torch.tensor(plaintext, dtype=torch.long)
        
        return ciphertext, plaintext, cipher_type

# Split Dataset
train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)
train_data, val_data = train_test_split(train_data, test_size=0.2, random_state=42)

# Create Dataloaders
train_dataset = CipherDataset(train_data)
val_dataset = CipherDataset(val_data)
test_dataset = CipherDataset(test_data)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [12]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'


In [13]:
import torch.nn as nn
import torch.optim as optim

class GatingNetwork(nn.Module):
    def __init__(self, embedding_dim, num_experts):
        super(GatingNetwork, self).__init__()
        self.fc = nn.Linear(embedding_dim, num_experts)
    
    def forward(self, x):
        return torch.softmax(self.fc(x.float()), dim=-1)

class Expert(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(Expert, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, x):
        assert torch.max(x) < self.embedding.num_embeddings and torch.min(x) >= 0, "Index out of bounds in embedding"
        x = self.embedding(x.long())  # Ensure input to embedding is Long
        x, _ = self.lstm(x)
        x = self.fc(x)
        return x

class MoE(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_experts):
        super(MoE, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.gating_network = GatingNetwork(embedding_dim, num_experts)
        self.experts = nn.ModuleList([Expert(vocab_size, embedding_dim, hidden_dim) for _ in range(num_experts)])
    
    def forward(self, x):
        assert torch.max(x) < self.embedding.num_embeddings and torch.min(x) >= 0, "Index out of bounds in embedding"
        x = self.embedding(x.long())  # Ensure input to embedding is Long
        gate_outputs = self.gating_network(x)
        expert_outputs = [expert(x) for expert in self.experts]
        output = sum(gate_outputs[:, :, i].unsqueeze(2) * expert_outputs[i] for i in range(len(self.experts)))
        return output


# Hyperparameters
vocab_size = 27  # 26 letters + space
embedding_dim = 64
hidden_dim = 128
num_experts = 5

# Instantiate Model
model = MoE(vocab_size, embedding_dim, hidden_dim, num_experts).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [14]:
def train(model, train_loader, val_loader, criterion, optimizer, num_epochs=20):
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for ciphertext, plaintext, cipher_type in train_loader:
            ciphertext, plaintext = ciphertext.to(device), plaintext.to(device)
            optimizer.zero_grad()
            outputs = model(ciphertext)
            loss = criterion(outputs.view(-1, vocab_size), plaintext.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
        
        validate(model, val_loader, criterion)

def validate(model, val_loader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for ciphertext, plaintext, cipher_type in val_loader:
            ciphertext, plaintext = ciphertext.to(device), plaintext.to(device)
            outputs = model(ciphertext)
            loss = criterion(outputs.view(-1, vocab_size), plaintext.view(-1))
            total_loss += loss.item()
    
    avg_loss = total_loss / len(val_loader)
    print(f"Validation Loss: {avg_loss:.4f}")

# Train the model
train(model, train_loader, val_loader, criterion, optimizer, num_epochs=20)


AssertionError: Index out of bounds in embedding

In [None]:
def evaluate(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for ciphertext, plaintext, cipher_type in test_loader:
            ciphertext, plaintext = ciphertext.to(device), plaintext.to(device)
            outputs = model(ciphertext)
            _, predicted = torch.max(outputs, -1)
            total += plaintext.size(0) * plaintext.size(1)
            correct += (predicted == plaintext).sum().item()
    
    accuracy = correct / total
    print(f'Test Accuracy: {accuracy:.4f}')

# Evaluate the model
evaluate(model, test_loader)


AssertionError: Index out of bounds in embedding