# Import required packages

In [1]:
import time

import torch
import torch.nn as nn
import torch.optim as optim
import math
import random
from torch.utils.data import Dataset, DataLoader

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)

In [2]:
# Define Hyperparameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('running on device:', device)
batch_size = 5000
epochs = 50
learning_rate = 0.001
MODEL_PATH = 'math_transformer.pth'

running on device: cuda


In [3]:
# Define characters
chars = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
         '+', '-', '*', '/', '=', '<SOS>', '<EOS>', '<PAD>']
vocab = {c: i for i, c in enumerate(chars)}
vocab_size = len(vocab)
SOS_token = vocab['<SOS>']
EOS_token = vocab['<EOS>']
PAD_token = vocab['<PAD>']

In [4]:
# Generate Mathematical Samples
def generate_math_dataset(num_samples, max_digits=1, operations=('+', '-', '*', '/')):
    samples = []
    for _ in range(num_samples):
        op = random.choice(operations)
        if op in ['+', '-', '*']:
            a = random.randint(10 ** (max_digits - 1), 10 ** max_digits - 1)
            b = random.randint(10 ** (max_digits - 1), 10 ** max_digits - 1)
            if op == '-' and a < b:
                a, b = b, a
            res = eval(f"{a}{op}{b}")
        elif op == '/':
            b = random.randint(10 ** (max_digits - 1), 10 ** max_digits - 1)
            res = random.randint(10 ** (max_digits - 1), 10 ** max_digits - 1)
            a = res * b
        question = f"{a}{op}{b}="
        answer = f"{res}"
        samples.append((question, answer))
    return samples

In [5]:
# Math Dataset Class
class MathDataset(Dataset):
    def __init__(self, samples, vocab, max_question_len=8, max_answer_len=6):
        self.samples = samples
        self.vocab = vocab
        self.max_q = max_question_len
        self.max_a = max_answer_len + 2  # Include SOS and EOS

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        q, a = self.samples[idx]
        # Wrap questions
        q_ids = [vocab[c] for c in q][:self.max_q]
        q_pad = [PAD_token] * (self.max_q - len(q_ids))
        q_ids += q_pad

        # Wrap answers
        a_ids = [SOS_token] + [vocab[c] for c in a][:self.max_a]
        a_ids += [EOS_token]
        a_ids = a_ids[:self.max_a]
        a_pad = [PAD_token] * (self.max_a - len(a_ids))
        a_ids += a_pad

        return torch.LongTensor(q_ids), torch.LongTensor(a_ids)

In [6]:
# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(1)]
        return self.dropout(x)

In [7]:
# Math Transformer Model
class MathTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=128, nhead=4, num_layers=3, dim_feedforward=512, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=False,
        )
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None,
                src_padding_mask=None, tgt_padding_mask=None):
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        tgt = self.embedding(tgt) * math.sqrt(self.d_model)
        tgt = self.pos_encoder(tgt)

        output = self.transformer(
            src, tgt,
            src_mask=src_mask,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_padding_mask,
            tgt_key_padding_mask=tgt_padding_mask,
            memory_key_padding_mask=src_padding_mask
        )
        return self.fc_out(output)

In [8]:
# Generate mask
def generate_mask(src, tgt, PAD_token):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len)).bool()

    src_padding_mask = (src == PAD_token).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_token).transpose(0, 1)

    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [9]:
# Generate Dataset
train_data = generate_math_dataset(20000)
val_data = generate_math_dataset(1000)
train_dataset = MathDataset(train_data, vocab)
val_dataset = MathDataset(val_data, vocab)

# Create Data Loader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [10]:
# Initiate model
model = MathTransformer(vocab_size).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_token)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
t0 = time.time()

