# Instructions when using Google Colab and Google Cloud Storage

## Authenticate

In [0]:
from google.colab import auth
auth.authenticate_user()

## Load data from Cloud Storage

### Install Cloud Storage FUSE

In [0]:
!echo "deb http://packages.cloud.google.com/apt gcsfuse-bionic main" > /etc/apt/sources.list.d/gcsfuse.list
!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -
!apt -qq update
!apt -qq install gcsfuse

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100   653  100   653    0     0   9069      0 --:--:-- --:--:-- --:--:--  9069
OK
54 packages can be upgraded. Run 'apt list --upgradable' to see them.
The following NEW packages will be installed:
  gcsfuse
0 upgraded, 1 newly installed, 0 to remove and 54 not upgraded.
Need to get 4,274 kB of archives.
After this operation, 12.8 MB of additional disk space will be used.
Selecting previously unselected package gcsfuse.
(Reading database ... 144568 files and directories currently installed.)
Preparing to unpack .../gcsfuse_0.28.1_amd64.deb ...
Unpacking gcsfuse (0.28.1) ...
Setting up gcsfuse (0.28.1) ...


### Create connection between Cloud Storage and Google Colab

In [0]:
# Connecting with dataset buckets
!mkdir data
!mkdir data/celeb-df
!gcsfuse celeb-df data/celeb-df
!ls data/celeb-df | wc

# Connecting with checkpoints for the models
!mkdir checkpoints
!gcsfuse checkpoints_models checkpoints
!ls checkpoints | wc

mkdir: cannot create directory ‘data’: File exists
mkdir: cannot create directory ‘data/celeb-df’: File exists
Using mount point: /content/data/celeb-df
Opening GCS connection...
Opening bucket...
Mounting file system...
File system has been successfully mounted.
      3       3      22
mkdir: cannot create directory ‘checkpoints’: File exists
Using mount point: /content/checkpoints
Opening GCS connection...
Opening bucket...
Mounting file system...
File system has been successfully mounted.
      2       2      18


# Train Capsule model

**Public repository:** 
https://github.com/nii-yamagishilab/Capsule-Forensics-v2

**Reference:**
H. H. Nguyen, J. Yamagishi, and I. Echizen, “Use of a Capsule Network to Detect Fake Images and Videos,” arXiv preprint arXiv:1910.12467. 2019 Oct 29.

## 1. Import dependencies

In [0]:
import os
import random
import torch
import torch.nn.functional as F
from torch import nn
import torch.backends.cudnn as cudnn
import numpy as np
from torch.autograd import Variable
from torch.optim import Adam
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from tqdm import tqdm
from sklearn import metrics

## 2. Capsule Network architecture

In [0]:
NO_CAPS = 10

In [0]:
class StatsNet(nn.Module):
    def __init__(self):
        super(StatsNet, self).__init__()

    def forward(self, x):
        x = x.view(x.data.shape[0], x.data.shape[1], x.data.shape[2]*x.data.shape[3])

        mean = torch.mean(x, 2)
        std = torch.std(x, 2)

        return torch.stack((mean, std), dim=1)

class View(nn.Module):
    def __init__(self, *shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, input):
        return input.view(self.shape)

class VggExtractor(nn.Module):
    def __init__(self, train=False):
        super(VggExtractor, self).__init__()

        self.vgg_1 = self.Vgg(models.vgg19(pretrained=True), 0, 18)
        if train:
            self.vgg_1.train(mode=True)
            self.freeze_gradient()
        else:
            self.vgg_1.eval()

    def Vgg(self, vgg, begin, end):
        features = nn.Sequential(*list(vgg.features.children())[begin:(end+1)])
        return features

    def freeze_gradient(self, begin=0, end=9):
        for i in range(begin, end+1):
            self.vgg_1[i].requires_grad = False

    def forward(self, input):
        return self.vgg_1(input)

class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()

        self.capsules = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(256, 64, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.Conv2d(64, 16, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                StatsNet(),

                nn.Conv1d(2, 8, kernel_size=5, stride=2, padding=2),
                nn.BatchNorm1d(8),
                nn.Conv1d(8, 1, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm1d(1),
                View(-1, 8),
                )
                for _ in range(NO_CAPS)]
        )

    def squash(self, tensor, dim):
        squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        return scale * tensor / (torch.sqrt(squared_norm))

    def forward(self, x):
        # outputs = [capsule(x.detach()) for capsule in self.capsules]
        # outputs = [capsule(x.clone()) for capsule in self.capsules]
        outputs = [capsule(x) for capsule in self.capsules]
        output = torch.stack(outputs, dim=-1)

        return self.squash(output, dim=-1)

class RoutingLayer(nn.Module):
    def __init__(self, gpu_id, num_input_capsules, num_output_capsules, data_in, data_out, num_iterations):
        super(RoutingLayer, self).__init__()

        self.gpu_id = gpu_id
        self.num_iterations = num_iterations
        self.route_weights = nn.Parameter(torch.randn(num_output_capsules, num_input_capsules, data_out, data_in))


    def squash(self, tensor, dim):
        squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        return scale * tensor / (torch.sqrt(squared_norm))

    def forward(self, x, random, dropout):
        # x[b, data, in_caps]

        x = x.transpose(2, 1)
        # x[b, in_caps, data]

        if random:
            noise = Variable(0.01*torch.randn(*self.route_weights.size()))
            if self.gpu_id >= 0:
                noise = noise.cuda(self.gpu_id)
            route_weights = self.route_weights + noise
        else:
            route_weights = self.route_weights

        priors = route_weights[:, None, :, :, :] @ x[None, :, :, :, None]

        # route_weights [out_caps , 1 , in_caps , data_out , data_in]
        # x             [   1     , b , in_caps , data_in ,    1    ]
        # priors        [out_caps , b , in_caps , data_out,    1    ]

        priors = priors.transpose(1, 0)
        # priors[b, out_caps, in_caps, data_out, 1]

        if dropout > 0.0:
            drop = Variable(torch.FloatTensor(*priors.size()).bernoulli(1.0- dropout))
            if self.gpu_id >= 0:
                drop = drop.cuda(self.gpu_id)
            priors = priors * drop
            

        logits = Variable(torch.zeros(*priors.size()))
        # logits[b, out_caps, in_caps, data_out, 1]

        if self.gpu_id >= 0:
            logits = logits.cuda(self.gpu_id)

        num_iterations = self.num_iterations

        for i in range(num_iterations):
            probs = F.softmax(logits, dim=2)
            outputs = self.squash((probs * priors).sum(dim=2, keepdim=True), dim=3)

            if i != self.num_iterations - 1:
                delta_logits = priors * outputs
                logits = logits + delta_logits

        # outputs[b, out_caps, 1, data_out, 1]
        outputs = outputs.squeeze()

        if len(outputs.shape) == 3:
            outputs = outputs.transpose(2, 1).contiguous() 
        else:
            outputs = outputs.unsqueeze_(dim=0).transpose(2, 1).contiguous()
        # outputs[b, data_out, out_caps]

        return outputs

class CapsuleNet(nn.Module):
    def __init__(self, num_class, gpu_id):
        super(CapsuleNet, self).__init__()

        self.num_class = num_class
        self.fea_ext = FeatureExtractor()
        self.fea_ext.apply(self.weights_init)

        self.routing_stats = RoutingLayer(gpu_id=gpu_id, num_input_capsules=NO_CAPS, num_output_capsules=num_class, data_in=8, data_out=4, num_iterations=2)

    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    def forward(self, x, random=False, dropout=0.0):

        z = self.fea_ext(x)
        z = self.routing_stats(z, random, dropout=dropout)
        # z[b, data, out_caps]

        # classes = F.softmax(z, dim=-1)

        # class_ = classes.detach()
        # class_ = class_.mean(dim=1)

        # return classes, class_

        classes = F.softmax(z, dim=-1)
        class_ = classes.detach()
        class_ = class_.mean(dim=1)

        return z, class_

class CapsuleLoss(nn.Module):
    def __init__(self, gpu_id):
        super(CapsuleLoss, self).__init__()
        self.cross_entropy_loss = nn.CrossEntropyLoss()

        if gpu_id >= 0:
            self.cross_entropy_loss.cuda(gpu_id)

    def forward(self, classes, labels):
        loss_t = self.cross_entropy_loss(classes[:,0,:], labels)

        for i in range(classes.size(1) - 1):
            loss_t = loss_t + self.cross_entropy_loss(classes[:,i+1,:], labels)

        return loss_t

