In [1]:
from sympy import *
import random
import re
import tokenize
from io import StringIO
import torch
from torch import nn
from torch.autograd import Variable
import pandas as pd
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import math

In [2]:
device = torch.device("cuda" if torch.torch.cuda.is_available() else "cpu")

In [3]:
"""
A class for representing and generating expressions in Polish notation

Expressions are generated with the algorithm descibed in Appendix C of DEEP LEARNING FOR SYMBOLIC MATHEMATICS (Guillaume Lample, Francois Charton 2019)
which aims to weight deep, shallow, left-leaning, and right leaning expression trees all equally

We used polish notation for this project as it can more concisely represent expressions than infix notation as it never needs parenthesis

"""

class Expression:

  """
  dictionary of operations with their corresponding arity as keys
  """
  _ops = {
    'sin' : 1,
    'cos' : 1,
    'tan' : 1,
    'square' : 1,
    'cube' : 1,
    'exp' : 1,
    'log': 1,

    '+' : 2,
    '-' : 2,
    '*' : 2,
    '/' : 2,
    '**' : 2,
  }


  """
  Maps operations to anonymous function that generates an infix expression
  """
  _infix_reps = {
      'sin': lambda args: f'sin({args[0]})',
      'cos': lambda args: f'cos({args[0]})',
      'tan': lambda args: f'tan({args[0]})',
      'square': lambda args: f'({args[0]})**2',
      'cube': lambda args: f'({args[0]})**3',
      'exp': lambda args: f'exp({args[0]})',
      'log': lambda args: f'log({args[0]})',

      '+': lambda args: f'({args[0]})+({args[1]})',
      '-': lambda args: f'({args[0]})-({args[1]})',
      '*': lambda args: f'({args[0]})*({args[1]})',
      '/': lambda args: f'({args[0]})/({args[1]})',
      '**': lambda args: f'({args[0]})**({args[1]})'
  }
  """
  unnormalized probabilities of each unary op
  """
  _unary_op_probs = {
    'sin' : 1,
    'cos' : 1,
    'tan' : 2,
    'square' : 4,
    'cube' : 3,
    'exp' : 2,
    'log' : 1
  }
  """
  unnormalized probabilities of each binary op
  """
  _bin_op_probs = {
    '+' : 3,
    '-' : 3,
    '*' : 2,
    '/' : 2,
    '**' : 1,
  }
  """
  maps sympy functions to ones contained in this class
  """
  _from_sympy = {
      sin : 'sin',
      cos : 'cos',
      tan : 'tan',
      exp : 'exp',
      log : 'log',

      Add : '+',
      Mul : '*',
      Pow : '**',
  }
  """
    Generates a numpy array representing counts of possible trees of n internal nodes generated from e empty nodes
    D(0, n) = 0
    D(e, 0) = L ** e
    D(e, n) = L * D(e - 1, n) + p_1 * D(e, n - 1) + p_2 * D(e + 1, n - 1)

    from Appendix C.2 of DEEP LEARNING FOR SYMBOLIC MATHEMATICS (Guillaume Lample, Francois Charton 2019)
  """
  def _unary_binary_dist(self, size):

    #generating transposed version
    D = np.zeros((size * 2 + 1, size))
    D[:,0] = self._num_leaves ** np.arange(size * 2 + 1)
    D[0,0] = 0

    for n in range(1, size):
      for e in range(1, size * 2):
        D[e, n] = self._num_leaves * D[e - 1, n] + self._num_unary_ops * D[e, n - 1] + self._num_bin_ops * D[e + 1, n - 1]
    return D[:,:size+1]
    
  """
  

  Samples a position of a node and arity

  from Appendix C.3 of DEEP LEARNING FOR SYMBOLIC MATHEMATICS (Guillaume Lample, Francois Charton 2019)

  Parameters
    e -- number of empty nodes to sample from
    n -- number of operations
  """
  def _sample(self, e, n):

    P = np.zeros((e, 2))

    for k in range(e):
      P[k,0] = (self._num_leaves ** k) * self._num_unary_ops * self._unary_binary_dist[e - k][n - 1]
    for k in range(e):
      P[k,1] = (self._num_leaves ** k) * self._num_bin_ops * self._unary_binary_dist[e - k + 1][n - 1]

    P /= self._unary_binary_dist[e,n]
    k = np.random.choice(2*e, p=P.T.flatten())

    arity = 1 if k < e else 2
    k = k % e
    return k , arity

  def _choose_unary_op(self):
    return np.random.choice(tuple(self._unary_op_probs.keys()), p=self._unary_op_norm_prob)

  def _choose_bin_op(self):
    return np.random.choice(tuple(self._bin_op_probs.keys()), p=self._bin_op_norm_prob)
  
  def _choose_leaf(self):
    if(random.random() < 0.3):
      return 'x'
    return random.randrange(0,10)

  def _gen_from_sympy(self, expr):
    self._rep = []
    stack = [expr]
    while(len(stack) != 0):
      expr = stack.pop()
      #print(expr, self._rep)
      if isinstance(expr, Symbol):
        self._rep.append(str(expr))
      elif isinstance(expr, Integer):
        self._rep.append(str(expr))
      elif isinstance(expr, Rational):
        self._rep.append('/')

        args = str(expr).split('/')
        self._rep.append(str(args[0]))
        self._rep.append(str(args[1]))
      elif expr == E:
        self._rep.append('e')
      elif expr == pi:
        self._rep.append('pi')
      elif expr == I:
        self._rep.append('i')


      else:
        for i in range(len(expr.args) - 1):
          self._rep.append(self._from_sympy[type(expr)])

        for item in expr.args:
          stack.append(item)

  def _gen_random(self, num_ops):
    self._num_leaves = 1
    self._num_bin_ops = len(self._bin_op_probs.keys())
    self._num_unary_ops = len(self._unary_op_probs.keys())

    self._unary_binary_dist = self._unary_binary_dist(num_ops + 1)

    self._bin_op_norm_prob = np.fromiter(self._bin_op_probs.values(), dtype=float)
    self._bin_op_norm_prob /= self._bin_op_norm_prob.sum()

    self._unary_op_norm_prob = np.fromiter(self._unary_op_probs.values(), dtype=float)
    self._unary_op_norm_prob /= self._unary_op_norm_prob.sum()

    rep = [None]
    e = 1
    skipped = 0
    for n in range(num_ops, 0, - 1):
      k, arity = self._sample(e, n)
      skipped += k
      if arity == 1:
        op = self._choose_unary_op()

        #O(N) is bad for this. TODO: change to a dynamic programming approach so it is O(1) per iteration of parent loop
        encountered_empty = 0
        pos = 0
        for i in range(len(rep)):
          if(rep[i] == None):
            encountered_empty += 1
          if encountered_empty == skipped + 1:
            pos = i
            break

          
        rep = rep[:pos] + [op] + [None] + rep[pos + 1:]
        e = e - k
      else:
        op = self._choose_bin_op()

        encountered_empty = 0
        pos = 0
        for i in range(len(rep)):
          if(rep[i] == None):
            encountered_empty += 1
          if encountered_empty == skipped + 1:
            pos = i
            break

        rep = rep[:pos] + [op] + [None] + [None] + rep[pos + 1:]
        e = e - k + 1

    for i in range(len(rep)):
      if(rep[i] is None):
        rep[i] = self._choose_leaf()
    self._rep = rep


  def __init__(self, expr=None, num_ops=None):

    if(expr is not None):
      self._gen_from_sympy(expr)
    else:
      self._gen_random(num_ops)



  def to_infix(self):
    stack = []

    for i in range(len(self._rep) - 1, -1, -1):
      token = self._rep[i]

      if token in self._ops:
        arity = self._ops[token]

        args = stack[-arity:]
        stack = stack[:-arity]

        stack.append(self._infix_reps[token](args))
      else:
        stack.append(token)
    return stack.pop()
  def get_rep(self):
    return self._rep


