In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets
import torchvision.transforms as transforms
from tqdm import tqdm
import math

In [2]:
def squash(input):
  squared_norm=(input**2).sum(dim=-1,keepdim=True)
  output=squared_norm*input/((1.+squared_norm)*torch.sqrt(squared_norm))
  return output

In [3]:
def to_one_hot(x,size):
    x_one_hot=torch.zeros(x.size()[0],size)
    for i in range(x.size()[0]):
        x_one_hot[i,x[i]]=1.0
    return x_one_hot

In [4]:
class Conv(nn.Module):
  def __init__(self,input_num_channels=1,output_num_channels=256,k_size=9):
    super(Conv,self).__init__()
    self.convolution=nn.Conv2d(in_channels=input_num_channels,out_channels=output_num_channels,kernel_size=[k_size,k_size],stride=1,bias=True)
  def forward(self,input):
    #input size is (batch_size, 1, 28, 28)
    output=self.convolution(input)
    output=F.relu(output)
    #output size is (batch_size, 256, 20, 20)
    return output

In [5]:
class Primary_Caps(nn.Module):
  def __init__(self,num_caps=32,input_channels=256,output_channels=8,k_size=9):
    super(Primary_Caps,self).__init__()
    self.primary_capsules=nn.ModuleList([nn.Conv2d(in_channels=input_channels, out_channels=output_channels,kernel_size=[k_size,k_size],stride=2,bias=True)for _ in range(num_caps)])
  def forward(self,input):
    #input size is (batch_size, 256, 20, 20)
    output=[p_c(input) for p_c in self.primary_capsules]
    output=torch.stack(output, dim=1)
    #output size is (batch_size, 32, 8, 6, 6)
    output=torch.transpose(output,1,2)
    #output size is (batch_size, 8, 32, 6, 6)
    output=torch.reshape(output,(batch_size,-1,output.size()[1]))
    #output size is (batch_size, 32*6*6, 8)
    output=squash(output)
    return output

In [6]:
class Digit_Caps(nn.Module):
  def __init__(self,num_caps=10,input_size=8,output_size=16,num_routes=1152):
    super(Digit_Caps,self).__init__()
    self.num_caps=num_caps
    self.num_routes=num_routes
    self.W=nn.Parameter(torch.randn((batch_size,num_routes,num_caps,output_size,input_size)))
  def forward(self,input):
    #input size is (batch_size, 32*6*6, 8)
    #W size is (batch_size, 32*6*6, 10, 16, 8) 
    input=torch.stack([input] * self.num_caps, dim=2).unsqueeze(4)
    #input size is (batch_size, 32*6*6, 10, 8, 1)
    u_hat=torch.matmul(self.W,input)
    #u_hat size is (batch_size, 32*6*6, 10, 16, 1)
    b=Variable(torch.zeros(batch_size,self.num_routes,self.num_caps,1))
    #b size is (batch_size, 32*6*6, 10, 1)
    b=b.cuda()
    for i in range(0,3,1):
      c=F.softmax(b,dim=1).unsqueeze(4)
      #c size is (batch_size, 32*6*6, 10, 1, 1)
      s=c*u_hat
      #s size is (batch_size, 32*6*6, 10, 16, 1)
      s=s.sum(dim=1 ,keepdim=True)
      #s size is (batch_size, 1, 10, 16, 1)
      v=squash(s)
      #v size is (batch_size, 1, 10, 16, 1)
      if i<2:
        a=torch.matmul(u_hat.transpose(3, 4), torch.cat([v] * self.num_routes, dim=1))
        #a size is (batch_size, 32*6*6, 10, 1, 1)
        b=b+a.squeeze(4).mean(dim=0, keepdim=True)
        #b size is (batch_size, 32*6*6, 10, 1)
    return v.squeeze(1)

In [7]:
class Decoder(nn.Module):
  def __init__(self,input_size=160,decoder1_size=512,decoder2_size=1024,output_size=784,img_height=28,img_width=28,img_channels=1):
    super(Decoder,self).__init__()
    self.num_channels=img_channels
    self.height=img_height
    self.width=img_width
    self.linear1=nn.Linear(in_features=input_size,out_features=decoder1_size,bias=True)
    self.linear2=nn.Linear(in_features=decoder1_size,out_features=decoder2_size,bias=True)
    self.linear3=nn.Linear(in_features=decoder2_size,out_features=output_size,bias=True)
  def forward(self,input):
    #input size is (batch_size, 10, 16, 1)
    classes=torch.sqrt((input**2).sum(2))
    classes=F.softmax(classes, dim=0)
    #classes size is (batch_size, 10, 1)
    _,max_length_indices=classes.max(dim=1)
    #max_length_indices size is (batch_size, 1)
    masked=Variable(torch.sparse.torch.eye(10))
    masked=masked.cuda()
    masked=masked.index_select(dim=0,index=Variable(max_length_indices.squeeze(1).data))
    #masked size is (batch_size, 10)
    input=(input*masked[:,:,None,None]).view(input.size(0),-1)
    #input size is (batch_size, 160)
    output=self.linear1(input)
    output=F.relu(output)
    #output size is (batch_size, 512)
    output=self.linear2(output)
    output=F.relu(output)
    #output size is (batch_size, 1024)
    output=self.linear3(output)
    output=F.relu(output)
    #output size is (batch_size, 784)
    output=torch.reshape(output,(batch_size,self.num_channels,self.height,self.width))
    #output size is (batch_size, 1, 28, 28)
    return output, masked