## 3. Default settings



In [0]:
dataset = os.getcwd() + '/celeb-df'
train_set = '/train'
val_set = '/validation'
test_set = '/test'
workers = 0
batch_size = 32
image_size = 300
learning_rate = 0.0005
beta1_for_adam = 0.9
gpu_id = 0
manual_seed = random.randint(1, 1000)
print("Random Seed: ", manual_seed)
random.seed(manual_seed)
torch.manual_seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)
cudnn.benchmark = True
disable_random_routing_matrix = False
random_setting = not disable_random_routing_matrix
text_writer_train = open(os.path.join(os.getcwd(), 'train.csv'), 'a')
text_writer_test = open(os.path.join(os.getcwd(), 'test.txt'), 'w')

Random Seed:  287


## 4. Creating Capsule

In [0]:
vgg_ext = VggExtractor()
capnet = CapsuleNet(2, gpu_id)
capsule_loss = CapsuleLoss(gpu_id)
optimizer = Adam(capnet.parameters(), lr=learning_rate, betas=(beta1_for_adam, 0.999))

capnet.cuda(gpu_id)
vgg_ext.cuda(gpu_id)
capsule_loss.cuda(gpu_id)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/checkpoints/vgg19-dcbb9e9d.pth


HBox(children=(IntProgress(value=0, max=574673361), HTML(value='')))




CapsuleLoss(
  (cross_entropy_loss): CrossEntropyLoss()
)

## 5. Creating datasets for training, validation, and test

In [0]:
transform_fwd = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

dataset_train = dset.ImageFolder(root=dataset + train_set, transform=transform_fwd)
assert dataset_train
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=int(workers))

dataset_val = dset.ImageFolder(root=dataset + val_set, transform=transform_fwd)
assert dataset_val
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=int(workers))

dataset_test = dset.ImageFolder(root=dataset + test_set, transform=transform_fwd)
assert dataset_test
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=int(workers))

RuntimeError: ignored

## 6. Start training phase
- Save the fitted model at the end

In [0]:
for epoch in range(1, 26):
        count = 0
        loss_train = 0
        loss_test = 0

        tol_label = np.array([], dtype=np.float)
        tol_pred = np.array([], dtype=np.float)

        for img_data, labels_data in tqdm(dataloader_train):

            labels_data[labels_data > 1] = 1
            img_label = labels_data.numpy().astype(np.float)
            optimizer.zero_grad()

            if gpu_id >= 0:
              img_data = img_data.cuda(gpu_id)
              labels_data = labels_data.cuda(gpu_id)

            input_v = Variable(img_data)
            x = vgg_ext(input_v)
            classes, class_ = capnet(x, random=optrandom, dropout=0.05)

            loss_dis = capsule_loss(classes, Variable(labels_data, requires_grad=False))
            loss_dis_data = loss_dis.item()

            loss_dis.backward()
            optimizer.step()

            output_dis = class_.data.cpu().numpy()
            output_pred = np.zeros((output_dis.shape[0]), dtype=np.float)

            for i in range(output_dis.shape[0]):
                if output_dis[i,1] >= output_dis[i,0]:
                    output_pred[i] = 1.0
                else:
                    output_pred[i] = 0.0

            tol_label = np.concatenate((tol_label, img_label))
            tol_pred = np.concatenate((tol_pred, output_pred))

            loss_train += loss_dis_data
            count += 1


        acc_train = metrics.accuracy_score(tol_label, tol_pred)
        loss_train /= count

        ########################################################################

        # do checkpointing & validation
        torch.save(capnet.state_dict(), os.path.join(os.getcwd(), 'capsule_%d.pt' % epoch))
        torch.save(optimizer.state_dict(), os.path.join(os.getcwd(), 'optim_%d.pt' % epoch))

        capnet.eval()

        tol_label = np.array([], dtype=np.float)
        tol_pred = np.array([], dtype=np.float)

        count = 0

        for img_data, labels_data in dataloader_val:

            labels_data[labels_data > 1] = 1
            img_label = labels_data.numpy().astype(np.float)

            if gpu_id >= 0:
              img_data = img_data.cuda(gpu_id)
              labels_data = labels_data.cuda(gpu_id)

            input_v = Variable(img_data)

            x = vgg_ext(input_v)
            classes, class_ = capnet(x, random=False)

            loss_dis = capsule_loss(classes, Variable(labels_data, requires_grad=False))
            loss_dis_data = loss_dis.item()
            output_dis = class_.data.cpu().numpy()

            output_pred = np.zeros((output_dis.shape[0]), dtype=np.float)

            for i in range(output_dis.shape[0]):
                if output_dis[i,1] >= output_dis[i,0]:
                    output_pred[i] = 1.0
                else:
                    output_pred[i] = 0.0

            tol_label = np.concatenate((tol_label, img_label))
            tol_pred = np.concatenate((tol_pred, output_pred))

            loss_test += loss_dis_data
            count += 1

        acc_test = metrics.accuracy_score(tol_label, tol_pred)
        loss_test /= count

        print('[Epoch %d] Train loss: %.4f   acc: %.2f | Test loss: %.4f  acc: %.2f'
        % (epoch, loss_train, acc_train*100, loss_test, acc_test*100))

        text_writer.write('%d,%.4f,%.2f,%.4f,%.2f\n'
        % (epoch, loss_train, acc_train*100, loss_test, acc_test*100))

        text_writer.flush()
        capnet.train(mode=True)

text_writer.close()

100%|██████████| 1/1 [00:00<00:00,  1.34it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 1] Train loss: 2.7654   acc: 100.00 | Test loss: 2.7647  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.47it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 2] Train loss: 1.6929   acc: 100.00 | Test loss: 2.7656  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.46it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 3] Train loss: 1.6828   acc: 100.00 | Test loss: 2.7663  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.48it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 4] Train loss: 1.6399   acc: 100.00 | Test loss: 2.7670  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.46it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 5] Train loss: 1.6405   acc: 100.00 | Test loss: 2.7675  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.47it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 6] Train loss: 1.6311   acc: 100.00 | Test loss: 2.7680  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.49it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 7] Train loss: 1.6214   acc: 100.00 | Test loss: 2.7684  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.48it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 8] Train loss: 1.6075   acc: 100.00 | Test loss: 2.7687  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.46it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 9] Train loss: 1.5987   acc: 100.00 | Test loss: 2.7690  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.45it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 10] Train loss: 1.5939   acc: 100.00 | Test loss: 2.7692  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.47it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 11] Train loss: 1.5998   acc: 100.00 | Test loss: 2.7694  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.45it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 12] Train loss: 1.5997   acc: 100.00 | Test loss: 2.7696  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.45it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 13] Train loss: 1.5835   acc: 100.00 | Test loss: 2.7698  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.47it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 14] Train loss: 1.5941   acc: 100.00 | Test loss: 2.7699  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.41it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 15] Train loss: 1.5962   acc: 100.00 | Test loss: 2.7700  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.43it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 16] Train loss: 1.5974   acc: 100.00 | Test loss: 2.7701  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.47it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 17] Train loss: 1.6057   acc: 100.00 | Test loss: 2.7701  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.46it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 18] Train loss: 1.5824   acc: 100.00 | Test loss: 2.7701  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.46it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 19] Train loss: 1.5678   acc: 100.00 | Test loss: 2.7701  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.41it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 20] Train loss: 1.5796   acc: 100.00 | Test loss: 2.7701  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.47it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 21] Train loss: 1.5892   acc: 100.00 | Test loss: 2.7700  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.43it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 22] Train loss: 1.5883   acc: 100.00 | Test loss: 2.7699  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.42it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 23] Train loss: 1.5779   acc: 100.00 | Test loss: 2.7698  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.44it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

[Epoch 24] Train loss: 1.5814   acc: 100.00 | Test loss: 2.7696  acc: 100.00


100%|██████████| 1/1 [00:00<00:00,  1.40it/s]


[Epoch 25] Train loss: 1.5883   acc: 100.00 | Test loss: 2.7694  acc: 100.00


## 7. Start evaluation phase

In [0]:
capnet.load_state_dict(torch.load(os.path.join(os.getcwd(),'capsule_' + str(21) + '.pt')))
capnet.eval()