In [4]:
def taylor_series(f_str, a, order):
  x = symbols('x')
  f = parse_expr(f_str)
  ret = f.subs(x, a)
  for i in range(1,order + 1):
    #print(i)
    f = diff(f,x)
    ret = ret + (f*(x-a))/factorial(i)
  return ret

In [5]:
def test_expr():
  for i in range(10):
    expr = Expression(num_ops=3)
    print(f"Expression {i+1}:")
    print(f"\tTokenixed prefix: ", expr.get_rep())
    print(f"\tInfix: ", expr.to_infix())
test_expr()

Expression 1:
	Tokenixed prefix:  ['/', 'exp', '+', 9, 3, 'x']
	Infix:  (x)/(exp((3)+(9)))
Expression 2:
	Tokenixed prefix:  ['*', 'square', 'x', 'cube', 4]
	Infix:  ((4)**3)*((x)**2)
Expression 3:
	Tokenixed prefix:  ['square', 'tan', '+', 7, 'x']
	Infix:  (tan((x)+(7)))**2
Expression 4:
	Tokenixed prefix:  ['-', '/', 'cube', 0, 'x', 8]
	Infix:  (8)-((x)/((0)**3))
Expression 5:
	Tokenixed prefix:  ['cube', 'cube', 'exp', 'x']
	Infix:  ((exp(x))**3)**3
Expression 6:
	Tokenixed prefix:  ['exp', '/', '-', 'x', 2, 7]
	Infix:  exp((7)/((2)-(x)))
Expression 7:
	Tokenixed prefix:  ['*', '-', 'x', 'square', 6, 3]
	Infix:  (3)*(((6)**2)-(x))
Expression 8:
	Tokenixed prefix:  ['*', '-', '+', 4, 7, 3, 'x']
	Infix:  (x)*((3)-((7)+(4)))
Expression 9:
	Tokenixed prefix:  ['/', 'tan', 9, '*', 'x', 2]
	Infix:  ((2)*(x))/(tan(9))
Expression 10:
	Tokenixed prefix:  ['-', 'exp', 6, '**', 'x', 4]
	Infix:  ((4)**(x))-(exp(6))


In [6]:
def gen_pair(ops=3):
  expr = Expression(num_ops=ops)
  tay = taylor_series(expr.to_infix(), Symbol('a'), 4)
  tay_rep = Expression(expr=tay)
  return expr, tay_rep


In [7]:
class FunctionDataset(torch.utils.data.Dataset):
  """
  not proud of the specification of both ops and sequence length but it works for now
  """
  def __init__(self, ops=3, max_seq_length=32, num_items=100):

    raw_input = []
    raw_output = []

    while len(raw_input) < num_items:
      expr, tay = gen_pair(ops)
      #we only need the postfix representation
      expr = expr.get_rep()
      tay = tay.get_rep()

      #discards expressions too long and nan values
      if(len(expr) + 2 <= max_seq_length and len(tay) + 2 <= max_seq_length and tay != nan):
        #insert start and end tokens
        expr.insert(0,'<SOS>')
        expr.append('<EOS>')

        tay.insert(0,'<SOS>')
        tay.append('<EOS>')
        
        raw_input.append(expr)
        raw_output.append(tay)
    
    #generate vocab
    self.vocab = set()
    for expr, tay in zip(raw_input, raw_output):
      self.vocab |= set(expr) |(set(tay))

    #token -> idx
    self.token_to_idx = {value : index + 1 for index, value in enumerate(self.vocab)}

    #idx -> token
    self.idx_to_token = {index + 1 : value for index, value in enumerate(self.vocab)}

    self.input = []
    self.output = []
    for raw_expr, raw_tay in zip(raw_input, raw_output):
      expr = [self.token_to_idx[token] for token in raw_expr] + [0] * (max_seq_length - len(raw_expr))
      tay = [self.token_to_idx[token] for token in raw_tay] + [0] * (max_seq_length - len(raw_tay))

      self.input.append(torch.tensor(expr, dtype=torch.long, device=device))
      self.output.append(torch.tensor(tay, dtype=torch.long, device=device))
    
  def __len__(self):
    return len(self.input)

  def __getitem__(self, idx):
    return self.input[idx].to(device), self.output[idx].to(device)

  def get_alphabet(self):
    return self.vocab
      



In [8]:
d = FunctionDataset(num_items=1000)
train_idx = list(range(0, int(9*len(d)/10)))
test_idx = list(range(int(9*len(d)/10), len(d)))
train_dataset = torch.utils.data.Subset(d, train_idx)
test_dataset = torch.utils.data.Subset(d, test_idx)

In [9]:
class Encoder(nn.Module):
  def __init__(self, vocab_size, embedding_dim=512, num_layers=2, hidden_size=512, dropout=0.2):
    super(Encoder, self).__init__()
    self.embedding_dim = embedding_dim
    self.num_layers = num_layers
    self.hidden_size = hidden_size

    self.embedding = nn.Embedding(
        num_embeddings=vocab_size,
        embedding_dim=self.embedding_dim
    )
    self.lstm = nn.LSTM(
        input_size=self.embedding_dim,
        hidden_size=self.hidden_size,
        num_layers=self.num_layers,
        dropout=dropout,
    )
  """
  input shape (SEQUENCE_LENGTH, BATCH_SIZE)
  h,c shape (HIDDEN_SIZE)
  """
  def forward(self, x):
    embed = self.embedding(x)
    output, (h,c) = self.lstm(embed)
    return h, c



In [10]:
class Decoder(nn.Module):
  def __init__(self, vocab_size, embedding_dim=512, num_layers=2, hidden_size=512, dropout=0.2):
    super(Decoder, self).__init__()
    self.embedding_dim = embedding_dim
    self.num_layers = num_layers
    self.output_size = vocab_size
    self.hidden_size = hidden_size

    self.embedding = nn.Embedding(
        num_embeddings=vocab_size,
        embedding_dim=self.embedding_dim
    )
    self.lstm = nn.LSTM(
        input_size=self.embedding_dim,
        hidden_size=self.hidden_size,
        num_layers=self.num_layers,
        dropout=0.2,
    )
    self.out = nn.Linear(self.hidden_size, self.output_size)
    self.softmax = nn.LogSoftmax(dim=2)
    self.to(device)
  """
  input shape (BATCH_SIZE)
  output shape
  """
  def forward(self, input, h_0, c_0):
    embedded = self.embedding(input.unsqueeze(0))
    output, (h,c) = self.lstm(embedded, (h_0, c_0))
    output = self.out(output)     
    output = self.softmax(output)
    return output.squeeze(0), h , c