# Train the model
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for src, tgt in train_loader:
        src = src.transpose(0, 1).to(device)  # (seq_len, batch)
        tgt = tgt.transpose(0, 1).to(device)

        tgt_input = tgt[:-1, :]  # decoder input
        tgt_output = tgt[1:, :]  # decoder target

        optimizer.zero_grad()

        src_mask, tgt_mask, src_pad_mask, tgt_pad_mask = generate_mask(src, tgt_input, PAD_token)
        src_mask = src_mask.to(device)
        tgt_mask = tgt_mask.to(device)
        output = model(src, tgt_input, src_mask, tgt_mask, src_pad_mask, tgt_pad_mask)

        loss = criterion(output.reshape(-1, output.shape[-1]), tgt_output.reshape(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    if epoch % 10 == 0:
      print(f"Epoch {epoch + 1}, Loss: {total_loss / len(train_loader):.4f}")
print('time elapsed:', time.time() - t0)

Epoch 1, Loss: 2.4317
Epoch 11, Loss: 1.0089
Epoch 21, Loss: 0.2543
Epoch 31, Loss: 0.0735
Epoch 41, Loss: 0.0485
time elapsed: 91.75194668769836


In [11]:
# Save the trained model
torch.save(model.state_dict(), MODEL_PATH)

In [12]:
# Training models in GPU and Predicting in CPU
# loaded_model.load_state_dict(torch.load('math_transformer.pth', map_location=torch.device('cpu')))

In [13]:
# Load model
loaded_model = MathTransformer(vocab_size).to(device)
loaded_model.load_state_dict(torch.load(MODEL_PATH))
loaded_model.eval();

In [14]:
# Test the model
def predict(model, question, vocab, max_length=10):
    model.eval()
    src = torch.LongTensor([vocab[c] for c in question]).unsqueeze(1).to(device)
    memory = model.transformer.encoder(model.pos_encoder(model.embedding(src) * math.sqrt(model.d_model)))
    ys = torch.LongTensor([[SOS_token]]).to(device)

    for _ in range(max_length):
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(ys.size(0)).to(device)
        out = model.transformer.decoder(model.pos_encoder(model.embedding(ys) * math.sqrt(model.d_model)),
                                        memory, tgt_mask)
        out = model.fc_out(out[-1, :])
        next_token = out.argmax().item()
        ys = torch.cat([ys, torch.LongTensor([[next_token]]).to(device)], dim=0)

        if next_token == EOS_token:
            break

    return ''.join([chars[i] for i in ys.squeeze().tolist() if i not in [SOS_token, EOS_token, PAD_token]])

In [15]:
# Test the model
test_samples = generate_math_dataset(20)
for q, a in test_samples:
    pred = predict(loaded_model, q, vocab)
    print(f"Question: {q} Correct Answer: {a} Predicted Result: {pred}")


Question: 7-1= Correct Answer: 6 Predicted Result: 6
Question: 6+3= Correct Answer: 9 Predicted Result: 9
Question: 9-2= Correct Answer: 7 Predicted Result: 7
Question: 7-3= Correct Answer: 4 Predicted Result: 4
Question: 3+5= Correct Answer: 8 Predicted Result: 8
Question: 5+1= Correct Answer: 6 Predicted Result: 6
Question: 6-5= Correct Answer: 1 Predicted Result: 1
Question: 9*8= Correct Answer: 72 Predicted Result: 72
Question: 4+7= Correct Answer: 11 Predicted Result: 11
Question: 3+8= Correct Answer: 11 Predicted Result: 11
Question: 16/8= Correct Answer: 2 Predicted Result: 3
Question: 9+4= Correct Answer: 13 Predicted Result: 13
Question: 7+7= Correct Answer: 14 Predicted Result: 14
Question: 32/8= Correct Answer: 4 Predicted Result: 4
Question: 10/2= Correct Answer: 5 Predicted Result: 5
Question: 1*9= Correct Answer: 9 Predicted Result: 9
Question: 9+1= Correct Answer: 10 Predicted Result: 10
Question: 1/1= Correct Answer: 1 Predicted Result: 1
Question: 8+3= Correct Answer: 

In [16]:
loaded_model

MathTransformer(
  (embedding): Embedding(18, 128)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-2): 3 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=512, out_features=128, bias=True)
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): TransformerDe