In [6]:
import numpy as np
import os
import numpy as np
import random
import albumentations
import torch
from albumentations.pytorch import ToTensorV2
import torch.nn as nn
import pandas as pd
import cv2
from tqdm import tqdm

# Load class & utils

In [2]:
model_urls = {
    "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
    "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
    "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
    "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
    "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
}

def l2_norm(input, axis=1):
    norm = torch.norm(input, 2, axis, True)
    output = torch.div(input, norm)
    return output

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=1,
        bias=False
    )

class MixStyle(nn.Module):
    """Based on MixStyle.
    https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/modeling/ops/mixstyle.py
    Reference:
      Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
    """

    def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
        """
        Args:
          p (float): probability of using MixStyle.
          alpha (float): parameter of the Beta distribution.
          eps (float): scaling parameter to avoid numerical issues.
          mix (str): how to mix.
        """
        super().__init__()
        self.p = p
        self.beta = torch.distributions.Beta(alpha, alpha)
        self.eps = eps
        self.alpha = alpha
        self.mix = mix
        self._activated = True

    def __repr__(self):
        return (
            f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})"
        )

    def set_activation_status(self, status=True):
        self._activated = status

    def update_mix_method(self, mix="random"):
        self.mix = mix

    def forward(self, x, labels=None):
        if not self.training or not self._activated:
            return x

        if random.random() > self.p:
            return x

        B = x.size(0)

        mu = x.mean(dim=[2, 3], keepdim=True)
        var = x.var(dim=[2, 3], keepdim=True)
        sig = (var + self.eps).sqrt()
        mu, sig = mu.detach(), sig.detach()
        x_normed = (x-mu) / sig

        lmda = self.beta.sample((B, 1, 1, 1))
        lmda = lmda.to(x.device)

        if self.mix == "random":
            # random shuffle
            perm = torch.randperm(B)

        elif self.mix == "crossdomain":
            # split into two halves and swap the order
            perm = torch.arange(B - 1, -1, -1)  # inverse index
            perm_b, perm_a = perm.chunk(2)
            perm_b = perm_b[torch.randperm(perm_b.shape[0])]
            perm_a = perm_a[torch.randperm(perm_a.shape[0])]
            perm = torch.cat([perm_b, perm_a], 0)
        #######################
        #        Added
        #######################
        elif self.mix == "crosssample":
            assert labels != None, 'Label is None'
            contrast_3d = (labels.long()  == 0).nonzero(as_tuple=True)[0]  # find 3d mask attack
            contrast_bf = (labels.long() == 1).nonzero(as_tuple=True)[0] # find bonafide
            contrast_print = (labels.long() == 2).nonzero(as_tuple=True)[0] # find print attack
            contrast_cut = (labels.long() == 3).nonzero(as_tuple=True)[0] # find paper cut attack
            contrast_replay = (labels.long() == 4).nonzero(as_tuple=True)[0] # find replay attack

            perm_idx_3d = contrast_3d[torch.randperm(len(contrast_3d))]
            perm_idx_bf = contrast_bf[torch.randperm(len(contrast_bf))]
            perm_idx_print = contrast_print[torch.randperm(len(contrast_print))]
            perm_idx_cut = contrast_cut[torch.randperm(len(contrast_cut))]
            perm_idx_replay = contrast_replay[torch.randperm(len(contrast_replay))]

            old_idx = torch.cat([contrast_bf, contrast_3d, contrast_print, contrast_cut, contrast_replay], 0)
            perm = torch.cat([perm_idx_bf, perm_idx_3d, perm_idx_print, perm_idx_cut, perm_idx_replay], 0)
            perm = perm[torch.argsort(old_idx)]

        else:
            raise NotImplementedError

        mu2, sig2 = mu[perm], sig[perm]
        mu_mix = mu*lmda + mu2 * (1-lmda)
        sig_mix = sig*lmda + sig2 * (1-lmda)

        return x_normed*sig_mix + mu_mix

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes,
            planes,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(
            planes, planes * self.expansion, kernel_size=1, bias=False
        )
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Backbone(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self):
        pass

    @property
    def out_features(self):
        """Output feature dimension."""
        if self.__dict__.get("_out_features") is None:
            return None
        return self._out_features