In [11]:
class Model(nn.Module):
  def __init__(self, encoder, decoder):
    super(Model, self).__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.to(device)
  """
  Input tensor of shape (SEQUENCE_LENGTH, BATCH_SIZE)
  Output tensor of shape (SEQUENCE_LENGTH, BATCH_SIZE, VOCAB_SIZE)

  if tgt is none use teacher forecasting
  """
  def forward(self, input, tgt=None):
    if len(input.shape) < 2:
        input = input.unsqueeze(1)
    batch_size = input.shape[1]
    h, c = enc(input)
    target = torch.zeros(batch_size, dtype=torch.long).to(device)
    if tgt is None:
      max_seq_length = input.shape[0]
      target[:] = d.token_to_idx['<SOS>']
    else:
      max_seq_length = tgt.shape[1]
      target[:] = tgt[:,0]
    outputs = torch.zeros(max_seq_length, batch_size, dec.output_size, dtype=torch.float).to(device)
    for i in range(max_seq_length):
        prediction, h, c = dec(target, h, c)
        outputs[i] = prediction
        if tgt is None:
          target = prediction.argmax(dim=1)
        else:
          target = tgt[:,i]
    return outputs


In [12]:
enc = Encoder(len(d.get_alphabet()) + 1)
dec = Decoder(len(d.get_alphabet()) + 1)
m = Model(enc,dec).to(device)

In [13]:
def test_epoch_LSTM(model, test_loader, criterion, batch_size=4):
  model.eval()
  total_loss = 0
  total_items = 0
  num_correct = 0
  for src, tgt in tqdm(test_loader):
    src = src.to(device)
    tgt = tgt.to(device)

    pred = model(src.squeeze().T, tgt=tgt[:,:-1])
    pred = pred.permute((1,2,0))
    tgt_out = tgt[:,1:]

    loss = criterion(pred, tgt_out)

    total_loss += loss.item()
    total_items += (tgt_out != 0).sum(dim=(0,1))

    num_correct += (torch.logical_and((logits.argmax(dim=2) == tgt_out), (tgt_out != 0))).sum(dim=(0,1))
  return total_loss, num_correct / total_items

In [14]:
def train_epoch_LSTM(model, train_loader, optimizer, criterion, batch_size=4):
  model.train()
  total_loss = 0
  total_items = 0
  num_correct = 0
  for src, tgt in tqdm(train_loader):
    src = src.to(device)
    tgt = tgt.to(device)

    pred = model(src.squeeze().T,tgt=tgt[:,:-1])

    pred = pred.permute((1,2,0))
    tgt_out = tgt[:,1:]
    loss = criterion(pred, tgt_out)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += loss.item()
    total_items += (tgt_out != 0).sum(dim=(0,1))

    num_correct += (torch.logical_and((pred.argmax(dim=1) == tgt_out), (tgt_out != 0))).sum(dim=(0,1))
  return total_loss, num_correct / total_items

In [15]:
def train_LSTM(model, train_dataset,  test_dataset, batch_size=32, epochs=50):
  train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
  criterion = nn.CrossEntropyLoss()


  optim = torch.optim.Adam(model.parameters(), lr=1e-3)
  for e in range(epochs):
    train_loss, train_acc = train_epoch_LSTM(model, train_loader, optim, criterion, batch_size=batch_size)
    test_loss, test_acc = train_epoch_LSTM(model, test_loader, optim, criterion, batch_size=batch_size)
    print(f'Epoch: {e + 1} Training Loss: {train_loss} Training Accuracy: {train_acc} Test Loss: {test_loss} Test Accuracy: {test_acc}')

In [16]:
train_LSTM(m, train_dataset, test_dataset, batch_size=32)

100%|██████████| 29/29 [00:02<00:00, 12.40it/s]
100%|██████████| 4/4 [00:00<00:00, 12.54it/s]


Epoch: 1 Training Loss: 33.54634836316109 Training Accuracy: 0.11032736301422119 Test Loss: 2.403767615556717 Test Accuracy: 0.2083333283662796


100%|██████████| 29/29 [00:02<00:00, 12.17it/s]
100%|██████████| 4/4 [00:00<00:00, 12.72it/s]


Epoch: 2 Training Loss: 16.89692533016205 Training Accuracy: 0.3069271147251129 Test Loss: 2.0362119674682617 Test Accuracy: 0.3550724685192108


100%|██████████| 29/29 [00:02<00:00, 12.78it/s]
100%|██████████| 4/4 [00:00<00:00, 13.01it/s]


Epoch: 3 Training Loss: 14.930061250925064 Training Accuracy: 0.3930186331272125 Test Loss: 1.7664760649204254 Test Accuracy: 0.4003623127937317


100%|██████████| 29/29 [00:02<00:00, 13.37it/s]
100%|██████████| 4/4 [00:00<00:00, 13.92it/s]


Epoch: 4 Training Loss: 13.757144123315811 Training Accuracy: 0.4315427839756012 Test Loss: 1.6039934158325195 Test Accuracy: 0.45652174949645996


100%|██████████| 29/29 [00:02<00:00, 12.93it/s]
100%|██████████| 4/4 [00:00<00:00, 11.98it/s]


Epoch: 5 Training Loss: 13.655325770378113 Training Accuracy: 0.461566299200058 Test Loss: 1.4803729951381683 Test Accuracy: 0.4547101557254791


100%|██████████| 29/29 [00:02<00:00, 12.92it/s]
100%|██████████| 4/4 [00:00<00:00, 11.31it/s]


Epoch: 6 Training Loss: 12.540652394294739 Training Accuracy: 0.4760354459285736 Test Loss: 1.3665111362934113 Test Accuracy: 0.48731884360313416


100%|██████████| 29/29 [00:02<00:00, 12.35it/s]
100%|██████████| 4/4 [00:00<00:00, 11.27it/s]


Epoch: 7 Training Loss: 11.688833743333817 Training Accuracy: 0.500813901424408 Test Loss: 1.4113059043884277 Test Accuracy: 0.5072463750839233


100%|██████████| 29/29 [00:02<00:00, 11.68it/s]
100%|██████████| 4/4 [00:00<00:00,  9.46it/s]


Epoch: 8 Training Loss: 11.284646302461624 Training Accuracy: 0.5116657614707947 Test Loss: 1.495109736919403 Test Accuracy: 0.49637681245803833


100%|██████████| 29/29 [00:02<00:00, 12.06it/s]
100%|██████████| 4/4 [00:00<00:00, 13.42it/s]


Epoch: 9 Training Loss: 10.787398785352707 Training Accuracy: 0.5210707187652588 Test Loss: 1.2415013760328293 Test Accuracy: 0.5398550629615784


100%|██████████| 29/29 [00:02<00:00, 12.77it/s]
100%|██████████| 4/4 [00:00<00:00, 12.80it/s]


Epoch: 10 Training Loss: 10.22880694270134 Training Accuracy: 0.5395188927650452 Test Loss: 1.2313932180404663 Test Accuracy: 0.5525362491607666


100%|██████████| 29/29 [00:02<00:00, 12.96it/s]
100%|██████████| 4/4 [00:00<00:00, 12.56it/s]


Epoch: 11 Training Loss: 9.765019744634628 Training Accuracy: 0.5608609318733215 Test Loss: 1.1571755409240723 Test Accuracy: 0.5615941882133484


100%|██████████| 29/29 [00:02<00:00, 12.60it/s]
100%|██████████| 4/4 [00:00<00:00, 13.84it/s]


Epoch: 12 Training Loss: 9.277985289692879 Training Accuracy: 0.5653825402259827 Test Loss: 1.1532021164894104 Test Accuracy: 0.5670289993286133


100%|██████████| 29/29 [00:02<00:00, 13.19it/s]
100%|██████████| 4/4 [00:00<00:00, 11.74it/s]


