In [None]:
# function ClickConnect(){
# console.log("Working");
# document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click();
# }
# interval = setInterval(ClickConnect,60000)

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount = True)

driveDir = '/content/drive/Shareddrives/JCC/TCC/'

!mkdir -p '/content/dataset/'
!unzip -qq -o '/content/drive/Shareddrives/JCC/TCC/dataset/real_fake' -d '/content/dataset'

!mkdir -p '/content/drive/Shareddrives/JCC/TCC/Resultados/'
!mkdir -p '/content/drive/Shareddrives/JCC/TCC/Modelos/'

Mounted at /content/drive


In [None]:
#
# Dynamic Routing Between Capsules
# https://arxiv.org/pdf/1710.09829.pdf
#

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
import torch.nn.functional as F


class ConvUnit(nn.Module):
    def __init__(self, in_channels):
        super(ConvUnit, self).__init__()

        self.conv0 = nn.Conv2d(in_channels=in_channels,
                               out_channels=32,  # fixme constant
                               kernel_size=9,  # fixme constant
                               stride=2, # fixme constant
                               bias=True)

    def forward(self, x):
        return self.conv0(x)

class CapsuleLayer(nn.Module):
    def __init__(self, in_units, in_channels, num_units, unit_size, use_routing):
        super(CapsuleLayer, self).__init__()

        self.in_units = in_units
        self.in_channels = in_channels
        self.num_units = num_units
        self.use_routing = use_routing

        if self.use_routing:
            # In the paper, the deeper capsule layer(s) with capsule inputs (DigitCaps) use a special routing algorithm
            # that uses this weight matrix.
            self.W = nn.Parameter(torch.randn(1, in_channels, num_units, unit_size, in_units))
        else:
            # The first convolutional capsule layer (PrimaryCapsules in the paper) does not perform routing.
            # Instead, it is composed of several convolutional units, each of which sees the full input.
            # It is implemented as a normal convolutional layer with a special nonlinearity (squash()).
            def create_conv_unit(unit_idx):
                unit = ConvUnit(in_channels=in_channels)
                self.add_module("unit_" + str(unit_idx), unit)
                return unit
            self.units = [create_conv_unit(i) for i in range(self.num_units)]

    @staticmethod
    def squash(s):
        # This is equation 1 from the paper.
        mag_sq = torch.sum(s**2, dim=2, keepdim=True)
        mag = torch.sqrt(mag_sq)
        s = (mag_sq / (1.0 + mag_sq)) * (s / mag)
        return s

    def forward(self, x):
        if self.use_routing:
            return self.routing(x)
        else:
            return self.no_routing(x)

    def no_routing(self, x):
        # Get output for each unit.
        # Each will be (batch, channels, height, width).
        u = [self.units[i](x) for i in range(self.num_units)]

        # Stack all unit outputs (batch, unit, channels, height, width).
        u = torch.stack(u, dim=1)

        # Flatten to (batch, unit, output).
        u = u.view(x.size(0), self.num_units, -1)

        # Return squashed outputs.
        return CapsuleLayer.squash(u)

    def routing(self, x):
        batch_size = x.size(0)

        # (batch, in_units, features) -> (batch, features, in_units)
        x = x.transpose(1, 2)

        # (batch, features, in_units) -> (batch, features, num_units, in_units, 1)
        x = torch.stack([x] * self.num_units, dim=2).unsqueeze(4)

        # (batch, features, in_units, unit_size, num_units)
        W = torch.cat([self.W] * batch_size, dim=0)

        # Transform inputs by weight matrix.
        # (batch_size, features, num_units, unit_size, 1)
        u_hat = torch.matmul(W, x)

        # Initialize routing logits to zero.
        b_ij = Variable(torch.zeros(1, self.in_channels, self.num_units, 1)).cuda()

        # Iterative routing.
        num_iterations = 3
        for iteration in range(num_iterations):
            # Convert routing logits to softmax.
            # (batch, features, num_units, 1, 1)
            c_ij = F.softmax(b_ij, dim=0)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

            # Apply routing (c_ij) to weighted inputs (u_hat).
            # (batch_size, 1, num_units, unit_size, 1)
            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)

            # (batch_size, 1, num_units, unit_size, 1)
            v_j = CapsuleLayer.squash(s_j)

            # (batch_size, features, num_units, unit_size, 1)
            v_j1 = torch.cat([v_j] * self.in_channels, dim=1)

            # (1, features, num_units, 1)
            u_vj1 = torch.matmul(u_hat.transpose(3, 4), v_j1).squeeze(4).mean(dim=0, keepdim=True)

            # Update b_ij (routing)
            b_ij = b_ij + u_vj1

        return v_j.squeeze(1)


