In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys
from pathlib import Path
HOME = os.getcwd()

DATA_FOLDER = os.path.join(HOME, 'training', 'data')
current = HOME
while 'src' not in os.listdir(current):
    current = Path(current).parent

sys.path.append(str(current))
sys.path.append(os.path.join(str(current), 'FaceSpoofing'))


In [3]:
import os
import numpy as np
import pytorch_lightning as L
import torchvision
import torch

In [4]:
from torch.utils.data import DataLoader, Dataset
TRAIN_CROPPED = os.path.join(DATA_FOLDER, 'train_cropped')
TEST_CROPPED = os.path.join(DATA_FOLDER, 'test_cropped')

In [5]:
class SpoofDataset(Dataset):
    def __init__(self, images):
        self.images = images
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image_path, label = self.images[idx]
        image = torchvision.io.read_image(image_path)
        image = torchvision.transforms.functional.resize(image, (224, 224))
        image = torch.tensor(image).float()
        return image, label

In [6]:
class FPAD(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.densenet = torchvision.models.densenet121(pretrained=True)
        self.transition = torch.nn.Sequential(
            torch.nn.BatchNorm2d(1024),
            torch.nn.ReLU(),
            torch.nn.Conv2d(1024, 512, 1),
            torch.nn.AdaptiveAvgPool2d(14)
        )
        
        self.fmap = torch.nn.Sequential(
            torch.nn.Conv2d(512, 1, 1),
            torch.nn.Sigmoid()
        )
        
        self.classifier = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(14 * 14, 1),
        )
    
    def forward(self, x):
        x = self.densenet.features(x)
        x = self.transition(x)
        x = self.fmap(x)
        
        return x, self.classifier(x)

In [7]:
class FPADLit(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = FPAD()
        self.criterion = torch.nn.BCEWithLogitsLoss()
        
    def forward(self, x):
        return self.model(x)
    
    def __loss(self, feature_map, logits, y: torch.Tensor, alpha=0.5):
        # Assuming that the feature map is Nx512x14x14
        # and the logits is Nx1, we need to expand the
        # labels to apply bce loss
        pixel_loss = self.criterion(
            feature_map,
            y.flatten().unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(feature_map).float()
        )
        
        binary_loss = self.criterion(logits.flatten().float(), y.flatten().float())
        
        return alpha * pixel_loss + (1 - alpha) * binary_loss
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        fmap, logits = self.forward(x)
        loss = self.__loss(fmap, logits, y)
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_fmap_loss', self.__loss(fmap, logits, y, alpha=1), prog_bar=True)
        self.log('train_binary_loss', self.__loss(fmap, logits, y, alpha=0), prog_bar=True)
        
        acc = (logits.flatten().sigmoid() > 0.5).float().eq(y.flatten()).float().mean()
        self.log('train_acc', acc, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        fmap, logits = self.forward(x)
        loss = self.__loss(fmap, logits, y)
        self.log('val_loss', loss)
        self.log('val_fmap_loss', self.__loss(fmap, logits, y, alpha=1))
        self.log('val_binary_loss', self.__loss(fmap, logits, y, alpha=0))
        
        acc = (logits.flatten().sigmoid() > 0.5).float().eq(y.flatten()).float().mean()
        self.log('val_acc', acc)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

In [8]:
# BASE = os.path.join('CelebData/CelebA_Spoof/Data')
DATA_FOLDER
train = []
for root, dirs, files in os.walk(TRAIN_CROPPED):
    for name in files:
        if name.endswith('.jpg'):
            if 'spoof' in root:
                train.append((os.path.join(root, name), 0))
            else:
                train.append((os.path.join(root, name), 1))

In [20]:
test = []
for root, dirs, files in os.walk(TEST_CROPPED):
    for name in files:
        if name.endswith('.jpg'):
            if 'spoof' in root:
                test.append((os.path.join(root, name), 0))
            else:
                test.append((os.path.join(root, name), 1))

In [10]:
# train = [os.path.join(TRAIN_CROPPED, file_name) for file_name in os.listdir(TRAIN_CROPPED)]
train_dl = DataLoader(SpoofDataset(train[:int(len(train) * 0.8)]), batch_size=4, shuffle=True)
val_dl = DataLoader(SpoofDataset(train[int(len(train) * 0.8):]), batch_size=4, shuffle=False)

In [11]:
from pytorch_lightning.callbacks import EarlyStopping
early = EarlyStopping(patience=3, mode='max', monitor='val_acc')

In [12]:
trainer = L.Trainer(
    max_epochs=25, 
    accelerator='gpu',
    callbacks=[early],
    val_check_interval=.25,
)
lit = FPADLit()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [13]:
trainer.fit(lit, train_dl, val_dl)

You are using a CUDA device ('NVIDIA GeForce RTX 3070 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type              | Params
------------------------------------------------
0 | model     | FPAD              | 8.5 M 
1 | criterion | BCEWithLogitsLoss | 0     
------------------------------------------------
8.5 M     Trainable params
0         Non-trainable params
8.5 M     Total params
34.026    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  image = torch.tensor(image).float()
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [14]:
trainer.validate(lit, val_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Validation: 0it [00:00, ?it/s]

  image = torch.tensor(image).float()


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Runningstage.validating metric      DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         val_acc             0.94525545835495
     val_binary_loss        0.2112363576889038
      val_fmap_loss         0.5805960893630981
        val_loss            0.3959161341190338
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss': 0.3959161341190338,
  'val_fmap_loss': 0.5805960893630981,
  'val_binary_loss': 0.2112363576889038,
  'val_acc': 0.94525545835495}]

In [22]:
test_dl = DataLoader(SpoofDataset(test), batch_size=4)

In [23]:
trainer.validate(lit, test_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Validation: 0it [00:00, ?it/s]

  image = torch.tensor(image).float()


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Runningstage.validating metric      DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         val_acc            0.9700099229812622
     val_binary_loss        0.11846015602350235
      val_fmap_loss         0.6761477589607239
        val_loss            0.3973047733306885
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss': 0.3973047733306885,
  'val_fmap_loss': 0.6761477589607239,
  'val_binary_loss': 0.11846015602350235,
  'val_acc': 0.9700099229812622}]

In [24]:
torch.save(lit.model.state_dict(), 'weights_v2.pt')

In [25]:
mdl = FPAD()
mdl.load_state_dict(torch.load('weights_v2.pt'))



<All keys matched successfully>

In [26]:
mdl

FPAD(
  (densenet): DenseNet(
    (features): Sequential(
      (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu0): ReLU(inplace=True)
      (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (denseblock1): _DenseBlock(
        (denselayer1): _DenseLayer(
          (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (denselayer2): _DenseLayer(
          (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, 