Epoch: 13 Training Loss: 8.85831581056118 Training Accuracy: 0.5939591526985168 Test Loss: 1.073831096291542 Test Accuracy: 0.5960144996643066


100%|██████████| 29/29 [00:02<00:00, 12.93it/s]
100%|██████████| 4/4 [00:00<00:00, 12.35it/s]


Epoch: 14 Training Loss: 8.514876186847687 Training Accuracy: 0.6080665588378906 Test Loss: 1.0374141335487366 Test Accuracy: 0.6123188138008118


100%|██████████| 29/29 [00:02<00:00, 12.77it/s]
100%|██████████| 4/4 [00:00<00:00, 13.36it/s]


Epoch: 15 Training Loss: 8.178459718823433 Training Accuracy: 0.6214505434036255 Test Loss: 0.9112862944602966 Test Accuracy: 0.6286231875419617


100%|██████████| 29/29 [00:02<00:00, 12.98it/s]
100%|██████████| 4/4 [00:00<00:00, 13.02it/s]


Epoch: 16 Training Loss: 7.890336871147156 Training Accuracy: 0.633387565612793 Test Loss: 0.9736251980066299 Test Accuracy: 0.6557971239089966


100%|██████████| 29/29 [00:02<00:00, 12.95it/s]
100%|██████████| 4/4 [00:00<00:00, 13.32it/s]


Epoch: 17 Training Loss: 7.656392619013786 Training Accuracy: 0.6404412984848022 Test Loss: 0.8182681947946548 Test Accuracy: 0.6666666865348816


100%|██████████| 29/29 [00:02<00:00, 12.55it/s]
100%|██████████| 4/4 [00:00<00:00, 14.10it/s]


Epoch: 18 Training Loss: 7.175373286008835 Training Accuracy: 0.6648580431938171 Test Loss: 0.8907041102647781 Test Accuracy: 0.6666666865348816


100%|██████████| 29/29 [00:02<00:00, 12.98it/s]
100%|██████████| 4/4 [00:00<00:00, 13.10it/s]


Epoch: 19 Training Loss: 7.184923782944679 Training Accuracy: 0.6594321131706238 Test Loss: 0.9066859036684036 Test Accuracy: 0.6757246255874634


100%|██████████| 29/29 [00:02<00:00, 12.70it/s]
100%|██████████| 4/4 [00:00<00:00, 13.05it/s]


Epoch: 20 Training Loss: 7.240866914391518 Training Accuracy: 0.6606981158256531 Test Loss: 1.0580588430166245 Test Accuracy: 0.679347813129425


100%|██████████| 29/29 [00:02<00:00, 12.76it/s]
100%|██████████| 4/4 [00:00<00:00, 12.48it/s]


Epoch: 21 Training Loss: 6.729033187031746 Training Accuracy: 0.685114860534668 Test Loss: 0.8202158361673355 Test Accuracy: 0.7137681245803833


100%|██████████| 29/29 [00:02<00:00, 12.74it/s]
100%|██████████| 4/4 [00:00<00:00, 12.41it/s]


Epoch: 22 Training Loss: 6.196915924549103 Training Accuracy: 0.7095315456390381 Test Loss: 0.748314768075943 Test Accuracy: 0.7246376872062683


100%|██████████| 29/29 [00:02<00:00, 12.95it/s]
100%|██████████| 4/4 [00:00<00:00, 12.61it/s]


Epoch: 23 Training Loss: 5.835441127419472 Training Accuracy: 0.7268945574760437 Test Loss: 0.7451856583356857 Test Accuracy: 0.739130437374115


100%|██████████| 29/29 [00:02<00:00, 12.87it/s]
100%|██████████| 4/4 [00:00<00:00, 13.07it/s]


Epoch: 24 Training Loss: 5.731698855757713 Training Accuracy: 0.7301501035690308 Test Loss: 0.7365870028734207 Test Accuracy: 0.7228260636329651


100%|██████████| 29/29 [00:02<00:00, 12.78it/s]
100%|██████████| 4/4 [00:00<00:00, 12.84it/s]


Epoch: 25 Training Loss: 5.526354745030403 Training Accuracy: 0.7296075224876404 Test Loss: 0.7384300380945206 Test Accuracy: 0.7481883764266968


100%|██████████| 29/29 [00:02<00:00, 12.67it/s]
100%|██████████| 4/4 [00:00<00:00, 13.45it/s]


Epoch: 26 Training Loss: 5.3005111664533615 Training Accuracy: 0.7451618909835815 Test Loss: 0.5821344628930092 Test Accuracy: 0.77173912525177


100%|██████████| 29/29 [00:02<00:00, 12.85it/s]
100%|██████████| 4/4 [00:00<00:00, 13.75it/s]


Epoch: 27 Training Loss: 5.153714179992676 Training Accuracy: 0.7536625266075134 Test Loss: 0.6297977864742279 Test Accuracy: 0.7572463750839233


100%|██████████| 29/29 [00:02<00:00, 12.73it/s]
100%|██████████| 4/4 [00:00<00:00, 13.95it/s]


Epoch: 28 Training Loss: 4.818244755268097 Training Accuracy: 0.7645143866539001 Test Loss: 0.5325523167848587 Test Accuracy: 0.7771739363670349


100%|██████████| 29/29 [00:02<00:00, 12.85it/s]
100%|██████████| 4/4 [00:00<00:00, 13.12it/s]


Epoch: 29 Training Loss: 4.497244276106358 Training Accuracy: 0.7822390794754028 Test Loss: 0.4644264504313469 Test Accuracy: 0.8061594367027283


100%|██████████| 29/29 [00:02<00:00, 12.78it/s]
100%|██████████| 4/4 [00:00<00:00, 11.82it/s]


Epoch: 30 Training Loss: 4.192795678973198 Training Accuracy: 0.7990595102310181 Test Loss: 0.4667748138308525 Test Accuracy: 0.8170289993286133


100%|██████████| 29/29 [00:02<00:00, 12.95it/s]
100%|██████████| 4/4 [00:00<00:00, 13.18it/s]


Epoch: 31 Training Loss: 4.004155166447163 Training Accuracy: 0.8100922703742981 Test Loss: 0.5621817260980606 Test Accuracy: 0.8297101259231567


100%|██████████| 29/29 [00:02<00:00, 13.17it/s]
100%|██████████| 4/4 [00:00<00:00, 12.97it/s]


Epoch: 32 Training Loss: 3.766948379576206 Training Accuracy: 0.8167842030525208 Test Loss: 0.5449516102671623 Test Accuracy: 0.8442028760910034


100%|██████████| 29/29 [00:02<00:00, 13.18it/s]
100%|██████████| 4/4 [00:00<00:00, 13.14it/s]


Epoch: 33 Training Loss: 3.745029255747795 Training Accuracy: 0.8135286569595337 Test Loss: 0.43658923357725143 Test Accuracy: 0.8496376872062683


100%|██████████| 29/29 [00:02<00:00, 13.18it/s]
100%|██████████| 4/4 [00:00<00:00, 12.85it/s]


Epoch: 34 Training Loss: 3.4139687307178974 Training Accuracy: 0.8289021253585815 Test Loss: 0.43198559433221817 Test Accuracy: 0.8550724387168884


100%|██████████| 29/29 [00:02<00:00, 13.15it/s]
100%|██████████| 4/4 [00:00<00:00, 13.02it/s]