tol_label = np.array([], dtype=np.float)
tol_pred = np.array([], dtype=np.float)
tol_pred_prob = np.array([], dtype=np.float)

count = 0
loss_test = 0

for img_data, labels_data in tqdm(dataloader_test):

        labels_data[labels_data > 1] = 1
        img_label = labels_data.numpy().astype(np.float)

        if gpu_id >= 0:
          img_data = img_data.cuda(gpu_id)
          labels_data = labels_data.cuda(gpu_id)

        input_v = Variable(img_data)

        x = vgg_ext(input_v)
        classes, class_ = capnet(x, random=random_setting)

        output_dis = class_.data.cpu()
        output_pred = np.zeros((output_dis.shape[0]), dtype=np.float)

        for i in range(output_dis.shape[0]):
            if output_dis[i,1] >= output_dis[i,0]:
                output_pred[i] = 1.0
            else:
                output_pred[i] = 0.0

        tol_label = np.concatenate((tol_label, img_label))
        tol_pred = np.concatenate((tol_pred, output_pred))
        
        pred_prob = torch.softmax(output_dis, dim=1)
        tol_pred_prob = np.concatenate((tol_pred_prob, pred_prob[:,1].data.numpy()))

        count += 1

acc_test = metrics.accuracy_score(tol_label, tol_pred)
loss_test /= count

fpr, tpr, thresholds = roc_curve(tol_label, tol_pred_prob, pos_label=1)
eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)

# fnr = 1 - tpr
# hter = (fpr + fnr)/2

print('[Epoch %d] Test acc: %.2f   EER: %.2f' % (21, acc_test*100, eer*100))
text_writer_test.write('%d,%.2f,%.2f\n'% (21, acc_test*100, eer*100))

text_writer_test.flush()
text_writer_test.close()

# Train ClassNSeg model

**Public repository:** 
https://github.com/nii-yamagishilab/ClassNSeg 

**Reference:**
H. H. Nguyen, F. Fang, J. Yamagishi, and I. Echizen, “Multi-task Learning for Detecting and Segmenting Manipulated Facial Images and Videos,” Proc. of the 10th IEEE International Conference on Biometrics: Theory, Applications and Systems (BTAS), 8 pages, (September 2019)

## 1. Import dependencies

In [0]:
import torch
import torch.backends.cudnn as cudnn
from torch import nn
from torch.optim import Adam
import torchvision.datasets as dset
import torchvision.transforms as transforms
from tqdm import tqdm
import cv2
import numpy as np
import os
import sys
import random

## 2. ClassNSeg Network architecture

In [0]:
class Encoder(nn.Module):
    def __init__(self, depth=3):
        super(Encoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(depth, 8, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(),

            nn.Conv2d(8, 8, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(),

            nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),

            nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),

            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )

        self.encoder.apply(self.weights_init)

    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(0.5, 0.02)
            m.bias.data.fill_(0)

    def forward(self, x):
        return self.encoder(x)

class Decoder(nn.Module):
    def __init__(self, depth=3):
        super(Decoder, self).__init__()

        self.shared = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1, output_padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=1, padding=1, output_padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
        )

        self.segmenter = nn.Sequential(

            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=1, padding=1, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(),

            nn.ConvTranspose2d(16, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),

            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=1, padding=1, output_padding=0),
            nn.BatchNorm2d(8),
            nn.ReLU(),

            nn.ConvTranspose2d(8, 8, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(),

            nn.ConvTranspose2d(8, 2, kernel_size=3, stride=1, padding=1, output_padding=0),
            nn.Softmax(dim=1)
        )

        self.decoder = nn.Sequential(

            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=1, padding=1, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(),

            nn.ConvTranspose2d(16, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),

            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=1, padding=1, output_padding=0),
            nn.BatchNorm2d(8),
            nn.ReLU(),

            nn.ConvTranspose2d(8, 8, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(),

            nn.ConvTranspose2d(8, depth, kernel_size=3, stride=1, padding=1, output_padding=0),
            nn.Tanh()
        )

        self.segmenter.apply(self.weights_init)
        self.decoder.apply(self.weights_init)

    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(0.5, 0.02)
            m.bias.data.fill_(0)

    def forward(self, x):
        latent = self.shared(x)
        seg = self.segmenter(latent)
        rect = self.decoder(latent)

        return seg, rect

class ActivationLoss(nn.Module):
    def __init__(self):
        super(ActivationLoss, self).__init__()

    def forward(self, zero, one, labels):

        loss_act = torch.abs(one - labels.data) + torch.abs(zero - (1.0 - labels.data))
        return 1 / labels.shape[0] * loss_act.sum()
        
class ReconstructionLoss(nn.Module):
    def __init__(self):
        super(ReconstructionLoss, self).__init__()
        self.loss = nn.MSELoss()

    def forward(self, reconstruction, groundtruth):

        return self.loss(reconstruction, groundtruth.data)

class SegmentationLoss(nn.Module):
    def __init__(self):
        super(SegmentationLoss, self).__init__()
        self.loss = nn.CrossEntropyLoss()

    def forward(self, segment, groundtruth):

        return self.loss(segment.view(segment.shape[0], segment.shape[1], segment.shape[2] * segment.shape[3]), 
            groundtruth.data.view(groundtruth.shape[0], groundtruth.shape[1] * groundtruth.shape[2]))

## 3. Default settings

In [0]:
dataset = '/celeb-df'
train_set = '/train'
val_set = '/validation'
test_set = '/test'
workers = 0
batch_size = 64
manual_seed = random.randint(1, 1000)
print("Random Seed: ", manual_seed)
random.seed(manual_seed)
torch.manual_seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)
cudnn.benchmark = True
learning_rate = 0.001
beta1_for_adam = 0.9
weight_decay = 0.0005
eps = 1e-07
gpu_id = 0
gamma = 1
text_writer_train = open(os.path.join(os.getcwd(), 'train.csv'), 'a')
text_writer_test = open(os.path.join(os.getcwd(), 'classification.txt'), 'w')

Random Seed:  371


## 4. Creating ClassNSeg

In [0]:
encoder = Encoder(3)
decoder = Decoder(3)
act_loss_fn = ActivationLoss()
rect_loss_fn = ReconstructionLoss()
seg_loss_fn = SegmentationLoss()

optimizer_encoder = Adam(encoder.parameters(), lr=learning_rate, betas=(beta1_for_adam, 0.999), weight_decay=weight_decay, eps=eps)
optimizer_decoder = Adam(decoder.parameters(), lr=learning_rate, betas=(beta1_for_adam, 0.999), weight_decay=weight_decay, eps=eps)

encoder.cuda(gpu_id)
decoder.cuda(gpu_id)
act_loss_fn.cuda(gpu_id)
seg_loss_fn.cuda(gpu_id)
rect_loss_fn.cuda(gpu_id)

ReconstructionLoss(
  (loss): MSELoss()
)

## 5. Start pre-processing phase

In [0]:
input_real = os.getcwd() + '/celeb-df/validation/0_real_videos/'
input_fake = os.getcwd() + '/celeb-df/validation/1_fake_videos/'
mask = ''
output_real = os.getcwd() + '/deepfakes/real'
output_fake = os.getcwd() + '/deepfakes/fake'
image_size = 256
limit = 10
scale = 1.3

def to_bw(mask, thresh_binary=10, thresh_otsu=255):
    im_gray = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    (thresh, im_bw) = cv2.threshold(im_gray, thresh_binary, thresh_otsu, cv2.THRESH_BINARY | cv2.THRESH_OTSU)

    return im_bw

def get_bbox(mask, thresh_binary=127, thresh_otsu=255):
    im_bw = to_bw(mask, thresh_binary, thresh_otsu)

    # im2, contours, hierarchy = cv2.findContours(im_bw,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
    contours, hierarchy = cv2.findContours(im_bw,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)

    locations = np.array([], dtype=np.int).reshape(0, 5)

    for c in contours:
        # calculate moments for each contour
        M = cv2.moments(c)
        if M["m00"] > 0:
            cX = int(M["m10"] / M["m00"])
        else:
            cX = 0
        if M["m00"] > 0:
            cY = int(M["m01"] / M["m00"])
        else:
            cY = 0

        # calculate the rectangle bounding box
        x,y,w,h = cv2.boundingRect(c)
        locations = np.concatenate((locations, np.array([[cX, cY, w, h, w + h]])), axis=0)

    max_idex = locations[:,4].argmax()
    bbox = locations[max_idex, 0:4].reshape(4)
    return bbox

