# Authenticate

In [0]:
from google.colab import auth
auth.authenticate_user()

# Load data

## Install Cloud Storage FUSE

In [2]:
!echo "deb http://packages.cloud.google.com/apt gcsfuse-bionic main" > /etc/apt/sources.list.d/gcsfuse.list
!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -
!apt -qq update
!apt -qq install gcsfuse

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100   653  100   653    0     0   7867      0 --:--:-- --:--:-- --:--:--  7963
OK
52 packages can be upgraded. Run 'apt list --upgradable' to see them.
The following NEW packages will be installed:
  gcsfuse
0 upgraded, 1 newly installed, 0 to remove and 52 not upgraded.
Need to get 4,274 kB of archives.
After this operation, 12.8 MB of additional disk space will be used.
Selecting previously unselected package gcsfuse.
(Reading database ... 144568 files and directories currently installed.)
Preparing to unpack .../gcsfuse_0.28.1_amd64.deb ...
Unpacking gcsfuse (0.28.1) ...
Setting up gcsfuse (0.28.1) ...


## Create connection between Cloud Storage and Google Colab

In [3]:
!mkdir celeb-df
!gcsfuse celeb-df celeb-df
!ls celeb-df
!ls celeb-df | wc

Using mount point: /content/celeb-df
Opening GCS connection...
Opening bucket...
Mounting file system...
File system has been successfully mounted.
test  train  validation
      3       3      22


# Train Capsule model

**Public repository:** 
https://github.com/nii-yamagishilab/Capsule-Forensics-v2

**Reference:**
H. H. Nguyen, J. Yamagishi, and I. Echizen, “Use of a Capsule Network to Detect Fake Images and Videos,” arXiv preprint arXiv:1910.12467. 2019 Oct 29.

## 1. Import dependencies

In [0]:
import os
import random
import torch
import torch.nn.functional as F
from torch import nn
import torch.backends.cudnn as cudnn
import numpy as np
from torch.autograd import Variable
from torch.optim import Adam
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from tqdm import tqdm
from sklearn import metrics

## 2. Capsule Network architecture

In [0]:
NO_CAPS = 10

In [0]:
class StatsNet(nn.Module):
    def __init__(self):
        super(StatsNet, self).__init__()

    def forward(self, x):
        x = x.view(x.data.shape[0], x.data.shape[1], x.data.shape[2]*x.data.shape[3])

        mean = torch.mean(x, 2)
        std = torch.std(x, 2)

        return torch.stack((mean, std), dim=1)

class View(nn.Module):
    def __init__(self, *shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, input):
        return input.view(self.shape)

class VggExtractor(nn.Module):
    def __init__(self, train=False):
        super(VggExtractor, self).__init__()

        self.vgg_1 = self.Vgg(models.vgg19(pretrained=True), 0, 18)
        if train:
            self.vgg_1.train(mode=True)
            self.freeze_gradient()
        else:
            self.vgg_1.eval()

    def Vgg(self, vgg, begin, end):
        features = nn.Sequential(*list(vgg.features.children())[begin:(end+1)])
        return features

    def freeze_gradient(self, begin=0, end=9):
        for i in range(begin, end+1):
            self.vgg_1[i].requires_grad = False

    def forward(self, input):
        return self.vgg_1(input)

class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()

        self.capsules = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(256, 64, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.Conv2d(64, 16, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                StatsNet(),

                nn.Conv1d(2, 8, kernel_size=5, stride=2, padding=2),
                nn.BatchNorm1d(8),
                nn.Conv1d(8, 1, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm1d(1),
                View(-1, 8),
                )
                for _ in range(NO_CAPS)]
        )

    def squash(self, tensor, dim):
        squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        return scale * tensor / (torch.sqrt(squared_norm))

    def forward(self, x):
        # outputs = [capsule(x.detach()) for capsule in self.capsules]
        # outputs = [capsule(x.clone()) for capsule in self.capsules]
        outputs = [capsule(x) for capsule in self.capsules]
        output = torch.stack(outputs, dim=-1)

        return self.squash(output, dim=-1)

class RoutingLayer(nn.Module):
    def __init__(self, gpu_id, num_input_capsules, num_output_capsules, data_in, data_out, num_iterations):
        super(RoutingLayer, self).__init__()

        self.gpu_id = gpu_id
        self.num_iterations = num_iterations
        self.route_weights = nn.Parameter(torch.randn(num_output_capsules, num_input_capsules, data_out, data_in))


    def squash(self, tensor, dim):
        squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        return scale * tensor / (torch.sqrt(squared_norm))

    def forward(self, x, random, dropout):
        # x[b, data, in_caps]

        x = x.transpose(2, 1)
        # x[b, in_caps, data]

        if random:
            noise = Variable(0.01*torch.randn(*self.route_weights.size()))
            if self.gpu_id >= 0:
                noise = noise.cuda(self.gpu_id)
            route_weights = self.route_weights + noise
        else:
            route_weights = self.route_weights

        priors = route_weights[:, None, :, :, :] @ x[None, :, :, :, None]

        # route_weights [out_caps , 1 , in_caps , data_out , data_in]
        # x             [   1     , b , in_caps , data_in ,    1    ]
        # priors        [out_caps , b , in_caps , data_out,    1    ]

        priors = priors.transpose(1, 0)
        # priors[b, out_caps, in_caps, data_out, 1]

        if dropout > 0.0:
            drop = Variable(torch.FloatTensor(*priors.size()).bernoulli(1.0- dropout))
            if self.gpu_id >= 0:
                drop = drop.cuda(self.gpu_id)
            priors = priors * drop
            

        logits = Variable(torch.zeros(*priors.size()))
        # logits[b, out_caps, in_caps, data_out, 1]

        if self.gpu_id >= 0:
            logits = logits.cuda(self.gpu_id)

        num_iterations = self.num_iterations

        for i in range(num_iterations):
            probs = F.softmax(logits, dim=2)
            outputs = self.squash((probs * priors).sum(dim=2, keepdim=True), dim=3)

            if i != self.num_iterations - 1:
                delta_logits = priors * outputs
                logits = logits + delta_logits

        # outputs[b, out_caps, 1, data_out, 1]
        outputs = outputs.squeeze()

        if len(outputs.shape) == 3:
            outputs = outputs.transpose(2, 1).contiguous() 
        else:
            outputs = outputs.unsqueeze_(dim=0).transpose(2, 1).contiguous()
        # outputs[b, data_out, out_caps]

        return outputs