In [None]:
#
# Dynamic Routing Between Capsules
# https://arxiv.org/pdf/1710.09829.pdf
#

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
import torch.nn.functional as F


class CapsuleConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CapsuleConvLayer, self).__init__()

        self.conv0 = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=9, # fixme constant
                               stride=1,
                               bias=True)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.conv0(x))


In [None]:
#
# Dynamic Routing Between Capsules
# https://arxiv.org/pdf/1710.09829.pdf
#

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
import torchvision.utils as vutils
import torch.nn.functional as F


class CapsuleNetwork(nn.Module):
    def __init__(self,
                 image_width,
                 image_height,
                 image_channels,
                 conv_inputs,
                 conv_outputs,
                 num_primary_units,
                 primary_unit_size,
                 num_output_units,
                 output_unit_size):
        super(CapsuleNetwork, self).__init__()

        self.reconstructed_image_count = 0

        self.image_channels = image_channels
        self.image_width = image_width
        self.image_height = image_height

        self.conv1 = CapsuleConvLayer(in_channels=conv_inputs,
                                      out_channels=conv_outputs)

        self.primary = CapsuleLayer(in_units=0,
                                    in_channels=conv_outputs,
                                    num_units=num_primary_units,
                                    unit_size=primary_unit_size,
                                    use_routing=False)

        self.digits = CapsuleLayer(in_units=num_primary_units,
                                   in_channels=primary_unit_size,
                                   num_units=num_output_units,
                                   unit_size=output_unit_size,
                                   use_routing=True)

        reconstruction_size = image_width * image_height * image_channels
        self.reconstruct0 = nn.Linear(num_output_units*output_unit_size, int((reconstruction_size * 2) / 3))
        self.reconstruct1 = nn.Linear(int((reconstruction_size * 2) / 3), int((reconstruction_size * 3) / 2))
        self.reconstruct2 = nn.Linear(int((reconstruction_size * 3) / 2), reconstruction_size)

        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return self.digits(self.primary(self.conv1(x)))

    def loss(self, images, input, target, size_average=True):
        return self.margin_loss(input, target, size_average) + self.reconstruction_loss(images, input, size_average)

    def margin_loss(self, input, target, size_average=True):
        batch_size = input.size(0)

        # ||vc|| from the paper.
        v_mag = torch.sqrt((input**2).sum(dim=2, keepdim=True))

        # Calculate left and right max() terms from equation 4 in the paper.
        zero = Variable(torch.zeros(1)).cuda()
        m_plus = 0.9
        m_minus = 0.1
        max_l = torch.max(m_plus - v_mag, zero).view(batch_size, -1)**2
        max_r = torch.max(v_mag - m_minus, zero).view(batch_size, -1)**2

        # This is equation 4 from the paper.
        loss_lambda = 0.5
        T_c = target
        L_c = T_c * max_l + loss_lambda * (1.0 - T_c) * max_r
        L_c = L_c.sum(dim=1)

        if size_average:
            L_c = L_c.mean()

        return L_c

    def reconstruction_loss(self, images, input, size_average=True):
        # Get the lengths of capsule outputs.
        v_mag = torch.sqrt((input**2).sum(dim=2))

        # Get index of longest capsule output.
        _, v_max_index = v_mag.max(dim=1)
        v_max_index = v_max_index.data

        # Use just the winning capsule's representation (and zeros for other capsules) to reconstruct input image.
        batch_size = input.size(0)
        all_masked = [None] * batch_size
        for batch_idx in range(batch_size):
            # Get one sample from the batch.
            input_batch = input[batch_idx]

            # Copy only the maximum capsule index from this batch sample.
            # This masks out (leaves as zero) the other capsules in this sample.
            batch_masked = Variable(torch.zeros(input_batch.size())).cuda()
            batch_masked[v_max_index[batch_idx]] = input_batch[v_max_index[batch_idx]]
            all_masked[batch_idx] = batch_masked

        # Stack masked capsules over the batch dimension.
        masked = torch.stack(all_masked, dim=0)

        # Reconstruct input image.
        masked = masked.view(input.size(0), -1)
        output = self.relu(self.reconstruct0(masked))
        output = self.relu(self.reconstruct1(output))
        output = self.sigmoid(self.reconstruct2(output))
        output = output.view(-1, self.image_channels, self.image_height, self.image_width)

        # Save reconstructed images occasionally.
        if self.reconstructed_image_count % 10 == 0:
            if output.size(1) == 2:
                # handle two-channel images
                zeros = torch.zeros(output.size(0), 1, output.size(2), output.size(3))
                output_image = torch.cat([zeros, output.data.cpu()], dim=1)
            else:
                # assume RGB or grayscale
                output_image = output.data.cpu()
            vutils.save_image(output_image, "reconstruction.png")
        self.reconstructed_image_count += 1

        # The reconstruction loss is the sum squared difference between the input image and reconstructed image.
        # Multiplied by a small number so it doesn't dominate the margin (class) loss.
        error = (output - images).view(output.size(0), -1)
        error = error**2
        error = torch.sum(error, dim=1) * 0.0005

        # Average over batch
        if size_average:
            error = error.mean()

        return error