def extract_face(image, bbox, scale = 1.0):
    h, w, d = image.shape
    radius = int(bbox[3] * scale / 2)

    y_1 = bbox[1] - radius
    y_2 = bbox[1] + radius
    x_1 = bbox[0] - radius
    x_2 = bbox[0] + radius

    if x_1 < 0:
        x_1 = 0
    if y_1 < 0:
        y_1 = 0
    if x_2 > w:
        x_2 = w
    if y_2 > h:
        y_2 = h

    crop_img = image[y_1:y_2, x_1:x_2]

    if crop_img is not None:
        crop_img = cv2.resize(crop_img, (image_size, image_size))

    return crop_img

def extract_face_videos(input_real, input_fake, input_mask, output_real, output_fake):

    blank_img = np.zeros((image_size,image_size,3), np.uint8)

    for f in os.listdir(input_fake):
        if os.path.isfile(os.path.join(input_fake, f)):
            if f.lower().endswith(('mp4')):
                print(f)
                filename = os.path.splitext(f)[0]

                vidcap_fake = cv2.VideoCapture(os.path.join(input_fake, f))
                success_fake, image_fake = vidcap_fake.read()
                print(success_fake)

                image_mask = cv2.imread(os.path.join(input_mask, filename, '0000.png'))

                count = 0

                while (success_fake):

                    bbox = get_bbox(image_mask)

                    altered_cropped = extract_face(image_fake, bbox, scale)

                    mask_cropped = to_bw(extract_face(image_mask, bbox, scale))
                    mask_cropped = np.stack((mask_cropped,mask_cropped, mask_cropped), axis=2)

                    altered_cropped = np.concatenate((altered_cropped, mask_cropped), axis=1)

                    cv2.imwrite(os.path.join(output_fake, filename + "_%d.jpg" % count), altered_cropped)

                    count += 1

                    if count >= limit:
                        break

                    print("still")

                    success_fake, image_fake = vidcap_fake.read()
                    image_mask = cv2.imread(os.path.join(input_mask, filename, str(count).zfill(4) + '.png'))

extract_face_videos(input_real, input_fake, mask, output_real, output_fake)

id0_id16_0005.mp4
True


error: ignored

## 6. Creating datasets for training, validation, and test

In [0]:
class Normalize_3D(object):
        def __init__(self, mean, std):
            self.mean = mean
            self.std = std

        def __call__(self, tensor):
            """
                Tensor: Normalized image.
            Args:
                tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
            Returns:        """
            for t, m, s in zip(tensor, self.mean, self.std):
                t.sub_(m).div_(s)
            return tensor

class UnNormalize_3D(object):
        def __init__(self, mean, std):
            self.mean = mean
            self.std = std

        def __call__(self, tensor):
            """
                Tensor: Normalized image.
            Args:
                tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
            Returns:        """
            for t, m, s in zip(tensor, self.mean, self.std):
                t.mul_(s).add_(m)
            return tensor

transform_tns = transforms.Compose([
        transforms.ToTensor(),
    ])

transform_pil = transforms.Compose([
        transforms.ToPILImage(),
    ])
    