class CapsuleNet(nn.Module):
    def __init__(self, num_class, gpu_id):
        super(CapsuleNet, self).__init__()

        self.num_class = num_class
        self.fea_ext = FeatureExtractor()
        self.fea_ext.apply(self.weights_init)

        self.routing_stats = RoutingLayer(gpu_id=gpu_id, num_input_capsules=NO_CAPS, num_output_capsules=num_class, data_in=8, data_out=4, num_iterations=2)

    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    def forward(self, x, random=False, dropout=0.0):

        z = self.fea_ext(x)
        z = self.routing_stats(z, random, dropout=dropout)
        # z[b, data, out_caps]

        # classes = F.softmax(z, dim=-1)

        # class_ = classes.detach()
        # class_ = class_.mean(dim=1)

        # return classes, class_

        classes = F.softmax(z, dim=-1)
        class_ = classes.detach()
        class_ = class_.mean(dim=1)

        return z, class_

class CapsuleLoss(nn.Module):
    def __init__(self, gpu_id):
        super(CapsuleLoss, self).__init__()
        self.cross_entropy_loss = nn.CrossEntropyLoss()

        if gpu_id >= 0:
            self.cross_entropy_loss.cuda(gpu_id)

    def forward(self, classes, labels):
        loss_t = self.cross_entropy_loss(classes[:,0,:], labels)

        for i in range(classes.size(1) - 1):
            loss_t = loss_t + self.cross_entropy_loss(classes[:,i+1,:], labels)

        return loss_t

## 3. Default settings



In [0]:
dataset = os.getcwd() + '/celeb-df'
train_set = '/train'
val_set = '/validation'
test_set = '/test'
workers = 0
batch_size = 32
image_size = 300
learning_rate = 0.0005
beta1_for_adam = 0.9
gpu_id = 0
manual_seed = random.randint(1, 1000)
print("Random Seed: ", manual_seed)
random.seed(manual_seed)
torch.manual_seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)
cudnn.benchmark = True
disable_random_routing_matrix = False
random_setting = not disable_random_routing_matrix
text_writer_train = open(os.path.join(os.getcwd(), 'train.csv'), 'a')
text_writer_test = open(os.path.join(os.getcwd(), 'test.txt'), 'w')

Random Seed:  287


## 4. Creating Capsule

In [0]:
vgg_ext = VggExtractor()
capnet = CapsuleNet(2, gpu_id)
capsule_loss = CapsuleLoss(gpu_id)
optimizer = Adam(capnet.parameters(), lr=learning_rate, betas=(beta1_for_adam, 0.999))

capnet.cuda(gpu_id)
vgg_ext.cuda(gpu_id)
capsule_loss.cuda(gpu_id)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/checkpoints/vgg19-dcbb9e9d.pth


HBox(children=(IntProgress(value=0, max=574673361), HTML(value='')))




CapsuleLoss(
  (cross_entropy_loss): CrossEntropyLoss()
)

## 5. Creating datasets for training, validation, and test

In [0]:
transform_fwd = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

dataset_train = dset.ImageFolder(root=dataset + train_set, transform=transform_fwd)
assert dataset_train
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=int(workers))

dataset_val = dset.ImageFolder(root=dataset + val_set, transform=transform_fwd)
assert dataset_val
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=int(workers))

dataset_test = dset.ImageFolder(root=dataset + test_set, transform=transform_fwd)
assert dataset_test
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=int(workers))

RuntimeError: ignored

## 6. Start training phase
- Save the fitted model at the end

