In [1]:
import torch 
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import os
import torch.nn as nn
import torchvision.transforms.functional as F
import torch.nn.functional as f
from torch.autograd import Variable
import datetime
n_classes=10


In [2]:
trainset = torchvision.datasets.MNIST(os.getcwd(),train=True,transform=F.to_tensor)
testset = torchvision.datasets.MNIST(os.getcwd(),train=False,transform=F.to_tensor)
# trainset = torchvision.datasets.CIFAR10(os.getcwd(),download=True,train=True,transform=F.to_tensor)
# cifar_testset = torchvision.datasets.CIFAR10(os.getcwd(),download=True,train=False,transform=F.to_tensor)

In [3]:
batch_size =2

In [4]:
train_loader = torch.utils.data.DataLoader(
    dataset=trainset,
    batch_size=batch_size,
    drop_last=True,

    shuffle=True,)
test_loader = torch.utils.data.DataLoader(
                dataset=testset,
                batch_size=batch_size,
                drop_last=True,
                shuffle=False)

In [5]:
class ConvLayer(nn.Module):
    def __init__(self,in_ch=1,out_ch=256,kernel_size=9,stride=1):
        super(ConvLayer,self).__init__()
        self.conv1 = torch.nn.Conv2d(in_ch,out_ch,kernel_size,stride)
    def squash(self,x):
        norms = torch.norm(x,dim=1,keepdim=True)
        x = ((norms/(1+norms**2))*x)
        return x
    def forward(self,x):
#         print(x.shape)
        x = self.squash(self.conv1(x))
        return x
    
class PrimaryCapsules(nn.Module):
    def __init__(self,in_ch=256,out_caps_grids=32,out_ch_of_each_caps=8,kernel_size=9,stride=2):
        super(PrimaryCapsules,self).__init__()
        self.out_caps_grids=out_caps_grids
        self.out_ch_of_each_caps=out_ch_of_each_caps
        self.prim_caps = torch.nn.Conv2d(in_ch,out_caps_grids*out_ch_of_each_caps,kernel_size,stride)
    def reshape(self,x):
        b,d1,d2,d3=x.shape
        x=x.permute(0,2,3,1)
        x=x.reshape(b,d2,d3,self.out_caps_grids,self.out_ch_of_each_caps)
        return x
    def squash(self,x):
        norms = torch.norm(x,dim=-1,keepdim=True)
        x = ((norms/(1+norms**2))*x)
        return x
    def forward(self,x):
        return self.squash(self.reshape(self.prim_caps(x)))
    
class DigitCaps(nn.Module):
    def __init__(self,input_grid=( 6, 6, 32),in_cap_dim=8,out_cap_dim=16,out_caps=10):
        super(DigitCaps,self).__init__()
        self.in_cap_dim=in_cap_dim
        self.num_caps = 1
        
        for i in input_grid:
            self.num_caps*=i
        self.W = torch.nn.Parameter(torch.randn(1, out_caps, self.num_caps, out_cap_dim, in_cap_dim))
        self.b = torch.nn.Parameter(torch.zeros(1, out_caps, self.num_caps,1,1),requires_grad=False)
    def reshape(self,x):
        batch_size=x.shape[0]
        x=x.reshape(batch_size,1,-1,self.in_cap_dim,1)
        return x
    def squash(self,x):
        norms = torch.norm(x,dim=-1,keepdim=True)
        x = ((norms/(1+norms**2))*x)
        return x
    def forward(self,x):
        batch_size=x.shape[0]
#         print('x',x.shape)
        u = self.reshape(x)
#         print('w',self.W.shape)
#         print('u',u.shape)
        
        uhat = torch.matmul(self.W,u)
#         print('uhat',uhat.shape)
        b = self.b
        for i in range(3):
#             print('iter ',i)
#             print('b',self.b.shape)
            c = torch.softmax(b,dim=2)
#             print('c',c.shape)
#             print(c.sum())
#             try:
            ahat= (uhat*c)
#             except:
#                 print('b',self.b.shape)

#                 print('uhat',uhat.shape)
#                 print('c',c.shape)

#             print('ahat',ahat.shape)
            a=ahat.sum(dim=2,keepdim=True).permute(0,1,2,4,3)
#             print('a',a.shape)
#             print('uhat',uhat.shape)

            adotuhat= torch.matmul(a,uhat).mean(dim=0,keepdim=True)
#             print('adotuhat.shape',adotuhat.shape)
#             print('b.shape',b.shape)
            b = b+adotuhat
            if self.training:
                self.b=torch.nn.Parameter(b,requires_grad=False)
        x = self.squash(a)
        return x



