<a href="https://colab.research.google.com/github/ehgus/CS420CompilerDesign/blob/master/CS470_capsnet_revision.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Capsnet
- dataset: mnist
- model: capsnet+FCnet

Future usage: decomposition of overlaped images

Future task
1. Visualize the model
2. Check each layers are functioning as intended(freezing some layer, scheduler, etc)
3. stabilize the code on gpu environment
4. check squeeze function also works well on batch==1 case

# 0. Module importing

In [1]:
#For path setting and image loading
import os
from torchvision import transforms, datasets

#For implementing train part
import torch
import torch.nn as nn
from torch.nn.init import constant_
import torch.nn.functional as F
from torch.optim import Adam, lr_scheduler


#For image creation
import numpy as np
from matplotlib import pyplot as plt
import math
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device setting: {device}")

device setting: cpu


# 1. Gdrivce mounting
Mount google drive to the program. 

This might change your drive. It's recommended to executing it <b>locally</b>.

In [2]:
from google.colab import drive

drive.mount('/gdrive')
gdrive_root = '/gdrive/My Drive'

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


# 2. Data loading module
Module for loading mnist dataset with batch parsing.

In [0]:
def load_mnist(path, batch_size=100, shift_pixels=2):
    """
    Construct dataloaders for training and test data. Data augmentation is also done here.
    :param path: file path of the dataset
    :param batch_size: batch size
    :param shift_pixels: maximum number of pixels to shift in each direction
    
    :return: train_loader, test_loader
    """
    
    kwargs = {'num_workers': 1, 'pin_memory': True}
    train_loader = torch.utils.data.DataLoader(
          datasets.MNIST(path, train=True,download=True, transform=transforms.Compose([transforms.RandomCrop(size=28, padding=shift_pixels), transforms.ToTensor()])), batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
          datasets.MNIST(path, train=False,download=True, transform=transforms.ToTensor()), batch_size=batch_size, shuffle=True, **kwargs)

    return train_loader, test_loader

# 3.1. CapsNet Network layer modules
modules of capsnet architecture

<b>def</b> squash
- Vector is squashed with lengh 0~1.

<b>class</b> DenseCapsule
- Core part with dynamic routing

<b>class</b> PrimaryCapsule
- Vector creation

In [0]:
def squash(inputs, axis=-1):
    """
    The non-linear activation used in Capsule. It drives the length of a large vector to near 1 and small vector to 0
    :param inputs: vectors to be squashed
    :param axis: the axis to squash. negative axis means axis=axis+dimension-1.

    :return: a Tensor with same size as inputs
    """
    norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)
    scale = norm**2 / (1 + norm**2) / (norm + 1e-8)
    return scale * inputs


class DenseCapsule(nn.Module):
    """
    The dense capsule layer. It is similar to Dense (FC) layer. Dense layer has `in_num` inputs, each is a scalar, the
    output of the neuron from the former layer, and it has `out_num` output neurons. DenseCapsule just expands the
    output of the neuron from scalar to vector. So its input size = [None, in_num_caps, in_dim_caps] and output size = \
    [None, out_num_caps, out_dim_caps]. For Dense Layer, in_dim_caps = out_dim_caps = 1.
    :param in_num_caps: number of cpasules inputted to this layer
    :param in_dim_caps: dimension of input capsules
    :param out_num_caps: number of capsules outputted from this layer
    :param out_dim_caps: dimension of output capsules
    :param routings: number of iterations for the routing algorithm
    """
    def __init__(self,in_num_caps, in_dim_caps, out_num_caps, out_dim_caps, routings=3):
        super(DenseCapsule, self).__init__()
        self.in_num_caps = in_num_caps
        self.in_dim_caps = in_dim_caps
        self.out_num_caps = out_num_caps
        self.out_dim_caps = out_dim_caps
        self.routings = routings
        self.weight = nn.Parameter( 0.01*torch.randn(out_num_caps, in_num_caps, out_dim_caps, in_dim_caps))

    def forward(self, x):
        #assertion and initalization
        assert self.routings >=1, "The 'routings' should be > 0."
        '''
        step1: Affine transforatmion
        '''
        # weight.size  =      [ out_num_caps, in_num_caps, out_dim_caps, in_dim_caps]
        # x.size       =[batch, 1           , in_num_caps,  in_dim_caps, 1          ]
        #rst:x_hat.size=[batch, out_num_caps, in_num_caps, out_dim_caps]

        # torch.matmul: [out_dim_caps, in_dim_caps] x [in_dim_caps, 1] -> [out_dim_caps, 1]
        # squeeze     : [batch,out_num_caps, in_num_caps, out_dim_caps, 1] -> [batch,out_num_caps, in_num_caps, out_dim_caps]

        x_hat = torch.squeeze(torch.matmul(self.weight, x[:, None, :, :, None]),dim=-1)
        b = torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps).to(device)
        '''
        step2: dynamic routing
        '''
        # For iteration, use detached matrix so that no gradients flow on this path.
        # In backward, no gradient can flow from `x_hat_detached`.
        x_hat_detached = x_hat.detach()
        
        for i in range(self.routings-1):
            # c.size = [batch, out_num_caps, in_num_caps]
            c = F.softmax(b, dim=-1)
            
            # c.size expanded to [batch, out_num_caps, in_num_caps, 1           ]
            # x_hat.size     =   [batch, out_num_caps, in_num_caps, out_dim_caps]
            # => outputs.size=   [batch, out_num_caps, 1          , out_dim_caps]     
            outputs = squash(torch.sum(c[:, :, :, None] * x_hat_detached,dim=-2,keepdim=True))
                
            # outputs.size       =[batch, out_num_caps, None,      , out_dim_caps]
            # x_hat_detached.size=[batch, out_num_caps, in_num_caps, out_dim_caps]
            # => b.size          =[batch, out_num_caps, in_num_caps]
            b =b+ torch.sum(outputs * x_hat_detached, dim=-1)

        c = F.softmax(b, dim=-1)
        # => outputs.size=   [batch, out_num_caps, out_dim_caps]
        outputs=squash(torch.sum(c[:, :, :, None] * x_hat,dim=-2))

        return outputs