transform_norm = Normalize_3D((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

transform_unnorm = UnNormalize_3D((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

dataset_train = dset.ImageFolder(root=os.path.join(dataset, train_set), transform=transform_tns)
assert dataset_train
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=int(workers))

dataset_val = dset.ImageFolder(root=os.path.join(dataset, val_set), transform=transform_tns)
assert dataset_val
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=int(workers))

dataset_test = dset.ImageFolder(root=os.path.join(dataset, test_set), transform=transform_tns)
assert dataset_test
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=int(workers))

FileNotFoundError: ignored

## 7. Start training phase
- Save the fitted model at the end

In [0]:
for epoch in range(1, 101):
        count = 0
        loss_act_train = 0.0
        loss_seg_train = 0.0
        loss_rect_train = 0.0
        loss_act_test = 0.0
        loss_seg_test = 0.0
        loss_rect_test = 0.0

        tol_label = np.array([], dtype=np.float)
        tol_pred = np.array([], dtype=np.float)

        for fft_data, labels_data in tqdm(dataloader_train):

            optimizer_encoder.zero_grad()
            optimizer_decoder.zero_grad()

            fft_label = labels_data.numpy().astype(np.float)
            labels_data = labels_data.float()

            rgb = transform_norm(fft_data[:,:,:,0:256])
            mask = fft_data[:,0,:,256:512]
            mask[mask >= 0.5] = 1.0
            mask[mask < 0.5] = 0.0
            mask = mask.long()

            if gpu_id >= 0:
              rgb = rgb.cuda(gpu_id)
              mask = mask.cuda(gpu_id)
              labels_data = labels_data.cuda(gpu_id)

            latent = encoder(rgb).reshape(-1, 2, 64, 16, 16)

            zero_abs = torch.abs(latent[:,0]).view(latent.shape[0], -1)
            zero = zero_abs.mean(dim=1)

            one_abs = torch.abs(latent[:,1]).view(latent.shape[0], -1)
            one = one_abs.mean(dim=1)

            loss_act = act_loss_fn(zero, one, labels_data)
            loss_act_data = loss_act.item()

            y = torch.eye(2)
            if gpu_id >= 0:
              y = y.cuda(gpu_id)

            y = y.index_select(dim=0, index=labels_data.data.long())

            latent = (latent * y[:,:,None, None, None]).reshape(-1, 128, 16, 16)

            seg, rect = decoder(latent)

            loss_seg = seg_loss_fn(seg, mask)
            loss_seg = loss_seg * gamma
            loss_seg_data = loss_seg.item()

            loss_rect = rect_loss_fn(rect, rgb)
            loss_rect = loss_rect * gamma
            loss_rect_data = loss_rect.item()

            loss_total = loss_act + loss_seg + loss_rect
            loss_total.backward()

            optimizer_decoder.step()
            optimizer_encoder.step()

            output_pred = np.zeros((fft_data.shape[0]), dtype=np.float)

            for i in range(fft_data.shape[0]):
                if one[i] >= zero[i]:
                    output_pred[i] = 1.0
                else:
                    output_pred[i] = 0.0

            tol_label = np.concatenate((tol_label, fft_label))
            tol_pred = np.concatenate((tol_pred, output_pred))

            loss_act_train += loss_act_data
            loss_seg_train += loss_seg_data
            loss_rect_train += loss_rect_data
            count += 1

        acc_train = metrics.accuracy_score(tol_label, tol_pred)
        loss_act_train /= count
        loss_seg_train /= count
        loss_rect_train /= count

        ########################################################################
        # do checkpointing & validation

        torch.save(encoder.state_dict(), os.path.join(os.getcwd(), 'encoder_%d.pt' % epoch))
        torch.save(optimizer_encoder.state_dict(), os.path.join(os.getcwd(), 'optim_encoder_%d.pt' % epoch))

        torch.save(decoder.state_dict(), os.path.join(os.getcwd(), 'decoder_%d.pt' % epoch))
        torch.save(optimizer_decoder.state_dict(), os.path.join(os.getcwd(), 'optim_decoder_%d.pt' % epoch))

        encoder.eval()
        decoder.eval()

        tol_label = np.array([], dtype=np.float)
        tol_pred = np.array([], dtype=np.float)
        tol_pred_prob = np.array([], dtype=np.float)

        count = 0

        for fft_data, labels_data in tqdm(dataloader_val):

            fft_label = labels_data.numpy().astype(np.float)
            labels_data = labels_data.float()

            rgb = transform_norm(fft_data[:,:,:,0:256])
            mask = fft_data[:,0,:,256:512]
            mask[mask >= 0.5] = 1.0
            mask[mask < 0.5] = 0.0
            mask = mask.long()

            if gpu_id >= 0:
                rgb = rgb.cuda(gpu_id)
                mask = mask.cuda(gpu_id)
                labels_data = labels_data.cuda(gpu_id)

            latent = encoder(rgb).reshape(-1, 2, 64, 16, 16)

            zero_abs = torch.abs(latent[:,0]).view(latent.shape[0], -1)
            zero = zero_abs.mean(dim=1)

            one_abs = torch.abs(latent[:,1]).view(latent.shape[0], -1)
            one = one_abs.mean(dim=1)

            loss_act = act_loss_fn(zero, one, labels_data)
            loss_act_data = loss_act.item()

            y = torch.eye(2)
            if gpu_id >= 0:
                y = y.cuda(gpu_id)

            y = y.index_select(dim=0, index=labels_data.data.long())

            latent = (latent * y[:,:,None, None, None]).reshape(-1, 128, 16, 16)

            seg, rect = decoder(latent)

            loss_seg = seg_loss_fn(seg, mask)
            loss_seg = loss_seg * gamma
            loss_seg_data = loss_seg.item()

            loss_rect = rect_loss_fn(rect, rgb)
            loss_rect = loss_rect * gamma
            loss_rect_data = loss_rect.item()

            output_pred = np.zeros((fft_data.shape[0]), dtype=np.float)

            for i in range(fft_data.shape[0]):
                if one[i] >= zero[i]:
                    output_pred[i] = 1.0
                else:
                    output_pred[i] = 0.0

            tol_label = np.concatenate((tol_label, fft_label))
            tol_pred = np.concatenate((tol_pred, output_pred))
            
            pred_prob = torch.softmax(torch.cat((zero.reshape(zero.shape[0],1), one.reshape(one.shape[0],1)), dim=1), dim=1)
            tol_pred_prob = np.concatenate((tol_pred_prob, pred_prob[:,1].data.cpu().numpy()))

            loss_act_test += loss_act_data
            loss_seg_test += loss_seg_data
            loss_rect_test += loss_rect_data
            count += 1

        acc_test = metrics.accuracy_score(tol_label, tol_pred)
        loss_act_test /= count
        loss_seg_test /= count
        loss_rect_test /= count

        fpr, tpr, thresholds = roc_curve(tol_label, tol_pred_prob, pos_label=1)
        eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)

        print('[Epoch %d] Train: act_loss: %.4f  seg_loss: %.4f  rect_loss: %.4f  acc: %.2f | Test: act_loss: %.4f  seg_loss: %.4f  rect_loss: %.4f  acc: %.2f  eer: %.2f'
        % (epoch, loss_act_train, loss_seg_train, loss_rect_train, acc_train*100, loss_act_test, loss_seg_test, loss_rect_test, acc_test*100, eer*100))

        text_writer.write('%d,%.4f,%.4f,%.4f,%.2f,%.4f,%.4f,%.4f,%.2f,%.2f\n'
        % (epoch, loss_act_train, loss_seg_train, loss_rect_train, acc_train*100, loss_act_test, loss_seg_test, loss_rect_test, acc_test*100, eer*100))

        text_writer_train.flush()

        ########################################################################

        real_img = transform_tns(Image.open(os.path.join('test_img', 'real.jpg'))).unsqueeze(0)[:,:,:,0:256]
        real_mask = transform_tns(Image.open(os.path.join('test_img', 'real.jpg'))).unsqueeze(0)[:,:,:,256:512]
        fake_img = transform_tns(Image.open(os.path.join('test_img', 'fake.jpg'))).unsqueeze(0)[:,:,:,0:256]
        fake_mask = transform_tns(Image.open(os.path.join('test_img', 'fake.jpg'))).unsqueeze(0)[:,:,:,256:512]

        rgb = torch.cat((real_img, fake_img), dim=0)
        rgb = transform_norm(rgb)

        real_mask[real_mask >= 0.5] = 1.0
        real_mask[real_mask < 0.5] = 0.0
        real_mask = real_mask.long()

        fake_mask[fake_mask >= 0.5] = 1.0
        fake_mask[fake_mask < 0.5] = 0.0
        fake_mask = fake_mask.long()

        # real = 1, fake = 0
        labels_data = torch.FloatTensor([1,0])

        if gpu_id >= 0:
          rgb = rgb.cuda(gpu_id)
          labels_data = labels_data.cuda(gpu_id)

        latent = encoder(rgb).reshape(-1, 2, 64, 16, 16)

        zero_abs = torch.abs(latent[:,0]).view(latent.shape[0], -1)
        zero = zero_abs.mean(dim=1)

        one_abs = torch.abs(latent[:,1]).view(latent.shape[0], -1)
        one = one_abs.mean(dim=1)

        y = torch.eye(2)

        if gpu_id >= 0:
          y = y.cuda(gpu_id)

        y = y.index_select(dim=0, index=labels_data.data.long())

        latent = (latent * y[:,:,None, None, None]).reshape(-1, 128, 16, 16)

        seg, rect = decoder(latent)

        seg = seg[:,1,:,:].detach().cpu()
        seg[seg >= 0.5] = 1.0
        seg[seg < 0.5] = 0.0

        rect = transform_unnorm(rect).detach().cpu()

        real_seg = transform_pil(seg[0])
        fake_seg = transform_pil(seg[1])

        real_img = transform_pil(rect[0])
        fake_img = transform_pil(rect[1])

        real_seg.save(os.path.join(os.getcwd(), 'image', 'seg_real_' + str(epoch).zfill(3) + '.jpg'))
        fake_seg.save(os.path.join(os.getcwd(), 'image', 'seg_fake_' + str(epoch).zfill(3) + '.jpg'))

        real_img.save(os.path.join(os.getcwd(), 'image', 'real_' + str(epoch).zfill(3) + '.jpg'))
        fake_img.save(os.path.join(os.getcwd(), 'image', 'fake_' + str(epoch).zfill(3) + '.jpg'))

        encoder.train(mode=True)
        decoder.train(mode=True)

text_writer_train.close()

NameError: ignored

## 8. Start evaluation phase

In [0]:
encoder.load_state_dict(torch.load(os.path.join(os.getcwd(),'encoder_' + str(46) + '.pt')))
encoder.eval()

loss_act_test = 0.0

tol_label = np.array([], dtype=np.float)
tol_pred = np.array([], dtype=np.float)
tol_pred_prob = np.array([], dtype=np.float)

count = 0

for fft_data, labels_data in tqdm(dataloader_test):

        fft_label = labels_data.numpy().astype(np.float)
        labels_data = labels_data.float()

        rgb = transform_norm(fft_data[:,:,:,0:256])

        if gpu_id >= 0:
          rgb = rgb.cuda(gpu_id)
          labels_data = labels_data.cuda(gpu_id)

        latent = encoder(rgb).reshape(-1, 2, 64, 16, 16)

        zero_abs = torch.abs(latent[:,0]).view(latent.shape[0], -1)
        zero = zero_abs.mean(dim=1)

        one_abs = torch.abs(latent[:,1]).view(latent.shape[0], -1)
        one = one_abs.mean(dim=1)

        loss_act = act_loss_fn(zero, one, labels_data)
        loss_act_data = loss_act.item()

        output_pred = np.zeros((fft_data.shape[0]), dtype=np.float)

        for i in range(fft_data.shape[0]):
            if one[i] >= zero[i]:
                output_pred[i] = 1.0
            else:
                output_pred[i] = 0.0

        tol_label = np.concatenate((tol_label, fft_label))
        tol_pred = np.concatenate((tol_pred, output_pred))
        
        pred_prob = torch.softmax(torch.cat((zero.reshape(zero.shape[0],1), one.reshape(one.shape[0],1)), dim=1), dim=1)
        tol_pred_prob = np.concatenate((tol_pred_prob, pred_prob[:,1].data.cpu().numpy()))

        loss_act_test += loss_act_data
        count += 1

acc_test = metrics.accuracy_score(tol_label, tol_pred)
loss_act_test /= count

fpr, tpr, thresholds = roc_curve(tol_label, tol_pred_prob, pos_label=1)
eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)

print('[Epoch %d] act_loss: %.4f  acc: %.2f   eer: %.2f' % (opt.id, loss_act_test, acc_test*100, eer*100))
text_writer_test.write('%d,%.4f,%.2f,%.2f\n'% (46, loss_act_test, acc_test*100, eer*100))

text_writer_test.flush()
text_writer_test.close()

NameError: ignored

# Train DSP-FWA model

**Public repository:** 
https://github.com/danmohaha/DSP-FWA 

**Citation:**
@inproceedings{li2019exposing,
  title={Exposing DeepFake Videos By Detecting Face Warping Artifacts},
  author={Li, Yuezun and Lyu, Siwei},
  booktitle={IEEE Conference on Computer Vision and Pattern Recognition Workshops (CVPRW)},
  year={2019}
}