In [0]:
for epoch in range(1, 26):
        count = 0
        loss_train = 0
        loss_test = 0

        tol_label = np.array([], dtype=np.float)
        tol_pred = np.array([], dtype=np.float)

        for img_data, labels_data in tqdm(dataloader_train):

            labels_data[labels_data > 1] = 1
            img_label = labels_data.numpy().astype(np.float)
            optimizer.zero_grad()

            if gpu_id >= 0:
              img_data = img_data.cuda(gpu_id)
              labels_data = labels_data.cuda(gpu_id)

            input_v = Variable(img_data)
            x = vgg_ext(input_v)
            classes, class_ = capnet(x, random=optrandom, dropout=0.05)

            loss_dis = capsule_loss(classes, Variable(labels_data, requires_grad=False))
            loss_dis_data = loss_dis.item()

            loss_dis.backward()
            optimizer.step()

            output_dis = class_.data.cpu().numpy()
            output_pred = np.zeros((output_dis.shape[0]), dtype=np.float)

            for i in range(output_dis.shape[0]):
                if output_dis[i,1] >= output_dis[i,0]:
                    output_pred[i] = 1.0
                else:
                    output_pred[i] = 0.0

            tol_label = np.concatenate((tol_label, img_label))
            tol_pred = np.concatenate((tol_pred, output_pred))

            loss_train += loss_dis_data
            count += 1


        acc_train = metrics.accuracy_score(tol_label, tol_pred)
        loss_train /= count

        ########################################################################

        # do checkpointing & validation
        torch.save(capnet.state_dict(), os.path.join(os.getcwd(), 'capsule_%d.pt' % epoch))
        torch.save(optimizer.state_dict(), os.path.join(os.getcwd(), 'optim_%d.pt' % epoch))

        capnet.eval()

        tol_label = np.array([], dtype=np.float)
        tol_pred = np.array([], dtype=np.float)

        count = 0

        for img_data, labels_data in dataloader_val:

            labels_data[labels_data > 1] = 1
            img_label = labels_data.numpy().astype(np.float)

            if gpu_id >= 0:
              img_data = img_data.cuda(gpu_id)
              labels_data = labels_data.cuda(gpu_id)

            input_v = Variable(img_data)

            x = vgg_ext(input_v)
            classes, class_ = capnet(x, random=False)

            loss_dis = capsule_loss(classes, Variable(labels_data, requires_grad=False))
            loss_dis_data = loss_dis.item()
            output_dis = class_.data.cpu().numpy()

            output_pred = np.zeros((output_dis.shape[0]), dtype=np.float)

            for i in range(output_dis.shape[0]):
                if output_dis[i,1] >= output_dis[i,0]:
                    output_pred[i] = 1.0
                else:
                    output_pred[i] = 0.0

            tol_label = np.concatenate((tol_label, img_label))
            tol_pred = np.concatenate((tol_pred, output_pred))

            loss_test += loss_dis_data
            count += 1

        acc_test = metrics.accuracy_score(tol_label, tol_pred)
        loss_test /= count

        print('[Epoch %d] Train loss: %.4f   acc: %.2f | Test loss: %.4f  acc: %.2f'
        % (epoch, loss_train, acc_train*100, loss_test, acc_test*100))

        text_writer.write('%d,%.4f,%.2f,%.4f,%.2f\n'
        % (epoch, loss_train, acc_train*100, loss_test, acc_test*100))

        text_writer.flush()
        capnet.train(mode=True)

text_writer.close()

100%|██████████| 1/1 [00:00<00:00,  1.34it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 1] Train loss: 2.7654   acc: 100.00 | Test loss: 2.7647  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.47it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 2] Train loss: 1.6929   acc: 100.00 | Test loss: 2.7656  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.46it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 3] Train loss: 1.6828   acc: 100.00 | Test loss: 2.7663  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.48it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 4] Train loss: 1.6399   acc: 100.00 | Test loss: 2.7670  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.46it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 5] Train loss: 1.6405   acc: 100.00 | Test loss: 2.7675  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.47it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 6] Train loss: 1.6311   acc: 100.00 | Test loss: 2.7680  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.49it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 7] Train loss: 1.6214   acc: 100.00 | Test loss: 2.7684  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.48it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 8] Train loss: 1.6075   acc: 100.00 | Test loss: 2.7687  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.46it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 9] Train loss: 1.5987   acc: 100.00 | Test loss: 2.7690  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.45it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 10] Train loss: 1.5939   acc: 100.00 | Test loss: 2.7692  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.47it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 11] Train loss: 1.5998   acc: 100.00 | Test loss: 2.7694  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.45it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 12] Train loss: 1.5997   acc: 100.00 | Test loss: 2.7696  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.45it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 13] Train loss: 1.5835   acc: 100.00 | Test loss: 2.7698  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.47it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 14] Train loss: 1.5941   acc: 100.00 | Test loss: 2.7699  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.41it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 15] Train loss: 1.5962   acc: 100.00 | Test loss: 2.7700  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.43it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 16] Train loss: 1.5974   acc: 100.00 | Test loss: 2.7701  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.47it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 17] Train loss: 1.6057   acc: 100.00 | Test loss: 2.7701  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.46it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 18] Train loss: 1.5824   acc: 100.00 | Test loss: 2.7701  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.46it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 19] Train loss: 1.5678   acc: 100.00 | Test loss: 2.7701  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.41it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 20] Train loss: 1.5796   acc: 100.00 | Test loss: 2.7701  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.47it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 21] Train loss: 1.5892   acc: 100.00 | Test loss: 2.7700  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.43it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 22] Train loss: 1.5883   acc: 100.00 | Test loss: 2.7699  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.42it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 23] Train loss: 1.5779   acc: 100.00 | Test loss: 2.7698  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.44it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 24] Train loss: 1.5814   acc: 100.00 | Test loss: 2.7696  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.40it/s]


[Epoch 25] Train loss: 1.5883   acc: 100.00 | Test loss: 2.7694  acc: 100.00


## 7. Start evaluation phase

In [0]:
capnet.load_state_dict(torch.load(os.path.join(os.getcwd(),'capsule_' + str(21) + '.pt')))
capnet.eval()

tol_label = np.array([], dtype=np.float)
tol_pred = np.array([], dtype=np.float)
tol_pred_prob = np.array([], dtype=np.float)

count = 0
loss_test = 0