In [6]:
# model1 = ConvLayer()
# model2 = PrimaryCapsules()
# model3 = DigitCaps()
# for batch_id, (data, target) in enumerate(train_loader,1):
#     data = model1(data)
#     print(data.shape)
#     data = model2(data)
#     print(data.shape)
#     data = model3(data)
#     print(data.shape)
#     break

In [7]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder,self).__init__()
        self.l1=ConvLayer()
        self.l2=PrimaryCapsules()
        self.l3=DigitCaps()
    
    def forward(self,x):
        x= self.l1(x)
#         print(x.shape)
        x=self.l2(x)
#         print(x.shape)
        x=self.l3(x)
#         print(x.shape)
        
#         print(output.shape)
        x = x.squeeze().squeeze()
#         print('x.shape',x.shape)
        return x


In [8]:
class Decoder(nn.Module):
    def __init__(self,l0=16,l1=512,l2=1024,l3=784):
        super(Decoder,self).__init__()
        self.fc1=nn.Linear(l0,l1)
        self.fc2=nn.Linear(l1,l2)
        self.fc3=nn.Linear(l2,l3)

    def forward(self,x):
#         print(x.shape)
        x = self.fc1(x)
#         print(x.shape)
        x=f.relu(x)
        x = self.fc2(x)
        x=f.relu(x)
#         print(x.shape)
        x=self.fc3(x)
        x=torch.sigmoid(x)
#         print(x.shape)
        return x


In [9]:
class CapsNet3D(nn.Module):
    def __init__(self):
        super(CapsNet3D,self).__init__()
        self.enc = Encoder()
        self.dec = Decoder()
        self.encloss=nn.CrossEntropyLoss()
    def margin_loss(self,pred_prob,target):

#         print('pred_prob.shape',pred_prob.shape)
#         print('pred_prob[:5]',pred_prob[:5])
#         print('target.shape',target.shape)
#         print('target[:5]',target[:5])
        
        _, max_length_indices = pred_prob.max(dim=1)
#         print('_.shape',_.shape)
#         print('_[:5]',_[:5])
#         print('max_length_indices.shape',max_length_indices.shape)
#         print('max_length_indices[:5]',max_length_indices[:5])
        masked= Variable(torch.sparse.torch.eye(n_classes),requires_grad=False)
#         masked = torch.eye(n_classes)
        if USE_CUDA:
            masked=masked.cuda()
        
#         print('masked.shape',masked.shape)
#         print('masked',masked)
#         masked = masked.index_select(dim=0, index=max_length_indices.data)
        masked = masked.index_select(dim=0, index=target)

        Tk = masked
#         print("Tk.shape",Tk.shape)
#         print("Tk[0:5]",Tk[0:5])
        left = Tk*f.relu(0.9-pred_prob)**2
#         print('left.shape',left.shape)
        right = 0.5*(1-Tk)*f.relu(pred_prob-0.1)**2
#         print('right.shape',right.shape)
        total =(left+right).sum()
#         print('total.shape',total.shape)
        return total
    def loss(self,pred_prob,target_class,target_recon=None,decout=None):
#         pred_prob =torch.norm(encout,dim=-1,keepdim=False)
#         encoder_loss = self.encloss(pred_prob,target_class)
#         print(target_class[0])
#         print(pred_prob[0])
        margin = self.margin_loss(pred_prob,target_class)
        decoder_loss = ((target_recon-decout)**2).sum()
#         total = margin
#         total = encoder_loss,margin
#         print(total, typeof)
#         return encoder_loss,decoder_loss
        return margin,decoder_loss
    
    def fetch(self,encout):
#         print('fetching')
# #         print('encout.shape',encout.shape)
        pred_prob =torch.norm(encout,dim=-1,keepdim=False)#.squeeze()
#         print('prob',pred_prob.shape)
        
        _, max_length_indices = pred_prob.max(dim=1)
#         print('_.shape',_.shape)
#         print('_[:10]',_[:10])
#         print('max_length_indices.shape',max_length_indices.shape)
#         print('max_length_indices[:10]',max_length_indices[:10])
        masked= Variable(torch.sparse.torch.eye(n_classes),requires_grad=False)

#         masked = torch.sparse.torch.eye(n_classes)
        if USE_CUDA:
            masked=masked.cuda()
        
#         print('masked.shape',masked.shape)
#         print('masked',masked)
        masked = masked.index_select(dim=0, index=max_length_indices.data)