## 1. Import dependencies

In [0]:
import torch
from torch import nn
from torchvision import models
import torch.nn.functional as F
import cv2, os, math, dlib
import numpy as np

## 2. DSP-FWA Network architecture

In [0]:
class ResNet(nn.Module):
    def __init__(self, layers=18, num_class=2, pretrained=True):
        super(ResNet, self).__init__()
        if layers == 18:
            self.resnet = models.resnet18(pretrained=pretrained)
        elif layers == 34:
            self.resnet = models.resnet34(pretrained=pretrained)
        elif layers == 50:
            self.resnet = models.resnet50(pretrained=pretrained)
        elif layers == 101:
            self.resnet = models.resnet101(pretrained=pretrained)
        elif layers == 152:
            self.resnet = models.resnet152(pretrained=pretrained)
        else:
            raise ValueError('layers should be 18, 34, 50, 101.')
        self.num_class = num_class
        if layers in [18, 34]:
            self.fc = nn.Linear(512, num_class)
        if layers in [50, 101, 152]:
            self.fc = nn.Linear(512 * 4, num_class)

    def conv_base(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        layer1 = self.resnet.layer1(x)
        layer2 = self.resnet.layer2(layer1)
        layer3 = self.resnet.layer3(layer2)
        layer4 = self.resnet.layer4(layer3)
        return layer1, layer2, layer3, layer4

    def forward(self, x):
        layer1, layer2, layer3, layer4 = self.conv_base(x)
        x = self.resnet.avgpool(layer4)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x



class SPPNet(nn.Module):
    def __init__(self, backbone=101, num_class=2, pool_size=(1, 2, 6), pretrained=True):
        # Only resnet is supported in this version
        super(SPPNet, self).__init__()
        if backbone in [18, 34, 50, 101, 152]:
            self.resnet = ResNet(backbone, num_class, pretrained)
        else:
            raise ValueError('Resnet{} is not supported yet.'.format(backbone))

        if backbone in [18, 34]:
            self.c = 512
        if backbone in [50, 101, 152]:
            self.c = 2048

        self.spp = SpatialPyramidPool2D(out_side=pool_size)
        num_features = self.c * (pool_size[0] ** 2 + pool_size[1] ** 2 + pool_size[2] ** 2)
        self.classifier = nn.Linear(num_features, num_class)

    def forward(self, x):
        _, _, _, x = self.resnet.conv_base(x)
        x = self.spp(x)
        x = self.classifier(x)
        return x


class SpatialPyramidPool2D(nn.Module):
    """
    Args:
        out_side (tuple): Length of side in the pooling results of each pyramid layer.

    Inputs:
        - `input`: the input Tensor to invert ([batch, channel, width, height])
    """

    def __init__(self, out_side):
        super(SpatialPyramidPool2D, self).__init__()
        self.out_side = out_side

    def forward(self, x):
        # batch_size, c, h, w = x.size()
        out = None
        for n in self.out_side:
            w_r, h_r = map(lambda s: math.ceil(s / n), x.size()[2:])  # Receptive Field Size
            s_w, s_h = map(lambda s: math.floor(s / n), x.size()[2:])  # Stride
            max_pool = nn.MaxPool2d(kernel_size=(w_r, h_r), stride=(s_w, s_h))
            y = max_pool(x)
            if out is None:
                out = y.view(y.size()[0], -1)
            else:
                out = torch.cat((out, y.view(y.size()[0], -1)), 1)
        return out

## 3. Default settings

In [0]:
ckpt_name = 'SPP-res50.pth'
model_path = os.path.join(os.getcwd(), ckpt_name)
f_path = os.getcwd() + '/celeb-df/validation/1_fake_videos/id0_id16_0005.mp4'
print('Testing: ' + f_path)
suffix = f_path.split('.')[-1]
num_class = 2
layers = 50

Testing: /content/celeb-df/validation/1_fake_videos/id0_id16_0005.mp4


## Start pre-processing phase

In [0]:
class pv():
    def crop_video(pathIn, pathOut, pos, size):
        """
        Crop video
        :param pathIn:
        :param pathOut:
        :param pos: (left, top, right, bottom)
        :return:
        """

        imgs, frame_num, fps, width, height = parse_vid(pathIn)

        for i, image in enumerate(imgs):
            y1 = np.int32(pos[0])
            x1 = np.int32(pos[1])
            y2 = np.int32(pos[2])
            x2 = np.int32(pos[3])
            roi = image[y1:y2, x1:x2, :]
            if size is not 'None':
                roi = cv2.resize(roi, (size[1], size[0]))
            imgs[i] = roi

        gen_vid(pathOut, imgs, fps, width, height)


    def get_video_dims(video_path):
        vidcap = cv2.VideoCapture(video_path)
        width = np.int32(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))  # float
        height = np.int32(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))  # float
        vidcap.release()
        return width, height


    def get_video_frame_nums(video_path):
        vidcap = cv2.VideoCapture(video_path)
        frame_num = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
        vidcap.release()
        return frame_num


    def get_video_fps(video_path):
        vidcap = cv2.VideoCapture(video_path)
        fps = vidcap.get(cv2.CAP_PROP_FPS)
        vidcap.release()
        return fps


    def parse_vid(video_path):
        vidcap = cv2.VideoCapture(video_path)
        frame_num = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = vidcap.get(cv2.CAP_PROP_FPS)
        width = np.int32(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)) # float
        height = np.int32(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))  # float
        imgs = []
        while True:
            success, image = vidcap.read()
            if success:
                imgs.append(image)
            else:
                break

        vidcap.release()
        if len(imgs) != frame_num:
            frame_num = len(imgs)
        return imgs, frame_num, fps, width, height


    def parse_vid_into_imgs(video_path, folder, im_name='{:05d}.jpg'):
        imgs, frame_num, fps, width, height = parse_vid(video_path)
        for id, im in enumerate(imgs):
            im_name = im_name.format(id)
            cv2.imwrite(folder + '/' + im_name, im)
        print('Save original images to folder {}'.format(folder))


    def gen_vid(video_path, imgs, fps, width=None, height=None):
        # Combine video
        ext = Path(video_path).suffix
        if ext == '.mp4':
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Be sure to use lower case
        elif ext == '.avi':
            fourcc = cv2.VideoWriter_fourcc(*'MJPG')  #*'XVID')
        else:
            # if not .mp4 or avi, we force it to mp4
            video_path = video_path.replace(ext, '.mp4')
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Be sure to use lower case
        if width is None or height is None:
            height, width= imgs[0].shape[:2]
        else:
            imgs_ = [cv2.resize(img, (width, height)) for img in imgs]
            imgs = imgs_

        out = cv2.VideoWriter(video_path, fourcc, fps, (np.int32(width), np.int32(height)))

        for image in imgs:
            out.write(np.uint8(image))  # Write out frame to video

        # Release everything if job is finished
        out.release()
        print('The output video is ' + video_path)


    def gen_vid_from_folder(video_path, img_dir, ext, fps, width=None, height=None):
        imgs_path = sorted(Path(img_dir).glob('*' + ext))
        imgs = [cv2.imread(str(p)) for p in imgs_path]
        gen_vid(video_path, imgs, fps, width, height)


    def resize_video(video_path, w=None, h=None, scale=1., out_path=None):
        imgs, frame_num, fps, width, height = parse_vid(video_path)
        # Resize imgs
        if w is None or h is None:
            width, height = int(width * scale), int(height * scale)
            for i, im in enumerate(imgs):
                im = cv2.resize(im, None, None, fx=scale, fy=scale)
                imgs[i] = im
        else:
            width, height = w, h
            for i, im in enumerate(imgs):
                im = cv2.resize(im, (w, h))
                imgs[i] = im
        if out_path:
            gen_vid(out_path, imgs, fps, width, height)
        return imgs, frame_num, fps, width, height


    def extract_key_frames(video_path, len_window=50):
        """
        The frames which the average interframe difference are local maximum are
        considered to be key frames.
        It should be noted that smoothing the average difference value before
        calculating the local maximum can effectively remove noise to avoid
        repeated extraction of frames of similar scenes.
        """
        imgs, frame_num, fps, width, height = parse_vid(video_path)
        frame_diffs = []
        for i in range(1, len(imgs)):
            curr_frame = cv2.cvtColor(imgs[i], cv2.COLOR_BGR2LUV)
            prev_frame = cv2.cvtColor(imgs[i - 1], cv2.COLOR_BGR2LUV)
            # logic here
            diff = cv2.absdiff(curr_frame, prev_frame)
            diff_sum = np.sum(diff)
            diff_sum_mean = diff_sum / (diff.shape[0] * diff.shape[1])
            frame_diffs.append(diff_sum_mean)
        from scipy.signal import argrelextrema
        # compute keyframe
        diff_array = np.array(frame_diffs)

        def smooth(x, window_len=13, window='hanning'):
            s = np.r_[2 * x[0] - x[window_len:1:-1],
                      x, 2 * x[-1] - x[-1:-window_len:-1]]
            if window == 'flat':  # moving average
                w = np.ones(window_len, 'd')
            else:
                w = getattr(np, window)(window_len)
            y = np.convolve(w / w.sum(), s, mode='same')
            return y[window_len - 1:-window_len + 1]

        sm_diff_array = smooth(diff_array, len_window)
        frame_indexes = np.asarray(argrelextrema(sm_diff_array, np.greater))[0]
        key_frames = []
        for i in frame_indexes:
            key_frames.append(imgs[i])
        return key_frames