for img_data, labels_data in tqdm(dataloader_test):

        labels_data[labels_data > 1] = 1
        img_label = labels_data.numpy().astype(np.float)

        if gpu_id >= 0:
          img_data = img_data.cuda(gpu_id)
          labels_data = labels_data.cuda(gpu_id)

        input_v = Variable(img_data)

        x = vgg_ext(input_v)
        classes, class_ = capnet(x, random=random_setting)

        output_dis = class_.data.cpu()
        output_pred = np.zeros((output_dis.shape[0]), dtype=np.float)

        for i in range(output_dis.shape[0]):
            if output_dis[i,1] >= output_dis[i,0]:
                output_pred[i] = 1.0
            else:
                output_pred[i] = 0.0

        tol_label = np.concatenate((tol_label, img_label))
        tol_pred = np.concatenate((tol_pred, output_pred))
        
        pred_prob = torch.softmax(output_dis, dim=1)
        tol_pred_prob = np.concatenate((tol_pred_prob, pred_prob[:,1].data.numpy()))

        count += 1

acc_test = metrics.accuracy_score(tol_label, tol_pred)
loss_test /= count

fpr, tpr, thresholds = roc_curve(tol_label, tol_pred_prob, pos_label=1)
eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)

# fnr = 1 - tpr
# hter = (fpr + fnr)/2

print('[Epoch %d] Test acc: %.2f   EER: %.2f' % (21, acc_test*100, eer*100))
text_writer_test.write('%d,%.2f,%.2f\n'% (21, acc_test*100, eer*100))

text_writer_test.flush()
text_writer_test.close()

# Train ClassNSeg model

**Public repository:** 
https://github.com/nii-yamagishilab/ClassNSeg 

**Reference:**
H. H. Nguyen, F. Fang, J. Yamagishi, and I. Echizen, “Multi-task Learning for Detecting and Segmenting Manipulated Facial Images and Videos,” Proc. of the 10th IEEE International Conference on Biometrics: Theory, Applications and Systems (BTAS), 8 pages, (September 2019)

## 1. Import dependencies

In [0]:
import torch
import torch.backends.cudnn as cudnn
from torch import nn
from torch.optim import Adam
import torchvision.datasets as dset
import torchvision.transforms as transforms
from tqdm import tqdm
import cv2
import numpy as np
import os
import sys
import random

## 2. ClassNSeg Network architecture

In [0]:
class Encoder(nn.Module):
    def __init__(self, depth=3):
        super(Encoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(depth, 8, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(),

            nn.Conv2d(8, 8, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(),

            nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),

            nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),

            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )

        self.encoder.apply(self.weights_init)

    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(0.5, 0.02)
            m.bias.data.fill_(0)

    def forward(self, x):
        return self.encoder(x)

class Decoder(nn.Module):
    def __init__(self, depth=3):
        super(Decoder, self).__init__()

        self.shared = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1, output_padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=1, padding=1, output_padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
        )

        self.segmenter = nn.Sequential(

            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=1, padding=1, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(),

            nn.ConvTranspose2d(16, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),

            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=1, padding=1, output_padding=0),
            nn.BatchNorm2d(8),
            nn.ReLU(),

            nn.ConvTranspose2d(8, 8, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(),

            nn.ConvTranspose2d(8, 2, kernel_size=3, stride=1, padding=1, output_padding=0),
            nn.Softmax(dim=1)
        )

        self.decoder = nn.Sequential(

            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=1, padding=1, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(),

            nn.ConvTranspose2d(16, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),

            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=1, padding=1, output_padding=0),
            nn.BatchNorm2d(8),
            nn.ReLU(),

            nn.ConvTranspose2d(8, 8, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(),

            nn.ConvTranspose2d(8, depth, kernel_size=3, stride=1, padding=1, output_padding=0),
            nn.Tanh()
        )

        self.segmenter.apply(self.weights_init)
        self.decoder.apply(self.weights_init)

    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(0.5, 0.02)
            m.bias.data.fill_(0)

    def forward(self, x):
        latent = self.shared(x)
        seg = self.segmenter(latent)
        rect = self.decoder(latent)

        return seg, rect

class ActivationLoss(nn.Module):
    def __init__(self):
        super(ActivationLoss, self).__init__()

    def forward(self, zero, one, labels):

        loss_act = torch.abs(one - labels.data) + torch.abs(zero - (1.0 - labels.data))
        return 1 / labels.shape[0] * loss_act.sum()
        
class ReconstructionLoss(nn.Module):
    def __init__(self):
        super(ReconstructionLoss, self).__init__()
        self.loss = nn.MSELoss()

    def forward(self, reconstruction, groundtruth):

        return self.loss(reconstruction, groundtruth.data)

class SegmentationLoss(nn.Module):
    def __init__(self):
        super(SegmentationLoss, self).__init__()
        self.loss = nn.CrossEntropyLoss()

    def forward(self, segment, groundtruth):

        return self.loss(segment.view(segment.shape[0], segment.shape[1], segment.shape[2] * segment.shape[3]), 
            groundtruth.data.view(groundtruth.shape[0], groundtruth.shape[1] * groundtruth.shape[2]))

## 3. Default settings

In [33]:
dataset = '/celeb-df'
train_set = '/train'
val_set = '/validation'
test_set = '/test'
workers = 0
batch_size = 64
manual_seed = random.randint(1, 1000)
print("Random Seed: ", manual_seed)
random.seed(manual_seed)
torch.manual_seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)
cudnn.benchmark = True
learning_rate = 0.001
beta1_for_adam = 0.9
weight_decay = 0.0005
eps = 1e-07
gpu_id = 0
gamma = 1
text_writer_train = open(os.path.join(os.getcwd(), 'train.csv'), 'a')
text_writer_test = open(os.path.join(os.getcwd(), 'classification.txt'), 'w')

Random Seed:  115


## 4. Creating ClassNSeg

In [23]:
encoder = Encoder(3)
decoder = Decoder(3)
act_loss_fn = ActivationLoss()
rect_loss_fn = ReconstructionLoss()
seg_loss_fn = SegmentationLoss()