Epoch: 35 Training Loss: 3.3060493394732475 Training Accuracy: 0.833604633808136 Test Loss: 0.44968750327825546 Test Accuracy: 0.8297101259231567


100%|██████████| 29/29 [00:02<00:00, 13.01it/s]
100%|██████████| 4/4 [00:00<00:00, 13.42it/s]


Epoch: 36 Training Loss: 3.31807928532362 Training Accuracy: 0.8317959904670715 Test Loss: 0.40729691833257675 Test Accuracy: 0.8442028760910034


100%|██████████| 29/29 [00:02<00:00, 12.62it/s]
100%|██████████| 4/4 [00:00<00:00, 13.55it/s]


Epoch: 37 Training Loss: 3.0408864244818687 Training Accuracy: 0.846626877784729 Test Loss: 0.34821728616952896 Test Accuracy: 0.8550724387168884


100%|██████████| 29/29 [00:02<00:00, 13.16it/s]
100%|██████████| 4/4 [00:00<00:00, 13.76it/s]


Epoch: 38 Training Loss: 2.7950876764953136 Training Accuracy: 0.8594682812690735 Test Loss: 0.3298104703426361 Test Accuracy: 0.8677536249160767


100%|██████████| 29/29 [00:02<00:00, 12.88it/s]
100%|██████████| 4/4 [00:00<00:00, 13.47it/s]


Epoch: 39 Training Loss: 2.6594934165477753 Training Accuracy: 0.8634472489356995 Test Loss: 0.30419545248150826 Test Accuracy: 0.8713768124580383


100%|██████████| 29/29 [00:02<00:00, 12.89it/s]
100%|██████████| 4/4 [00:00<00:00, 12.14it/s]


Epoch: 40 Training Loss: 2.746246851980686 Training Accuracy: 0.8632664084434509 Test Loss: 0.3538799397647381 Test Accuracy: 0.8731883764266968


100%|██████████| 29/29 [00:02<00:00, 13.22it/s]
100%|██████████| 4/4 [00:00<00:00, 13.87it/s]


Epoch: 41 Training Loss: 2.62623093649745 Training Accuracy: 0.8697775602340698 Test Loss: 0.25195237621665 Test Accuracy: 0.89673912525177


100%|██████████| 29/29 [00:02<00:00, 12.38it/s]
100%|██████████| 4/4 [00:00<00:00, 13.14it/s]


Epoch: 42 Training Loss: 2.1809401847422123 Training Accuracy: 0.892204761505127 Test Loss: 0.23134178668260574 Test Accuracy: 0.8931159377098083


100%|██████████| 29/29 [00:02<00:00, 12.97it/s]
100%|██████████| 4/4 [00:00<00:00, 12.86it/s]


Epoch: 43 Training Loss: 1.9221416637301445 Training Accuracy: 0.9130041599273682 Test Loss: 0.2636130861938 Test Accuracy: 0.9166666865348816


100%|██████████| 29/29 [00:02<00:00, 12.71it/s]
100%|██████████| 4/4 [00:00<00:00, 12.36it/s]


Epoch: 44 Training Loss: 1.7475877851247787 Training Accuracy: 0.9169831871986389 Test Loss: 0.18385760113596916 Test Accuracy: 0.945652186870575


100%|██████████| 29/29 [00:02<00:00, 12.96it/s]
100%|██████████| 4/4 [00:00<00:00, 13.38it/s]


Epoch: 45 Training Loss: 1.552013697102666 Training Accuracy: 0.9336227178573608 Test Loss: 0.17230243608355522 Test Accuracy: 0.9438405632972717


100%|██████████| 29/29 [00:02<00:00, 12.91it/s]
100%|██████████| 4/4 [00:00<00:00, 12.62it/s]


Epoch: 46 Training Loss: 1.3388654943555593 Training Accuracy: 0.9455597996711731 Test Loss: 0.17177767306566238 Test Accuracy: 0.9420289993286133


100%|██████████| 29/29 [00:02<00:00, 13.05it/s]
100%|██████████| 4/4 [00:00<00:00, 12.98it/s]


Epoch: 47 Training Loss: 1.3090003356337547 Training Accuracy: 0.9509857296943665 Test Loss: 0.1702333316206932 Test Accuracy: 0.9420289993286133


100%|██████████| 29/29 [00:02<00:00, 12.89it/s]
100%|██████████| 4/4 [00:00<00:00, 13.23it/s]


Epoch: 48 Training Loss: 1.3711021598428488 Training Accuracy: 0.9399529695510864 Test Loss: 0.1349223591387272 Test Accuracy: 0.9619565010070801


100%|██████████| 29/29 [00:02<00:00, 12.95it/s]
100%|██████████| 4/4 [00:00<00:00, 13.15it/s]


Epoch: 49 Training Loss: 1.1047004032880068 Training Accuracy: 0.958039402961731 Test Loss: 0.10266529954969883 Test Accuracy: 0.97826087474823


100%|██████████| 29/29 [00:02<00:00, 13.14it/s]
100%|██████████| 4/4 [00:00<00:00, 12.99it/s]

Epoch: 50 Training Loss: 0.8827230539172888 Training Accuracy: 0.9717851281166077 Test Loss: 0.12204395607113838 Test Accuracy: 0.9710144996643066





In [17]:
#adapted from https://torchtutorialstaging.z5.web.core.windows.net/beginner/translation_transformer.html
class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, dropout, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding):
        return self.dropout(token_embedding +
                            self.pos_embedding[:token_embedding.size(0),:])

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
  src_seq_len = src.shape[0]
  tgt_seq_len = tgt.shape[0]

  tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
  src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)

  src_padding_mask = (src == 0).transpose(0, 1)
  tgt_padding_mask = (tgt == 0).transpose(0, 1)
  return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

class TransformerModel(nn.Module):
    def __init__(self, num_encoder_layers, nhead, num_decoder_layers,
                 emb_size, src_vocab_size, tgt_vocab_size,
                 dim_feedforward:int = 512, dropout:float = 0.1):
        super(TransformerModel, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_size, nhead=nhead,
                                                dim_feedforward=dim_feedforward)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        decoder_layer = nn.TransformerDecoderLayer(d_model=emb_size, nhead=nhead,
                                                dim_feedforward=dim_feedforward)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.emb_size = emb_size
        self.src_tok_emb = self.embedding = nn.Embedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = self.embedding = nn.Embedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def forward(self, src, trg, src_mask,
                tgt_mask, src_padding_mask,
                tgt_padding_mask, memory_key_padding_mask):
        src_emb = self.positional_encoding(self.src_tok_emb(src)* math.sqrt(self.emb_size))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg)* math.sqrt(self.emb_size))
        memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask)
        outs = self.transformer_decoder(tgt_emb, memory, tgt_mask, None,
                                        tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

In [29]:
model = TransformerModel(num_encoder_layers=6, nhead=8, num_decoder_layers=6,
                 emb_size=512, src_vocab_size=(len(d.get_alphabet()) + 1), tgt_vocab_size=(len(d.get_alphabet()) + 1),
                 dim_feedforward = 512, dropout = 0.2).to(device)

In [19]:
def train_epoch_transformer(model, train_loader, optimizer, criterion, batch_size):
  model.train()
  total_loss = 0
  num_correct = 0
  total_items = 0
  for src, tgt in tqdm(train_loader):
      src = src.to(device).T
      tgt = tgt.to(device).T

      tgt_input = tgt[:-1, :]

      src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

      logits = model(src, tgt_input, src_mask, tgt_mask,
                                src_padding_mask, tgt_padding_mask, src_padding_mask)

      optimizer.zero_grad()

      tgt_out = tgt[1:,:]
      loss = criterion(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))

      loss.backward()
      optimizer.step()

      total_loss += loss.item()
      total_items += (tgt_out != 0).sum(dim=(0,1))

      num_correct += (torch.logical_and((logits.argmax(dim=2) == tgt_out), (tgt_out != 0))).sum(dim=(0,1))
  return total_loss / len(train_loader), num_correct / total_items

