<a href="https://colab.research.google.com/github/KHarsh98/Capsule-Forensics-v2/blob/master/Capnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from google.colab import drive
drive.mount("/content/gdrive", force_remount=True)

Mounted at /content/gdrive


In [0]:

# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load in 

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch

# Input data files are available in the "../input/" directory.
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os


In [0]:
"""
Copyright (c) 2019, National Institute of Informatics
All rights reserved.
Author: Huy H. Nguyen
-----------------------------------------------------
Script for Capsule-Forensics-v2 model
"""

import sys
sys.setrecursionlimit(15000)
import torch
import torch.nn.functional as F
from torch import nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import torchvision.models as models

NO_CAPS= 10 #Changed from 10 to 15 capsules

class StatsNet(nn.Module):
    def __init__(self):
        super(StatsNet, self).__init__()

    def forward(self, x):
        x = x.view(x.data.shape[0], x.data.shape[1], x.data.shape[2]*x.data.shape[3])

        mean = torch.mean(x, 2)
        std = torch.std(x, 2)

        return torch.stack((mean, std), dim=1)

class View(nn.Module):
    def __init__(self, *shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, input):
        return input.view(self.shape)


class VggExtractor(nn.Module):
    def __init__(self):
        super(VggExtractor, self).__init__()

        self.vgg_1 = self.Vgg(models.vgg19(pretrained=True), 0, 18)
        self.vgg_1.eval()

    def Vgg(self, vgg, begin, end):
        features = nn.Sequential(*list(vgg.features.children())[begin:(end+1)])
        return features

    def forward(self, input):
        return self.vgg_1(input)

class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()

        self.capsules = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(256, 64, kernel_size=3, stride=1, padding=1), #changed from 64 to 16 filters - REVERSED
                nn.BatchNorm2d(64), #changed from 64 to 16 - REVERSED
                nn.ReLU(),
                nn.Conv2d(64, 16, kernel_size=3, stride=1, padding=1), #changed from 64 to 16 neurons - REVERSED
                nn.BatchNorm2d(16),
                nn.ReLU(),
                StatsNet(),

                nn.Conv1d(2, 8, kernel_size=5, stride=2, padding=2),#changed from 8 to 4 neurons - EEVERSED
                nn.BatchNorm1d(8), #changed from 8 to 4
                nn.Conv1d(8, 1, kernel_size=3, stride=1, padding=1), #changed from 8 to 4 - REVERSED
                nn.BatchNorm1d(1),
                View(-1, 8),
                )
                for _ in range(NO_CAPS)]
        )

    def squash(self, tensor, dim):
        squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        return scale * tensor / (torch.sqrt(squared_norm))

    def forward(self, x):
        outputs = [capsule(x.detach()) for capsule in self.capsules]
        output = torch.stack(outputs, dim=-1)

        return self.squash(output, dim=-1)

class RoutingLayer(nn.Module):
    def __init__(self, gpu_id, num_input_capsules, num_output_capsules, data_in, data_out, num_iterations):
        super(RoutingLayer, self).__init__()

        self.gpu_id = gpu_id
        self.num_iterations = num_iterations
        self.route_weights = nn.Parameter(torch.randn(num_output_capsules, num_input_capsules, data_out, data_in))


    def squash(self, tensor, dim):
        squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        return scale * tensor / (torch.sqrt(squared_norm))

    def forward(self, x, random, dropout):
        # x[b, data, in_caps]

        x = x.transpose(2, 1)
        # x[b, in_caps, data]

        if random:
            noise = Variable(0.01*torch.randn(*self.route_weights.size()))
            if self.gpu_id >= 0:
                noise = noise.cuda(self.gpu_id)
            route_weights = self.route_weights + noise
        else:
            route_weights = self.route_weights

        priors = route_weights[:, None, :, :, :] @ x[None, :, :, :, None]

        # route_weights [out_caps , 1 , in_caps , data_out , data_in]
        # x             [   1     , b , in_caps , data_in ,    1    ]
        # priors        [out_caps , b , in_caps , data_out,    1    ]

        priors = priors.transpose(1, 0)
        # priors[b, out_caps, in_caps, data_out, 1]

        if dropout > 0.0:
            drop = Variable(torch.FloatTensor(*priors.size()).bernoulli(1.0- dropout))
            if self.gpu_id >= 0:
                drop = drop.cuda(self.gpu_id)
            priors = priors * drop
            

        logits = Variable(torch.zeros(*priors.size()))
        # logits[b, out_caps, in_caps, data_out, 1]

        if self.gpu_id >= 0:
            logits = logits.cuda(self.gpu_id)

        num_iterations = self.num_iterations

        for i in range(num_iterations):
            probs = F.softmax(logits, dim=2)
            outputs = self.squash((probs * priors).sum(dim=2, keepdim=True), dim=3)

            if i != self.num_iterations - 1:
                delta_logits = priors * outputs
                logits = logits + delta_logits

        # outputs[b, out_caps, 1, data_out, 1]
        outputs = outputs.squeeze()

        if len(outputs.shape) == 3:
            outputs = outputs.transpose(2, 1).contiguous() 
        else:
            outputs = outputs.unsqueeze_(dim=0).transpose(2, 1).contiguous()
        # outputs[b, data_out, out_caps]

        return outputs