optimizer_encoder = Adam(encoder.parameters(), lr=learning_rate, betas=(beta1_for_adam, 0.999), weight_decay=weight_decay, eps=eps)
optimizer_decoder = Adam(decoder.parameters(), lr=learning_rate, betas=(beta1_for_adam, 0.999), weight_decay=weight_decay, eps=eps)

encoder.cuda(gpu_id)
decoder.cuda(gpu_id)
act_loss_fn.cuda(gpu_id)
seg_loss_fn.cuda(gpu_id)
rect_loss_fn.cuda(gpu_id)

ReconstructionLoss(
  (loss): MSELoss()
)

## 5. Creating datasets for training, validation, and test

In [31]:
class Normalize_3D(object):
        def __init__(self, mean, std):
            self.mean = mean
            self.std = std

        def __call__(self, tensor):
            """
                Tensor: Normalized image.
            Args:
                tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
            Returns:        """
            for t, m, s in zip(tensor, self.mean, self.std):
                t.sub_(m).div_(s)
            return tensor

class UnNormalize_3D(object):
        def __init__(self, mean, std):
            self.mean = mean
            self.std = std

        def __call__(self, tensor):
            """
                Tensor: Normalized image.
            Args:
                tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
            Returns:        """
            for t, m, s in zip(tensor, self.mean, self.std):
                t.mul_(s).add_(m)
            return tensor

transform_tns = transforms.Compose([
        transforms.ToTensor(),
    ])

transform_pil = transforms.Compose([
        transforms.ToPILImage(),
    ])
    