class PrimaryCapsule(nn.Module):
    """
    Apply Conv2D with `out_channels` and then reshape to get capsules
    :param in_channels: input channels
    :param out_channels: output channels
    :param dim_caps: dimension of capsule
    :param kernel_size: kernel size
    :return: output tensor, size=[batch, num_caps, dim_caps]
    """
    def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride=1, padding=0):
        super(PrimaryCapsule, self).__init__()
        self.dim_caps = dim_caps
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,groups=in_channels//dim_caps,bias=False)
        # Set requires_grad to False
        #conv2 initialization: no routing algorithm
        #Each values of conv filters are supposed to be calculated by softmax function, setting to 1 to all parameter doesn't matter. Denominator than wil be movoed on previous conv network.
        # self.conv2d.weight=torch.ones_like(self.conv2d.weight)
        self.conv2d.weight=nn.Parameter(torch.ones_like(self.conv2d.weight),requires_grad=False)
        

    def forward(self, x):
        # outputs.size= [batch, in_num_caps, in_dim_caps]
        outputs = squash(self.conv2d(x).view(x.size(0), -1, self.dim_caps))
        return outputs

#3.2. CapsNet Network

In [0]:
class CapsuleNet(nn.Module):
    """
    :param input_size: data size = [channels, width, height]
    :param classes: number of classes
    :param routings: number of routing iterations
    Shape:
        - Input: (batch, channels, width, height), optional (batch, classes) .
        - Output:((batch, classes), (batch, channels, width, height))
    """
    def __init__(self ,input_size, classes,num_labels,routings):
        super(CapsuleNet, self).__init__()
        self.input_size = input_size
        self.classes = classes
        self.routings = routings

        #encoder network
        self.encoder=nn.Sequential(
            nn.Conv2d(input_size[0], 256, kernel_size=9, stride=1, padding=0),
            nn.ReLU(),
            PrimaryCapsule(256, 256, 8, kernel_size=9, stride=2, padding=0),
            DenseCapsule(in_num_caps=32*6*6, in_dim_caps=8, out_num_caps=classes, out_dim_caps=16, routings=routings)
        )
        # Decoder network.
        self.decoder = nn.Sequential(
            nn.Linear(16*classes, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, input_size[0] * input_size[1] * input_size[2]),
            nn.Sigmoid()
        )


    def forward(self, x, y=None):
        x = self.encoder(x)
        length = x.norm(p=2, dim=-1)
        # during testing, no label given. create one-hot coding using 'length'
        if y is None: 
            index = length.max(dim=1)[1] #return named indices
            y = torch.zeros(length.size()).scatter_(1, index.view(-1, 1).cpu().data, 1.).to(device)
        reconstruction = self.decoder((x * y[:, :, None]).view(x.size(0), -1))

        return length, reconstruction.view(-1, *self.input_size)