class CapsuleNet(nn.Module):
    def __init__(self, num_class, gpu_id):
        super(CapsuleNet, self).__init__()

        self.num_class = num_class
        self.fea_ext = FeatureExtractor()
        self.fea_ext.apply(self.weights_init)

        self.routing_stats = RoutingLayer(gpu_id=gpu_id, num_input_capsules=NO_CAPS, num_output_capsules=num_class, data_in=8, data_out=4, num_iterations=2)

    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    def forward(self, x, random=False, dropout=0.0):

        z = self.fea_ext(x)
        z = self.routing_stats(z, random, dropout=dropout)
        # z[b, data, out_caps]

        classes = F.softmax(z, dim=-1)

        class_ = classes.detach()
        class_ = class_.mean(dim=1)

        return classes, class_

class CapsuleLoss(nn.Module):
    def __init__(self, gpu_id):
        super(CapsuleLoss, self).__init__()
        self.cross_entropy_loss = nn.CrossEntropyLoss()

        if gpu_id >= 0:
            self.cross_entropy_loss.cuda(gpu_id)

    def forward(self, classes, labels):
        loss_t = self.cross_entropy_loss(classes[:,0,:], labels)

        for i in range(classes.size(1) - 1):
            loss_t = loss_t + self.cross_entropy_loss(classes[:,i+1,:], labels)

        return loss_t

In [0]:
import sys

sys.setrecursionlimit(15000)
import os
import torch
import torch.backends.cudnn as cudnn
import numpy as np
from torch.autograd import Variable
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
from tqdm import tqdm
import argparse
from sklearn import metrics
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from sklearn.metrics import roc_curve


transform_fwd = transforms.Compose([
        transforms.Resize((300, 300)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

dataset_test = dset.ImageFolder("/content/gdrive/My Drive/FFHQ Dataset", transform=transform_fwd)
assert dataset_test
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=32, shuffle=False,
                                                  num_workers=0)

vgg_ext = VggExtractor()
capnet = CapsuleNet(2, 0)


capnet.load_state_dict(torch.load("/content/capsule_21.pt"))
capnet.eval()

vgg_ext.cuda(0)
capnet.cuda(0)

    ##################################################################################

tol_label = np.array([], dtype=np.float)
tol_pred = np.array([], dtype=np.float)
tol_pred_prob = np.array([], dtype=np.float)


count = 0
loss_test = 0

for img_data, labels_data in tqdm(dataloader_test):
        
    labels_data[labels_data > 0] = 5
    labels_data[labels_data == 0] = 1
    labels_data[labels_data == 5] = 0

    img_label = labels_data.numpy().astype(np.float)
    img_data = img_data.cuda(0)
    labels_data = labels_data.cuda(0)

    input_v = Variable(img_data)

    x = vgg_ext(input_v)
    classes, class_ = capnet(x, random=False)

    output_dis = class_.data.cpu()
    output_pred = np.zeros((output_dis.shape[0]), dtype=np.float)

    for i in range(output_dis.shape[0]):
        if output_dis[i, 1] >= output_dis[i, 0]:
            output_pred[i] = 1.0
        else:
            output_pred[i] = 0.0

    tol_label = np.concatenate((tol_label, img_label))
    tol_pred = np.concatenate((tol_pred, output_pred))
        
    #pred_prob = torch.softmax(output_dis, dim=1)
    #tol_pred_prob = np.concatenate((tol_pred_prob, pred_prob[:,1].data.numpy()))
    count+= 1

acc_test = metrics.accuracy_score(tol_label, tol_pred)

loss_test /= count

#fpr, tpr, thresholds = roc_curve(tol_label, tol_pred_prob, pos_label=1)
#eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)

#fnr = 1 - tpr
#hter = (fpr + fnr)/2
#print('[Epoch %d] Test acc: %.2f   EER: %.2f' % (0, acc_test * 100, eer * 100))
    #text_writer.write('%d,%.2f,%.2f\n' % (opt.id, acc_test * 100, eer * 100))

print(acc_test * 100)
print(tol_pred)  
print(np.count_nonzero(tol_pred > 0.5))

100%|██████████| 625/625 [3:35:37<00:00, 20.70s/it]

47.925000000000004
[0. 0. 0. ... 0. 0. 0.]
4951