transform_norm = Normalize_3D((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

transform_unnorm = UnNormalize_3D((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

dataset_train = dset.ImageFolder(root=os.path.join(dataset, train_set), transform=transform_tns)
assert dataset_train
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=int(workers))

dataset_val = dset.ImageFolder(root=os.path.join(dataset, val_set), transform=transform_tns)
assert dataset_val
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=int(workers))

dataset_test = dset.ImageFolder(root=os.path.join(dataset, test_set), transform=transform_tns)
assert dataset_test
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=int(workers))

FileNotFoundError: ignored

## 6. Start pre-processing phase

In [0]:
input_real = '/celeb-df/'
input_fake = ''
mask = ''
output_real = 'deepfakes/real'
output_fake = 'deepfakes/fake'
image_size = 256
limit = 10
scale = 1.3

def to_bw(mask, thresh_binary=10, thresh_otsu=255):
    im_gray = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    (thresh, im_bw) = cv2.threshold(im_gray, thresh_binary, thresh_otsu, cv2.THRESH_BINARY | cv2.THRESH_OTSU)

    return im_bw

def get_bbox(mask, thresh_binary=127, thresh_otsu=255):
    im_bw = to_bw(mask, thresh_binary, thresh_otsu)

    # im2, contours, hierarchy = cv2.findContours(im_bw,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
    contours, hierarchy = cv2.findContours(im_bw,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)

    locations = np.array([], dtype=np.int).reshape(0, 5)

    for c in contours:
        # calculate moments for each contour
        M = cv2.moments(c)
        if M["m00"] > 0:
            cX = int(M["m10"] / M["m00"])
        else:
            cX = 0
        if M["m00"] > 0:
            cY = int(M["m01"] / M["m00"])
        else:
            cY = 0

        # calculate the rectangle bounding box
        x,y,w,h = cv2.boundingRect(c)
        locations = np.concatenate((locations, np.array([[cX, cY, w, h, w + h]])), axis=0)

    max_idex = locations[:,4].argmax()
    bbox = locations[max_idex, 0:4].reshape(4)
    return bbox

def extract_face(image, bbox, scale = 1.0):
    h, w, d = image.shape
    radius = int(bbox[3] * scale / 2)

    y_1 = bbox[1] - radius
    y_2 = bbox[1] + radius
    x_1 = bbox[0] - radius
    x_2 = bbox[0] + radius

    if x_1 < 0:
        x_1 = 0
    if y_1 < 0:
        y_1 = 0
    if x_2 > w:
        x_2 = w
    if y_2 > h:
        y_2 = h

    crop_img = image[y_1:y_2, x_1:x_2]

    if crop_img is not None:
        crop_img = cv2.resize(crop_img, (image_size, image_size))

    return crop_img

def extract_face_videos(input_real, input_fake, input_mask, output_real, output_fake):

    blank_img = np.zeros((image_size,image_size,3), np.uint8)

    for f in os.listdir(input_fake):
        if os.path.isfile(os.path.join(input_fake, f)):
            if f.lower().endswith(('mp4')):
                print(f)
                filename = os.path.splitext(f)[0]

                vidcap_real = cv2.VideoCapture(os.path.join(input_real, filename[0:3] + '.mp4'))
                success_real, image_real = vidcap_real.read()

                vidcap_fake = cv2.VideoCapture(os.path.join(input_fake, f))
                success_fake, image_fake = vidcap_fake.read()

                image_mask = cv2.imread(os.path.join(input_mask, filename, '0000.png'))

                count = 0

                while (success_real and success_fake):

                    bbox = get_bbox(image_mask)

                    if bbox is None:
                        count += 1
                        continue

                    original_cropped = extract_face(image_real, bbox, scale)
                    altered_cropped = extract_face(image_fake, bbox, scale)

                    mask_cropped = to_bw(extract_face(image_mask, bbox, scale))
                    mask_cropped = np.stack((mask_cropped,mask_cropped, mask_cropped), axis=2)

                    if (original_cropped is not None) and (altered_cropped is not None) and (mask_cropped is not None):
                        original_cropped = np.concatenate((original_cropped, blank_img), axis=1)
                        altered_cropped = np.concatenate((altered_cropped, mask_cropped), axis=1)

                        cv2.imwrite(os.path.join(output_real, filename + "_%d.jpg" % count), original_cropped)
                        cv2.imwrite(os.path.join(output_fake, filename + "_%d.jpg" % count), altered_cropped)

                        count += 1

                    if count >= opt.limit:
                        break

                    success_real, image_real = vidcap_real.read()
                    success_fake, image_fake = vidcap_fake.read()
                    image_mask = cv2.imread(os.path.join(input_mask, filename, str(count).zfill(4) + '.png'))

extract_face_videos(input_real, input_fake, mask, output_real, output_fake)

## 7. Start training phase
- Save the fitted model at the end

In [36]:
for epoch in range(1, 101):
        count = 0
        loss_act_train = 0.0
        loss_seg_train = 0.0
        loss_rect_train = 0.0
        loss_act_test = 0.0
        loss_seg_test = 0.0
        loss_rect_test = 0.0

        tol_label = np.array([], dtype=np.float)
        tol_pred = np.array([], dtype=np.float)

        for fft_data, labels_data in tqdm(dataloader_train):

            optimizer_encoder.zero_grad()
            optimizer_decoder.zero_grad()

            fft_label = labels_data.numpy().astype(np.float)
            labels_data = labels_data.float()

            rgb = transform_norm(fft_data[:,:,:,0:256])
            mask = fft_data[:,0,:,256:512]
            mask[mask >= 0.5] = 1.0
            mask[mask < 0.5] = 0.0
            mask = mask.long()

            if gpu_id >= 0:
              rgb = rgb.cuda(gpu_id)
              mask = mask.cuda(gpu_id)
              labels_data = labels_data.cuda(gpu_id)

            latent = encoder(rgb).reshape(-1, 2, 64, 16, 16)

            zero_abs = torch.abs(latent[:,0]).view(latent.shape[0], -1)
            zero = zero_abs.mean(dim=1)

            one_abs = torch.abs(latent[:,1]).view(latent.shape[0], -1)
            one = one_abs.mean(dim=1)

            loss_act = act_loss_fn(zero, one, labels_data)
            loss_act_data = loss_act.item()

            y = torch.eye(2)
            if gpu_id >= 0:
              y = y.cuda(gpu_id)

            y = y.index_select(dim=0, index=labels_data.data.long())

            latent = (latent * y[:,:,None, None, None]).reshape(-1, 128, 16, 16)

            seg, rect = decoder(latent)

            loss_seg = seg_loss_fn(seg, mask)
            loss_seg = loss_seg * gamma
            loss_seg_data = loss_seg.item()

            loss_rect = rect_loss_fn(rect, rgb)
            loss_rect = loss_rect * gamma
            loss_rect_data = loss_rect.item()

            loss_total = loss_act + loss_seg + loss_rect
            loss_total.backward()

            optimizer_decoder.step()
            optimizer_encoder.step()

            output_pred = np.zeros((fft_data.shape[0]), dtype=np.float)

            for i in range(fft_data.shape[0]):
                if one[i] >= zero[i]:
                    output_pred[i] = 1.0
                else:
                    output_pred[i] = 0.0

            tol_label = np.concatenate((tol_label, fft_label))
            tol_pred = np.concatenate((tol_pred, output_pred))

            loss_act_train += loss_act_data
            loss_seg_train += loss_seg_data
            loss_rect_train += loss_rect_data
            count += 1

        acc_train = metrics.accuracy_score(tol_label, tol_pred)
        loss_act_train /= count
        loss_seg_train /= count
        loss_rect_train /= count

        ########################################################################
        # do checkpointing & validation

        torch.save(encoder.state_dict(), os.path.join(os.getcwd(), 'encoder_%d.pt' % epoch))
        torch.save(optimizer_encoder.state_dict(), os.path.join(os.getcwd(), 'optim_encoder_%d.pt' % epoch))

        torch.save(decoder.state_dict(), os.path.join(os.getcwd(), 'decoder_%d.pt' % epoch))
        torch.save(optimizer_decoder.state_dict(), os.path.join(os.getcwd(), 'optim_decoder_%d.pt' % epoch))

        encoder.eval()
        decoder.eval()

        tol_label = np.array([], dtype=np.float)
        tol_pred = np.array([], dtype=np.float)
        tol_pred_prob = np.array([], dtype=np.float)

        count = 0

        for fft_data, labels_data in tqdm(dataloader_val):

            fft_label = labels_data.numpy().astype(np.float)
            labels_data = labels_data.float()

            rgb = transform_norm(fft_data[:,:,:,0:256])
            mask = fft_data[:,0,:,256:512]
            mask[mask >= 0.5] = 1.0
            mask[mask < 0.5] = 0.0
            mask = mask.long()

            if gpu_id >= 0:
                rgb = rgb.cuda(gpu_id)
                mask = mask.cuda(gpu_id)
                labels_data = labels_data.cuda(gpu_id)

            latent = encoder(rgb).reshape(-1, 2, 64, 16, 16)

            zero_abs = torch.abs(latent[:,0]).view(latent.shape[0], -1)
            zero = zero_abs.mean(dim=1)

            one_abs = torch.abs(latent[:,1]).view(latent.shape[0], -1)
            one = one_abs.mean(dim=1)

            loss_act = act_loss_fn(zero, one, labels_data)
            loss_act_data = loss_act.item()

            y = torch.eye(2)
            if gpu_id >= 0:
                y = y.cuda(gpu_id)

            y = y.index_select(dim=0, index=labels_data.data.long())

            latent = (latent * y[:,:,None, None, None]).reshape(-1, 128, 16, 16)

            seg, rect = decoder(latent)

            loss_seg = seg_loss_fn(seg, mask)
            loss_seg = loss_seg * gamma
            loss_seg_data = loss_seg.item()

            loss_rect = rect_loss_fn(rect, rgb)
            loss_rect = loss_rect * gamma
            loss_rect_data = loss_rect.item()

            output_pred = np.zeros((fft_data.shape[0]), dtype=np.float)

            for i in range(fft_data.shape[0]):
                if one[i] >= zero[i]:
                    output_pred[i] = 1.0
                else:
                    output_pred[i] = 0.0

            tol_label = np.concatenate((tol_label, fft_label))
            tol_pred = np.concatenate((tol_pred, output_pred))
            
            pred_prob = torch.softmax(torch.cat((zero.reshape(zero.shape[0],1), one.reshape(one.shape[0],1)), dim=1), dim=1)
            tol_pred_prob = np.concatenate((tol_pred_prob, pred_prob[:,1].data.cpu().numpy()))

            loss_act_test += loss_act_data
            loss_seg_test += loss_seg_data
            loss_rect_test += loss_rect_data
            count += 1

        acc_test = metrics.accuracy_score(tol_label, tol_pred)
        loss_act_test /= count
        loss_seg_test /= count
        loss_rect_test /= count

        fpr, tpr, thresholds = roc_curve(tol_label, tol_pred_prob, pos_label=1)
        eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)

        print('[Epoch %d] Train: act_loss: %.4f  seg_loss: %.4f  rect_loss: %.4f  acc: %.2f | Test: act_loss: %.4f  seg_loss: %.4f  rect_loss: %.4f  acc: %.2f  eer: %.2f'
        % (epoch, loss_act_train, loss_seg_train, loss_rect_train, acc_train*100, loss_act_test, loss_seg_test, loss_rect_test, acc_test*100, eer*100))

        text_writer.write('%d,%.4f,%.4f,%.4f,%.2f,%.4f,%.4f,%.4f,%.2f,%.2f\n'
        % (epoch, loss_act_train, loss_seg_train, loss_rect_train, acc_train*100, loss_act_test, loss_seg_test, loss_rect_test, acc_test*100, eer*100))

        text_writer_train.flush()

        ########################################################################

        real_img = transform_tns(Image.open(os.path.join('test_img', 'real.jpg'))).unsqueeze(0)[:,:,:,0:256]
        real_mask = transform_tns(Image.open(os.path.join('test_img', 'real.jpg'))).unsqueeze(0)[:,:,:,256:512]
        fake_img = transform_tns(Image.open(os.path.join('test_img', 'fake.jpg'))).unsqueeze(0)[:,:,:,0:256]
        fake_mask = transform_tns(Image.open(os.path.join('test_img', 'fake.jpg'))).unsqueeze(0)[:,:,:,256:512]

        rgb = torch.cat((real_img, fake_img), dim=0)
        rgb = transform_norm(rgb)

        real_mask[real_mask >= 0.5] = 1.0
        real_mask[real_mask < 0.5] = 0.0
        real_mask = real_mask.long()

        fake_mask[fake_mask >= 0.5] = 1.0
        fake_mask[fake_mask < 0.5] = 0.0
        fake_mask = fake_mask.long()

        # real = 1, fake = 0
        labels_data = torch.FloatTensor([1,0])

        if gpu_id >= 0:
          rgb = rgb.cuda(gpu_id)
          labels_data = labels_data.cuda(gpu_id)

        latent = encoder(rgb).reshape(-1, 2, 64, 16, 16)

        zero_abs = torch.abs(latent[:,0]).view(latent.shape[0], -1)
        zero = zero_abs.mean(dim=1)

        one_abs = torch.abs(latent[:,1]).view(latent.shape[0], -1)
        one = one_abs.mean(dim=1)

        y = torch.eye(2)

        if gpu_id >= 0:
          y = y.cuda(gpu_id)

        y = y.index_select(dim=0, index=labels_data.data.long())

        latent = (latent * y[:,:,None, None, None]).reshape(-1, 128, 16, 16)

        seg, rect = decoder(latent)

        seg = seg[:,1,:,:].detach().cpu()
        seg[seg >= 0.5] = 1.0
        seg[seg < 0.5] = 0.0

        rect = transform_unnorm(rect).detach().cpu()

        real_seg = transform_pil(seg[0])
        fake_seg = transform_pil(seg[1])

        real_img = transform_pil(rect[0])
        fake_img = transform_pil(rect[1])

        real_seg.save(os.path.join(os.getcwd(), 'image', 'seg_real_' + str(epoch).zfill(3) + '.jpg'))
        fake_seg.save(os.path.join(os.getcwd(), 'image', 'seg_fake_' + str(epoch).zfill(3) + '.jpg'))

        real_img.save(os.path.join(os.getcwd(), 'image', 'real_' + str(epoch).zfill(3) + '.jpg'))
        fake_img.save(os.path.join(os.getcwd(), 'image', 'fake_' + str(epoch).zfill(3) + '.jpg'))

        encoder.train(mode=True)
        decoder.train(mode=True)

