In [76]:
import torch
import torch.nn as nn
import torch.optim as optim
import string

# 1. Prepare the dataset
alphabet = string.ascii_lowercase
char_to_idx = {char: idx for idx, char in enumerate(alphabet)}
idx_to_char = {idx: char for idx, char in enumerate(alphabet)}

# Create input and output pairs
input_chars = alphabet[:-1]  # all letters except the last one
output_chars = alphabet[1:]  # all letters except the first one

# Convert to tensors
input_indices = torch.tensor([char_to_idx[char] for char in input_chars])
output_indices = torch.tensor([char_to_idx[char] for char in output_chars])

# 2. Define the neural network
class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNN, self).__init__()
        self.hidden_size = hidden_size
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):        
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

class SimpleNNWithFunction(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNNWithFunction, self).__init__()
        self.hidden_size = hidden_size
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = x * 5 #function 5x #might be better because it gives a hint that you gotta look afterward?
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Hyperparameters
input_size = len(alphabet)
hidden_size = 26
output_size = len(alphabet)
learning_rate = 0.01
num_epochs = 50

# 3. Initialize the neural network, loss function, and optimizer
model = SimpleNN(input_size, hidden_size, output_size)
modelF = SimpleNNWithFunction(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
optimizerF = optim.Adam(modelF.parameters(), lr=learning_rate)

# 4. Train the neural network
for epoch in range(num_epochs):
    for i, (input_idx, output_idx) in enumerate(zip(input_indices, output_indices)):
        input_one_hot = torch.zeros(input_size)
        input_one_hot[input_idx] = 1.0

        # Forward pass
        outputs = model(input_one_hot.unsqueeze(0))
        loss = criterion(outputs, output_idx.unsqueeze(0))

        

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        #modelF
        outputsF = modelF(input_one_hot.unsqueeze(0))
        lossF = criterion(outputsF, output_idx.unsqueeze(0))
        optimizerF.zero_grad()
        lossF.backward()
        optimizerF.step()

    #if (epoch+1) % 100 == 0:
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}\t\t  Epoch [{epoch+1}/{num_epochs}], Loss: {lossF.item():.4f}')

# 5. Test the neural network
def predict_next_char(model, char):
    model.eval()
    input_idx = char_to_idx[char]
    input_one_hot = torch.zeros(input_size)
    input_one_hot[input_idx] = 1.0
    with torch.no_grad():
        output = model(input_one_hot.unsqueeze(0))
        _, predicted_idx = torch.max(output, 1)
    return idx_to_char[predicted_idx.item()]

# Test the model
test_char = 'g'
predicted_char = predict_next_char(model, test_char)
print(f'The next character after {test_char} is {predicted_char}')


Epoch [1/50], Loss: 3.5775		  Epoch [1/50], Loss: 4.1247
Epoch [2/50], Loss: 3.3575		  Epoch [2/50], Loss: 3.2118
Epoch [3/50], Loss: 2.9391		  Epoch [3/50], Loss: 2.3076
Epoch [4/50], Loss: 2.3615		  Epoch [4/50], Loss: 1.0725
Epoch [5/50], Loss: 1.6834		  Epoch [5/50], Loss: 0.2672
Epoch [6/50], Loss: 0.9105		  Epoch [6/50], Loss: 0.0872
Epoch [7/50], Loss: 0.4247		  Epoch [7/50], Loss: 0.0474
Epoch [8/50], Loss: 0.2073		  Epoch [8/50], Loss: 0.0319
Epoch [9/50], Loss: 0.1303		  Epoch [9/50], Loss: 0.0236
Epoch [10/50], Loss: 0.0853		  Epoch [10/50], Loss: 0.0184
Epoch [11/50], Loss: 0.0644		  Epoch [11/50], Loss: 0.0149
Epoch [12/50], Loss: 0.0508		  Epoch [12/50], Loss: 0.0123
Epoch [13/50], Loss: 0.0402		  Epoch [13/50], Loss: 0.0104
Epoch [14/50], Loss: 0.0331		  Epoch [14/50], Loss: 0.0089
Epoch [15/50], Loss: 0.0277		  Epoch [15/50], Loss: 0.0077
Epoch [16/50], Loss: 0.0240		  Epoch [16/50], Loss: 0.0068
Epoch [17/50], Loss: 0.0204		  Epoch [17/50], Loss: 0.0060
Epoch [18/50], 