In [None]:
#
# Dynamic Routing Between Capsules
# https://arxiv.org/pdf/1710.09829.pdf
#

import torch
import os
import csv
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
import torch.nn.functional as F

#
# Settings.
#

learning_rate = 0.001

batch_size = 8

# Stop training if loss goes below this threshold.
early_stop_loss = 0.0001

image_size = 32
# Normalization for MNIST dataset.
dataset_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    # transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.ImageFolder(
    '/content/dataset/real_vs_fake/real-vs-fake/train/', transform=dataset_transform)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = datasets.ImageFolder(
    '/content/dataset/real_vs_fake/real-vs-fake/valid/', transform=dataset_transform)
val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=True)

#
# Create capsule network.
#

conv_inputs = 3
conv_outputs = 256
num_primary_units = 8
# primary_unit_size = 32 * 6 * 6  # fixme get from conv2d
primary_unit_size = 2048
output_unit_size = 16

network = CapsuleNetwork(image_width=image_size,
                         image_height=image_size,
                         image_channels=3,
                         conv_inputs=conv_inputs,
                         conv_outputs=conv_outputs,
                         num_primary_units=num_primary_units,
                         primary_unit_size=primary_unit_size,
                         num_output_units=2,  # one for each MNIST digit
                         output_unit_size=output_unit_size).cuda()
# print(network)


# Converts batches of class indices to classes of one-hot vectors.
def to_one_hot(x, length):
    batch_size = x.size(0)
    x_one_hot = torch.zeros(batch_size, length)
    for i in range(batch_size):
        x_one_hot[i, x[i]] = 1.0
    return x_one_hot

# This is the test function from the basic Pytorch MNIST example, but adapted to use the capsule network.
# https://github.com/pytorch/examples/blob/master/mnist/main.py


def test():
    network.eval()
    test_loss = 0
    correct = 0
    for data, target in tqdm(val_loader):
        target_indices = target
        target_one_hot = to_one_hot(
            target_indices, length=network.digits.num_units)

        data = data.cuda()
        target = target_one_hot.cuda()

        with torch.set_grad_enabled(False):
            output = network(data)

            # sum up batch loss
            test_loss += network.loss(data, output, target).data

        v_mag = torch.sqrt((output**2).sum(dim=2, keepdim=True))

        pred = v_mag.data.max(1, keepdim=True)[1].cpu()

        correct += pred.eq(target_indices.view_as(pred)).sum()

    test_loss /= len(val_loader.dataset)

    print('\nTest Loss: {:.6f}\tAccuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss,
        correct,
        len(val_loader.dataset),
        100. * correct / len(val_loader.dataset)))

    return test_loss, 100. * correct / len(val_loader.dataset)