text_writer_train.close()

NameError: ignored

## 8. Start evaluation phase

In [4]:
encoder.load_state_dict(torch.load(os.path.join(os.getcwd(),'encoder_' + str(46) + '.pt')))
encoder.eval()

loss_act_test = 0.0

tol_label = np.array([], dtype=np.float)
tol_pred = np.array([], dtype=np.float)
tol_pred_prob = np.array([], dtype=np.float)

count = 0

for fft_data, labels_data in tqdm(dataloader_test):

        fft_label = labels_data.numpy().astype(np.float)
        labels_data = labels_data.float()

        rgb = transform_norm(fft_data[:,:,:,0:256])

        if gpu_id >= 0:
          rgb = rgb.cuda(gpu_id)
          labels_data = labels_data.cuda(gpu_id)

        latent = encoder(rgb).reshape(-1, 2, 64, 16, 16)

        zero_abs = torch.abs(latent[:,0]).view(latent.shape[0], -1)
        zero = zero_abs.mean(dim=1)

        one_abs = torch.abs(latent[:,1]).view(latent.shape[0], -1)
        one = one_abs.mean(dim=1)

        loss_act = act_loss_fn(zero, one, labels_data)
        loss_act_data = loss_act.item()

        output_pred = np.zeros((fft_data.shape[0]), dtype=np.float)

        for i in range(fft_data.shape[0]):
            if one[i] >= zero[i]:
                output_pred[i] = 1.0
            else:
                output_pred[i] = 0.0

        tol_label = np.concatenate((tol_label, fft_label))
        tol_pred = np.concatenate((tol_pred, output_pred))
        
        pred_prob = torch.softmax(torch.cat((zero.reshape(zero.shape[0],1), one.reshape(one.shape[0],1)), dim=1), dim=1)
        tol_pred_prob = np.concatenate((tol_pred_prob, pred_prob[:,1].data.cpu().numpy()))

        loss_act_test += loss_act_data
        count += 1