In [8]:
class CapsNet(nn.Module):
  def __init__(self):
    super(CapsNet,self).__init__()
    self.conv=Conv()
    self.p_cap=Primary_Caps()
    self.d_cap=Digit_Caps()
    self.dec=Decoder()
  def forward(self,input):
    out_conv=self.conv(input)
    out_p_c=self.p_cap(out_conv)
    output=self.d_cap(out_p_c)
    rec,mask=self.dec(output)
    return output,rec,mask
  def margin_loss(self,output,true_output):
    Lambda=0.5
    m_plus=0.9
    m_minus=0.1
    squared_norm=(output**2).sum(2,keepdim=True)
    norm=torch.sqrt(squared_norm)
    term1=F.relu(m_plus-norm).view(batch_size,-1)
    term2=F.relu(norm-m_minus).view(batch_size,-1)
    loss=true_output*term1+Lambda*(1.0-true_output)*term2
    loss=loss.sum(-1).mean()
    return loss
  def reconstruction_loss(self,input,rec):
    mse=nn.MSELoss()
    input=torch.reshape(input,(input.size()[0],-1))
    rec=torch.reshape(rec,(rec.size()[0],-1))
    loss=0.005*mse(input,rec)
    return loss
  def loss(self,input,output,true_output,rec):
    loss_m_l=self.margin_loss(output,true_output)
    loss_r_l=self.reconstruction_loss(input,rec)
    total_loss=loss_m_l+loss_r_l
    return total_loss

In [9]:
def train(capsnet,optimizer,train_loader,epoch,save_path='capsnet'):
    capsnet.train()
    num_batches=len(list(enumerate(train_loader)))
    total_loss=0
    for batch_idx,(input,label) in enumerate(tqdm(train_loader)):
        target=torch.sparse.torch.eye(10).index_select(dim=0,index=label)
        input,target=Variable(input),Variable(target)
        input,target=input.cuda(),target.cuda()
        output,rec,masked=capsnet(input)
        loss=capsnet.loss(input,output,target,rec)
        train_loss=loss.item()
        if math.isnan(train_loss):
          capsnet=torch.load(save_path)
        else:
          loss.backward()
          optimizer.step()
          total_loss+=train_loss
          #torch.save(capsnet,save_path)
        correct=sum(np.argmax(masked.data.cpu().numpy(),1)==np.argmax(target.data.cpu().numpy(),1))
        #if math.isnan(train_loss):
          #print("loading previous best model")
          #capsnet=torch.load(save_path)
        if batch_idx%100==0:
            optimizer.zero_grad()
            tqdm.write("Epoch: [{}/{}], Batch: [{}/{}], train accuracy: {:.6f}, loss: {:.6f}".format(
                epoch,num_epochs,batch_idx+1,num_batches,correct/float(batch_size),train_loss/float(batch_size)))
            if not (math.isnan(train_loss)):
              torch.save(capsnet,save_path)
    tqdm.write('Epoch: [{}/{}], train loss: {:.6f}'.format(epoch,num_epochs,total_loss/len(train_loader.dataset)))
    #if not (math.isnan(total_loss)):
      #torch.save(capsnet,save_path)

In [10]:
def test(capsnet,test_loader,epoch):
    capsnet
    capsnet.eval()
    test_loss=0
    correct=0
    for batch_idx,(input,label) in enumerate(test_loader):
        target=torch.sparse.torch.eye(10).index_select(dim=0,index=label)
        input,target=Variable(input),Variable(target)
        input,target=input.cuda(),target.cuda()
        output,rec,masked=capsnet(input)
        loss=capsnet.loss(input,output,target,rec)
        test_loss+=loss.item()
        correct+=sum(np.argmax(masked.data.cpu().numpy(),1)==np.argmax(target.data.cpu().numpy(),1))
    tqdm.write("Epoch: [{}/{}], test accuracy: {:.6f}, loss: {:.6f}".format(epoch,num_epochs,correct/len(test_loader.dataset),test_loss/len(test_loader)))

In [11]:
data_transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,),(0.3081,))
                ])
batch_size=128
num_epochs=30
capsnet=CapsNet()
capsnet=torch.nn.DataParallel(capsnet)
train_loader=torch.utils.data.DataLoader(datasets.MNIST("mnist",train=True,download=True,transform=data_transform),batch_size=batch_size,shuffle=True,drop_last=True)
test_loader=torch.utils.data.DataLoader(datasets.MNIST("mnist",train=False,download=True,transform=data_transform),batch_size=batch_size,shuffle=True,drop_last=True)
capsnet=capsnet.cuda()
capsnet=capsnet.module
optimizer=torch.optim.Adam(capsnet.parameters(),lr=1e-3)
for e in range(1,num_epochs+1):
  train(capsnet,optimizer,train_loader,e)
  test(capsnet,test_loader,e)
  torch.cuda.empty_cache()

  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [1/30], Batch: [1/468], train accuracy: 0.125000, loss: 0.007060


 21%|██▏       | 100/468 [00:55<03:03,  2.00it/s]