#         print('masked.shape',masked.shape)
#         print('masked',masked)
#         print('encout.shape',encout.shape)

        decin =torch.matmul(masked[:, None,:],encout).view(encout.size(0), -1)
#         print('decin',decin)
#         print('decin.shape',decin.shape)
#         print(decin[0,:])
#         print(encout[0,max_length_indices[0],:])
        return pred_prob,decin
        
    def forward(self,x):
        encout=self.enc(x)
        pred_prob,decin=self.fetch(encout)
        decout=self.dec(decin).reshape(-1,1,28,28)
#         print(decout.shape)
        return pred_prob,decout


In [10]:
caps=CapsNet3D().cuda()
# model =caps

In [11]:
loss_fn=caps.loss
model = nn.DataParallel(caps)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)


In [12]:
print(model)

DataParallel(
  (module): CapsNet3D(
    (enc): Encoder(
      (l1): ConvLayer(
        (conv1): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))
      )
      (l2): PrimaryCapsules(
        (prim_caps): Conv2d(256, 256, kernel_size=(9, 9), stride=(2, 2))
      )
      (l3): DigitCaps()
    )
    (dec): Decoder(
      (fc1): Linear(in_features=16, out_features=512, bias=True)
      (fc2): Linear(in_features=512, out_features=1024, bias=True)
      (fc3): Linear(in_features=1024, out_features=784, bias=True)
    )
    (encloss): CrossEntropyLoss()
  )
)


In [13]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params

8141840

In [14]:
k1=1
k2=0
USE_CUDA=True
n_epochs = 100
start = 0
l = 3
train_loss_accuracy_hist=np.zeros((l,n_epochs))
test_loss_accuracy_hist=np.zeros((l,n_epochs))
ltr= len(train_loader)
lts= len(test_loader)
for epoch in range(start,start+n_epochs):
    print('epoch/n_epochs',epoch,'/',n_epochs)
    model.train()
    train_loss_enc = 0
    train_loss_dec = 0
    train_pred = np.empty((batch_size), int)
    train_targets = np.empty((batch_size), int)
    test_targets = np.empty((batch_size), int)
    test_pred =  np.empty((batch_size), int)
    for batch_id, (data, target) in enumerate(train_loader,1):
#         target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
        print(batch_id,'/',ltr,end='\r')
        data, target = Variable(data), Variable(target)
        if USE_CUDA:
            data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
#         pred_prob = model(data)
        pred_prob,dec = model(data)

        loss1,loss2=loss_fn(pred_prob,target,data,dec)
#         loss=loss_fn(pred_prob,target,data,None)
#         loss1,loss2=loss_fn(pred_prob,target)
#         print(loss1.item(),loss2.item())
#         loss = loss1#+loss2
        loss1,loss2 = k1*loss1,k2*loss2 
        loss = loss1+loss2
        loss.backward()
        optimizer.step()
        train_loss_enc += loss1.item()
        train_loss_dec += loss2.item()
#         print(output.shape)
        train_pred = np.append(train_pred, np.argmax(pred_prob.data.cpu().numpy(), 1), axis=0)
        train_targets = np.append(train_targets, target.cpu(), axis=0)
#         print(pre1,pre2)
#         if batch_id % 1 == 0:
#             print('epoch',epoch, 'batch_id',batch_id,"train accuracy:", sum(train_pred[batch_size:] == train_targets[batch_size:]) / len(train_pred[batch_size:]),end='\r')
#         if batch_id == 1:
#             break
    train_targets=train_targets[batch_size:]
    train_pred=train_pred[batch_size:]
    train_loss_accuracy_hist[:,epoch]=[train_loss_enc/ len(train_loader),train_loss_dec / len(train_loader),sum(train_pred == train_targets) / len(train_pred)]

    
    print ('\ntrain loss enc',train_loss_accuracy_hist[0,epoch],'train loss dec',train_loss_accuracy_hist[1,epoch],'\ntrain accuracy',train_loss_accuracy_hist[-1,epoch],"\n")
#     if epoch == 0:
#         break

    model.eval()
    test_loss_enc = 0
    test_loss_dec = 0
    inxHistTest=[]
    for batch_id, (data, target) in enumerate(test_loader,1):
#         target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
        print('\t\t\t',batch_id,'/',lts,end='\r')

        data, target = Variable(data), Variable(target)

        if USE_CUDA:
            data, target = data.cuda(), target.cuda()
#         pred_prob = model(data)
        pred_prob,dec = model(data)
#         loss=loss_fn(target,pred_prob,data,None)
        loss1,loss2=loss_fn(pred_prob,target,data,dec)