# This is the train function from the basic Pytorch MNIST example, but adapted to use the capsule network.
# https://github.com/pytorch/examples/blob/master/mnist/main.py
def train(optimizer):
    network.train()
    train_loss = 0.0
    for (data, target) in tqdm(train_loader):
        target_one_hot = to_one_hot(target, length=network.digits.num_units)

        data, target = data.cuda(), target_one_hot.cuda()

        optimizer.zero_grad()

        with torch.set_grad_enabled(True):
            output = network(data)

            loss = network.loss(data, output, target)
            loss.backward()

            optimizer.step()

            train_loss += loss.data

        # if loss.data < early_stop_loss:
        #     break

    train_loss /= len(train_loader.dataset)

    print('Train Loss: {:.6f}'.format(
                train_loss))

    return train_loss


num_epochs = 25
list_train = []
list_test = []
list_accuracy = []
optimizer = optim.Adam(network.parameters(), lr=learning_rate)

for epoch in range(1, num_epochs + 1):
    print('{} - Epoch: '.format(epoch))
    list_train.append(train(optimizer).item())
    test_loss, accuracy = test()
    list_test.append(test_loss.item())
    list_accuracy.append(accuracy.item())

    with open(f'{driveDir}/Resultados/losses_{epoch}', 'w') as csvFile:
        writer = csv.writer(csvFile)

        for i in range(len(list_train)):
            writer.writerow([list_train[i], list_test[i], list_accuracy[i]])

        csvFile.close()

    torch.save({
        'model': network.state_dict(),
        'optimizer': optimizer.state_dict(),
    }, f'{driveDir}/Modelos/model_{epoch}.pth')

    if list_train[-1] < early_stop_loss:
        break

1 - Epoch: 


100%|██████████| 12500/12500 [22:22<00:00,  9.31it/s]


Train Loss: 0.034893


100%|██████████| 2500/2500 [02:48<00:00, 14.84it/s]



Test Loss: 0.025715	Accuracy: 14633/20000 (73%)

2 - Epoch: 


100%|██████████| 12500/12500 [22:05<00:00,  9.43it/s]


Train Loss: 0.023691


100%|██████████| 2500/2500 [02:44<00:00, 15.20it/s]



Test Loss: 0.021564	Accuracy: 15879/20000 (79%)

3 - Epoch: 


100%|██████████| 12500/12500 [22:03<00:00,  9.44it/s]


Train Loss: 0.020297


100%|██████████| 2500/2500 [02:44<00:00, 15.19it/s]



Test Loss: 0.019648	Accuracy: 16402/20000 (82%)

4 - Epoch: 


100%|██████████| 12500/12500 [22:01<00:00,  9.46it/s]


Train Loss: 0.018219


100%|██████████| 2500/2500 [02:50<00:00, 14.68it/s]



Test Loss: 0.018981	Accuracy: 16609/20000 (83%)

5 - Epoch: 


100%|██████████| 12500/12500 [22:02<00:00,  9.45it/s]


Train Loss: 0.016696


100%|██████████| 2500/2500 [02:45<00:00, 15.14it/s]



Test Loss: 0.016451	Accuracy: 17206/20000 (86%)

6 - Epoch: 


100%|██████████| 12500/12500 [21:57<00:00,  9.49it/s]


Train Loss: 0.015584


100%|██████████| 2500/2500 [02:45<00:00, 15.12it/s]



Test Loss: 0.015408	Accuracy: 17491/20000 (87%)

7 - Epoch: 


100%|██████████| 12500/12500 [21:43<00:00,  9.59it/s]


Train Loss: 0.014517


100%|██████████| 2500/2500 [02:49<00:00, 14.79it/s]



Test Loss: 0.015100	Accuracy: 17598/20000 (88%)

8 - Epoch: 


100%|██████████| 12500/12500 [21:47<00:00,  9.56it/s]


Train Loss: 0.013752


100%|██████████| 2500/2500 [02:47<00:00, 14.94it/s]



Test Loss: 0.017445	Accuracy: 16762/20000 (84%)

9 - Epoch: 


100%|██████████| 12500/12500 [21:32<00:00,  9.67it/s]


Train Loss: 0.013042


100%|██████████| 2500/2500 [02:44<00:00, 15.19it/s]



Test Loss: 0.013890	Accuracy: 17904/20000 (90%)

10 - Epoch: 


100%|██████████| 12500/12500 [21:27<00:00,  9.71it/s]


Train Loss: 0.012464


100%|██████████| 2500/2500 [02:46<00:00, 15.02it/s]



Test Loss: 0.013922	Accuracy: 17836/20000 (89%)

11 - Epoch: 