Epoch: [1/30], Batch: [101/468], train accuracy: 0.117188, loss: 0.006719


 43%|████▎     | 200/468 [01:49<02:15,  1.97it/s]

Epoch: [1/30], Batch: [201/468], train accuracy: 0.265625, loss: 0.006535


 64%|██████▍   | 300/468 [02:44<01:28,  1.90it/s]

Epoch: [1/30], Batch: [301/468], train accuracy: 0.320312, loss: 0.006262


 85%|████████▌ | 400/468 [03:40<00:34,  1.94it/s]

Epoch: [1/30], Batch: [401/468], train accuracy: 0.468750, loss: 0.005886


100%|██████████| 468/468 [04:19<00:00,  1.80it/s]


Epoch: [1/30], train loss: 0.006418
Epoch: [1/30], test accuracy: 0.527300, loss: nan


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [2/30], Batch: [1/468], train accuracy: 0.492188, loss: 0.005852


 21%|██▏       | 100/468 [00:56<03:14,  1.89it/s]

Epoch: [2/30], Batch: [101/468], train accuracy: 0.492188, loss: 0.005417


 43%|████▎     | 200/468 [01:52<02:18,  1.93it/s]

Epoch: [2/30], Batch: [201/468], train accuracy: 0.726562, loss: 0.004289


 64%|██████▍   | 300/468 [02:48<01:27,  1.92it/s]

Epoch: [2/30], Batch: [301/468], train accuracy: 0.757812, loss: 0.003555


 85%|████████▌ | 400/468 [03:44<00:35,  1.89it/s]

Epoch: [2/30], Batch: [401/468], train accuracy: 0.773438, loss: 0.003454


100%|██████████| 468/468 [04:24<00:00,  1.77it/s]


Epoch: [2/30], train loss: nan
Epoch: [2/30], test accuracy: 0.811800, loss: 0.437063


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [3/30], Batch: [1/468], train accuracy: 0.828125, loss: 0.003393


 21%|██▏       | 100/468 [00:56<03:10,  1.93it/s]

Epoch: [3/30], Batch: [101/468], train accuracy: 0.843750, loss: 0.003386


 43%|████▎     | 200/468 [01:52<02:19,  1.91it/s]

Epoch: [3/30], Batch: [201/468], train accuracy: 0.867188, loss: 0.002887


 64%|██████▍   | 300/468 [02:49<01:27,  1.91it/s]

Epoch: [3/30], Batch: [301/468], train accuracy: 0.820312, loss: 0.002962


 85%|████████▌ | 400/468 [03:44<00:35,  1.92it/s]

Epoch: [3/30], Batch: [401/468], train accuracy: 0.867188, loss: 0.002820


100%|██████████| 468/468 [04:23<00:00,  1.78it/s]


Epoch: [3/30], train loss: nan
Epoch: [3/30], test accuracy: 0.792600, loss: 0.444561


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [4/30], Batch: [1/468], train accuracy: 0.789062, loss: 0.003393


 21%|██▏       | 100/468 [00:57<03:11,  1.92it/s]

Epoch: [4/30], Batch: [101/468], train accuracy: 0.851562, loss: 0.002859


 43%|████▎     | 200/468 [01:52<02:20,  1.91it/s]

Epoch: [4/30], Batch: [201/468], train accuracy: 0.906250, loss: 0.002457


 64%|██████▍   | 300/468 [02:48<01:27,  1.93it/s]

Epoch: [4/30], Batch: [301/468], train accuracy: 0.882812, loss: 0.002232


 85%|████████▌ | 400/468 [03:46<00:36,  1.88it/s]

Epoch: [4/30], Batch: [401/468], train accuracy: 0.921875, loss: 0.001951


100%|██████████| 468/468 [04:24<00:00,  1.77it/s]


Epoch: [4/30], train loss: 0.002547
Epoch: [4/30], test accuracy: 0.908500, loss: 0.225806


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [5/30], Batch: [1/468], train accuracy: 0.937500, loss: 0.001537


 21%|██▏       | 100/468 [00:56<03:10,  1.93it/s]

Epoch: [5/30], Batch: [101/468], train accuracy: 0.945312, loss: 0.001602


 43%|████▎     | 200/468 [01:53<02:21,  1.89it/s]

Epoch: [5/30], Batch: [201/468], train accuracy: 0.898438, loss: 0.001475


 64%|██████▍   | 300/468 [02:49<01:27,  1.92it/s]

Epoch: [5/30], Batch: [301/468], train accuracy: 0.906250, loss: 0.001683


 85%|████████▌ | 400/468 [03:45<00:35,  1.93it/s]

Epoch: [5/30], Batch: [401/468], train accuracy: 0.945312, loss: 0.001428


100%|██████████| 468/468 [04:24<00:00,  1.77it/s]


