In [1]:
from google.colab import drive
import os
import torch
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

In [2]:
print(f"CUDA is available? {torch.cuda.is_available()}")
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(dev)

CUDA is available? True
cuda


In [3]:
drive.mount("/content/drive",True)
root_dir = "/content/drive/My Drive/SB3/"

Mounted at /content/drive


In [4]:
train_transform = T.Compose([
    T.Resize(256),
    T.ToTensor()
])

In [5]:
train_dataset = ImageFolder(os.path.join(root_dir, "train"), transform=train_transform)
test_dataset = ImageFolder(os.path.join(root_dir, "test"), transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size = 4, num_workers=4, shuffle=True)
test_loader   = DataLoader(test_dataset,   batch_size = 4, num_workers=4, shuffle=False)
loaders = {
    "train": train_loader,
    "test": test_loader
}

In [6]:
num_classes = len(train_dataset.classes)
print(num_classes)

8


In [7]:
class ConvCapsLayer(nn.Module):
    def __init__(self, in_channel_types, out_channel_types, in_vector_length, out_vector_length, kernel_size, padding, stride, num_iterations, new_height, new_width):
      super().__init__()
      
      self.in_channel_types = in_channel_types
      self.out_channel_types = out_channel_types
      self.in_vector_length = in_vector_length
      self.out_vector_length = out_vector_length
      self.kernel_size = kernel_size
      self.padding = padding
      self.stride = stride
      self.num_iterations = num_iterations
      self.new_height = new_height
      self.new_width = new_width
      self.W = nn.Parameter(torch.randn(1, in_channel_types, 1, out_channel_types, out_vector_length, in_vector_length))
    
    def squash(self, in_tensor):
      squared_norm = (in_tensor ** 2).sum(-2, keepdim=True)
      out_tensor = squared_norm *  in_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm) + 1e-10)
      return out_tensor

    def forward(self, x):
      batch_size = x.size(0)
      y = torch.zeros(batch_size, self.in_vector_length, self.in_channel_types, self.new_height, self.new_width, dtype=torch.float).to(dev)
      kernel = torch.ones(1, 1, self.kernel_size, self.kernel_size, dtype=torch.float).to(dev)
      for i in range(self.in_vector_length):
        for j in range(self.in_channel_types):
          y[:,i,j,:,:] = F.conv2d(x[:,i,j,:,:].unsqueeze_(1), kernel, padding=self.padding, stride=self.stride).squeeze(1)

      y = y.view(batch_size, self.in_channel_types, self.new_height * self.new_width, -1)
      y = torch.stack([y] * self.out_channel_types, dim=3).unsqueeze(5)

      W = torch.cat([self.W] * batch_size, dim=0)
      c = W @ y

      b_ij = torch.zeros(1, self.in_channel_types, c.size(2), self.out_channel_types, 1).to(dev)

      for iteration in range(self.num_iterations):
        c_ij = F.softmax(b_ij, dim=-2)
        c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(5)

        s_j = (c_ij * c).sum(dim=1, keepdim=True)
        v_j = self.squash(s_j)

        if iteration < self.num_iterations - 1:
          a_ij = torch.matmul(c.transpose(4, 5), torch.cat([v_j] * self.in_channel_types, dim=1))
          b_ij = b_ij + a_ij.squeeze(5).mean(dim=0, keepdim=True)

      v_j = v_j.squeeze(1)
      v_j = v_j.transpose(3,1).squeeze(4)
      v_j = v_j.view(v_j.size(0), v_j.size(1), v_j.size(2), self.new_height, self.new_width)
      return v_j

In [11]:
class Reconstruction(nn.Module):
    def __init__(self, num_classes):
      super().__init__()
      
      self.fc_layer = nn.Sequential(
        nn.Linear(num_classes*16, 80),
        nn.ReLU(inplace=True)
      )

      self.reconstruction_layers = nn.Sequential(
          nn.ConvTranspose2d(1, 128, kernel_size=5, padding=2, stride=8, output_padding = (7,7)),
          nn.ReLU(inplace=True),
          nn.ConvTranspose2d(128, 64, kernel_size=5, padding=2, stride=8, output_padding = (7,7)),
          nn.ReLU(inplace=True),
          nn.Conv2d(64, 3, kernel_size=5, padding=2, stride=1),
          nn.ReLU(inplace=True)
    )
        
    def forward(self, x):
        x = self.fc_layer(x)
        print(x.size())
        x = x.view(x.size(0), 1, 8, 10)
        x = self.reconstruction_layers(x)
        return x

In [12]:
class CapsNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, padding=2, stride=2),
            nn.ReLU(inplace=True)
        )
        
        self.conv_capsule_layers = nn.Sequential(
            ConvCapsLayer(in_channel_types = 1, out_channel_types = 2, in_vector_length = 16, out_vector_length = 16, kernel_size = 5, padding = 2, stride = 2, num_iterations = 1, new_height = 64, new_width = 64),
            ConvCapsLayer(in_channel_types = 2, out_channel_types = 4, in_vector_length = 16, out_vector_length = 16, kernel_size = 5, padding = 2, stride = 1, num_iterations = 3, new_height = 64, new_width = 64),
            ConvCapsLayer(in_channel_types = 4, out_channel_types = 4, in_vector_length = 16, out_vector_length = 32, kernel_size = 5, padding = 2, stride = 2, num_iterations = 3, new_height = 32, new_width = 32),
            ConvCapsLayer(in_channel_types = 4, out_channel_types = 8, in_vector_length = 32, out_vector_length = 32, kernel_size = 5, padding = 2, stride = 1, num_iterations = 3, new_height = 32, new_width = 32),
            ConvCapsLayer(in_channel_types = 8, out_channel_types = 8, in_vector_length = 32, out_vector_length = 64, kernel_size = 5, padding = 2, stride = 2, num_iterations = 3, new_height = 16, new_width = 16),
            ConvCapsLayer(in_channel_types = 8, out_channel_types = 8, in_vector_length = 64, out_vector_length = 32, kernel_size = 5, padding = 2, stride = 1, num_iterations = 3, new_height = 16, new_width = 16),
            ConvCapsLayer(in_channel_types = 8, out_channel_types = num_classes, in_vector_length = 32, out_vector_length = 16, kernel_size = 5, padding = 2, stride = 2, num_iterations = 3, new_height = 8, new_width = 8)
        )

        self.reconstruction_layer = Reconstruction(num_classes=2)

    def capsule_average_pooling(self, x):
      height = x.size(3)
      width = x.size(4)
      x = x.sum(dim=4).sum(dim=3)
      x = x / (height * width)
      return x
    
    def score(self, x):
      return torch.sqrt((x ** 2).sum(dim=2))
    
    def forward(self, x):
      x = self.conv_layers(x)
      x = x.unsqueeze(2)
      x = self.conv_capsule_layers(x)
      x = x.transpose(1, 2)
      x = self.capsule_average_pooling(x)
      scores = self.score(x)
      x = x.view(x.size(0), x.size(1) * x.size(2))
      reconstructions = self.reconstruction_layer(x)
      return scores, reconstructions

In [13]:
batch,labels = next(iter(train_loader))
batch = batch.to(dev)
labels = labels.to(dev)
model = CapsNet(num_classes=2)
model = model.to(dev)
scores, reconstructions = model(batch)
print(reconstructions.size())

torch.Size([4, 80])
torch.Size([4, 3, 512, 640])
