In [3]:
from google.colab import drive
import os
import torch
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
from matplotlib import pyplot as plt
import random
import torch.optim as optim
import time

In [4]:
print(f"CUDA is available? {torch.cuda.is_available()}")
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(dev)

CUDA is available? True
cuda


In [5]:
drive.mount("/content/drive",True)
root_dir = "/content/drive/My Drive/SB3/"

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [7]:
train_transform = T.Compose([
    T.Resize(256),
    T.ToTensor()
])

In [8]:
train_dataset = ImageFolder(os.path.join(root_dir, "train"), transform=train_transform)
test_dataset = ImageFolder(os.path.join(root_dir, "test"), transform=train_transform)
num_classes = len(train_dataset.classes)
print(num_classes)

8


In [9]:
num_train = len(train_dataset)
train_idx = list(range(num_train))
random.shuffle(train_idx)
val_frac = 0.2
num_val = int(num_train*val_frac)
num_train = num_train - num_val
val_idx = train_idx[num_train:]
train_idx = train_idx[:num_train]
val_dataset = Subset(train_dataset,val_idx)
train_dataset = Subset(train_dataset, train_idx)

In [10]:
train_loader = DataLoader(train_dataset, batch_size=4, num_workers=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, num_workers=4, shuffle=False)
test_loader = DataLoader(test_dataset,   batch_size=4, num_workers=4, shuffle=False)
loaders = {
    "train": train_loader,
    "val" : val_loader,
    "test": test_loader
}

In [11]:
class ConvCapsLayer(nn.Module):
    def __init__(self, in_channel_types, out_channel_types, in_vector_length, out_vector_length, kernel_size, padding, stride, num_iterations, new_height, new_width):
      super().__init__()
      
      self.in_channel_types = in_channel_types
      self.out_channel_types = out_channel_types
      self.in_vector_length = in_vector_length
      self.out_vector_length = out_vector_length
      self.kernel_size = kernel_size
      self.padding = padding
      self.stride = stride
      self.num_iterations = num_iterations
      self.new_height = new_height
      self.new_width = new_width
      self.W = nn.Parameter(torch.randn(1, in_channel_types, 1, out_channel_types, out_vector_length, in_vector_length))
    
    def squash(self, in_tensor):
      squared_norm = (in_tensor ** 2).sum(-2, keepdim=True)
      out_tensor = squared_norm / (1 + squared_norm) * in_tensor / (torch.sqrt(squared_norm) + 1e-8)
      return out_tensor

    def forward(self, x):
      batch_size = x.size(0)
      y = torch.zeros(batch_size, self.in_vector_length, self.in_channel_types, self.new_height, self.new_width, dtype=torch.float).to(dev)
      kernel = torch.ones(1, 1, self.kernel_size, self.kernel_size, dtype=torch.float).to(dev)
      for i in range(self.in_vector_length):
        for j in range(self.in_channel_types):
          y[:,i,j,:,:] = F.conv2d(x[:,i,j,:,:].unsqueeze_(1), kernel, padding=self.padding, stride=self.stride).squeeze(1)

      y = y.view(batch_size, self.in_channel_types, self.new_height * self.new_width, -1)
      y = torch.stack([y] * self.out_channel_types, dim=3).unsqueeze(5)

      c = self.W @ y

      temp_c = c.detach()

      b_ij = torch.zeros(batch_size, self.in_channel_types, temp_c.size(2), self.out_channel_types, 1).to(dev)

      for iteration in range(self.num_iterations - 1):
        c_ij = F.softmax(b_ij, dim=-2).unsqueeze(5)
        
        s_j = (c_ij * temp_c).sum(dim=1, keepdim=True)
        v_j = self.squash(s_j)

        a_ij = torch.matmul(temp_c.transpose(4, 5), torch.cat([v_j] * self.in_channel_types, dim=1))
        b_ij = b_ij + a_ij.squeeze(5).mean(dim=0, keepdim=True)
      
      c_ij = F.softmax(b_ij, dim=-2).unsqueeze(5)
      
      s_j = (c_ij * c).sum(dim=1, keepdim=True)
      v_j = self.squash(s_j)
  
      v_j = v_j.squeeze(1)
      v_j = v_j.transpose(3,1).squeeze(4)
      v_j = v_j.view(v_j.size(0), v_j.size(1), v_j.size(2), self.new_height, self.new_width)
      return v_j

In [12]:
class Reconstruction(nn.Module):
    def __init__(self, num_classes):
      super().__init__()
      
      self.fc_layer = nn.Sequential(
        nn.Linear(num_classes*16, 64),
        nn.ReLU()
      )

      self.reconstruction_layers = nn.Sequential(
          nn.ConvTranspose2d(1, 128, kernel_size=5, padding=2, stride=9),
          nn.ReLU(),
          nn.ConvTranspose2d(128, 64, kernel_size=5, padding=1, stride=4, output_padding=(1,1)),
          nn.ReLU(),
          nn.Conv2d(64, 3, kernel_size=5, padding=2, stride=1),
          nn.ReLU()
    )
        
    def forward(self, x):
        x = self.fc_layer(x)
        x = x.view(x.size(0), 1, 8, 8)
        x = self.reconstruction_layers(x)
        return x