#         print(loss1.item(),loss2.item())
        loss1,loss2 = k1*loss1,k2*loss2
#         loss=loss_fn(pred_prob,target,data,dec)
        loss = loss1+loss2
#         loss = loss2

        test_loss_enc += loss1.item()
        test_loss_dec += loss2.item()

#         print(test_pred.shape)
        test_pred = np.append(test_pred, np.argmax(pred_prob.data.cpu().numpy(), 1), axis=0)
#         print('test_pred',test_pred)
#         test_targets = np.append(test_targets, np.argmax(target.data.cpu().numpy(), 1), axis=0)
        test_targets = np.append(test_targets, target.cpu(), axis=0)
#         print('test_targets',test_targets)
#         if batch_id % 1 == 0:
#             print ("\t\t\t",'epoch',epoch, 'batch_id',batch_id,"test accuracy:", sum(test_pred[batch_size:] == test_targets[batch_size:]) / len(test_pred[batch_size:]),end='\r')
#         if batch_id % 3 == 0:
#             break
    test_targets=test_targets[batch_size:]
    test_pred=test_pred[batch_size:]
    test_loss_accuracy_hist[:,epoch]=[test_loss_enc/ len(test_loader),test_loss_dec / len(test_loader),sum(test_pred == test_targets) / len(test_pred)]
#     print('test_targets',test_targets,'test_pred',test_pred)
    print ('\n\t\t\ttest loss enc',test_loss_accuracy_hist[0,epoch],'test loss dec',test_loss_accuracy_hist[1,epoch],'\n\t\t\ttest accuracy',test_loss_accuracy_hist[-1,epoch],"\n")
#     if epoch==1:
#         break

epoch/n_epochs 0 / 100
870 / 30000

KeyboardInterrupt: 

In [None]:
plt.figure(figsize=(20,5))
# n_epochs=100
rang = list(range(n_epochs))

plt.subplot(131)
# plt.ylim(-0.1, 1.1)
# plt.xlim(0,1.1)
plt.plot(rang,train_loss_accuracy_hist[0,:n_epochs],label='train')
plt.plot(rang,test_loss_accuracy_hist[0,:n_epochs],label='test')
plt.tick_params(axis='y', which='both', labelleft='on', labelright='on')

plt.title("Encoder Loss")
plt.legend()

plt.subplot(132)
# plt.ylim(-0.1, 1.1)
# plt.xlim(0,1.1)
plt.plot(rang,train_loss_accuracy_hist[1,:n_epochs],label='train')
plt.plot(rang,test_loss_accuracy_hist[1,:n_epochs],label='test')
plt.title("Decoder Loss")
plt.tick_params(axis='y', which='both', labelleft='on', labelright='on')

plt.legend()

plt.subplot(133)
plt.ylim(-0.1, 1.1)
# plt.xlim(0,1)
plt.plot(rang,train_loss_accuracy_hist[-1,:n_epochs],label='train')
plt.plot(rang,test_loss_accuracy_hist[-1,:n_epochs],label='test')
plt.tick_params(axis='y', which='both', labelleft='on', labelright='on')
plt.title("Accuracy")
plt.legend()
now = datetime.datetime.now()
fname= f"results/{now.year}-{now.month}-{now.day}-{now.hour}-{now.minute}-results-mnist.png"
# plt.savefig(fname)
plt.show()
print(fname)

In [None]:
info =str(model)+'\n'+str(optimizer)
file = open(f"results/{now.year}-{now.month}-{now.day}-{now.hour}-{now.minute}-{params}-info-mnist.txt","w+")
file.write(info)
file.close()

In [None]:
USE_CUDA=True
n_epochs = 1
start = 0
for epoch in range(start,start+n_epochs):
    model.eval()
    for batch_id, (data, target) in enumerate(test_loader,1):
#         target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
        print('\t\t\t',batch_id,'/',lts,end='\r')

        data, target = Variable(data), Variable(target)

        if USE_CUDA:
            data, target = data.cuda(), target.cuda()
#         pred_prob = model(data)
        pred_prob,dec = model(data)
        print(pred_prob.shape)
        aa = pred_prob.max(dim=1)
        print(aa.indices.shape)
        print(target.shape)
        comp = aa.indices==target
#         print(comp)
        for idx,i in enumerate(comp):
            if not i:
                print('target ',target[idx].item(),' prediction ',aa.indices[idx].item())
                r = idx

                plt.imshow(data.cpu()[r,0])
                plt.show()
                plt.imshow(dec.detach().cpu().numpy()[r,0])
                plt.show()
#             if idx ==100:
#                 break
        break