In [80]:
from google.colab import drive
import os
import torch
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
from torchvision import datasets
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
from matplotlib import pyplot as plt
import random
import torch.optim as optim
import time

In [81]:
print(f"CUDA is available? {torch.cuda.is_available()}")
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(dev)

CUDA is available? True
cuda


In [82]:
drive.mount("/content/drive",True)
root_dir = "/content/drive/My Drive/SB3/"

Mounted at /content/drive


In [85]:
train_transform = T.Compose([
    T.Resize(320),
    T.ColorJitter(),
    T.RandomHorizontalFlip(),
    T.ToTensor()
])

test_transform = T.Compose([
    T.Resize(320),
    T.ToTensor()
])

In [86]:
train_dataset = ImageFolder(os.path.join(root_dir, "train"), transform=train_transform)
test_dataset = ImageFolder(os.path.join(root_dir, "test"), transform=test_transform)
num_classes = len(train_dataset.classes)
print(num_classes)

8


In [87]:
print(len(train_dataset))
print(len(test_dataset))

1109
471


In [88]:
counts = torch.zeros(num_classes)

for i in range(len(train_dataset)):
  counts[train_dataset[i][1]] += 1.0

print(counts)

tensor([500.,  28.,  38., 116.,  17.,  19., 339.,  52.])


In [89]:
weights = len(train_dataset) / counts
weights = weights.to(dev)

In [90]:
x = train_dataset[0][0]
print(x.max())
print(x.min())

tensor(0.9961)
tensor(0.)


In [91]:
num_train = len(train_dataset)
train_idx = list(range(num_train))
random.shuffle(train_idx)
val_frac = 0.2
num_val = int(num_train*val_frac)
num_train = num_train - num_val
val_idx = train_idx[num_train:]
train_idx = train_idx[:num_train]
val_dataset = Subset(train_dataset,val_idx)
train_dataset = Subset(train_dataset, train_idx)

In [92]:
train_loader = DataLoader(train_dataset, batch_size=4, num_workers=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, num_workers=4, shuffle=False)
test_loader = DataLoader(test_dataset,   batch_size=4, num_workers=4, shuffle=False)
loaders = {
    "train": train_loader,
    "val" : val_loader,
    "test": test_loader
}

In [93]:
def squash(s, dim=-1):
	'''
	"Squashing" non-linearity that shrunks short vectors to almost zero length and long vectors to a length slightly below 1
	Eq. (1): v_j = ||s_j||^2 / (1 + ||s_j||^2) * s_j / ||s_j||
	
	Args:
		s: 	Vector before activation
		dim:	Dimension along which to calculate the norm
	
	Returns:
		Squashed vector
	'''
	squared_norm = torch.sum(s**2, dim=dim, keepdim=True)
	return squared_norm / (1 + squared_norm) * s / (torch.sqrt(squared_norm) + 1e-8)

In [94]:
class PrimaryCapsules(nn.Module):
  def __init__(self, in_channels, out_channels, vector_length, kernel_size, stride, padding):
    """
    Initialize the layer.
    Args:
      in_channels: 	Number of input channels.
      out_channels: 	Number of output channels.
      vector_length:		Dimensionality, i.e. length, of the output capsule vector.

    """
    super().__init__()
    self.vector_length = vector_length
    self.num_caps_channels = int(out_channels / vector_length)
    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)

  def forward(self, x):
    x = self.conv(x)
    x = x.view(x.size(0), self.num_caps_channels, x.size(2), x.size(3), self.vector_length)
    x = x.view(x.size(0), -1, self.vector_length)
    return squash(x)