acc_test = metrics.accuracy_score(tol_label, tol_pred)
loss_act_test /= count

fpr, tpr, thresholds = roc_curve(tol_label, tol_pred_prob, pos_label=1)
eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)

print('[Epoch %d] act_loss: %.4f  acc: %.2f   eer: %.2f' % (opt.id, loss_act_test, acc_test*100, eer*100))
text_writer_test.write('%d,%.4f,%.2f,%.2f\n'% (46, loss_act_test, acc_test*100, eer*100))

text_writer_test.flush()
text_writer_test.close()

NameError: ignored

# Train DSP-FWA model

In [0]:
import torch
from torch import nn
from torchvision import models
import torch.nn.functional as F
import os, math

In [0]:
class ResNet(nn.Module):
    def __init__(self, layers=18, num_class=2, pretrained=True):
        super(ResNet, self).__init__()
        if layers == 18:
            self.resnet = models.resnet18(pretrained=pretrained)
        elif layers == 34:
            self.resnet = models.resnet34(pretrained=pretrained)
        elif layers == 50:
            self.resnet = models.resnet50(pretrained=pretrained)
        elif layers == 101:
            self.resnet = models.resnet101(pretrained=pretrained)
        elif layers == 152:
            self.resnet = models.resnet152(pretrained=pretrained)
        else:
            raise ValueError('layers should be 18, 34, 50, 101.')
        self.num_class = num_class
        if layers in [18, 34]:
            self.fc = nn.Linear(512, num_class)
        if layers in [50, 101, 152]:
            self.fc = nn.Linear(512 * 4, num_class)

    def conv_base(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        layer1 = self.resnet.layer1(x)
        layer2 = self.resnet.layer2(layer1)
        layer3 = self.resnet.layer3(layer2)
        layer4 = self.resnet.layer4(layer3)
        return layer1, layer2, layer3, layer4

    def forward(self, x):
        layer1, layer2, layer3, layer4 = self.conv_base(x)
        x = self.resnet.avgpool(layer4)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x



class SPPNet(nn.Module):
    def __init__(self, backbone=101, num_class=2, pool_size=(1, 2, 6), pretrained=True):
        # Only resnet is supported in this version
        super(SPPNet, self).__init__()
        if backbone in [18, 34, 50, 101, 152]:
            self.resnet = ResNet(backbone, num_class, pretrained)
        else:
            raise ValueError('Resnet{} is not supported yet.'.format(backbone))

        if backbone in [18, 34]:
            self.c = 512
        if backbone in [50, 101, 152]:
            self.c = 2048

        self.spp = SpatialPyramidPool2D(out_side=pool_size)
        num_features = self.c * (pool_size[0] ** 2 + pool_size[1] ** 2 + pool_size[2] ** 2)
        self.classifier = nn.Linear(num_features, num_class)

    def forward(self, x):
        _, _, _, x = self.resnet.conv_base(x)
        x = self.spp(x)
        x = self.classifier(x)
        return x


class SpatialPyramidPool2D(nn.Module):
    """
    Args:
        out_side (tuple): Length of side in the pooling results of each pyramid layer.

    Inputs:
        - `input`: the input Tensor to invert ([batch, channel, width, height])
    """

    def __init__(self, out_side):
        super(SpatialPyramidPool2D, self).__init__()
        self.out_side = out_side

    def forward(self, x):
        # batch_size, c, h, w = x.size()
        out = None
        for n in self.out_side:
            w_r, h_r = map(lambda s: math.ceil(s / n), x.size()[2:])  # Receptive Field Size
            s_w, s_h = map(lambda s: math.floor(s / n), x.size()[2:])  # Stride
            max_pool = nn.MaxPool2d(kernel_size=(w_r, h_r), stride=(s_w, s_h))
            y = max_pool(x)
            if out is None:
                out = y.view(y.size()[0], -1)
            else:
                out = torch.cat((out, y.view(y.size()[0], -1)), 1)
        return out

# Train Ictu Oculi model