def caps_loss(y_true, y_pred, x, x_recon, lam_recon):
    """
    Capsule loss = Margin loss + lam_recon * reconstruction loss.
    :param y_true: true labels, one-hot coding, size=[batch, classes]
    :param y_pred: predicted labels by CapsNet, size=[batch, classes]
    :param x: input data, size=[batch, channels, width, height]
    :param x_recon: reconstructed data, size is same as `x`
    :param lam_recon: coefficient for reconstruction loss
    :return: variable contains a scalar loss value.
    """
    L = y_true * torch.clamp(0.9 - y_pred, min=0.) ** 2 + \
        0.5 * (1 - y_true) * torch.clamp(y_pred - 0.1, min=0.) ** 2
    L_margin = L.sum(dim=1).mean()

    L_recon = nn.MSELoss()(x_recon, x)
    return L_margin + lam_recon * L_recon


def show_reconstruction(model, test_loader, n_images, args):
    import matplotlib.pyplot as plt
    from utils import combine_images
    from PIL import Image
    import numpy as np

    model.eval()
    for x, _ in test_loader:
        x = x[:min(n_images, x.size(0))].to(device)
        _, x_recon = model(x)
        data = np.concatenate([x.data, x_recon.data])
        img = combine_images(np.transpose(data, [0, 2, 3, 1]))
        image = img * 255
        Image.fromarray(image.astype(np.uint8)).save(args['save_dir'] + "/real_and_recon.png")
        print()
        print('Reconstructed images are saved to %s/real_and_recon.png' % args['save_dir'])
        print('-' * 70)
        plt.imshow(plt.imread(args['save_dir'] + "/real_and_recon.png", ))
        plt.show()
        break

#4. model initialization

In [14]:
# setting the hyper parameters
args={'epochs': 50,
      'batch_size':100,
      'lr':0.001,                         #Initial learning rate
      'lr_decay':0.9,                     #The value multiplied by lr at each epoch. Set a larger value for larger epochs
      'lam_recon': 0.0005 * 784,          #The coefficient for the loss of decoder
      'data_dir':'/my_data',                #Directory of data. If no data, use \'--download\' flag to download it
      'save_dir':'/my_data'
      }

model = CapsuleNet(input_size=[1, 28, 28], classes=10,num_labels=10, routings=3)
model = model.to(device)
#optimizer setting
optimizer = Adam(model.parameters(), lr=args['lr'])
lr_decay = lr_scheduler.ExponentialLR(optimizer, gamma=args['lr_decay'])

# Print your neural network structure
print(model)

CapsuleNet(
  (encoder): Sequential(
    (0): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))
    (1): ReLU()
    (2): PrimaryCapsule(
      (conv2d): Conv2d(256, 256, kernel_size=(9, 9), stride=(2, 2), groups=32, bias=False)
    )
    (3): DenseCapsule()
  )
  (decoder): Sequential(
    (0): Linear(in_features=160, out_features=512, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=512, out_features=1024, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=1024, out_features=784, bias=True)
    (5): Sigmoid()
  )
)


# 5. Loading pre-trained model if exist



In [15]:
#check point 
ckpt_dir = os.path.join(gdrive_root, 'checkpoints')
if not os.path.exists(ckpt_dir):
  os.makedirs(ckpt_dir)
  
best_acc = 0.
ckpt_path = os.path.join(ckpt_dir, 'lastest.pt')
if os.path.exists(ckpt_path):
  ckpt = torch.load(ckpt_path)
  try:
    model.load_state_dict(ckpt['model'])
    optimizer.load_state_dict(ckpt['optimizer'])
    best_acc = ckpt['best_acc']
  except RuntimeError as e:
      print('wrong checkpoint')
  else:    
    print('checkpoint is loaded !')
    print('current best accuracy : %.2f' % best_acc)

checkpoint is loaded !
current best accuracy : 0.00


# 6. Training and testing model

In [0]:
def test(model, test_loader, args):
    model.eval()
    test_loss = 0
    correct = 0
    for x, y in test_loader:

        # change to one-hot coding
        # convert input data to device
        y = torch.zeros(y.size(0), 10).scatter(1, y.view(-1, 1), 1.)
        x = x.to(device)
        y = y.to(device)

        #test
        y_pred, x_recon = model(x)
        # sum up batch loss
        test_loss += caps_loss(y, y_pred, x, x_recon, args['lam_recon']).data * x.size(0)  
        y_pred = y_pred.data.max(1)[1]
        y_true = y.data.max(1)[1]
        correct += torch.sum(y_pred.eq(y_true)) 

    test_loss /= len(test_loader.dataset)
    return test_loss, correct / len(test_loader.dataset)