In [16]:
class CapsNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, padding=2, stride=2),
            nn.ReLU()
        )
        
        self.conv_capsule_layers = nn.Sequential(
            ConvCapsLayer(in_channel_types = 1, out_channel_types = 2, in_vector_length = 16, out_vector_length = 16, kernel_size = 5, padding = 2, stride = 2, num_iterations = 1, new_height = 64, new_width = 64),
            ConvCapsLayer(in_channel_types = 2, out_channel_types = 4, in_vector_length = 16, out_vector_length = 16, kernel_size = 5, padding = 2, stride = 1, num_iterations = 3, new_height = 64, new_width = 64),
            ConvCapsLayer(in_channel_types = 4, out_channel_types = 4, in_vector_length = 16, out_vector_length = 32, kernel_size = 5, padding = 2, stride = 2, num_iterations = 3, new_height = 32, new_width = 32),
            ConvCapsLayer(in_channel_types = 4, out_channel_types = 8, in_vector_length = 32, out_vector_length = 32, kernel_size = 5, padding = 2, stride = 1, num_iterations = 3, new_height = 32, new_width = 32),
            ConvCapsLayer(in_channel_types = 8, out_channel_types = 8, in_vector_length = 32, out_vector_length = 64, kernel_size = 5, padding = 2, stride = 2, num_iterations = 3, new_height = 16, new_width = 16),
            ConvCapsLayer(in_channel_types = 8, out_channel_types = 8, in_vector_length = 64, out_vector_length = 32, kernel_size = 5, padding = 2, stride = 1, num_iterations = 3, new_height = 16, new_width = 16),
            ConvCapsLayer(in_channel_types = 8, out_channel_types = num_classes, in_vector_length = 32, out_vector_length = 16, kernel_size = 5, padding = 2, stride = 2, num_iterations = 3, new_height = 8, new_width = 8)
        )

        self.reconstruction_layer = Reconstruction(num_classes)

    def capsule_average_pooling(self, x):
      height = x.size(3)
      width = x.size(4)
      x = x.sum(dim=4).sum(dim=3)
      x = x / (height * width)
      return x
    
    def score(self, x):
      return torch.sqrt((x ** 2).sum(dim=2))
    
    def forward(self, x):
      x = self.conv_layers(x)
      x = x.unsqueeze(2)
      x = self.conv_capsule_layers(x)
      x = x.transpose(1, 2)
      x = self.capsule_average_pooling(x)
      scores = self.score(x)
      x = x.view(x.size(0), x.size(1) * x.size(2))
      reconstructions = self.reconstruction_layer(x)
      return scores, reconstructions

In [17]:
model = CapsNet(num_classes=num_classes)
model = model.to(dev)

In [18]:
batch,labels = next(iter(train_loader))
batch = batch.to(dev)
labels = labels.to(dev)
scores, reconstructions = model(batch)

In [19]:
def train(epochs, dev, lr=0.001):
    try:
        # Create model
        model = CapsNet(num_classes=num_classes)
        model = model.to(dev)
        # Optimizer
        optimizer = optim.Adam(model.parameters(), lr=lr)
        # Initialize history
        history_loss = {"train": [], "val": [], "test": []}
        history_accuracy = {"train": [], "val": [], "test": []}
        # Process each epoch
        for epoch in range(epochs):
            # Initialize epoch variables
            sum_loss = {"train": 0, "val": 0, "test": 0}
            sum_accuracy = {"train": 0, "val": 0, "test": 0}
            # Process each split
            for split in ["train", "val", "test"]:
                if split == "train":
                  model.train()
                else:
                  model.eval()
                # Process each batch
                for (input, labels) in loaders[split]:
                    # Move to CUDA
                    input = input.to(dev)
                    labels = labels.to(dev)
                    # Reset gradients
                    optimizer.zero_grad()
                    # Compute output
                    pred, reconstructions = model(input)
                    print(pred)
                    score_loss = F.cross_entropy(pred, labels)
                    reconstruction_loss = F.mse_loss(input, reconstructions)
                    loss = score_loss + reconstruction_loss
                    # Update loss
                    sum_loss[split] += loss.item()
                    # Check parameter update
                    if split == "train":
                        # Compute gradients
                        loss.backward()
                        # Optimize
                        optimizer.step()
                    # Compute accuracy
                    _,pred_labels = pred.max(1)
                    batch_accuracy = (pred_labels == labels).sum().item()/input.size(0)
                    # Update accuracy
                    sum_accuracy[split] += batch_accuracy
            # Compute epoch loss/accuracy
            epoch_loss = {split: sum_loss[split]/len(loaders[split]) for split in ["train", "val", "test"]}
            epoch_accuracy = {split: sum_accuracy[split]/len(loaders[split]) for split in ["train", "val", "test"]}
            # Update history
            for split in ["train", "val", "test"]:
                history_loss[split].append(epoch_loss[split])
                history_accuracy[split].append(epoch_accuracy[split])
            # Print info
            print(f"Epoch {epoch+1}:",
                  f"TrL={epoch_loss['train']:.4f},",
                  f"TrA={epoch_accuracy['train']:.4f},",
                  f"VL={epoch_loss['val']:.4f},",
                  f"VA={epoch_accuracy['val']:.4f},",
                  f"TeL={epoch_loss['test']:.4f},",
                  f"TeA={epoch_accuracy['test']:.4f},"
                )
    except KeyboardInterrupt:
        print("Interrupted")
    finally:
        # Plot loss
        plt.title("Loss")
        for split in ["train", "val", "test"]:
            plt.plot(history_loss[split], label=split)
        plt.legend()
        plt.show()
        # Plot accuracy
        plt.title("Accuracy")
        for split in ["train", "val", "test"]:
            plt.plot(history_accuracy[split], label=split)
        plt.legend()
        plt.show()

In [None]:
train(100, dev)