Epoch: [5/30], train loss: 0.001533
Epoch: [5/30], test accuracy: 0.935500, loss: 0.175530


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [6/30], Batch: [1/468], train accuracy: 0.937500, loss: 0.001355


 21%|██▏       | 100/468 [00:56<03:11,  1.93it/s]

Epoch: [6/30], Batch: [101/468], train accuracy: 0.890625, loss: 0.001776


 43%|████▎     | 200/468 [01:52<02:18,  1.93it/s]

Epoch: [6/30], Batch: [201/468], train accuracy: 0.921875, loss: 0.001348


 64%|██████▍   | 300/468 [02:48<01:30,  1.86it/s]

Epoch: [6/30], Batch: [301/468], train accuracy: 0.914062, loss: 0.001430


 85%|████████▌ | 400/468 [03:45<00:35,  1.89it/s]

Epoch: [6/30], Batch: [401/468], train accuracy: 0.898438, loss: 0.001468


100%|██████████| 468/468 [04:24<00:00,  1.77it/s]


Epoch: [6/30], train loss: 0.001297
Epoch: [6/30], test accuracy: 0.948600, loss: 0.126327


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [7/30], Batch: [1/468], train accuracy: 0.945312, loss: 0.000989


 21%|██▏       | 100/468 [00:57<03:16,  1.88it/s]

Epoch: [7/30], Batch: [101/468], train accuracy: 0.992188, loss: 0.000772


 43%|████▎     | 200/468 [01:54<02:21,  1.89it/s]

Epoch: [7/30], Batch: [201/468], train accuracy: 0.914062, loss: 0.001443


 64%|██████▍   | 300/468 [02:51<01:28,  1.90it/s]

Epoch: [7/30], Batch: [301/468], train accuracy: 0.953125, loss: 0.001010


 85%|████████▌ | 400/468 [03:47<00:35,  1.91it/s]

Epoch: [7/30], Batch: [401/468], train accuracy: 0.953125, loss: 0.000849


100%|██████████| 468/468 [04:26<00:00,  1.76it/s]


Epoch: [7/30], train loss: 0.000975
Epoch: [7/30], test accuracy: 0.956800, loss: 0.122947


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [8/30], Batch: [1/468], train accuracy: 0.921875, loss: 0.001226


 21%|██▏       | 100/468 [00:57<03:18,  1.85it/s]

Epoch: [8/30], Batch: [101/468], train accuracy: 0.953125, loss: 0.000900


 43%|████▎     | 200/468 [01:53<02:18,  1.93it/s]

Epoch: [8/30], Batch: [201/468], train accuracy: 0.960938, loss: 0.000987


 64%|██████▍   | 300/468 [02:49<01:27,  1.92it/s]

Epoch: [8/30], Batch: [301/468], train accuracy: 0.953125, loss: 0.000951


 85%|████████▌ | 400/468 [03:45<00:36,  1.85it/s]

Epoch: [8/30], Batch: [401/468], train accuracy: 0.976562, loss: 0.000817


100%|██████████| 468/468 [04:25<00:00,  1.77it/s]


Epoch: [8/30], train loss: 0.000888
Epoch: [8/30], test accuracy: 0.961000, loss: 0.098817


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [9/30], Batch: [1/468], train accuracy: 0.992188, loss: 0.000583


 21%|██▏       | 100/468 [00:56<03:09,  1.94it/s]

Epoch: [9/30], Batch: [101/468], train accuracy: 0.937500, loss: 0.001119


 43%|████▎     | 200/468 [01:52<02:23,  1.86it/s]

Epoch: [9/30], Batch: [201/468], train accuracy: 0.953125, loss: 0.001028


 64%|██████▍   | 300/468 [02:48<01:26,  1.94it/s]

Epoch: [9/30], Batch: [301/468], train accuracy: 0.937500, loss: 0.001021


 85%|████████▌ | 400/468 [03:44<00:35,  1.93it/s]

Epoch: [9/30], Batch: [401/468], train accuracy: 0.960938, loss: 0.001053


100%|██████████| 468/468 [04:22<00:00,  1.78it/s]


Epoch: [9/30], train loss: 0.000794
Epoch: [9/30], test accuracy: 0.968200, loss: 0.089165


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [10/30], Batch: [1/468], train accuracy: 0.953125, loss: 0.000906


 21%|██▏       | 100/468 [00:56<03:11,  1.92it/s]

Epoch: [10/30], Batch: [101/468], train accuracy: 0.960938, loss: 0.000889


 43%|████▎     | 200/468 [01:52<02:19,  1.93it/s]

Epoch: [10/30], Batch: [201/468], train accuracy: 0.984375, loss: 0.000579


 64%|██████▍   | 300/468 [02:48<01:28,  1.89it/s]

Epoch: [10/30], Batch: [301/468], train accuracy: 0.968750, loss: 0.000654


 85%|████████▌ | 400/468 [03:45<00:35,  1.93it/s]

Epoch: [10/30], Batch: [401/468], train accuracy: 0.968750, loss: 0.000866


100%|██████████| 468/468 [04:23<00:00,  1.77it/s]