In [20]:
def test_epoch_transformer(model, test_loader, criterion, batch_size):
  model.eval()
  total_loss = 0
  num_correct = 0
  total_items = 0
  for src, tgt in tqdm(train_loader):
      src = src.to(device).T
      tgt = tgt.to(device).T

      tgt_input = tgt[:-1, :]

      src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

      logits = model(src, tgt_input, src_mask, tgt_mask,
                                src_padding_mask, tgt_padding_mask, src_padding_mask)


      tgt_out = tgt[1:,:]
      loss = criterion(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))


      total_loss += loss.item()
      total_items += (tgt_out != 0).sum(dim=(0,1))

      num_correct += (torch.logical_and((logits.argmax(dim=2) == tgt_out), (tgt_out != 0))).sum(dim=(0,1))
  return total_loss / len(train_loader), num_correct / total_items

In [27]:
def train_transformer(model, train_dataset, test_dataset, batch_size=32, epochs=60):
  train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
  criterion = nn.CrossEntropyLoss()
  
  optim = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-9)
  for e in range(epochs):
    train_loss, train_acc = train_epoch_transformer(model, train_loader, optim, criterion, batch_size=batch_size)
    test_loss, test_acc = train_epoch_transformer(model, test_loader, optim, criterion, batch_size=batch_size)
    print(f'Epoch: {e + 1} Training Loss: {train_loss} Training Accuracy: {train_acc} Test Loss: {test_loss} Test Accuracy: {test_acc}')

In [30]:
train_transformer(model,train_dataset, test_dataset)

100%|██████████| 29/29 [00:03<00:00,  7.36it/s]
100%|██████████| 4/4 [00:00<00:00,  8.20it/s]


Epoch: 1 Training Loss: 1.455214222957348 Training Accuracy: 0.0005425935960374773 Test Loss: 0.826636791229248 Test Accuracy: 0.0


100%|██████████| 29/29 [00:03<00:00,  7.84it/s]
100%|██████████| 4/4 [00:00<00:00,  6.78it/s]


Epoch: 2 Training Loss: 0.6331408054664217 Training Accuracy: 0.22065472602844238 Test Loss: 0.6608844324946404 Test Accuracy: 0.36775362491607666


100%|██████████| 29/29 [00:04<00:00,  7.17it/s]
100%|██████████| 4/4 [00:00<00:00,  6.68it/s]


Epoch: 3 Training Loss: 0.5160817555312452 Training Accuracy: 0.3798155188560486 Test Loss: 0.4616780951619148 Test Accuracy: 0.4384058117866516


100%|██████████| 29/29 [00:04<00:00,  6.97it/s]
100%|██████████| 4/4 [00:00<00:00,  8.48it/s]


Epoch: 4 Training Loss: 0.4689414357316905 Training Accuracy: 0.45216134190559387 Test Loss: 0.358040913939476 Test Accuracy: 0.4637681245803833


100%|██████████| 29/29 [00:03<00:00,  7.36it/s]
100%|██████████| 4/4 [00:00<00:00,  7.24it/s]


Epoch: 5 Training Loss: 0.42675841471244547 Training Accuracy: 0.48272743821144104 Test Loss: 0.38209368288517 Test Accuracy: 0.5163043737411499


100%|██████████| 29/29 [00:04<00:00,  7.23it/s]
100%|██████████| 4/4 [00:00<00:00,  4.95it/s]


Epoch: 6 Training Loss: 0.39837100382508905 Training Accuracy: 0.5132935643196106 Test Loss: 0.3516445979475975 Test Accuracy: 0.5362318754196167


100%|██████████| 29/29 [00:03<00:00,  7.64it/s]
100%|██████████| 4/4 [00:00<00:00, 11.06it/s]


Epoch: 7 Training Loss: 0.3765170019248436 Training Accuracy: 0.5201663970947266 Test Loss: 0.3149130195379257 Test Accuracy: 0.5452898740768433


100%|██████████| 29/29 [00:04<00:00,  6.76it/s]
100%|██████████| 4/4 [00:00<00:00,  8.09it/s]


Epoch: 8 Training Loss: 0.3486768377238306 Training Accuracy: 0.5532646179199219 Test Loss: 0.29547226801514626 Test Accuracy: 0.5561594367027283


100%|██████████| 29/29 [00:03<00:00,  7.45it/s]
100%|██████████| 4/4 [00:00<00:00,  6.82it/s]


Epoch: 9 Training Loss: 0.33421368948344526 Training Accuracy: 0.5650207996368408 Test Loss: 0.26647259667515755 Test Accuracy: 0.5942028760910034


100%|██████████| 29/29 [00:03<00:00,  7.31it/s]
100%|██████████| 4/4 [00:00<00:00,  7.34it/s]


Epoch: 10 Training Loss: 0.3208088052683863 Training Accuracy: 0.5883523225784302 Test Loss: 0.238606795668602 Test Accuracy: 0.6068840622901917


100%|██████████| 29/29 [00:03<00:00,  7.51it/s]
100%|██████████| 4/4 [00:00<00:00,  8.01it/s]


Epoch: 11 Training Loss: 0.3076090309126624 Training Accuracy: 0.5903418064117432 Test Loss: 0.26060666143894196 Test Accuracy: 0.6177536249160767


100%|██████████| 29/29 [00:04<00:00,  6.93it/s]
100%|██████████| 4/4 [00:00<00:00,  7.94it/s]


Epoch: 12 Training Loss: 0.29449746526520826 Training Accuracy: 0.6082473993301392 Test Loss: 0.29906148463487625 Test Accuracy: 0.5905796885490417


100%|██████████| 29/29 [00:03<00:00,  7.51it/s]
100%|██████████| 4/4 [00:00<00:00,  8.56it/s]


Epoch: 13 Training Loss: 0.2833210836196768 Training Accuracy: 0.6180140972137451 Test Loss: 0.2384313829243183 Test Accuracy: 0.635869562625885


100%|██████████| 29/29 [00:03<00:00,  7.33it/s]
100%|██████████| 4/4 [00:00<00:00,  7.30it/s]


Epoch: 14 Training Loss: 0.27015075704147073 Training Accuracy: 0.6256104111671448 Test Loss: 0.21439559757709503 Test Accuracy: 0.6539855003356934


100%|██████████| 29/29 [00:04<00:00,  6.39it/s]
100%|██████████| 4/4 [00:00<00:00,  7.02it/s]


Epoch: 15 Training Loss: 0.2604855751169139 Training Accuracy: 0.6386326551437378 Test Loss: 0.22344471514225006 Test Accuracy: 0.657608687877655


100%|██████████| 29/29 [00:04<00:00,  7.24it/s]
100%|██████████| 4/4 [00:00<00:00,  7.38it/s]


Epoch: 16 Training Loss: 0.25168907693747816 Training Accuracy: 0.6449629068374634 Test Loss: 0.18574578501284122 Test Accuracy: 0.679347813129425