## Start training phase

In [0]:
net = SPPNet(backbone=layers, num_class=num_class)
net = net.cuda()
net.eval()

if os.path.isfile(model_path):
  print("=> loading checkpoint '{}'".format(model_path))
  checkpoint = torch.load(model_path)
  start_epoch = checkpoint['epoch']
  net.load_state_dict(checkpoint['net'])
  print("=> loaded checkpoint '{}' (epoch {})".format(model_path, start_epoch))
else:
  raise ValueError("=> no checkpoint found at '{}'".format(model_path))

def im_test(net, im, args):
    face_info = lib.align(im[:, :, (2,1,0)], front_face_detector, lmark_predictor)
    # Samples
    if len(face_info) != 1:
        prob = -1
    else:
        _, point = face_info[0]
        rois = []
        for i in range(sample_num):
            roi, _ = lib.cut_head([im], point, i)
            rois.append(cv2.resize(roi[0], (args.input_size, args.input_size)))

        # vis_ = np.concatenate(rois, 1)
        # cv2.imwrite('vis.jpg', vis_)

        bgr_mean = np.array([103.939, 116.779, 123.68])
        bgr_mean = bgr_mean[np.newaxis, :, np.newaxis, np.newaxis]
        bgr_mean = torch.from_numpy(bgr_mean).float().cuda()

        rois = torch.from_numpy(np.array(rois)).float().cuda()
        rois = rois.permute((0, 3, 1, 2))
        prob = net(rois - bgr_mean)
        prob = F.softmax(prob, dim=1)
        prob = prob.data.cpu().numpy()
        prob = 1 - np.mean(np.sort(prob[:, 0])[np.round(sample_num / 2).astype(int):])
    return prob, face_info

# Parse video
imgs, frame_num, fps, width, height = pv.parse_vid(f_path)
probs = []
for fid, im in enumerate(imgs):
  print('Frame: ' + str(fid))
  prob, face_info = im_test(net, im, args)
  probs.append(prob)
print(probs)

=> loading checkpoint '/content/SPP-res50.pth'
=> loaded checkpoint '/content/SPP-res50.pth' (epoch 49)
Frame: 0


NameError: ignored

## Start evaluation phase

# Train Ictu Oculi model

**Public repository:** 
https://github.com/danmohaha/WIFS2018_In_Ictu_Oculi 

**Citation:**
@inproceedings{li2018ictu,
  title={In Ictu Oculi: Exposing AI Generated Fake Face Videos by Detecting Eye Blinking},
  author={Li, Yuezun and Chang, Ming-Ching and Lyu, Siwei},
  Booktitle={IEEE International Workshop on Information Forensics and Security (WIFS)},
  year={2018}
}

## 1. Import dependencies

In [0]:
import tensorflow as tf
import os

## 2. Ictu Oculi Network architecture

In [0]:
# =============================
# VGG16 network structure
# =============================

class vgg16():
    def get_vgg16_conv5(input, params):
        layers = edict()

        layers.conv1_1 = ops.conv2D(input=input, shape=(3, 3, 64), name='conv1_1', params=params)
        layers.conv1_1_relu = ops.activate(input=layers.conv1_1, name='conv1_1_relu', act_type='relu')
        layers.conv1_2 = ops.conv2D(input=layers.conv1_1_relu, shape=(3, 3, 64), name='conv1_2', params=params)
        layers.conv1_2_relu = ops.activate(input=layers.conv1_2, name='conv1_2_relu', act_type='relu')
        layers.pool1 = ops.max_pool(input=layers.conv1_2_relu, name='pool1')

        layers.conv2_1 = ops.conv2D(input=layers.pool1, shape=(3, 3, 128), name='conv2_1', params=params)
        layers.conv2_1_relu = ops.activate(input=layers.conv2_1, name='conv2_1_relu', act_type='relu')
        layers.conv2_2 = ops.conv2D(input=layers.conv2_1_relu, shape=(3, 3, 128), name='conv2_2', params=params)
        layers.conv2_2_relu = ops.activate(input=layers.conv2_2, name='conv2_2_relu', act_type='relu')
        layers.pool2 = ops.max_pool(input=layers.conv2_2_relu, name='pool2')

        layers.conv3_1 = ops.conv2D(input=layers.pool2, shape=(3, 3, 256), name='conv3_1', params=params)
        layers.conv3_1_relu = ops.activate(input=layers.conv3_1, name='conv3_1_relu', act_type='relu')
        layers.conv3_2 = ops.conv2D(input=layers.conv3_1_relu, shape=(3, 3, 256), name='conv3_2', params=params)
        layers.conv3_2_relu = ops.activate(input=layers.conv3_2, name='conv3_2_relu', act_type='relu')
        layers.conv3_3 = ops.conv2D(input=layers.conv3_2_relu, shape=(3, 3, 256), name='conv3_3', params=params)
        layers.conv3_3_relu = ops.activate(input=layers.conv3_3, name='conv3_3_relu', act_type='relu')
        layers.pool3 = ops.max_pool(input=layers.conv3_3_relu, name='pool3')

        layers.conv4_1 = ops.conv2D(input=layers.pool3, shape=(3, 3, 512), name='conv4_1', params=params)
        layers.conv4_1_relu = ops.activate(input=layers.conv4_1, name='conv4_1_relu', act_type='relu')
        layers.conv4_2 = ops.conv2D(input=layers.conv4_1_relu, shape=(3, 3, 512), name='conv4_2', params=params)
        layers.conv4_2_relu = ops.activate(input=layers.conv4_2, name='conv4_2_relu', act_type='relu')
        layers.conv4_3 = ops.conv2D(input=layers.conv4_2_relu, shape=(3, 3, 512), name='conv4_3', params=params)
        layers.conv4_3_relu = ops.activate(input=layers.conv4_3, name='conv4_3_relu', act_type='relu')
        layers.pool4 = ops.max_pool(input=layers.conv4_3_relu, name='pool4')

        layers.conv5_1 = ops.conv2D(input=layers.pool4, shape=(3, 3, 512), name='conv5_1', params=params)
        layers.conv5_1_relu = ops.activate(input=layers.conv5_1, name='conv5_1_relu', act_type='relu')
        layers.conv5_2 = ops.conv2D(input=layers.conv5_1_relu, shape=(3, 3, 512), name='conv5_2', params=params)
        layers.conv5_2_relu = ops.activate(input=layers.conv5_2, name='conv5_2_relu', act_type='relu')
        layers.conv5_3 = ops.conv2D(input=layers.conv5_2_relu, shape=(3, 3, 512), name='conv5_3', params=params)
        layers.conv5_3_relu = ops.activate(input=layers.conv5_3, name='conv5_3_relu', act_type='relu')

        return layers

    def get_vgg16_pool5(input, params):
        layers = get_vgg16_conv5(input, params)
        layers.pool5 = ops.max_pool(input=layers.conv5_3_relu, name='pool5')

        return layers

    def get_prob(input, params, num_class=1000, is_train=True):
        # Get pool5
        layers = get_vgg16_pool5(input, params)
        layers.fc6 = ops.fully_connected(input=layers.pool5, num_neuron=4096, name='fc6', params=params)
        if is_train:
            layers.fc6 = tf.nn.dropout(layers.fc6, keep_prob=0.5)
        layers.fc6_relu = ops.activate(input=layers.fc6, act_type='relu', name='fc6_relu')
        layers.fc7 = ops.fully_connected(input=layers.fc6_relu, num_neuron=4096, name='fc7', params=params)
        if is_train:
            layers.fc7 = tf.nn.dropout(layers.fc7, keep_prob=0.5)
        layers.fc7_relu = ops.activate(input=layers.fc7, act_type='relu', name='fc7_relu')
        layers.fc8 = ops.fully_connected(input=layers.fc7_relu, num_neuron=num_class, name='fc8', params=params)
        layers.prob = tf.nn.softmax(layers.fc8)
        return layers