Epoch: [10/30], train loss: nan
Epoch: [10/30], test accuracy: 0.957600, loss: 0.124632


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [11/30], Batch: [1/468], train accuracy: 0.937500, loss: 0.001081


 21%|██▏       | 100/468 [00:56<03:12,  1.91it/s]

Epoch: [11/30], Batch: [101/468], train accuracy: 0.968750, loss: 0.000954


 43%|████▎     | 200/468 [01:52<02:18,  1.93it/s]

Epoch: [11/30], Batch: [201/468], train accuracy: 0.960938, loss: 0.001105


 64%|██████▍   | 300/468 [02:48<01:27,  1.92it/s]

Epoch: [11/30], Batch: [301/468], train accuracy: 0.968750, loss: 0.000868


 85%|████████▌ | 400/468 [03:44<00:35,  1.92it/s]

Epoch: [11/30], Batch: [401/468], train accuracy: 0.960938, loss: 0.000656


100%|██████████| 468/468 [04:24<00:00,  1.77it/s]


Epoch: [11/30], train loss: 0.000816
Epoch: [11/30], test accuracy: 0.969600, loss: 0.084167


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [12/30], Batch: [1/468], train accuracy: 0.992188, loss: 0.000503


 21%|██▏       | 100/468 [00:56<03:11,  1.92it/s]

Epoch: [12/30], Batch: [101/468], train accuracy: 0.984375, loss: 0.000572


 43%|████▎     | 200/468 [01:51<02:19,  1.93it/s]

Epoch: [12/30], Batch: [201/468], train accuracy: 0.921875, loss: 0.001224


 64%|██████▍   | 300/468 [02:49<01:28,  1.91it/s]

Epoch: [12/30], Batch: [301/468], train accuracy: 0.992188, loss: 0.000402


 85%|████████▌ | 400/468 [03:44<00:35,  1.92it/s]

Epoch: [12/30], Batch: [401/468], train accuracy: 0.976562, loss: 0.000617


100%|██████████| 468/468 [04:23<00:00,  1.78it/s]


Epoch: [12/30], train loss: nan
Epoch: [12/30], test accuracy: 0.966300, loss: 0.106453


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [13/30], Batch: [1/468], train accuracy: 0.945312, loss: 0.000968


 21%|██▏       | 100/468 [00:57<03:17,  1.87it/s]

Epoch: [13/30], Batch: [101/468], train accuracy: 0.984375, loss: 0.000748


 43%|████▎     | 200/468 [01:54<02:21,  1.89it/s]

Epoch: [13/30], Batch: [201/468], train accuracy: 0.906250, loss: 0.001370


 64%|██████▍   | 300/468 [02:50<01:28,  1.90it/s]

Epoch: [13/30], Batch: [301/468], train accuracy: 0.968750, loss: 0.000699


 85%|████████▌ | 400/468 [03:47<00:36,  1.89it/s]

Epoch: [13/30], Batch: [401/468], train accuracy: 0.976562, loss: 0.000751


100%|██████████| 468/468 [04:26<00:00,  1.75it/s]


Epoch: [13/30], train loss: nan
Epoch: [13/30], test accuracy: 0.966300, loss: 0.121188


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [14/30], Batch: [1/468], train accuracy: 0.968750, loss: 0.000788


 21%|██▏       | 100/468 [00:57<03:14,  1.89it/s]

Epoch: [14/30], Batch: [101/468], train accuracy: 0.953125, loss: 0.001123


 43%|████▎     | 200/468 [01:54<02:21,  1.89it/s]

Epoch: [14/30], Batch: [201/468], train accuracy: 0.937500, loss: 0.000991


 64%|██████▍   | 300/468 [02:51<01:28,  1.90it/s]

Epoch: [14/30], Batch: [301/468], train accuracy: 0.937500, loss: 0.001042


 85%|████████▌ | 400/468 [03:47<00:35,  1.90it/s]

Epoch: [14/30], Batch: [401/468], train accuracy: 0.968750, loss: 0.000624


100%|██████████| 468/468 [04:27<00:00,  1.75it/s]


Epoch: [14/30], train loss: 0.000806
Epoch: [14/30], test accuracy: 0.974900, loss: 0.066330


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [15/30], Batch: [1/468], train accuracy: 0.976562, loss: 0.000518


 21%|██▏       | 100/468 [00:57<03:15,  1.88it/s]

Epoch: [15/30], Batch: [101/468], train accuracy: 0.960938, loss: 0.000492


 43%|████▎     | 200/468 [01:54<02:22,  1.89it/s]

Epoch: [15/30], Batch: [201/468], train accuracy: 0.976562, loss: 0.000622


 64%|██████▍   | 300/468 [02:51<01:28,  1.90it/s]

Epoch: [15/30], Batch: [301/468], train accuracy: 0.984375, loss: 0.000428


 85%|████████▌ | 400/468 [03:47<00:35,  1.89it/s]

Epoch: [15/30], Batch: [401/468], train accuracy: 0.984375, loss: 0.000589


100%|██████████| 468/468 [04:27<00:00,  1.75it/s]