100%|██████████| 29/29 [00:03<00:00,  7.76it/s]
100%|██████████| 4/4 [00:00<00:00,  7.64it/s]


Epoch: 17 Training Loss: 0.23915083449462365 Training Accuracy: 0.6605172753334045 Test Loss: 0.19017183408141136 Test Accuracy: 0.679347813129425


100%|██████████| 29/29 [00:04<00:00,  7.09it/s]
100%|██████████| 4/4 [00:00<00:00,  6.88it/s]


Epoch: 18 Training Loss: 0.2322501242160797 Training Accuracy: 0.6625067591667175 Test Loss: 0.22895560413599014 Test Accuracy: 0.6865941882133484


100%|██████████| 29/29 [00:03<00:00,  7.66it/s]
100%|██████████| 4/4 [00:00<00:00, 10.46it/s]


Epoch: 19 Training Loss: 0.2318629432341148 Training Accuracy: 0.6688370704650879 Test Loss: 0.1929892711341381 Test Accuracy: 0.66847825050354


100%|██████████| 29/29 [00:03<00:00,  8.56it/s]
100%|██████████| 4/4 [00:00<00:00,  6.71it/s]


Epoch: 20 Training Loss: 0.2221613649664254 Training Accuracy: 0.6852957010269165 Test Loss: 0.22097247838974 Test Accuracy: 0.7101449370384216


100%|██████████| 29/29 [00:03<00:00,  8.19it/s]
100%|██████████| 4/4 [00:00<00:00,  7.62it/s]


Epoch: 21 Training Loss: 0.2170614029826789 Training Accuracy: 0.6847531199455261 Test Loss: 0.21383285894989967 Test Accuracy: 0.7119565010070801


100%|██████████| 29/29 [00:03<00:00,  7.44it/s]
100%|██████████| 4/4 [00:00<00:00,  7.98it/s]


Epoch: 22 Training Loss: 0.20584901653487106 Training Accuracy: 0.7060951590538025 Test Loss: 0.20903633907437325 Test Accuracy: 0.717391312122345


100%|██████████| 29/29 [00:03<00:00,  7.61it/s]
100%|██████████| 4/4 [00:00<00:00,  8.36it/s]


Epoch: 23 Training Loss: 0.19918878119567346 Training Accuracy: 0.7088081240653992 Test Loss: 0.16393477842211723 Test Accuracy: 0.739130437374115


100%|██████████| 29/29 [00:03<00:00,  7.63it/s]
100%|██████████| 4/4 [00:00<00:00,  8.79it/s]


Epoch: 24 Training Loss: 0.18781256213270384 Training Accuracy: 0.72617107629776 Test Loss: 0.16667009517550468 Test Accuracy: 0.760869562625885


100%|██████████| 29/29 [00:04<00:00,  7.22it/s]
100%|██████████| 4/4 [00:00<00:00,  9.05it/s]


Epoch: 25 Training Loss: 0.18187563301160417 Training Accuracy: 0.7353951930999756 Test Loss: 0.14913938380777836 Test Accuracy: 0.7536231875419617


100%|██████████| 29/29 [00:03<00:00,  7.62it/s]
100%|██████████| 4/4 [00:00<00:00,  7.48it/s]


Epoch: 26 Training Loss: 0.17704206448176812 Training Accuracy: 0.7346717119216919 Test Loss: 0.14165825210511684 Test Accuracy: 0.7626811861991882


100%|██████████| 29/29 [00:03<00:00,  7.66it/s]
100%|██████████| 4/4 [00:00<00:00,  9.02it/s]


Epoch: 27 Training Loss: 0.1701118180464054 Training Accuracy: 0.7379273176193237 Test Loss: 0.15204482339322567 Test Accuracy: 0.7626811861991882


100%|██████████| 29/29 [00:03<00:00,  8.14it/s]
100%|██████████| 4/4 [00:00<00:00,  7.62it/s]


Epoch: 28 Training Loss: 0.16411014903208304 Training Accuracy: 0.7527582049369812 Test Loss: 0.19691462256014347 Test Accuracy: 0.7916666865348816


100%|██████████| 29/29 [00:04<00:00,  6.75it/s]
100%|██████████| 4/4 [00:00<00:00,  9.68it/s]


Epoch: 29 Training Loss: 0.16234498650863252 Training Accuracy: 0.755109429359436 Test Loss: 0.12147603370249271 Test Accuracy: 0.79347825050354


100%|██████████| 29/29 [00:04<00:00,  7.02it/s]
100%|██████████| 4/4 [00:00<00:00,  7.04it/s]


Epoch: 30 Training Loss: 0.15305199319946355 Training Accuracy: 0.7661421298980713 Test Loss: 0.11716478504240513 Test Accuracy: 0.8115941882133484


100%|██████████| 29/29 [00:03<00:00,  7.69it/s]
100%|██████████| 4/4 [00:00<00:00,  7.84it/s]


Epoch: 31 Training Loss: 0.1488134534708385 Training Accuracy: 0.7703020572662354 Test Loss: 0.12563841044902802 Test Accuracy: 0.8134058117866516


100%|██████████| 29/29 [00:03<00:00,  7.55it/s]
100%|██████████| 4/4 [00:00<00:00,  8.45it/s]


Epoch: 32 Training Loss: 0.1488193928681571 Training Accuracy: 0.7741001844406128 Test Loss: 0.11440369673073292 Test Accuracy: 0.8351449370384216


100%|██████████| 29/29 [00:03<00:00,  8.70it/s]
100%|██████████| 4/4 [00:00<00:00,  8.58it/s]


Epoch: 33 Training Loss: 0.13788407655625506 Training Accuracy: 0.7883884906768799 Test Loss: 0.11826138012111187 Test Accuracy: 0.7989130616188049


100%|██████████| 29/29 [00:04<00:00,  6.60it/s]
100%|██████████| 4/4 [00:00<00:00,  7.34it/s]


Epoch: 34 Training Loss: 0.1401812518978941 Training Accuracy: 0.7869415879249573 Test Loss: 0.11853856965899467 Test Accuracy: 0.820652186870575


100%|██████████| 29/29 [00:04<00:00,  7.19it/s]
100%|██████████| 4/4 [00:00<00:00,  7.04it/s]


Epoch: 35 Training Loss: 0.13890256064719167 Training Accuracy: 0.7999638319015503 Test Loss: 0.103473836556077 Test Accuracy: 0.820652186870575


100%|██████████| 29/29 [00:03<00:00,  8.12it/s]
100%|██████████| 4/4 [00:00<00:00, 12.26it/s]


Epoch: 36 Training Loss: 0.14504997586381846 Training Accuracy: 0.7853137850761414 Test Loss: 0.10095429420471191 Test Accuracy: 0.7916666865348816


100%|██████████| 29/29 [00:02<00:00, 11.83it/s]
100%|██████████| 4/4 [00:00<00:00, 11.34it/s]


Epoch: 37 Training Loss: 0.1256734006877603 Training Accuracy: 0.8043045997619629 Test Loss: 0.09001941047608852 Test Accuracy: 0.8568840622901917


100%|██████████| 29/29 [00:02<00:00, 12.05it/s]
100%|██████████| 4/4 [00:00<00:00, 10.68it/s]


Epoch: 38 Training Loss: 0.118595277440959 Training Accuracy: 0.82383793592453 Test Loss: 0.10015655495226383 Test Accuracy: 0.8297101259231567