def train(model, train_loader, test_loader, args, best_acc):
    """
    Training a CapsuleNet
    :param model: the CapsuleNet model
    :param train_loader: torch.utils.data.DataLoader for training data
    :param test_loader: torch.utils.data.DataLoader for test data
    :param args: arguments
    :return: The trained model
    """
    print('Begin Training' + '-'*70)
    from time import time
    
    t0 = time()

    
    for epoch in range(args['epochs']):
        #train phase    
        model.train()

        # decrease the learning rate by multiplying a factor `gamma`
        lr_decay.step()  
        ti = time()
        
        training_loss = 0.0
        for x, y in train_loader:
            
            # change to one-hot coding
            # convert input data to GPU Variable
            y = torch.zeros(y.size(0), 10).scatter_(1, y.view(-1, 1), 1.)

            # 수정 전: x, y = Variable(x.cuda()), Variable(y.cuda())
            x = x.to(device)
            y = y.to(device)

            # set gradients of optimizer to zero
            optimizer.zero_grad()  

            # forward
            y_pred, x_recon = model(x, y)  
            
            #compute loss
            loss = caps_loss(y, y_pred, x, x_recon, args['lam_recon'])

            # backward
            loss.backward() 

            #수정 전 loss.data[0] * x.size(0)
            training_loss += loss.data * x.size(0)  # record the batch loss
            optimizer.step()  # update the trainable parameters with computed gradients

        # compute validation loss and acc
        test_loss, test_acc = test(model, test_loader, args)
        
        print((f"==> Epoch {epoch}: "
               f" training loss={training_loss / len(train_loader.dataset):.5f}, "
               f" test loss={test_loss:.5f}, "
               f" test acc={test_acc:.4f}, "
               f" iteration time={time()-ti:.1f}s"))
        if test_acc >= best_acc:  # update best validation acc and save model
            best_acc = test_acc
            # Note: optimizer also has states ! don't forget to save them as well.
            ckpt = {'model':model.state_dict(),
                    'optimizer':optimizer.state_dict(),
                    'best_acc':best_acc}
            torch.save(ckpt, ckpt_path)
            print('checkpoint is saved !')
    print("Total time = %ds" % (time() - t0))
    print('End Training' + '-' * 70)
    return model

In [9]:
##problem: util module cannot be imported: use other methods
##parser makes error: redisign it




if True:
    # load data
    train_loader, test_loader = load_mnist(path=gdrive_root+args['data_dir'], batch_size=args['batch_size'])
    # train model
    train(model, train_loader, test_loader, args, best_acc)

Begin Training----------------------------------------------------------------------




==> Epoch 0:  training loss=0.03913,  test loss=0.03168,  test acc=0.0000,  iteration time=938.4s
checkpoint is saved !
==> Epoch 1:  training loss=0.03803,  test loss=0.03147,  test acc=0.0000,  iteration time=913.2s
checkpoint is saved !
==> Epoch 2:  training loss=0.03681,  test loss=0.03041,  test acc=0.0000,  iteration time=860.0s
checkpoint is saved !
==> Epoch 3:  training loss=0.03608,  test loss=0.03020,  test acc=0.0000,  iteration time=824.9s
checkpoint is saved !
==> Epoch 4:  training loss=0.03529,  test loss=0.02993,  test acc=0.0000,  iteration time=825.4s
checkpoint is saved !
==> Epoch 5:  training loss=0.03482,  test loss=0.02903,  test acc=0.0000,  iteration time=818.5s
checkpoint is saved !
==> Epoch 6:  training loss=0.03375,  test loss=0.02967,  test acc=0.0000,  iteration time=828.6s
checkpoint is saved !
==> Epoch 7:  training loss=0.03335,  test loss=0.03013,  test acc=0.0000,  iteration time=826.9s
checkpoint is saved !
==> Epoch 8:  training loss=0.03269,  te

KeyboardInterrupt: ignored

# 7. Plotting the result

In [0]:
def show_reconstruction(model, test_loader, n_images):

    model.eval()
    for x, _ in test_loader:
        #수정 전: x = Variable(x[:min(n_images, x.size(0))].cuda(), volatile=True)
        x = x[:min(n_images, x.size(0))].to(device)
        _, x_recon = model(x)
        data = np.concatenate([x.data, x_recon.data])
        img = combine_images(np.transpose(data, [0, 2, 3, 1]))
        image = img * 255
        Image.fromarray(image.astype(np.uint8)).save(gdrive+ "my_data/real_and_recon.png")
        print()
        print('Reconstructed images are saved to my_data/real_and_recon.png')
        print('-' * 70)
        plt.imshow(plt.imread(gdrive + "my_data/real_and_recon.png", ))
        plt.show()
        break

def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num)/width))
    shape = generated_images.shape[1:3]
    image = np.zeros((height*shape[0], width*shape[1]),
                     dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \
            img[:, :, 0]
    return image


if __name__=="__main__":
    plot_log('result/log.csv')

In [0]:
if False:
  test_loss, test_acc = test(model, test_loader,args)
  print(f'test acc = {test_acc:.4f}, test loss = {test_loss:.5f}')
  show_reconstruction(model, test_loader, 50)