Epoch: [15/30], train loss: 0.000575
Epoch: [15/30], test accuracy: 0.974900, loss: 0.069751


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [16/30], Batch: [1/468], train accuracy: 1.000000, loss: 0.000242


 21%|██▏       | 100/468 [00:57<03:14,  1.89it/s]

Epoch: [16/30], Batch: [101/468], train accuracy: 0.976562, loss: 0.000579


 43%|████▎     | 200/468 [01:54<02:20,  1.91it/s]

Epoch: [16/30], Batch: [201/468], train accuracy: 0.968750, loss: 0.000578


 64%|██████▍   | 300/468 [02:50<01:27,  1.93it/s]

Epoch: [16/30], Batch: [301/468], train accuracy: 0.984375, loss: 0.000501


 85%|████████▌ | 400/468 [03:45<00:35,  1.93it/s]

Epoch: [16/30], Batch: [401/468], train accuracy: 0.976562, loss: 0.000527


100%|██████████| 468/468 [04:25<00:00,  1.77it/s]


Epoch: [16/30], train loss: nan
Epoch: [16/30], test accuracy: 0.976600, loss: 0.065796


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [17/30], Batch: [1/468], train accuracy: 0.976562, loss: 0.000525


 21%|██▏       | 100/468 [00:56<03:12,  1.91it/s]

Epoch: [17/30], Batch: [101/468], train accuracy: 0.984375, loss: 0.000516


 43%|████▎     | 200/468 [01:52<02:18,  1.93it/s]

Epoch: [17/30], Batch: [201/468], train accuracy: 0.984375, loss: 0.000475


 64%|██████▍   | 300/468 [02:48<01:26,  1.93it/s]

Epoch: [17/30], Batch: [301/468], train accuracy: 0.976562, loss: 0.000473


 85%|████████▌ | 400/468 [03:44<00:35,  1.89it/s]

Epoch: [17/30], Batch: [401/468], train accuracy: 0.984375, loss: 0.000652


100%|██████████| 468/468 [04:23<00:00,  1.77it/s]


Epoch: [17/30], train loss: 0.000551
Epoch: [17/30], test accuracy: 0.973000, loss: 0.074616


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [18/30], Batch: [1/468], train accuracy: 0.984375, loss: 0.000520


 21%|██▏       | 100/468 [00:56<03:12,  1.92it/s]

Epoch: [18/30], Batch: [101/468], train accuracy: 0.953125, loss: 0.000928


 43%|████▎     | 200/468 [01:52<02:19,  1.93it/s]

Epoch: [18/30], Batch: [201/468], train accuracy: 0.945312, loss: 0.000896


 64%|██████▍   | 300/468 [02:47<01:27,  1.92it/s]

Epoch: [18/30], Batch: [301/468], train accuracy: 0.976562, loss: 0.000682


 85%|████████▌ | 400/468 [03:44<00:36,  1.87it/s]

Epoch: [18/30], Batch: [401/468], train accuracy: 0.992188, loss: 0.000639


100%|██████████| 468/468 [04:23<00:00,  1.77it/s]


Epoch: [18/30], train loss: nan
Epoch: [18/30], test accuracy: 0.968300, loss: 0.105283


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [19/30], Batch: [1/468], train accuracy: 0.984375, loss: 0.000697


 21%|██▏       | 100/468 [00:56<03:10,  1.93it/s]

Epoch: [19/30], Batch: [101/468], train accuracy: 0.976562, loss: 0.000499


 43%|████▎     | 200/468 [01:52<02:18,  1.93it/s]

Epoch: [19/30], Batch: [201/468], train accuracy: 0.968750, loss: 0.000593


 64%|██████▍   | 300/468 [02:47<01:27,  1.92it/s]

Epoch: [19/30], Batch: [301/468], train accuracy: 0.968750, loss: 0.000570


 85%|████████▌ | 400/468 [03:45<00:35,  1.90it/s]

Epoch: [19/30], Batch: [401/468], train accuracy: 0.953125, loss: 0.000792


100%|██████████| 468/468 [04:23<00:00,  1.77it/s]


Epoch: [19/30], train loss: 0.000586
Epoch: [19/30], test accuracy: 0.976600, loss: 0.060945


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [20/30], Batch: [1/468], train accuracy: 1.000000, loss: 0.000336


 21%|██▏       | 100/468 [00:56<03:10,  1.93it/s]

Epoch: [20/30], Batch: [101/468], train accuracy: 0.976562, loss: 0.000558


 43%|████▎     | 200/468 [01:51<02:18,  1.93it/s]

Epoch: [20/30], Batch: [201/468], train accuracy: 0.953125, loss: 0.000811


 64%|██████▍   | 300/468 [02:48<01:29,  1.87it/s]

Epoch: [20/30], Batch: [301/468], train accuracy: 0.984375, loss: 0.000664


 85%|████████▌ | 400/468 [03:44<00:35,  1.91it/s]

Epoch: [20/30], Batch: [401/468], train accuracy: 0.976562, loss: 0.000531


100%|██████████| 468/468 [04:23<00:00,  1.77it/s]