100%|██████████| 12500/12500 [21:23<00:00,  9.74it/s]


Train Loss: 0.011941


100%|██████████| 2500/2500 [02:45<00:00, 15.12it/s]



Test Loss: 0.013641	Accuracy: 17944/20000 (90%)

12 - Epoch: 


100%|██████████| 12500/12500 [21:30<00:00,  9.69it/s]


Train Loss: 0.011408


100%|██████████| 2500/2500 [02:49<00:00, 14.79it/s]



Test Loss: 0.012792	Accuracy: 18128/20000 (91%)

13 - Epoch: 


100%|██████████| 12500/12500 [21:25<00:00,  9.73it/s]


Train Loss: 0.011022


100%|██████████| 2500/2500 [02:45<00:00, 15.10it/s]



Test Loss: 0.012874	Accuracy: 18145/20000 (91%)

14 - Epoch: 


100%|██████████| 12500/12500 [21:29<00:00,  9.70it/s]


Train Loss: 0.010690


100%|██████████| 2500/2500 [02:46<00:00, 14.99it/s]



Test Loss: 0.012464	Accuracy: 18197/20000 (91%)

15 - Epoch: 


100%|██████████| 12500/12500 [21:24<00:00,  9.73it/s]


Train Loss: 0.010273


100%|██████████| 2500/2500 [02:47<00:00, 14.90it/s]



Test Loss: 0.012800	Accuracy: 18184/20000 (91%)

16 - Epoch: 


100%|██████████| 12500/12500 [21:23<00:00,  9.74it/s]


Train Loss: 0.009985


100%|██████████| 2500/2500 [02:47<00:00, 14.94it/s]



Test Loss: 0.012960	Accuracy: 18111/20000 (91%)

17 - Epoch: 


100%|██████████| 12500/12500 [21:29<00:00,  9.69it/s]


Train Loss: 0.009702


100%|██████████| 2500/2500 [02:48<00:00, 14.81it/s]



Test Loss: 0.012649	Accuracy: 18229/20000 (91%)

18 - Epoch: 


100%|██████████| 12500/12500 [21:22<00:00,  9.74it/s]


Train Loss: 0.009423


100%|██████████| 2500/2500 [02:47<00:00, 14.94it/s]



Test Loss: 0.012004	Accuracy: 18203/20000 (91%)

19 - Epoch: 


100%|██████████| 12500/12500 [21:25<00:00,  9.73it/s]


Train Loss: 0.009228


100%|██████████| 2500/2500 [02:41<00:00, 15.44it/s]



Test Loss: 0.011741	Accuracy: 18341/20000 (92%)

20 - Epoch: 


100%|██████████| 12500/12500 [21:20<00:00,  9.76it/s]


Train Loss: 0.009009


100%|██████████| 2500/2500 [02:44<00:00, 15.16it/s]



Test Loss: 0.011744	Accuracy: 18320/20000 (92%)

21 - Epoch: 


100%|██████████| 12500/12500 [21:19<00:00,  9.77it/s]


Train Loss: 0.008746


100%|██████████| 2500/2500 [02:43<00:00, 15.24it/s]



Test Loss: 0.012046	Accuracy: 18246/20000 (91%)

22 - Epoch: 


100%|██████████| 12500/12500 [21:17<00:00,  9.79it/s]


Train Loss: 0.008557


100%|██████████| 2500/2500 [02:43<00:00, 15.29it/s]



Test Loss: 0.011718	Accuracy: 18333/20000 (92%)

23 - Epoch: 


100%|██████████| 12500/12500 [21:18<00:00,  9.78it/s]


Train Loss: 0.008330


100%|██████████| 2500/2500 [02:43<00:00, 15.31it/s]



Test Loss: 0.011651	Accuracy: 18289/20000 (91%)

24 - Epoch: 


100%|██████████| 12500/12500 [21:21<00:00,  9.75it/s]


Train Loss: 0.008194


100%|██████████| 2500/2500 [02:44<00:00, 15.21it/s]



Test Loss: 0.011427	Accuracy: 18415/20000 (92%)

25 - Epoch: 


100%|██████████| 12500/12500 [21:14<00:00,  9.80it/s]


Train Loss: 0.008015


100%|██████████| 2500/2500 [02:45<00:00, 15.08it/s]



Test Loss: 0.011699	Accuracy: 18329/20000 (92%)