In [95]:
class RoutingCapsules(nn.Module):
  def __init__(self, in_vector_length, num_in_caps, num_out_caps, out_vector_length, num_routing):
    '''
		Initialize the layer.
		Args:
			in_vector_length: 		Dimensionality (i.e. length) of each capsule vector.
			num_in_caps: 		Number of input capsules if digits layer.
			num_out_caps: 		Number of capsules in the capsule layer
			out_vector_length: 		Dimensionality, i.e. length, of the output capsule vector.
			num_routing:	Number of iterations during routing algorithm		
    '''
    super().__init__()
    self.in_vector_length = in_vector_length
    self.num_in_caps = num_in_caps
    self.num_out_caps = num_out_caps
    self.out_vector_length = out_vector_length
    self.num_routing = num_routing

    self.W = nn.Parameter(torch.randn(1, num_out_caps, num_in_caps, out_vector_length, in_vector_length ) )
  
  def forward(self, x):
    batch_size = x.size(0)
    # (batch_size, num_in_caps, in_vector_length) -> (batch_size, 1, num_in_caps, in_vector_length, 1)
    x = x.unsqueeze(1).unsqueeze(4)
    #
    # W @ x =
    # (1, num_output_caps, num_in_caps, out_vector_length, in_vector_length) @ (batch_size, 1, num_in_caps, in_vector_length, 1) =
    # (batch_size, num_out_caps, num_in_caps, out_vector_length, 1)
    u_hat = torch.matmul(self.W, x)
    # (batch_size, num_out_caps, num_in_caps, out_vector_length)
    u_hat = u_hat.squeeze(-1)
    # detach u_hat during routing iterations to prevent gradients from flowing
    temp_u_hat = u_hat.detach()

    '''
    Procedure 1: Routing algorithm
    '''
    b = torch.zeros(batch_size, self.num_out_caps, self.num_in_caps, 1).to(dev)

    for route_iter in range(self.num_routing-1):
      # (batch_size, num_out_caps, num_in_caps, 1) -> Softmax along num_out_caps
      c = F.softmax(b, dim=1)

      # element-wise multiplication
      # (batch_size, num_out_caps, num_in_caps, 1) * (batch_size, num_in_caps, num_out_caps, out_vector_length) ->
      # (batch_size, num_out_caps, num_in_caps, out_vector_length) sum across num_in_caps ->
      # (batch_size, num_out_caps, out_vector_length)
      s = (c * temp_u_hat).sum(dim=2)
      # apply "squashing" non-linearity along dim_caps
      v = squash(s)
      # dot product agreement between the current output vj and the prediction uj|i
      # (batch_size, num_out_caps, num_in_caps, out_vector_length) @ (batch_size, num_out_caps, out_vector_length, 1)
      # -> (batch_size, num_out_caps, num_in_caps, 1)
      uv = torch.matmul(temp_u_hat, v.unsqueeze(-1))
      b += uv

    # last iteration is done on the original u_hat, without the routing weights update
    c = F.softmax(b, dim=1)
    s = (c * u_hat).sum(dim=2)
    # apply "squashing" non-linearity along dim_caps
    v = squash(s)

    return v

In [96]:
class Reconstruction(nn.Module):
    def __init__(self, num_classes, vector_length):
      super().__init__()
      
      self.fc_layer = nn.Sequential(
        nn.Linear(num_classes*vector_length, 64),
        nn.ReLU()
      )

      self.reconstruction_layers = nn.Sequential(
          nn.ConvTranspose2d(1, 128, kernel_size=5, padding=2, stride=11),
          nn.ReLU(),
          nn.ConvTranspose2d(128, 64, kernel_size=5, padding=2, stride=4, output_padding=(1,1)),
          nn.ReLU(),
          nn.Conv2d(64, 3, kernel_size=5, padding=7, stride=1),
          nn.ReLU()
    )
        
    def forward(self, x):
      x = self.fc_layer(x)
      x = x.view(x.size(0), 1, 8, 8)
      x = self.reconstruction_layers(x)
      return x

In [97]:
class CapsNet(nn.Module):
    def __init__(self, num_classes):
      super().__init__()
      self.conv_layer = nn.Sequential(
          nn.Conv2d(3, 64, kernel_size=5, padding=2, stride=8),
          nn.ReLU()
      )
      self.primary_caps = PrimaryCapsules(in_channels=64, out_channels=128, vector_length=8, kernel_size=5, padding=2, stride=1)
      self.digit_caps = nn.Sequential(
        RoutingCapsules(in_vector_length=8, num_in_caps=25600, num_out_caps=num_classes, out_vector_length=16, num_routing=3),
        RoutingCapsules(in_vector_length=16, num_in_caps=num_classes, num_out_caps=num_classes, out_vector_length=16, num_routing=3)
      )
      self.reconstruction_layer = Reconstruction(num_classes=num_classes, vector_length=16)
    
    '''
    def capsule_average_pooling(self, x):
      height = x.size(3)
      width = x.size(4)
      x = x.sum(dim=4).sum(dim=3)
      x = x / (height * width)
      return x
    '''
    
    def score(self, x):
      return torch.sqrt((x ** 2).sum(dim=2))
    
    def forward(self, x):
      x = self.conv_layer(x)
      x = self.primary_caps(x)
      x = self.digit_caps(x)
      scores = self.score(x)
      x = x.view(x.size(0), x.size(1) * x.size(2))
      reconstructions = self.reconstruction_layer(x)
      return scores, reconstructions

In [99]:
model = CapsNet(num_classes=num_classes)
model = model.to(dev)

In [100]:
batch,labels = next(iter(train_loader))
batch = batch.to(dev)
labels = labels.to(dev)
#print(batch.size())
scores, reconstructions = model(batch)
print(scores.size())
print(reconstructions.size())

torch.Size([4, 8])
torch.Size([4, 3, 320, 320])


In [101]:
batch,labels = next(iter(train_loader))
batch = batch.to(dev)
labels = labels.to(dev)
#print(batch.size())
out = model(batch)