Epoch: [20/30], train loss: 0.000557
Epoch: [20/30], test accuracy: 0.976900, loss: 0.060046


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [21/30], Batch: [1/468], train accuracy: 0.984375, loss: 0.000351


 21%|██▏       | 100/468 [00:56<03:10,  1.94it/s]

Epoch: [21/30], Batch: [101/468], train accuracy: 0.976562, loss: 0.000470


 43%|████▎     | 200/468 [01:51<02:18,  1.94it/s]

Epoch: [21/30], Batch: [201/468], train accuracy: 0.984375, loss: 0.000471


 64%|██████▍   | 300/468 [02:49<01:29,  1.88it/s]

Epoch: [21/30], Batch: [301/468], train accuracy: 0.976562, loss: 0.000531


 85%|████████▌ | 400/468 [03:45<00:35,  1.91it/s]

Epoch: [21/30], Batch: [401/468], train accuracy: 0.992188, loss: 0.000539


100%|██████████| 468/468 [04:24<00:00,  1.77it/s]


Epoch: [21/30], train loss: 0.000470
Epoch: [21/30], test accuracy: 0.981200, loss: 0.051693


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [22/30], Batch: [1/468], train accuracy: 0.992188, loss: 0.000270


 21%|██▏       | 100/468 [00:56<03:09,  1.94it/s]

Epoch: [22/30], Batch: [101/468], train accuracy: 0.992188, loss: 0.000396


 43%|████▎     | 200/468 [01:51<02:20,  1.90it/s]

Epoch: [22/30], Batch: [201/468], train accuracy: 0.968750, loss: 0.000690


 64%|██████▍   | 300/468 [02:49<01:28,  1.90it/s]

Epoch: [22/30], Batch: [301/468], train accuracy: 1.000000, loss: 0.000328


 85%|████████▌ | 400/468 [03:45<00:35,  1.91it/s]

Epoch: [22/30], Batch: [401/468], train accuracy: 0.984375, loss: 0.000419


100%|██████████| 468/468 [04:23<00:00,  1.77it/s]


Epoch: [22/30], train loss: 0.000454
Epoch: [22/30], test accuracy: 0.980900, loss: 0.052623


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [23/30], Batch: [1/468], train accuracy: 0.984375, loss: 0.000374


 21%|██▏       | 100/468 [00:55<03:09,  1.94it/s]

Epoch: [23/30], Batch: [101/468], train accuracy: 0.953125, loss: 0.000622


 43%|████▎     | 200/468 [01:52<02:23,  1.86it/s]

Epoch: [23/30], Batch: [201/468], train accuracy: 0.984375, loss: 0.000465


 64%|██████▍   | 300/468 [02:48<01:27,  1.92it/s]

Epoch: [23/30], Batch: [301/468], train accuracy: 0.976562, loss: 0.000445


 85%|████████▌ | 400/468 [03:44<00:35,  1.92it/s]

Epoch: [23/30], Batch: [401/468], train accuracy: 0.976562, loss: 0.000465


100%|██████████| 468/468 [04:23<00:00,  1.78it/s]


Epoch: [23/30], train loss: nan
Epoch: [23/30], test accuracy: 0.973500, loss: 0.111077


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [24/30], Batch: [1/468], train accuracy: 0.945312, loss: 0.001078


 21%|██▏       | 100/468 [00:56<03:10,  1.93it/s]

Epoch: [24/30], Batch: [101/468], train accuracy: 0.976562, loss: 0.000675


 43%|████▎     | 200/468 [01:53<02:21,  1.89it/s]

Epoch: [24/30], Batch: [201/468], train accuracy: 0.960938, loss: 0.000656


 64%|██████▍   | 300/468 [02:49<01:28,  1.89it/s]

Epoch: [24/30], Batch: [301/468], train accuracy: 0.976562, loss: 0.000574


 85%|████████▌ | 400/468 [03:46<00:35,  1.90it/s]

Epoch: [24/30], Batch: [401/468], train accuracy: 0.984375, loss: 0.000492


100%|██████████| 468/468 [04:25<00:00,  1.76it/s]


Epoch: [24/30], train loss: 0.000582
Epoch: [24/30], test accuracy: 0.981200, loss: 0.050320


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [25/30], Batch: [1/468], train accuracy: 0.992188, loss: 0.000422


 21%|██▏       | 100/468 [00:57<03:11,  1.92it/s]

Epoch: [25/30], Batch: [101/468], train accuracy: 0.992188, loss: 0.000456


 43%|████▎     | 200/468 [01:52<02:19,  1.93it/s]

Epoch: [25/30], Batch: [201/468], train accuracy: 0.960938, loss: 0.000619


 64%|██████▍   | 300/468 [02:48<01:27,  1.91it/s]

Epoch: [25/30], Batch: [301/468], train accuracy: 0.976562, loss: 0.000429


 85%|████████▌ | 400/468 [03:45<00:35,  1.92it/s]

Epoch: [25/30], Batch: [401/468], train accuracy: 1.000000, loss: 0.000345


100%|██████████| 468/468 [04:23<00:00,  1.77it/s]


Epoch: [25/30], train loss: nan
Epoch: [25/30], test accuracy: 0.981900, loss: 0.053330


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [26/30], Batch: [1/468], train accuracy: 0.968750, loss: 0.000611


 21%|██▏       | 100/468 [00:56<03:09,  1.94it/s]