100%|██████████| 29/29 [00:02<00:00, 11.69it/s]
100%|██████████| 4/4 [00:00<00:00, 11.54it/s]


Epoch: 39 Training Loss: 0.10955288697933328 Training Accuracy: 0.8245614171028137 Test Loss: 0.08481330052018166 Test Accuracy: 0.8623188138008118


100%|██████████| 29/29 [00:02<00:00, 12.00it/s]
100%|██████████| 4/4 [00:00<00:00, 11.48it/s]


Epoch: 40 Training Loss: 0.10408659818871267 Training Accuracy: 0.8435521721839905 Test Loss: 0.08129930961877108 Test Accuracy: 0.8659420013427734


100%|██████████| 29/29 [00:02<00:00, 11.51it/s]
100%|██████████| 4/4 [00:00<00:00, 11.55it/s]


Epoch: 41 Training Loss: 0.10132470971037602 Training Accuracy: 0.8482546806335449 Test Loss: 0.08962256461381912 Test Accuracy: 0.8822463750839233


100%|██████████| 29/29 [00:03<00:00,  8.60it/s]
100%|██████████| 4/4 [00:00<00:00,  9.70it/s]


Epoch: 42 Training Loss: 0.10397960245609283 Training Accuracy: 0.8442756533622742 Test Loss: 0.08576568961143494 Test Accuracy: 0.8568840622901917


100%|██████████| 29/29 [00:02<00:00, 10.96it/s]
100%|██████████| 4/4 [00:00<00:00, 10.78it/s]


Epoch: 43 Training Loss: 0.10139972793644872 Training Accuracy: 0.8495206832885742 Test Loss: 0.08576498925685883 Test Accuracy: 0.8586956262588501


100%|██████████| 29/29 [00:02<00:00, 10.66it/s]
100%|██████████| 4/4 [00:00<00:00, 11.15it/s]


Epoch: 44 Training Loss: 0.09904626654139881 Training Accuracy: 0.8506059050559998 Test Loss: 0.07154860347509384 Test Accuracy: 0.89673912525177


100%|██████████| 29/29 [00:03<00:00,  9.22it/s]
100%|██████████| 4/4 [00:00<00:00,  9.37it/s]


Epoch: 45 Training Loss: 0.08841892704367638 Training Accuracy: 0.8686923384666443 Test Loss: 0.10076631791889668 Test Accuracy: 0.9166666865348816


100%|██████████| 29/29 [00:02<00:00, 10.91it/s]
100%|██████████| 4/4 [00:00<00:00, 11.94it/s]


Epoch: 46 Training Loss: 0.08973115194460442 Training Accuracy: 0.8601917028427124 Test Loss: 0.07064413372427225 Test Accuracy: 0.9003623127937317


100%|██████████| 29/29 [00:02<00:00, 11.22it/s]
100%|██████████| 4/4 [00:00<00:00, 10.57it/s]


Epoch: 47 Training Loss: 0.0850652033655808 Training Accuracy: 0.8641707301139832 Test Loss: 0.0659559490159154 Test Accuracy: 0.9039855003356934


100%|██████████| 29/29 [00:02<00:00, 11.64it/s]
100%|██████████| 4/4 [00:00<00:00, 11.26it/s]


Epoch: 48 Training Loss: 0.0819415944660532 Training Accuracy: 0.8705009818077087 Test Loss: 0.05828547477722168 Test Accuracy: 0.8949275612831116


100%|██████████| 29/29 [00:02<00:00, 11.99it/s]
100%|██████████| 4/4 [00:00<00:00, 11.94it/s]


Epoch: 49 Training Loss: 0.07542850411143796 Training Accuracy: 0.887321412563324 Test Loss: 0.05291798524558544 Test Accuracy: 0.9130434989929199


100%|██████████| 29/29 [00:02<00:00, 11.78it/s]
100%|██████████| 4/4 [00:00<00:00, 12.19it/s]


Epoch: 50 Training Loss: 0.07204312381559405 Training Accuracy: 0.8934707641601562 Test Loss: 0.07510412763804197 Test Accuracy: 0.9057971239089966


100%|██████████| 29/29 [00:02<00:00, 11.83it/s]
100%|██████████| 4/4 [00:00<00:00, 11.81it/s]


Epoch: 51 Training Loss: 0.07371213120119326 Training Accuracy: 0.8900343775749207 Test Loss: 0.062376356683671474 Test Accuracy: 0.9202898740768433


100%|██████████| 29/29 [00:02<00:00, 12.14it/s]
100%|██████████| 4/4 [00:00<00:00, 11.42it/s]


Epoch: 52 Training Loss: 0.06792061287781288 Training Accuracy: 0.8956411480903625 Test Loss: 0.04890283849090338 Test Accuracy: 0.9148550629615784


100%|██████████| 29/29 [00:02<00:00, 11.78it/s]
100%|██████████| 4/4 [00:00<00:00, 10.38it/s]


Epoch: 53 Training Loss: 0.06084464850096867 Training Accuracy: 0.9139084815979004 Test Loss: 0.042735776863992214 Test Accuracy: 0.9311594367027283


100%|██████████| 29/29 [00:02<00:00, 11.98it/s]
100%|██████████| 4/4 [00:00<00:00, 12.19it/s]


Epoch: 54 Training Loss: 0.05910650589342775 Training Accuracy: 0.9128233194351196 Test Loss: 0.06917157303541899 Test Accuracy: 0.9202898740768433


100%|██████████| 29/29 [00:02<00:00, 11.79it/s]
100%|██████████| 4/4 [00:00<00:00, 11.35it/s]


Epoch: 55 Training Loss: 0.06835682073543811 Training Accuracy: 0.9012479782104492 Test Loss: 0.06457455828785896 Test Accuracy: 0.9112318754196167


100%|██████████| 29/29 [00:02<00:00, 11.95it/s]
100%|██████████| 4/4 [00:00<00:00, 11.76it/s]


Epoch: 56 Training Loss: 0.068217765282968 Training Accuracy: 0.8969072103500366 Test Loss: 0.056218622252345085 Test Accuracy: 0.9221014380455017


100%|██████████| 29/29 [00:02<00:00, 12.08it/s]
100%|██████████| 4/4 [00:00<00:00, 11.83it/s]


Epoch: 57 Training Loss: 0.057248392816761445 Training Accuracy: 0.9146319627761841 Test Loss: 0.03502130974084139 Test Accuracy: 0.9510869383811951


100%|██████████| 29/29 [00:02<00:00, 12.26it/s]
100%|██████████| 4/4 [00:00<00:00, 12.27it/s]


Epoch: 58 Training Loss: 0.049999234234464576 Training Accuracy: 0.926388144493103 Test Loss: 0.03485229657962918 Test Accuracy: 0.9384058117866516


100%|██████████| 29/29 [00:02<00:00, 12.37it/s]
100%|██████████| 4/4 [00:00<00:00, 12.38it/s]


Epoch: 59 Training Loss: 0.04878363674827691 Training Accuracy: 0.9218665361404419 Test Loss: 0.036520480178296566 Test Accuracy: 0.9365941882133484


100%|██████████| 29/29 [00:02<00:00, 12.31it/s]
100%|██████████| 4/4 [00:00<00:00, 12.66it/s]

Epoch: 60 Training Loss: 0.046162863743716274 Training Accuracy: 0.9314523339271545 Test Loss: 0.02858730871230364 Test Accuracy: 0.945652186870575