class BlinkCNN(object):
    """
    CNN for eye blinking detection
    """

    def __init__(self,
                 is_train
                 ):

        self.img_size = [224, 224, 3]
        self.num_classes = 2
        self.is_train = is_train

        self.layers = {}
        self.params = {}

        base = vgg16()

    def build(self):
        # Input
        self.input = tf.placeholder(dtype=tf.float32, shape=[None, self.img_size[0], self.img_size[1], self.img_size[2]])
        self.layers = base.get_prob(self.input, self.params, self.num_classes, self.is_train)
        self.prob = self.layers.prob
        self.gt = tf.placeholder(dtype=tf.int32, shape=[None])
        self.var_list = tf.trainable_variables()

    def loss(self):
        self.net_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.gt, logits=self.layers.fc8)
        self.net_loss = tf.reduce_mean(self.net_loss)
        tf.losses.add_loss(self.net_loss)
        # L2 weight regularize
        self.L2_loss = tf.reduce_mean([self.cfg.TRAIN.BETA * tf.nn.l2_loss(v)
                     for v in tf.trainable_variables() if 'weights' in v.name])
        tf.losses.add_loss(self.L2_loss)
        self.total_loss = tf.losses.get_total_loss()

class BlinkLRCN(object):
    """
    LRCN for eye blinking detection
    """

    def __init__(self,
                 is_train
                 ):

        cfg_file = os.path.join(pwd, 'blink_lrcn.yml')
        with open(cfg_file, 'r') as f:
            cfg = edict(yaml.load(f))

        self.cfg = cfg
        self.img_size = cfg.IMG_SIZE
        self.num_classes = cfg.NUM_CLASS
        self.is_train = is_train

        self.rnn_type = cfg.RNN_TYPE
        self.max_time = cfg.MAX_TIME
        self.hidden_unit = cfg.HIDDEN_UNIT

        if self.is_train:
            self.batch_size = cfg.TRAIN.BATCH_SIZE
        else:
            self.batch_size = cfg.TEST.BATCH_SIZE
        self.layers = {}
        self.params = {}

    def build(self):
        self.input = tf.placeholder(dtype=tf.float32,
                                    shape=[self.batch_size, self.max_time, self.img_size[0], self.img_size[1], self.img_size[2]])
        self.blined_gt = tf.placeholder(dtype=tf.int32, shape=[self.batch_size])
        self.eye_state_gt = tf.placeholder(dtype=tf.int32, shape=[self.batch_size, self.max_time])
        self.seq_len = tf.placeholder(dtype=tf.int32, shape=[self.batch_size])

        self.vgg16_fc6 = self._vgg16(self.input)
        self.rnn_out = self._rnn_cell(self.vgg16_fc6)
        self.out = self._fc(self.rnn_out)
        self.prob = tf.nn.softmax(self.out, dim=-1)

    def _vgg16(self, input):
        # Reshape from NxTxHxWxC to (NxT)xHxWxC
        input = tf.reshape(input, [-1, self.img_size[0], self.img_size[1], self.img_size[2]])
        layers = base.get_vgg16_pool5(input, self.params)
        layers.fc6 = net_ops.fully_connected(input=layers.pool5, num_neuron=4096, name='fc6', params=self.params)
        if self.is_train:
            layers.fc6 = tf.nn.dropout(layers.fc6, keep_prob=0.5)
        layers.fc6_relu = net_ops.activate(input=layers.fc6, act_type='relu', name='fc6_relu')
        cnn_out = tf.reshape(layers.fc6_relu, [-1, self.max_time, 4096])
        return cnn_out

    def _rnn_cell(self, input):
        with tf.variable_scope('rnn_cell'):
            size = np.prod(input.get_shape().as_list()[2:])
            rnn_inputs = tf.reshape(input, (-1, self.max_time, size))
            if self.rnn_type == 'LSTM':
                cell = tf.contrib.rnn.LSTMCell(self.hidden_unit)
            elif self.rnn_type == 'GRU':
                cell = tf.contrib.rnn.GRUCell(self.hidden_unit)
            else:
                raise ValueError('We only support LSTM or GRU...')
            rnn_outputs, _ = tf.nn.dynamic_rnn(
                cell,
                rnn_inputs,
                sequence_length=self.seq_len,
                dtype = tf.float32
            )
            return rnn_outputs

    def _avg_rnn_out(self, rnn_out):
        seq_len = tf.cast(self.seq_len, dtype=tf.float32)
        avg = tf.reduce_sum(rnn_out, axis=1) / tf.expand_dims(seq_len, axis=-1)
        return avg

    def _fc(self, input):
        # Reshape from NxTx256 to (NxT)x256
        input = tf.reshape(input, [-1, self.hidden_unit])
        out = net_ops.fully_connected(input=input, num_neuron=self.num_classes, name='fc_after_rnn', params=self.params)
        out = tf.reshape(out, [-1, self.max_time, self.num_classes])
        return out

    def loss(self):
        self.net_loss = []
        for batch_id in range(self.batch_size):
            out_cur = self.out[batch_id, :, :]
            eye_state_cur = self.eye_state_gt[batch_id, :]
            weights = tf.gather(tf.constant(self.cfg.TRAIN.CLASS_WEIGHTS, dtype=tf.float32), eye_state_cur)
            loss_per_batch = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=eye_state_cur, logits=out_cur)  # T x num_class
            loss_per_batch = loss_per_batch * weights
            # Select loss by real len
            seq_len = tf.cast(self.seq_len[batch_id], dtype=tf.float32)
            tf_idx = tf.range(0, self.seq_len[batch_id])
            loss_per_batch = tf.reduce_sum(tf.gather(loss_per_batch, tf_idx, axis=0)) / seq_len
            self.net_loss.append(loss_per_batch)
        self.net_loss = tf.reduce_mean(self.net_loss)
        tf.losses.add_loss(self.net_loss)
        # L2 weight regularize
        self.L2_loss = tf.reduce_mean([self.cfg.TRAIN.BETA * tf.nn.l2_loss(v)
                                       for v in tf.trainable_variables() if 'weights' in v.name or 'kernel' in v.name])
        tf.losses.add_loss(self.L2_loss)
        self.total_loss = tf.losses.get_total_loss()

## 3. Default settings

## Start training phase

In [0]:
with tf.Session() as sess:
        # Build network
        net = BlinkCNN(is_train=True)
        net.build()

        # Init solver
        solver = Solver(sess=sess, net=net)
        solver.init()

        # Eye state data generator
        data_gen = EyeData(
            anno_path='sample_eye_data/train.p',
            data_dir='sample_eye_data/',
            batch_size=net.cfg.TRAIN.BATCH_SIZE,
            is_augment=True,
            is_shuffle=True
        )

        print('Training...')
        # Training
        batch_num = data_gen.batch_num
        summary_idx = 0
        for epoch in range(solver.cfg.TRAIN.NUM_EPOCH):
            for i in range(batch_num):
                im_list, label_list, im_name_list \
                    = data_gen.get_batch(i, size=net.cfg.IMG_SIZE[:2])
                uvis.vis_im(im_list, 'tmp')
                _, summary, prob, net_loss = solver.train(im_list, label_list)
                solver.writer.add_summary(summary, summary_idx)
                summary_idx += 1
                pred_label = np.argmax(prob, axis=-1)
                print('====================================')
                print('Net loss: {}'.format(net_loss))
                print('Real label: {}'.format(label_list))
                print('Pred label: {}'.format(pred_label))
                print('Epoch: {}'.format(epoch))
            if epoch % solver.cfg.TRAIN.SAVE_INTERVAL == 0:
                solver.save(epoch)


## Start evaluating phase