Epoch: [26/30], Batch: [101/468], train accuracy: 0.968750, loss: 0.000625


 43%|████▎     | 200/468 [01:51<02:17,  1.94it/s]

Epoch: [26/30], Batch: [201/468], train accuracy: 0.976562, loss: 0.000474


 64%|██████▍   | 300/468 [02:47<01:29,  1.87it/s]

Epoch: [26/30], Batch: [301/468], train accuracy: 0.984375, loss: 0.000389


 85%|████████▌ | 400/468 [03:44<00:35,  1.91it/s]

Epoch: [26/30], Batch: [401/468], train accuracy: 0.992188, loss: 0.000452


100%|██████████| 468/468 [04:23<00:00,  1.78it/s]


Epoch: [26/30], train loss: 0.000419
Epoch: [26/30], test accuracy: 0.981100, loss: 0.053034


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [27/30], Batch: [1/468], train accuracy: 0.984375, loss: 0.000349


 21%|██▏       | 100/468 [00:56<03:10,  1.93it/s]

Epoch: [27/30], Batch: [101/468], train accuracy: 0.984375, loss: 0.000442


 43%|████▎     | 200/468 [01:51<02:19,  1.93it/s]

Epoch: [27/30], Batch: [201/468], train accuracy: 0.984375, loss: 0.000436


 64%|██████▍   | 300/468 [02:49<01:28,  1.89it/s]

Epoch: [27/30], Batch: [301/468], train accuracy: 0.968750, loss: 0.000708


 85%|████████▌ | 400/468 [03:45<00:35,  1.90it/s]

Epoch: [27/30], Batch: [401/468], train accuracy: 0.968750, loss: 0.000664


100%|██████████| 468/468 [04:24<00:00,  1.77it/s]


Epoch: [27/30], train loss: 0.000456
Epoch: [27/30], test accuracy: 0.982600, loss: 0.052301


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [28/30], Batch: [1/468], train accuracy: 0.984375, loss: 0.000489


 21%|██▏       | 100/468 [00:56<03:10,  1.93it/s]

Epoch: [28/30], Batch: [101/468], train accuracy: 0.960938, loss: 0.000470


 43%|████▎     | 200/468 [01:52<02:23,  1.87it/s]

Epoch: [28/30], Batch: [201/468], train accuracy: 0.984375, loss: 0.000511


 64%|██████▍   | 300/468 [02:49<01:27,  1.92it/s]

Epoch: [28/30], Batch: [301/468], train accuracy: 0.976562, loss: 0.000431


 85%|████████▌ | 400/468 [03:45<00:35,  1.92it/s]

Epoch: [28/30], Batch: [401/468], train accuracy: 0.984375, loss: 0.000374


100%|██████████| 468/468 [04:24<00:00,  1.77it/s]


Epoch: [28/30], train loss: 0.000397
Epoch: [28/30], test accuracy: 0.982600, loss: 0.049967


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [29/30], Batch: [1/468], train accuracy: 0.968750, loss: 0.000397


 21%|██▏       | 100/468 [00:56<03:10,  1.93it/s]

Epoch: [29/30], Batch: [101/468], train accuracy: 0.968750, loss: 0.000515


 43%|████▎     | 200/468 [01:53<02:22,  1.88it/s]

Epoch: [29/30], Batch: [201/468], train accuracy: 0.976562, loss: 0.000448


 64%|██████▍   | 300/468 [02:49<01:28,  1.90it/s]

Epoch: [29/30], Batch: [301/468], train accuracy: 1.000000, loss: 0.000400


 85%|████████▌ | 400/468 [03:45<00:35,  1.92it/s]

Epoch: [29/30], Batch: [401/468], train accuracy: 0.976562, loss: 0.000599


100%|██████████| 468/468 [04:24<00:00,  1.77it/s]


Epoch: [29/30], train loss: nan
Epoch: [29/30], test accuracy: 0.982100, loss: 0.048559


  0%|          | 0/468 [00:00<?, ?it/s]

Epoch: [30/30], Batch: [1/468], train accuracy: 0.960938, loss: 0.000543


 21%|██▏       | 100/468 [00:56<03:10,  1.93it/s]

Epoch: [30/30], Batch: [101/468], train accuracy: 0.976562, loss: 0.000458


 43%|████▎     | 200/468 [01:53<02:20,  1.91it/s]

Epoch: [30/30], Batch: [201/468], train accuracy: 0.976562, loss: 0.000439


 64%|██████▍   | 300/468 [02:49<01:27,  1.91it/s]

Epoch: [30/30], Batch: [301/468], train accuracy: 0.976562, loss: 0.000431


 85%|████████▌ | 400/468 [03:45<00:35,  1.93it/s]

Epoch: [30/30], Batch: [401/468], train accuracy: 0.992188, loss: 0.000359


100%|██████████| 468/468 [04:23<00:00,  1.77it/s]


Epoch: [30/30], train loss: nan
Epoch: [30/30], test accuracy: 0.980900, loss: 0.060096