In [102]:
def train(epochs, dev, lr=0.001):
    try:
        # Create model
        model = CapsNet(num_classes=num_classes)
        model = model.to(dev)
        # Optimizer
        optimizer = optim.Adam(model.parameters(), lr=lr)
        # Initialize history
        history_loss = {"train": [], "val": [], "test": []}
        history_accuracy = {"train": [], "val": [], "test": []}
        # Process each epoch
        for epoch in range(epochs):
            # Initialize epoch variables
            sum_loss = {"train": 0, "val": 0, "test": 0}
            sum_accuracy = {"train": 0, "val": 0, "test": 0}
            
            # Process each split
            for split in ["train", "val", "test"]:
                if split == "train":
                  model.train()
                else:
                  model.eval()
                # Process each batch
                for (input, labels) in loaders[split]:
                    # Move to CUDA
                    input = input.to(dev)
                    labels = labels.to(dev)
                    # Reset gradients
                    optimizer.zero_grad()
                    # Compute output
                    pred, reconstructions = model(input)
                    score_loss = F.cross_entropy(pred, labels, weight=weights)
                    reconstruction_loss = F.mse_loss(input, reconstructions)
                    loss = score_loss + reconstruction_loss
                    # Update loss
                    sum_loss[split] += loss.item()
                    # Check parameter update
                    if split == "train":
                        # Compute gradients
                        loss.backward()
                        # Optimize
                        optimizer.step()
                    # Compute accuracy
                    _,pred_labels = pred.max(1)
                    batch_accuracy = (pred_labels == labels).sum().item()/input.size(0)
                    # Update accuracy
                    sum_accuracy[split] += batch_accuracy
            # Compute epoch loss/accuracy
            epoch_loss = {split: sum_loss[split]/len(loaders[split]) for split in ["train", "val", "test"]}
            epoch_accuracy = {split: sum_accuracy[split]/len(loaders[split]) for split in ["train", "val", "test"]}
            # Update history
            for split in ["train", "val", "test"]:
                history_loss[split].append(epoch_loss[split])
                history_accuracy[split].append(epoch_accuracy[split])
            # Print info
            print(f"Epoch {epoch+1}:",
                  f"TrL={epoch_loss['train']:.4f},",
                  f"TrA={epoch_accuracy['train']:.4f},",
                  f"VL={epoch_loss['val']:.4f},",
                  f"VA={epoch_accuracy['val']:.4f},",
                  f"TeL={epoch_loss['test']:.4f},",
                  f"TeA={epoch_accuracy['test']:.4f},"
                )
    except KeyboardInterrupt:
        print("Interrupted")
    finally:
        # Plot loss
        plt.title("Loss")
        for split in ["train", "val", "test"]:
            plt.plot(history_loss[split], label=split)
        plt.legend()
        plt.show()
        # Plot accuracy
        plt.title("Accuracy")
        for split in ["train", "val", "test"]:
            plt.plot(history_accuracy[split], label=split)
        plt.legend()
        plt.show()

In [None]:
train(100, dev, lr=0.001)

Epoch 1: TrL=2.0871, TrA=0.4200, VL=2.0186, VA=0.4509, TeL=1.8672, TeA=0.4534,
Epoch 2: TrL=2.0579, TrA=0.4493, VL=2.0434, VA=0.4509, TeL=1.8886, TeA=0.4534,
Epoch 3: TrL=1.9971, TrA=0.4493, VL=1.9907, VA=0.4375, TeL=1.8102, TeA=0.4492,
Epoch 4: TrL=2.0089, TrA=0.4234, VL=1.9575, VA=0.5000, TeL=1.8028, TeA=0.5148,
Epoch 5: TrL=1.9971, TrA=0.4505, VL=1.9762, VA=0.4509, TeL=1.8141, TeA=0.4534,
Epoch 6: TrL=1.9887, TrA=0.4493, VL=1.9336, VA=0.4509, TeL=1.7591, TeA=0.4534,
Epoch 7: TrL=1.9609, TrA=0.4493, VL=1.9235, VA=0.4509, TeL=1.7519, TeA=0.4534,
Epoch 8: TrL=1.9663, TrA=0.4493, VL=1.9733, VA=0.4509, TeL=1.8273, TeA=0.4534,
Epoch 9: TrL=1.9300, TrA=0.4493, VL=2.0337, VA=0.4509, TeL=1.8032, TeA=0.4534,
Epoch 10: TrL=1.9318, TrA=0.4493, VL=1.9244, VA=0.4509, TeL=1.7283, TeA=0.4534,
Epoch 11: TrL=1.9429, TrA=0.4606, VL=1.9631, VA=0.4866, TeL=1.7992, TeA=0.5212,
Epoch 12: TrL=1.9256, TrA=0.4550, VL=1.9366, VA=0.4955, TeL=1.7682, TeA=0.5148,
Epoch 13: TrL=1.9177, TrA=0.4583, VL=1.8934, VA=0