In [1]:
#mlp.py
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from Data.Library import Library
%load_ext autoreload
%autoreload 2

In [2]:
class MLP(nn.Module):
    def __init__(self, vocab_size, n_gram, hidden_size, num_layers, device):
        super(MLP, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.device = device
        self.n_gram = n_gram

        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, hidden_size).to(self.device)

        # Define the fully connected layers
        self.fc_layers = nn.ModuleList()
        for i in range(num_layers):
            input_size = hidden_size * n_gram if i == 0 else hidden_size * n_gram
            output_size = hidden_size * n_gram
            self.fc_layers.append(nn.Linear(input_size, output_size))
        self.fc_layers = self.fc_layers.to(self.device)
        
        # Output layer
        self.output_layer = nn.Linear(hidden_size * n_gram, vocab_size).to(self.device)

    def forward(self, x):
        x = x.to(self.device)
        # Shape: [batch_size, seq_length, n_gram]
        x = torch.flatten(self.embedding(x), 2)
        #print(f"Shape after embedding: {x.shape}")
        
        for layer in self.fc_layers:
            x = F.relu(layer(x))  # Apply the fully connected layers with ReLU
        x = self.output_layer(x)
        return F.log_softmax(x, dim=-1).to('cpu').permute(0, 2, 1).to('cpu')

In [12]:
# Hyperparameters
epochs = 64
lr = 0.001
seq_length = 256
batch_size = 64
n_gram = 1
hidden_size = 64
num_layers = 1
train_size = 2**20
test_size=2**16

# Setup
device = torch.device('mps') 
print(f"Using device: {device}")
library = Library(encoding=76, train_size=train_size, test_size=test_size)
#print(f"Dataset size: {len(library.dataset)}")
#dataloader = library.get_train_dataloader(seq_length + 1)
#print(f"Number of batches in train dataloader: {len(dataloader)}")

model = MLP(
    vocab_size=library.encoding.max_token_value,
    n_gram=n_gram,
    hidden_size=hidden_size,
    num_layers=num_layers,
    device=device
)

loss_fn = nn.NLLLoss()
optim = torch.optim.Adam(model.parameters(), lr=lr)

x_batch = torch.zeros([batch_size, seq_length - n_gram, n_gram])
y_batch = torch.zeros([batch_size, seq_length - n_gram])
losses = torch.zeros(epochs)
perplexities = torch.zeros(epochs)

tic = time.time()
print('Training')
# Training Loop
for epoch in range(epochs):
    model.train()
    total_loss = 0
    dataloader = library.get_train_dataloader(seq_length+1)
    #print(f"Epoch {epoch + 1}: Checking dataloader...")
    #for batch in dataloader:
    #    print(batch)  # Ensure data is being yielded
    #    break
    for idx, data in enumerate(dataloader):
        mod_idx = idx % batch_size
        if data.shape[0] != seq_length+1:
            break # End of usable dataloader
        # Generate n-grams
        ngrams = library.ngramify(data[:-1].unsqueeze(0), n=n_gram)  # Shape: [num_ngrams, n_gram]

        # Pad ngrams to match [511, 2] if needed
        #if ngrams.shape[0] < seq_length - n_gram + 1:  # Target size: [511, 2]
        #    padding_size = seq_length - n_gram + 1 - ngrams.shape[0]
        #    ngrams = F.pad(ngrams, (0, 0, 0, padding_size))  # Pad to target size

        # Assign to batch
        x_batch[mod_idx] = ngrams
        target = data[n_gram+1:]
        #if target.shape[0] < seq_length - n_gram + 1:
        #    padding_size = seq_length - n_gram + 1 - target.shape[0]
        #    target = F.pad(target, (0, padding_size))
        y_batch[mod_idx] = target
        
        # Process the batch when it's full
        if mod_idx == batch_size - 1:
            # Update weights
            optim.zero_grad()
            y_pred = model(x_batch.long())
            loss = loss_fn(y_pred, y_batch.long())
            total_loss += loss.item()
            loss.backward()
            print(f"Samples Trained = {idx}: Loss = {loss.item():.4f}", end = '\r')
            optim.step()

    num_batches = idx + 1 if idx else 1  # Count batches processed
    avg_loss = total_loss / num_batches
    losses[epoch] = avg_loss
    perplexities[epoch] = library.calc_perplexity(model, n_gram=n_gram)
    print(f'Epoch {epoch + 1}/{epochs} - Loss: {avg_loss:.4f}, Perplexity: {perplexities[epoch]:.4f}')
print(time.time()-tic)

Using device: mps
Training
Epoch 1/64 - Loss: 0.0534, Perplexity: 17.6052
Epoch 2/64 - Loss: 0.0430, Perplexity: 15.8612
Epoch 3/64 - Loss: 0.0423, Perplexity: 15.5654
Epoch 4/64 - Loss: 0.0421, Perplexity: 15.4673
Epoch 5/64 - Loss: 0.0420, Perplexity: 15.4240
Epoch 6/64 - Loss: 0.0420, Perplexity: 15.4010
Epoch 7/64 - Loss: 0.0420, Perplexity: 15.3849
Epoch 8/64 - Loss: 0.0419, Perplexity: 15.3742
Epoch 9/64 - Loss: 0.0419, Perplexity: 15.3643
Epoch 10/64 - Loss: 0.0419, Perplexity: 15.3563
Epoch 11/64 - Loss: 0.0419, Perplexity: 15.3508
Epoch 12/64 - Loss: 0.0419, Perplexity: 15.3466
Epoch 13/64 - Loss: 0.0419, Perplexity: 15.3413
Epoch 14/64 - Loss: 0.0419, Perplexity: 15.3370
Epoch 15/64 - Loss: 0.0419, Perplexity: 15.3334
Epoch 16/64 - Loss: 0.0419, Perplexity: 15.3304
Epoch 17/64 - Loss: 0.0419, Perplexity: 15.3276
Epoch 18/64 - Loss: 0.0419, Perplexity: 15.3243
Epoch 19/64 - Loss: 0.0419, Perplexity: 15.3225
Epoch 20/64 - Loss: 0.0419, Perplexity: 15.3208
Epoch 21/64 - Loss: 0.

In [4]:
torch.save(perplexities, 'MLP14Kperplexities.pt')
torch.save(model.state_dict(), f'Models/MLP14K.pkl')