class ResNet(Backbone):

    def __init__(
        self,
        block,
        layers,
        ms_class=None,
        ms_layers=[],
        ms_p=0.5,
        ms_a=0.1,
        mix="crosssample",
        **kwargs
    ):
        self.inplanes = 64
        super().__init__()

        # backbone network
        self.conv1 = nn.Conv2d(
            3, 64, kernel_size=7, stride=2, padding=3, bias=False
        )
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        #self.global_avgpool = nn.AdaptiveAvgPool2d(1)

        self._out_features = 512 * block.expansion

        self.mixstyle = None
        if ms_layers:
            self.mixstyle = ms_class(p=ms_p, alpha=ms_a, mix=mix)
            for layer_name in ms_layers:
                assert layer_name in ["layer1", "layer2", "layer3", "layer4"]
            print(
                f"Insert {self.mixstyle.__class__.__name__} after {ms_layers}"
            )
            print(f'Using {mix}')
        else:
            print('No MixStyle')
        self.ms_layers = ms_layers

        self._init_params()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.inplanes,
                    planes * block.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def _init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode="fan_out", nonlinearity="relu"
                )
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def featuremaps(self, x, labels=None):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        if "layer1" in self.ms_layers:
            x = self.mixstyle(x, labels)

        x = self.layer2(x)
        if "layer2" in self.ms_layers:
            x = self.mixstyle(x, labels)

        x = self.layer3(x)
        if "layer3" in self.ms_layers:
            x = self.mixstyle(x, labels)

        x = self.layer4(x)
        if "layer4" in self.ms_layers:
            x = self.mixstyle(x, labels)

        return x

    def forward(self, x, labels=None):
        f = self.featuremaps(x, labels)
        return f


def init_pretrained_weights(model, model_url):
    pretrain_dict = torch.hub.load_state_dict_from_url(model_url)
    model.load_state_dict(pretrain_dict, strict=False)

class Classifier(nn.Module):
    def __init__(self, in_channels=512, num_classes=2):
        super(Classifier, self).__init__()

        self.classifier_layer = nn.Linear(in_channels, num_classes)
        self.classifier_layer.weight.data.normal_(0, 0.01)
        self.classifier_layer.bias.data.fill_(0.0)

    def forward(self, input, norm_flag=False):
        if(norm_flag):
            self.classifier_layer.weight.data = l2_norm(self.classifier_layer.weight, axis=0)
            output = self.classifier_layer(input)
        else:
            output = self.classifier_layer(input)
        return output

class FeatExt_MixStyleResCausalModel(nn.Module):
    def __init__(self, model_name='resnet18', pretrained=False, num_classes=2, prob=0.2, ms_class=MixStyle, ms_layers=["layer1", "layer2"], mix="crosssample"):
        super(FeatExt_MixStyleResCausalModel, self).__init__()
        self.feature_extractor = ResNet(
            block=BasicBlock,
            layers=[2, 2, 2, 2],
            ms_class=ms_class,
            ms_layers=ms_layers, #["layer1", "layer2"], # "layer3", "layer4"
            mix=mix
        )
        if pretrained:
            print('-------load model----------')
            init_pretrained_weights(self.feature_extractor, model_urls[model_name])
        
        self.classifier = Classifier(in_channels=512, num_classes=num_classes)   # resnet
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)

    def forward(self, input, labels=None, cf=['cs', 'dropout', 'replace'], norm=True):
        # labels = (labels == 1) + 0 # for handle multiclass
        feature = self.feature_extractor(input, labels)  # Batch, 512, 7, 7

        cls_feature = self.avgpool(feature)
        cls_feature = cls_feature.view(cls_feature.size(0), -1)
        cls = self.classifier(cls_feature)

        if (not self.training) or cf is None or labels is None:
            return cls, cls_feature

# Process each image

In [None]:
PRE__MEAN = [0.485, 0.456, 0.406]
PRE__STD = [0.229, 0.224, 0.225]
INPUT__FACE__SIZE = 256
PADDING = 0

def preprocess_frame_pipe():
    return albumentations.Compose([
        albumentations.Resize(height=INPUT__FACE__SIZE, width=INPUT__FACE__SIZE),
        albumentations.Normalize(PRE__MEAN, PRE__STD, always_apply=True),
        ToTensorV2(),    
    ])

def preprocess_frame(frame):
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = preprocess_frame_pipe()(image=frame)['image']
    frame = torch.tensor(frame).unsqueeze(0)
    return frame

def get_feature(img_path, model):
    device = next(model.parameters()).device
    face = cv2.imread(img_path)
    with torch.no_grad():
        input_frame = preprocess_frame(face).to(device)
        output, cls_feature = model(input_frame, cf=None)
    return cls_feature

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)  # Numpy module.
    random.seed(seed)  # Python random module.

    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    os.environ['PYTHONHASHSEED'] = str(seed)

# Feature Extraction

In [None]:
model_path = ''
df_path = ''

In [None]:
if (torch.cuda.is_available()):
    torch.cuda.empty_cache()
    device = torch.device('cuda')
else :
    device = torch.device('cpu')

set_seed(seed=777)

model = torch.nn.DataParallel(FeatExt_MixStyleResCausalModel(model_name='resnet18',  pretrained=False, num_classes=args.num_classes, ms_layers=[]))
model = model.to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

In [None]:
df = pd.read_csv()
df['vector'] = list()

In [None]:
for index, row in tqdm(df.iterrows()):
    path = df.iloc[index,0]  
    feature = get_feature(path, model) 
    df['vector'] = feature