In [1]:
from pathlib import Path
USE_COLAB: bool = True
dataset_base_path = Path("/content/drive/My Drive/ECE 792 - Advance Topics in Machine Learning/Datasets")
if USE_COLAB:
  from google.colab import drive
  
  # Mount the drive to access google shared docs
  drive.mount('/content/drive/', force_remount=True)

  if dataset_base_path.exists():
    print("Folder exists")
  else:
    print("DOESN'T EXIST. Add desired folder as a shortcut in your 'My Drive'")

Mounted at /content/drive/
Folder exists


In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import torchvision.utils as vutils

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from typing import Tuple, Optional, List

import argparse
import os
from tqdm import tqdm
import time
import copy
import math
from zipfile import ZipFile

from PIL import Image
from typing import Dict, List, Union

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

# PyTorch's versions:
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
print("NumPy Version: ",np.__version__)
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

PyTorch Version:  1.13.1+cu116
Torchvision Version:  0.14.1+cu116
NumPy Version:  1.22.4
Wed Mar 22 00:16:31 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   51C    P0    26W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               

In [3]:
class CelebrityData(torch.utils.data.Dataset):

  def __init__(
    self,
    base_path: Path,
    transform = None,
    seed = None,
    gans_to_skip: Optional[List[str]] = None,
    n_fake_imgs_to_extract: int = 40000,
    *,
    unzip_real_imgs: bool = True,
    ):
    '''
    Folder structure of base_path should be
    base_path -> RealFaces/FakeFaces
    RealFaces -> .zip
    FakeFaces -> GANType -> .zip (e.g., FakeFaces -> PGGAN -> .zip)
    '''
    super().__init__()
    self.rng = np.random.default_rng(seed)
    self.unzip_real_imgs = unzip_real_imgs

    self.fake_images_path = Path(base_path) / "FakeFaces"
    self.fake_image_gan_names = os.listdir(str(self.fake_images_path))
    if gans_to_skip is not None:
      print(f"Not unzipping '{gans_to_skip}'")
      self.fake_image_gan_names = list(filter(lambda x: x not in gans_to_skip, self.fake_image_gan_names))
    self.fake_images: Dict[str, List[Union[str, Path]]] = {}
    for fake_image_gan_name in self.fake_image_gan_names:
      print(f"Extracting imagery for '{fake_image_gan_name}'")
      fake_image_gan_path = self.fake_images_path / fake_image_gan_name
      zip_file = sorted(fake_image_gan_path.glob("*.zip"))
      if len(zip_file) == 1:
        zip_file = zip_file[0]
        with ZipFile(str(zip_file), 'r') as zipObj:
          zipObj.extractall()
        self.fake_images.update({fake_image_gan_name: zipObj.namelist()[:n_fake_imgs_to_extract]})
      else:
        fake_image_paths = sorted(fake_image_gan_path.glob("*.jpg"))
        self.fake_images.update({fake_image_gan_name: fake_image_paths[:n_fake_imgs_to_extract]})

    self.n_fake_gans = len(list(self.fake_images.keys()))

    if unzip_real_imgs:
      self.real_images_path = Path(base_path) / "RealFaces"
      print("Extracting RealFaces imagery")
      real_images_zip_files = sorted(self.real_images_path.glob("*.zip"))
      if len(real_images_zip_files) != 1:
        raise RuntimeError(f"Got more than or less than 1 zip file in '{self.real_images_path}'. Got '{len(real_images_zip_files)}'")
      self.real_images_zip_file = real_images_zip_files[0]
      # Create a ZipFile Object and load sample.zip in it
      with ZipFile(str(self.real_images_zip_file), 'r') as zipObj:
        # Extract all the contents of zip file in current directory
        zipObj.extractall()
      if self.len_of_fake_images >= len(zipObj.namelist()):
        self.real_images = zipObj.namelist()[1:]
      else:
        self.real_images = zipObj.namelist()[1:self.len_of_fake_images+1]
    else:
      self.real_imgs = []

    self.transform = transform

    # according to Deep Fake Image Detection Based on Pairwise Learning, we need to make combinations for all
    # real images with all fake images
    fake_img_list = []
    for fake_imgs in self.fake_images.values():
      fake_img_list.extend(fake_imgs)
    self.fake_img_list = fake_img_list

  def fake_image_rand_selection(self, index) -> str:
    rand_selection = self.rng.uniform(low=-0.499, high=len(self.fake_images) - 0.501)
    gan_selection = self.fake_image_gan_names[int(np.round(rand_selection))]
    
    return self.fake_images.get(gan_selection)[index // len(self.fake_images)]

  @property
  def len_of_fake_images(self) -> int:
    total_len = 0
    for val in self.fake_images.values():
      total_len += len(val)

    return total_len

  def len_of_real_and_fake(self):
    return len(self.real_images) + self.len_of_fake_images

  def 

In [4]:
from dataclasses import dataclass
from typing import List

@dataclass
class Loss:
  loss_vals_per_batch: List[float]
  loss_vals_per_epoch: List[float]
  batch_cnt: int = 0
  previous_batch_cnt: int = 0
  epoch_cnt: int = 0

  @classmethod
  def init(cls) -> "Loss":
    return cls(
        loss_vals_per_batch=[],
        loss_vals_per_epoch=[],
        batch_cnt=0,
        previous_batch_cnt=0,
        epoch_cnt=0,
    )

  def __add__(self, other: Union[float, int]) -> "Loss":
    self.loss_vals_per_batch.append(other)
    self.batch_cnt += 1
    return self

  def __iadd__(self, other: Union[float, int]) -> "Loss":
    return self.__add__(other)

  @property
  def current_loss(self) -> float:
    return np.sum(self.loss_vals_per_batch) / self.batch_cnt

  @property
  def previous_loss(self) -> float:
    if len(self.loss_vals_per_batch) > 1:
      return np.sum(self.loss_vals_per_batch[:-2]) / (self.batch_cnt - 1)
    else:
      return 0

  def update_for_epoch(self):
    self.epoch_cnt += 1
    self.loss_vals_per_epoch.append(
        sum(self.loss_vals_per_batch[self.previous_batch_cnt:self.batch_cnt])
        / (self.batch_cnt - self.previous_batch_cnt)
    )
    self.previous_batch_cnt = self.batch_cnt

In [5]:
import matplotlib.pyplot as plt
from typing import Optional

def plot_accuracy_or_loss(
  train_vals: List[float],
  output_path: Union[str, Path],
  validation_vals: Optional[List[float]] = None,
  test_vals: Optional[List[float]] = None,
  title: Optional[str] = None,
  ylabel: Optional[str] = None,
  xlabel: Optional[str] = None,
  plot_labels: Optional[Union[str, List[str]]] = None,
):
  if plot_labels is None:
    plot_labels = ["train"]
    if validation_vals is not None:
      plot_labels.append("validation")
    if test_vals is not None:
      plot_labels.append("test")
  elif not isinstance(plot_labels, list):
    plot_labels = [plot_labels] * 3

  x_epochs = np.arange(1, len(train_vals) + 1)
  plt.plot(x_epochs, train_vals, label=plot_labels[0])
  if validation_vals is not None:
    x_epochs = np.arange(len(train_vals) - len(validation_vals) + 1, len(train_vals) + 1)
    plt.plot(x_epochs, validation_vals, label=plot_labels[1])
  if test_vals is not None:
    x_epochs = np.arange(len(train_vals) - len(test_vals) + 1, len(train_vals) + 1)
    plt.plot(x_epochs, test_vals, label=plot_labels[2])
  if title is not None:
    plt.title(title)
  if ylabel is not None:
    plt.ylabel(ylabel)
  if xlabel is not None:
    plt.xlabel(xlabel)
  plt.legend()
  plt.savefig(output_path)
  plt.close()


def save_loss_plot(
  train_loss: List[float],
  output_path: Union[str, Path],
  test_loss: Optional[List[float]] = None,
  val_loss: Optional[List[float]] = None,
  title: str = "Loss",
  ylabel: str = "Loss",
  xlabel: str = "Epochs",
  plot_labels: Optional[Union[str, List[str]]] = None,
):
  plot_accuracy_or_loss(
    train_vals=train_loss,
    output_path=output_path,
    validation_vals=val_loss,
    test_vals=test_loss,
    title=title,
    ylabel=ylabel,
    xlabel=xlabel,
    plot_labels=plot_labels,
  )

def save_accuracy_plot(
  train_acc: List[float],
  output_path: Union[str, Path],
  test_acc: Optional[List[float]] = None,
  val_acc: Optional[List[float]] = None,
  plot_labels: Optional[Union[str, List[str]]] = None,
):
  plot_accuracy_or_loss(
    train_vals=train_acc,
    output_path=output_path,
    validation_vals=val_acc,
    test_vals=test_acc,
    title="Accuracy",
    ylabel="Accuracy",
    xlabel="Epochs",
    plot_labels=plot_labels,
  )

In [6]:
import re
def get_latest_model(base_path, suffix: str = ".pth") -> Path:
  epoch_num = []
  all_files = sorted(Path(base_path).glob(suffix))
  for file_ in all_files:
    idx_num = re.search("--", str(file_)).span()
    idx_pt = re.search(suffix, str(file_)).span()
    model_num = str(file_)[idx_num[-1]:idx_pt[0]]
    try:
      epoch_num.append(int(model_num))
    except ValueError:
      idx_num = re.search("--", str(model_num)).span()
      epoch_num.append(int(model_num[idx_num[-1]:]))

  idx = epoch_num.index(np.max(epoch_num))
  return all_files[idx]

In [7]:
import math
def number_of_combinations(n_objs: int, r_at_a_time: int) -> int:
  num = math.factorial(n_objs)
  den = math.factorial(n_objs - r_at_a_time) * math.factorial(r_at_a_time)
  return int(num / den)

def get_n_objs_for_a_number_of_combinations_with_2_at_a_time(combs: int) -> int:
  return int((1 + math.sqrt(1 + (4*combs * 2))) / 2)

Two-step learning policy as employed by 'Deep Fake Image Detection Based on Pairwise Learning'. Therefore, we first train the CFFN network with the contrastive loss. After the CFFN network is learned to minimize the contrastive loss, we then train the classification network using the outputs from the CFFN network, as this better feature representation output will allow the classification network to better classify the images as fake or real. The classification network is trained using the binary cross-entropy loss of predicting whether the image is real or fake [p.6. section 2.4]

**1. Common Fake Feature Network**

Network structure includes a pairwise learning approach. "A fake face image detector based on the novel CFFN, consisting of an improved DenseNet backbone network and Siamese network architecture...The cross-layer features are investigated by the proposed CFFN, which can be used to improve the performance."

The fake and real images are paired together and the pairwise information is used to construct the contrastive loss to learn the discriminative common fake feature (CFF) by the CFFN. The paper states that 2 million pairwise samples are used for training.

"One way to learn both the CFFs and classifier is the join learning strategy incorporating the contrastive loss and cross-entropy loss into the total energy function. In another way, the CFFN is first trained by the proposed contrastive loss and follows by training the classifier based on cross-entropy loss. When the first strategy is applied, it is difficult to observe the impact of both contrastive and cross-entropy loss functions on the performance of the fake image detection tasks. Therefore, we adopt the second strategy to ensure the best performance of the proposed method."

In [None]:
import itertools
from typing import Tuple
class CelebrityDataCFFN(CelebrityData):

  def __init__(self, base_path: Path, transform = None, seed = None, n_combinations: int = 4e6, gans_to_skip: Optional[List[str]] = None):
    super().__init__(base_path=base_path, transform=transform, seed=seed, gans_to_skip=gans_to_skip)

    # for training the CFFN we want fake-fake pairs & real-real pairs
    # self.n_fake_combinations = self.n_fake_gans * number_of_combinations(int(self.len_of_fake_images / self.n_fake_gans), 2)
    self.n_imgs_for_combinations = get_n_objs_for_a_number_of_combinations_with_2_at_a_time(n_combinations)
    # only making combinations between images made by the same GAN
    # could try experimenting with combinations of images between different GANs
    self.fake_image_combos: Dict[str, list] = {}
    for gan_name, img_list in self.fake_images.items():
      self.fake_image_combos.update({gan_name: list(itertools.combinations(img_list[:self.n_imgs_for_combinations], 2))})
    
    self.real_image_combos = list(itertools.combinations(self.real_images[:self.n_imgs_for_combinations], 2))

  def __getitem__(self, index):
    img0_path, img1_path, pair_indicator = self.choose_real_or_fake_pair(index)
    img0 = Image.open(img0_path).convert('RGB')
    if self.transform is not None:
      img0 = self.transform(img0)

    img1 = Image.open(img1_path).convert('RGB')
    if self.transform is not None:
      img1 = self.transform(img1)

    return img0, img1, pair_indicator

  # def choose_real_or_fake_pair(self, index) -> Tuple[str, str, int]:
  #   if self.rng.standard_normal() > 0:
  #     img_pair = next(itertools.islice(self.real_image_combos, index, None))
  #     pair_indicator = 1
  #   else:
  #     img_pair = self.fake_image_rand_selection(index)
  #     pair_indicator = 0

  #   return img_pair[0], img_pair[1], pair_indicator

  def choose_real_or_fake_pair(self, index) -> Tuple[str, str, int]:
    if self.rng.standard_normal() > 0:
      img_pair = self.real_image_combos[index]
      pair_indicator = 1
    else:
      img_pair = self.fake_image_rand_selection(index)
      pair_indicator = 0

    return img_pair[0], img_pair[1], pair_indicator

  # def fake_image_rand_selection(self, index) -> Tuple[str, str]:
  #   rand_selection = self.rng.uniform(low=-0.499, high=len(self.fake_images) - 0.501)
  #   gan_selection = self.fake_image_gan_names[int(np.round(rand_selection))]
  #   iter_combo = self.fake_image_combos.get(gan_selection)
  #   return next(itertools.islice(iter_combo, index // len(self.fake_image_gan_names), None))

  def fake_image_rand_selection(self, index) -> Tuple[str, str]:
    rand_selection = self.rng.uniform(low=-0.499, high=self.n_fake_gans - 0.501)
    gan_selection = self.fake_image_gan_names[int(np.round(rand_selection))]

    return self.fake_image_combos[gan_selection][index]

  def __len__(self):
    return len(self.real_image_combos)

In [8]:
# We will be working with GPU:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device : ' , device)

# Number of GPUs available. 
num_GPU = torch.cuda.device_count()
print('Number of GPU : ', num_GPU)

model_output_path = Path("/content/drive/MyDrive/ECE 792 - Advance Topics in Machine Learning/Code/DeepFakeImageDetection/CFFN/model")
if not model_output_path.exists():
  model_output_path.mkdir(exist_ok=True, parents=True)

config_cffn = { 'batch_size'             : 88,
                'image_size'             : 64,
                'n_channel'              : 3,
                'n_epochs'               : 15,
                'lr'                     : 1e-3,
                'growth_rate'            : 24,
                'transition_layer_theta' : 0.5,
                'device'                 : device,
                'm_th'                   : 0.5,
                'n_combinations'         : 2e6,
                'seed'                   : 999,
                'model_output_path'      : model_output_path,
                'chkp_freq'              : 1,  # number of epochs to save model out
                'n_workers'              : 4,
                'gans_to_skip'           : ["CDCGAN"],
}

Device :  cuda
Number of GPU :  1


In [None]:
celebrity_data_cffn = CelebrityDataCFFN(
  base_path=dataset_base_path,
  transform=transforms.Compose(
    [
      transforms.Resize(int(config_cffn["image_size"] * 1.1)),
      transforms.CenterCrop(config_cffn["image_size"]),
      transforms.ToTensor(),
      transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ]
  ),
  seed=config_cffn["seed"],
  n_combinations=config_cffn["n_combinations"],
)

dataloader_cffn = torch.utils.data.DataLoader(
  dataset=celebrity_data_cffn,
  shuffle=True,
  batch_size=config_cffn["batch_size"],
  num_workers=config_cffn["n_workers"],
  drop_last=True,  # drop last batch that may not be the same size as the expected batch for the network
  pin_memory=True,
)

NameError: ignored

In [9]:
from typing import Callable, Tuple

class DenseBlock2(nn.Module):
  conv0_0_out = None
  conv0_1_out = None
  batch_norm0_out = None
  concat0_out = None
  activation0_out = None
  conv1_0_out = None
  conv1_1_out = None
  batch_norm1_out = None
  concat1_out = None
  activation1_out = None
  trans_layer_out = None
  def __init__(
    self,
    in_channels: int,
    out_channels: int,
    growth_rate: int,
    transition_layer_theta: float,
    device: torch.device = None,
  ):
    super().__init__()
    if device is None:
      device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    self.conv0_0 = nn.Conv2d(
      in_channels=in_channels,
      out_channels=in_channels * 2,
      kernel_size=(1, 1),
      padding=0,
      stride=(1, 1),
      device=device,
    )
    self.conv0_1 = nn.Conv2d(
      in_channels=in_channels * 2,
      out_channels=growth_rate,
      kernel_size=(3, 3),
      padding=1,
      stride=(1, 1),
      device=device,
    )
    self.batch_norm0 = nn.BatchNorm2d(in_channels + growth_rate, device=device)

    self.conv1_0 = nn.Conv2d(
      in_channels=in_channels + growth_rate,
      out_channels=(in_channels + growth_rate) * 2,
      kernel_size=(1, 1),
      padding=0,
      stride=(1, 1),
      device=device,
    )
    self.conv1_1 = nn.Conv2d(
      in_channels=(in_channels + growth_rate) * 2,
      out_channels=growth_rate,
      kernel_size=(3, 3),
      padding=1,
      stride=(1, 1),
      device=device,
    )
    self.batch_norm1 = nn.BatchNorm2d(in_channels + (2 * growth_rate), device=device)

    trans_kernel_size = int(1/transition_layer_theta)
    self.trans_layer = nn.MaxPool3d(kernel_size=(trans_kernel_size, 1, 1))
    self.activation_func = nn.ReLU()

  def forward(self, x):
    self.conv0_0_out = self.conv0_0(x)
    self.conv0_1_out = self.conv0_1(self.conv0_0_out)
    self.concat0_out = torch.concat((self.conv0_1_out, x), dim=1)
    self.batch_norm0_out = self.batch_norm0(self.concat0_out)
    self.activation0_out = self.activation_func(self.batch_norm0_out)

    self.conv1_0_out = self.conv1_0(self.activation0_out)
    self.conv1_1_out = self.conv1_1(self.conv1_0_out)
    self.concat1_out = torch.concat((self.conv1_1_out, self.activation0_out), dim=1)
    self.batch_norm1_out = self.batch_norm1(self.concat1_out)
    self.activation1_out = self.activation_func(self.batch_norm1_out)

    self.trans_layer_out = self.trans_layer(self.activation1_out)

    return self.trans_layer_out

class DenseBlock3(nn.Module):
  conv0_0_out = None
  conv0_1_out = None
  batch_norm0_out = None
  concat0_out = None
  activation0_out = None
  conv1_0_out = None
  conv1_1_out = None
  batch_norm1_out = None
  concat1_out = None
  activation1_out = None
  conv2_0_out = None
  conv2_1_out = None
  batch_norm2_out = None
  concat2_out = None
  activation2_out = None
  trans_layer_out = None
  def __init__(
    self,
    in_channels: int,
    out_channels: int,
    growth_rate: int,
    transition_layer_theta: float,
    device: torch.device = None,
  ):
    super().__init__()
    if device is None:
      device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    self.conv0_0 = nn.Conv2d(
      in_channels=in_channels,
      out_channels=in_channels * 2,
      kernel_size=(1, 1),
      padding=0,
      stride=(1, 1),
      device=device,
    )
    self.conv0_1 = nn.Conv2d(
      in_channels=in_channels * 2,
      out_channels=growth_rate,
      kernel_size=(3, 3),
      padding=1,
      stride=(1, 1),
      device=device,
    )
    self.batch_norm0 = nn.BatchNorm2d(in_channels + growth_rate, device=device)

    self.conv1_0 = nn.Conv2d(
      in_channels=in_channels + growth_rate,
      out_channels=(in_channels + growth_rate) * 2,
      kernel_size=(1, 1),
      padding=0,
      stride=(1, 1),
      device=device,
    )
    self.conv1_1 = nn.Conv2d(
      in_channels=(in_channels + growth_rate) * 2,
      out_channels=growth_rate,
      kernel_size=(3, 3),
      padding=1,
      stride=(1, 1),
      device=device,
    )
    self.batch_norm1 = nn.BatchNorm2d(in_channels + (2 * growth_rate), device=device)

    self.conv2_0 = nn.Conv2d(
      in_channels=in_channels + (2 * growth_rate),
      out_channels=(in_channels + (2 * growth_rate)) * 2,
      kernel_size=(1, 1),
      padding=0,
      stride=(1, 1),
      device=device,
    )
    self.conv2_1 = nn.Conv2d(
      in_channels=(in_channels + (2 * growth_rate)) * 2,
      out_channels=growth_rate,
      kernel_size=(3, 3),
      padding=1,
      stride=(1, 1),
      device=device,
    )
    self.batch_norm2 = nn.BatchNorm2d((in_channels + (3 * growth_rate)), device=device)

    trans_kernel_size = int(1/transition_layer_theta)
    self.trans_layer = nn.MaxPool3d(kernel_size=(trans_kernel_size, 1, 1))
    self.activation_func = nn.ReLU()

  def forward(self, x):
    self.conv0_0_out = self.conv0_0(x)
    self.conv0_1_out = self.conv0_1(self.conv0_0_out)
    self.concat0_out = torch.concat((self.conv0_1_out, x), dim=1)
    self.batch_norm0_out = self.batch_norm0(self.concat0_out)
    self.activation0_out = self.activation_func(self.batch_norm0_out)

    self.conv1_0_out = self.conv1_0(self.activation0_out)
    self.conv1_1_out = self.conv1_1(self.conv1_0_out)
    self.concat1_out = torch.concat((self.conv1_1_out, self.activation0_out), dim=1)
    self.batch_norm1_out = self.batch_norm1(self.concat1_out)
    self.activation1_out = self.activation_func(self.batch_norm1_out)

    self.conv2_0_out = self.conv2_0(self.activation1_out)
    self.conv2_1_out = self.conv2_1(self.conv2_0_out)
    self.concat2_out = torch.concat((self.conv2_1_out, self.activation1_out), dim=1)
    self.batch_norm2_out = self.batch_norm2(self.concat2_out)
    self.activation2_out = self.activation_func(self.batch_norm2_out)

    self.trans_layer_out = self.trans_layer(self.activation2_out)

    return self.trans_layer_out

class DenseBlock4(nn.Module):
  conv0_0_out = None
  conv0_1_out = None
  batch_norm0_out = None
  concat0_out = None
  activation0_out = None
  conv1_0_out = None
  conv1_1_out = None
  batch_norm1_out = None
  concat1_out = None
  activation1_out = None
  conv2_0_out = None
  conv2_1_out = None
  batch_norm2_out = None
  concat2_out = None
  activation2_out = None
  conv3_0_out = None
  conv3_1_out = None
  batch_norm3_out = None
  concat3_out = None
  activation3_out = None
  trans_layer_out = None
  def __init__(
    self,
    in_channels: int,
    out_channels: int,
    growth_rate: int,
    transition_layer_theta: float,
    device: torch.device = None,
  ):
    super().__init__()
    if device is None:
      device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    self.conv0_0 = nn.Conv2d(
      in_channels=in_channels,
      out_channels=in_channels * 2,
      kernel_size=(1, 1),
      padding=0,
      stride=(1, 1),
      device=device,
    )
    self.conv0_1 = nn.Conv2d(
      in_channels=in_channels * 2,
      out_channels=growth_rate,
      kernel_size=(3, 3),
      padding=1,
      stride=(1, 1),
      device=device,
    )
    self.batch_norm0 = nn.BatchNorm2d(in_channels + growth_rate, device=device)

    self.conv1_0 = nn.Conv2d(
      in_channels=in_channels + growth_rate,
      out_channels=(in_channels + growth_rate) * 2,
      kernel_size=(1, 1),
      padding=0,
      stride=(1, 1),
      device=device,
    )
    self.conv1_1 = nn.Conv2d(
      in_channels=(in_channels + growth_rate) * 2,
      out_channels=growth_rate,
      kernel_size=(3, 3),
      padding=1,
      stride=(1, 1),
      device=device,
    )
    self.batch_norm1 = nn.BatchNorm2d(in_channels + (2 * growth_rate), device=device)

    self.conv2_0 = nn.Conv2d(
      in_channels=in_channels + (2 * growth_rate),
      out_channels=(in_channels + (2 * growth_rate)) * 2,
      kernel_size=(1, 1),
      padding=0,
      stride=(1, 1),
      device=device,
    )
    self.conv2_1 = nn.Conv2d(
      in_channels=(in_channels + (2 * growth_rate)) * 2,
      out_channels=growth_rate,
      kernel_size=(3, 3),
      padding=1,
      stride=(1, 1),
      device=device,
    )
    self.batch_norm2 = nn.BatchNorm2d((in_channels + (3 * growth_rate)), device=device)

    self.conv3_0 = nn.Conv2d(
      in_channels=in_channels + (3 * growth_rate),
      out_channels=(in_channels + (3 * growth_rate)) * 2,
      kernel_size=(1, 1),
      padding=0,
      stride=(1, 1),
      device=device,
    )
    self.conv3_1 = nn.Conv2d(
      in_channels=(in_channels + (3 * growth_rate)) * 2,
      out_channels=growth_rate,
      kernel_size=(3, 3),
      padding=1,
      stride=(1, 1),
      device=device,
    )
    self.batch_norm3 = nn.BatchNorm2d((in_channels + (4 * growth_rate)), device=device)

    trans_kernel_size = int(1/transition_layer_theta)
    self.trans_layer = nn.MaxPool3d(kernel_size=(trans_kernel_size, 1, 1))
    self.activation_func = nn.ReLU()

  def forward(self, x):
    self.conv0_0_out = self.conv0_0(x)
    self.conv0_1_out = self.conv0_1(self.conv0_0_out)
    self.concat0_out = torch.concat((self.conv0_1_out, x), dim=1)
    self.batch_norm0_out = self.batch_norm0(self.concat0_out)
    self.activation0_out = self.activation_func(self.batch_norm0_out)

    self.conv1_0_out = self.conv1_0(self.activation0_out)
    self.conv1_1_out = self.conv1_1(self.conv1_0_out)
    self.concat1_out = torch.concat((self.conv1_1_out, self.activation0_out), dim=1)
    self.batch_norm1_out = self.batch_norm1(self.concat1_out)
    self.activation1_out = self.activation_func(self.batch_norm1_out)

    self.conv2_0_out = self.conv2_0(self.activation1_out)
    self.conv2_1_out = self.conv2_1(self.conv2_0_out)
    self.concat2_out = torch.concat((self.conv2_1_out, self.activation1_out), dim=1)
    self.batch_norm2_out = self.batch_norm2(self.concat2_out)
    self.activation2_out = self.activation_func(self.batch_norm2_out)

    self.conv3_0_out = self.conv3_0(self.activation2_out)
    self.conv3_1_out = self.conv3_1(self.conv3_0_out)
    self.concat3_out = torch.concat((self.conv3_1_out, self.activation2_out), dim=1)
    self.batch_norm3_out = self.batch_norm3(self.concat3_out)
    self.activation3_out = self.activation_func(self.batch_norm3_out)

    self.trans_layer_out = self.trans_layer(self.activation3_out)

    return self.trans_layer_out

# class DenseBlock(nn.Module):
#   def __init__(
#     self,
#     n_conv: int,
#     in_channels: int,
#     out_channels: int,
#     growth_rate: int,
#     transition_layer_theta: float,
#     device: torch.device = None,
#   ):
#     super().__init__()
#     self.modules = []
#     self.batch_norms = []
#     if device is None:
#       device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     for idx in range(n_conv):
#       in_channels_with_growth = in_channels + (idx * growth_rate)
#       out_channels_with_growth = in_channels + ((idx + 1) * growth_rate)
#       self.modules.append(
#         [
#           nn.Conv2d(
#             in_channels=in_channels_with_growth,
#             out_channels=in_channels_with_growth * 2,
#             kernel_size=(1, 1),
#             padding=0,
#             stride=(1, 1),
#             device=device,
#           ),
#           nn.Conv2d(
#             in_channels=in_channels_with_growth * 2,
#             out_channels=growth_rate,
#             kernel_size=(3, 3),
#             padding=1,
#             stride=(1, 1),
#             device=device,
#           ),
#         ]
#       )
#       self.batch_norms.append(nn.BatchNorm2d(out_channels_with_growth, device=device))
#     trans_kernel_size = int(1 / transition_layer_theta)
#     self.trans_layer = nn.MaxPool3d(kernel_size=(trans_kernel_size, 1, 1))
#     self.activation_func = nn.ReLU()

#   def forward(self, x):
#     layer_outputs = [x]
#     for d_block, batch_norm in zip(self.modules, self.batch_norms):
#       for module in d_block:
#         x = module(x)
#       x = torch.concat((x, layer_outputs[-1]), dim=1)
#       x = batch_norm(x)
#       x = self.activation_func(x)
#       layer_outputs.append(x)

#     x = self.trans_layer(x)
#     return x


class CFFNEnergyFunction(nn.Module):
  loss = None

  def __init__(self, batch_size: int = 88, m_th: float = 0.5, device=device):
    super().__init__()
    self.m_th = torch.empty(batch_size, device=device).fill_(m_th)
    self.zero_tensor = torch.empty(batch_size, device=device).fill_(0)
    self.energy_function = nn.MSELoss(reduction="none")

  def forward(self, img0, img1, pairs_indicator):
    E_w = torch.mean(self.energy_function(img0, img1), dim=1)
    real_pairs = (0.5 * torch.mul(pairs_indicator, torch.pow(E_w, 2)))
    fake_pairs = torch.mul(
        (1 - pairs_indicator),
        torch.max(self.zero_tensor, self.energy_function(self.m_th, E_w))
        )
    self.loss = torch.mean(torch.add(real_pairs, fake_pairs))

    return self.loss

  def item(self):
    return self.loss.item()


class CFFN(nn.Module):
  dense_conv1_out = None
  dense_conv2_out = None
  dense_conv3_out = None
  dense_conv4_out = None
  conv5_out = None
  batch_norm5_out = None
  activation5_out = None

  def __init__(
    self,
    input_image_shape: Tuple[int, int],
    growth_rate: int = 24,
    transition_layer_theta: float = 0.5,
    learning_rate: float = 1e-3,
    m_th: float = 0.5,  # threshold for contrastive loss
    batch_size: int = 88,
    device: torch.device = None,
  ):
    super().__init__()
    self.conv0 = nn.Conv2d(in_channels=3, out_channels=48, kernel_size=(7, 7), stride=(4, 4))
    self.batch_norm0 = nn.BatchNorm2d(48)
    self.activation0 = nn.ReLU()
    if device is None:
      device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # self.dense_conv1 = DenseBlock(
    #   n_conv=2,
    #   in_channels=48,
    #   out_channels=48,
    #   growth_rate=growth_rate,
    #   transition_layer_theta=transition_layer_theta,
    #   device=device,
    # ).to(device)
    self.dense_conv1 = DenseBlock2(
      in_channels=48,
      out_channels=48,
      growth_rate=growth_rate,
      transition_layer_theta=transition_layer_theta,
      device=device,
    )
    self.dense_conv2 = DenseBlock3(
      in_channels=48,
      out_channels=60,
      growth_rate=24,
      transition_layer_theta=transition_layer_theta,
      device=device,
    ).to(device)
    self.dense_conv3 = DenseBlock4(
      in_channels=60,
      out_channels=78,
      growth_rate=24,
      transition_layer_theta=transition_layer_theta,
      device=device,
    ).to(device)
    self.dense_conv4 = DenseBlock2(
      in_channels=78,
      out_channels=126,
      growth_rate=24,
      transition_layer_theta=1,
      device=device,
    ).to(device)
    self.conv5 = nn.Conv2d(in_channels=126, out_channels=128, kernel_size=(3, 3))
    self.batch_norm5 = nn.BatchNorm2d(128)
    self.activation5 = nn.ReLU()
    # self.fully_connected: Callable = lambda in_conv_n_channels, conv_shape0, conv_shape1: nn.Sequential(
    #   nn.Flatten(),
    #   nn.Linear(in_conv_n_channels * conv_shape0 * conv_shape1, 128, device=device),
    #   nn.ReLU(),
    # )
    self.flatten = nn.Flatten()
    self.fully_connected3 = nn.Linear(78 * 15 * 15, 128, device=device)
    self.fully_connected4 = nn.Linear(126 * 15 * 15, 128, device=device)
    self.fully_connected5 = nn.Linear(128 * 13 * 13, 128, device=device)
    self.activation_func = nn.ReLU()
    self.loss = CFFNEnergyFunction(batch_size=batch_size, m_th=m_th, device=device).to(device)
    self.optimizer = torch.optim.Adam(
        self.parameters(), lr=learning_rate,
    )

  def forward(self, x):
    x = self.conv0(x)
    x = self.batch_norm0(x)
    x = self.activation0(x)
    self.dense_conv1_out = self.dense_conv1(x)
    # print(f"dense_conv1_out.shape = '{self.dense_conv1_out.shape}'")
    self.dense_conv2_out = self.dense_conv2(self.dense_conv1_out)
    # print(f"dense_conv2_out.shape = '{self.dense_conv2_out.shape}'")
    self.dense_conv3_out = self.dense_conv3(self.dense_conv2_out)
    # print(f"dense_conv3_out.shape = '{self.dense_conv3_out.shape}'")
    self.dense_conv4_out = self.dense_conv4(self.dense_conv3_out)
    # print(f"dense_conv4_out.shape = '{self.dense_conv4_out.shape}'")
    self.conv5_out = self.conv5(self.dense_conv4_out)
    # print(f"conv5_out.shape = '{self.conv5_out.shape}'")
    self.batch_norm5_out = self.batch_norm5(self.conv5_out)
    self.activation5_out = self.activation5(self.batch_norm5_out)

    # fn5_module = self.fully_connected(*self.activation5_out.shape[1:])
    # fn5 = fn5_module(self.activation5_out)
    activation5_flattened = self.flatten(self.activation5_out)
    fn5_out = self.fully_connected5(activation5_flattened)
    fn5 = self.activation_func(fn5_out)

    # fn4_module = self.fully_connected(*self.dense_conv4_out.shape[1:])
    # fn4 = fn4_module(self.dense_conv4_out)
    dense_conv4_flattened = self.flatten(self.dense_conv4_out)
    fn4_out = self.fully_connected4(dense_conv4_flattened)
    fn4 = self.activation_func(fn4_out)

    # fn3_module = self.fully_connected(*self.dense_conv3_out.shape[1:])
    # fn3 = fn3_module(self.dense_conv3_out)
    dense_conv3_flattened = self.flatten(self.dense_conv3_out)
    fn3_out = self.fully_connected3(dense_conv3_flattened)
    fn3 = self.activation_func(fn3_out)

    x_out = torch.cat((fn5, fn4, fn3), dim=1)

    # return output of convolution 5, which will be input to classification network
    # x_out is discriminative features output used for the pairwise learning for CFFN network
    return self.activation5_out, x_out

  def loss_back_grad(self, img0, img1, pairs_indicator, back_grad: bool = True):
    criterion = self.loss(img0, img1, pairs_indicator)
    if back_grad:
      criterion.backward()
      self.optimizer.step()

In [10]:
cffn = CFFN(
    input_image_shape=(64, 64),
    growth_rate=config_cffn["growth_rate"],
    transition_layer_theta=config_cffn["transition_layer_theta"],
    learning_rate=config_cffn["lr"],
    m_th=config_cffn["m_th"],
    batch_size=config_cffn["batch_size"],
    device=config_cffn["device"],
    ).to(config_cffn["device"])
cffn_models_output_path = config_cffn["model_output_path"] / "models"
cffn_models_output_path.mkdir(exist_ok=True, parents=True)
loss_output_path = config_cffn["model_output_path"] / "loss"
loss_output_path.mkdir(exist_ok=True, parents=True)

In [None]:
from tqdm import tqdm
## CFFN TRAINING
train: bool = False
test: bool = False
if train:
  epoch_tqdm = tqdm(total=config_cffn["n_epochs"], position=0)
  train_loss = Loss.init()
  print("Starting Training Loop...")
  for epoch in range(config_cffn["n_epochs"]):
    for batch_idx, (img0, img1, pair_indicator) in enumerate(dataloader_cffn):
      cffn.optimizer.zero_grad()
      img0 = img0.to(config_cffn["device"])
      img1 = img1.to(config_cffn["device"])
      _, img0_discriminative_features = cffn(img0)
      _, img1_discriminative_features = cffn(img1)
      pair_indicator = torch.tensor(pair_indicator, device=config_cffn["device"])
      cffn.loss_back_grad(
          img0_discriminative_features,
          img1_discriminative_features,
          pair_indicator
          )
      train_loss += cffn.loss.item()
      
      if batch_idx % 50 == 0:
        epoch_tqdm.write(" [%d/%d]\tLoss: %.8f" % (batch_idx, len(dataloader_cffn), train_loss.current_loss))

    epoch_tqdm.update(1)
    train_loss.update_for_epoch()
    if epoch % config_cffn["chkp_freq"] == 0:
      torch.save(
          {
            "CFFN_state_dict": cffn.state_dict(),
            "CFFN_optimizer": cffn.optimizer.state_dict(),      
          },
          str(cffn_models_output_path / f"CFFN--{epoch}.pth")
      )

      loss_epoch_out_path = loss_output_path / f"epoch-loss--{epoch}.png"
      save_loss_plot(train_loss.loss_vals_per_epoch, loss_epoch_out_path)
      loss_batch_out_path = loss_output_path / f"batch-loss--{epoch}.png"
      save_loss_plot(train_loss.loss_vals_per_batch, loss_batch_out_path, xlabel="Batches")
elif test:
  model_file = get_latest_model(cffn_models_output_path)
  print(model_file)
  checkpoint = torch.load(str(model_file))
  cffn.load_state_dict(checkpoint["CFFN_state_dict"])
  cffn.to(device)
  cffn.eval()
  test_loss = Loss.init()
  batch_tqdm = tqdm(total=len(dataloader_cffn), position=0)
  for batch_idx, (img0, img1, pair_indicator) in enumerate(dataloader_cffn):
    img0 = img0.to(config_cffn["device"])
    img1 = img1.to(config_cffn["device"])
    _, img0_discriminative_features = cffn(img0)
    _, img1_discriminative_features = cffn(img1)
    pair_indicator = torch.tensor(pair_indicator, device=config_cffn["device"])
    cffn.loss_back_grad(
        img0_discriminative_features,
        img1_discriminative_features,
        pair_indicator,
        back_grad=False,
        )
    test_loss += cffn.loss.item()
    
    if batch_idx % 50 == 0:
      batch_tqdm.write(" [%d/%d]\tLoss: %.8f" % (batch_idx, len(dataloader_cffn), test_loss.current_loss))

    batch_tqdm.update(1)
else:
  print("PASSING AS 'train' & 'test' ARE BOTH FALSE.")

PASSING AS 'train' & 'test' ARE BOTH FALSE.


**2. Classification Network**

"The classification sub-network consists of a convolution layer with two channels, and a fully connected layer with two neurons."

In [12]:
import pandas as pd
import seaborn as sns

@dataclass
class Accuracy:
  acc_vals_per_batch: List[float]
  acc_vals_per_epoch: List[float]
  precision_per_epoch: List[float]
  recall_per_epoch: List[float]
  f1_score_per_epoch: List[float]
  correct_hits: np.ndarray
  correct_hits_per_epoch: List[np.ndarray]
  incorrect_hits: np.ndarray
  incorrect_hits_per_epoch: List[np.ndarray]
  output_decisions: int
  batch_cnt: int = 0
  previous_batch_cnt: int = 0
  epoch_cnt: int = 0
  onehotencoding: bool = True

  @classmethod
  def from_output_decisions(cls, output_size: int, onehotencoding: bool = True) -> "Accuracy":
    if onehotencoding:
      inc_hits_shape = (output_size, output_size)
    else:
      inc_hits_shape = (output_size,)
    return cls(
        acc_vals_per_batch=[],
        acc_vals_per_epoch=[],
        precision_per_epoch=[],
        recall_per_epoch=[],
        f1_score_per_epoch=[],
        batch_cnt=0,
        previous_batch_cnt=0,
        epoch_cnt=0,
        correct_hits=np.zeros((output_size,)),
        correct_hits_per_epoch=[],
        incorrect_hits=np.zeros(inc_hits_shape),
        incorrect_hits_per_epoch=[],
        output_decisions=output_size,
        onehotencoding=onehotencoding,
    )

  def compare_batch(self, targets: torch.Tensor, outputs: torch.Tensor) -> List[Tuple[int, int]]:
    # determine accuracy between a batch of targets and outputs to update accuracy
    hit: int = 0
    indices = None
    if self.onehotencoding:
      indices = []
      for batch_idx, (target, output) in enumerate(zip(targets, outputs)):
          max_idx = int(torch.argmax(output))
          if bool(
              target[max_idx]
          ):  # see if the one hot encoding scheme of our output neuron layer determined the highest probability to be the same as the true target label
              hit += 1
              self.correct_hits[max_idx] += 1
              indices.append((batch_idx, max_idx))
          else:
              self.incorrect_hits[int(torch.argmax(target)), max_idx] += 1
    else:
      outputs = torch.argmax(outputs, dim=1)
      
      zero_targets_idx = torch.where(targets == 0)
      zero_equality = torch.eq(outputs[zero_targets_idx], targets[zero_targets_idx])
      self.correct_hits[0] += int(sum(zero_equality))
      hit = int(sum(zero_equality))
      self.incorrect_hits[0] += int(sum(~zero_equality))

      ones_targets_idx = torch.where(targets == 1)
      ones_equality = torch.eq(outputs[ones_targets_idx], targets[ones_targets_idx])
      self.correct_hits[1] += int(sum(ones_equality))
      hit += int(sum(ones_equality))
      self.incorrect_hits[1] += int(sum(~ones_equality))

    self.acc_vals_per_batch.append(hit / len(targets))
    self.batch_cnt += 1

    return indices
  
  @property
  def current_accuracy(self) -> float:
    return sum(self.acc_vals_per_batch) / self.batch_cnt

  def update_for_epoch(self):
    # update accuracy for epoch
    self.epoch_cnt += 1
    self.acc_vals_per_epoch.append(
        sum(self.acc_vals_per_batch[self.previous_batch_cnt : self.batch_cnt])
        / (self.batch_cnt - self.previous_batch_cnt)
    )
    if len(self.correct_hits_per_epoch) == 0:
      self.correct_hits_per_epoch.append(self.correct_hits.copy())
      self.incorrect_hits_per_epoch.append(self.incorrect_hits.copy())
    else:
      correct_hits_previous_epoch = self.correct_hits_per_epoch[-1].copy()
      self.correct_hits_per_epoch.append(self.correct_hits - correct_hits_previous_epoch)
      incorrect_hits_previous_epoch = self.incorrect_hits_per_epoch[-1].copy()
      self.incorrect_hits_per_epoch.append(self.incorrect_hits - incorrect_hits_previous_epoch)
    self.previous_batch_cnt = self.batch_cnt

  @property
  def confusion_matrix(self) -> np.ndarray:
    # get confusion matrix to better visualize incorrect hits vs correct hits
    if self.onehotencoding:
      return self.incorrect_hits.copy() + np.diag(self.correct_hits)
    else:  # assuming binary classification
      mat = np.diag(self.correct_hits)
      mat[0, 1] = self.cum_false_positive
      mat[1, 0] = self.cum_false_negative
      return mat

  def save_confusion_matrix(self, output_path: Union[str, Path], categories: str):
    # save confusion matrix
    df = pd.DataFrame(self.confusion_matrix, index=[i for i in categories], columns=[i for i in categories])
    plt.figure(figsize=(10, 7))
    sns.heatmap(df, annot=True)
    plt.savefig(output_path)
    plt.close()

  def roc_curve(self, output_path: Union[str, Path]):
    tp = []
    fp = []
    for epoch in range(self.epoch_cnt):
      tp.append(self.true_positive(epoch))
      fp.append(self.false_positive(epoch))
    tp /= max(tp)
    fp /= max(fp)
    plt.plot(fp, tp)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.0])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC curve")
    plt.savefig(output_path)
    plt.close()

  def cum_stats_to_csv(self, output_path: Union[str, Path]):
    df = pd.DataFrame(
        columns=[
            "true_positive",
            "false_positive",
            "true_negative",
            "false_negative",
            "accuracy",
            "recall",
            "precision",
            "f1_score",
        ]
    )
    for epoch in range(self.epoch_cnt):
        tp = self.true_positive(epoch)
        fp = self.false_positive(epoch)
        tn = self.true_negative(epoch)
        fn = self.false_negative(epoch)
        accuracy = self.acc_vals_per_epoch[epoch]
        recall = self.recall(epoch)
        precision = self.precision(epoch)
        f1_score = self.f1_score(epoch)
        df.loc[epoch] = [tp, fp, tn, fn, accuracy, recall, precision, f1_score]
    cum_tp = self.cum_true_positive
    cum_fp = self.cum_false_positive
    cum_tn = self.cum_true_negative
    cum_fn = self.cum_false_negative
    cum_accuracy = np.mean(self.acc_vals_per_epoch)
    cum_recall = self.cum_recall
    cum_precision = self.cum_precision
    cum_f1_score = self.cum_f1_score
    df.loc["cumulative"] = [cum_tp, cum_fp, cum_tn, cum_fn, cum_accuracy, cum_recall, cum_precision, cum_f1_score]
    df.to_csv(output_path)

  def true_positive(self, epoch: int) -> int:
    if len(self.correct_hits) != 2:
        raise RuntimeError(f"True Positive only defined for Binary Classification")
    return self.correct_hits_per_epoch[epoch][1]  # 1 = true label, 0 = false label

  @property
  def cum_true_positive(self) -> int:
    tp = 0
    for epoch in range(self.epoch_cnt):
        tp += self.true_positive(epoch)
    return tp

  def true_negative(self, epoch: int) -> int:
    if len(self.correct_hits) != 2:
        raise RuntimeError(f"True Negative only defined for Binary Classification")
    return self.correct_hits_per_epoch[epoch][0]

  @property
  def cum_true_negative(self) -> int:
    tn = 0
    for epoch in range(self.epoch_cnt):
        tn += self.true_negative(epoch)
    return tn

  def false_positive(self, epoch: int) -> int:
    if len(self.correct_hits) != 2:
        raise RuntimeError(f"False Positive only defined for Binary Classification")
    return self.incorrect_hits_per_epoch[epoch][0]

  @property
  def cum_false_positive(self) -> int:
    fp = 0
    for epoch in range(self.epoch_cnt):
        fp += self.false_positive(epoch)
    return fp

  def false_negative(self, epoch: int) -> int:
    if len(self.correct_hits) != 2:
        raise RuntimeError(f"False Negative only defined for Binary Classification")
    return self.incorrect_hits_per_epoch[epoch][1]

  @property
  def cum_false_negative(self) -> int:
    fn = 0
    for epoch in range(self.epoch_cnt):
        fn += self.false_negative(epoch)
    return fn

  def precision(self, epoch: int) -> float:
    return self.true_positive(epoch) / (self.true_positive(epoch) + self.false_positive(epoch))

  @property
  def cum_precision(self) -> float:
    return self.cum_true_positive / (self.cum_true_positive + self.cum_false_positive)

  def recall(self, epoch: int) -> float:
    return self.true_positive(epoch) / (self.true_positive(epoch) + self.false_negative(epoch))

  @property
  def cum_recall(self) -> float:
    return self.cum_true_positive / (self.cum_true_positive + self.cum_false_negative)

  def f1_score(self, epoch: int) -> float:
    return 2 * ((self.precision(epoch) * self.recall(epoch)) / (self.precision(epoch) + self.recall(epoch)))

  @property
  def cum_f1_score(self) -> float:
    return 2 * ((self.cum_precision * self.cum_recall) / (self.cum_precision + self.cum_recall))

In [13]:
from datetime import datetime

def training_plot_paths(output_path: Path, epoch: int) -> Tuple[Union[str, Path], ...]:
  output_path = Path(output_path)
  if not output_path.exists():
    output_path.mkdir(exist_ok=True, parents=True)
  time_now = datetime.now()
  model_pt_name = time_now.strftime("%Y-%m-%d--%H-%M-%S")
  Path(output_path / "loss").mkdir(exist_ok=True, parents=True)
  Path(output_path / "accuracy").mkdir(exist_ok=True, parents=True)
  Path(output_path / "confusion").mkdir(exist_ok=True, parents=True)
  Path(output_path / "models").mkdir(exist_ok=True, parents=True)
  Path(output_path / "stats").mkdir(exist_ok=True, parents=True)

  confusion_matrix_path = output_path / "confusion" / f"confusion-matrix-{model_pt_name}--{epoch}.png"
  accuracy_plot_path = output_path / "accuracy" / f"accuracy-{model_pt_name}--{epoch}.png"
  roc_curve_plot_path = output_path / "stats" / f"roc-curve--{model_pt_name}--{epoch}.png"
  cum_stats_csv_path = output_path / "stats" / f"cumulative-stats--{model_pt_name}--{epoch}.csv"
  loss_plot_path = output_path / "loss" / f"loss-{model_pt_name}--{epoch}.png"
  model_output_path = output_path / "models" / f"model-{model_pt_name}--{epoch}.pth"

  return (
    model_output_path,
    confusion_matrix_path,
    accuracy_plot_path,
    roc_curve_plot_path,
    cum_stats_csv_path,
    loss_plot_path,
  )

In [14]:
from sklearn.preprocessing import OneHotEncoder

def circular_index(idx: int, upper_bound: int) -> int:
  if idx < upper_bound:
    return idx
  return idx - (upper_bound * (idx // upper_bound))

class CelebrityDataClassificationNetwork(CelebrityData):
  def __init__(
      self,
      base_path: Path,
      transform = None,
      seed = None,
      gans_to_skip: Optional[List[str]] = None,
      unzip_real_imgs: bool = True,
  ):
    super().__init__(
        base_path=base_path,
        transform=transform,
        seed=seed,
        gans_to_skip=gans_to_skip,
        unzip_real_imgs=unzip_real_imgs,
    )
    self.to_tensor = transforms.ToTensor()
    self.enc = OneHotEncoder()
    self.enc.fit([[0], [1]])

  def __getitem__(self, index):
    img_path, label, gan_selection = self.choose_real_or_fake_image(index)

    img = Image.open(img_path).convert('RGB')
    if self.transform is not None:
      img = self.transform(img)
    else:
      img = self.to_tensor(img)

    return img, label, gan_selection

  def choose_real_or_fake_image(self, index) -> Tuple[str, np.ndarray, str]:
    if self.rng.standard_normal() > 0:
      img = self.real_images[circular_index(index // 2, len(self.real_images))]  # divide by 2 b/c we define __len__ as all real & fake images
      gan_selection = "real"
      label = [1]
    else:
      img, gan_selection = self.fake_image_rand_selection(index)
      label = [0]

    label = np.squeeze(self.enc.transform(np.column_stack(label).reshape(-1, 1)).toarray())
    return img, label.astype(np.float32), gan_selection

  def fake_image_rand_selection(self, index) -> str:
    rand_selection = self.rng.uniform(low=-0.499, high=self.n_fake_gans - 0.501)
    gan_selection = self.fake_image_gan_names[int(np.round(rand_selection))]

    return self.fake_images[gan_selection][index // (self.n_fake_gans * 2)], gan_selection  # mult den by 2 b/c of definition of __len__ being all real & fake images

  def names_of_gans(self) -> List[str]:
    return list(self.fake_images.keys())

  def __len__(self):
    return len(self.real_images) + self.len_of_fake_images

In [15]:
model_output_path = Path("/content/drive/MyDrive/ECE 792 - Advance Topics in Machine Learning/Code/DeepFakeImageDetection/CN/model")
if not model_output_path.exists():
  model_output_path.mkdir(exist_ok=True, parents=True)
test_output_path = model_output_path / "test"
if not test_output_path.exists():
  test_output_path.mkdir(exist_ok=True, parents=True)

config_cn = { 'batch_size'             : 88,
              'image_size'             : 64,
              'n_channel'              : 3,
              'n_epochs'               : 25,
              'n_test_epochs'          : 1,
              'lr'                     : 1e-3,
              'device'                 : device,
              'seed'                   : 999,
              'model_output_path'      : model_output_path,
              'test_output_path'       : test_output_path,
              'chkp_freq'              : 1,  # number of epochs to save model out
              'n_workers'              : 4,
              'gans_to_skip'           : None,
              'test_chkp_freq'         : 1,
}

In [16]:
celebrity_data_cn = CelebrityDataClassificationNetwork(
  base_path=dataset_base_path,
  transform=transforms.Compose(
    [
      transforms.Resize(int(config_cn["image_size"] * 1.1)),
      transforms.CenterCrop(config_cn["image_size"]),
      transforms.ToTensor(),
      transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ]
  ),
  seed=config_cn["seed"],
  gans_to_skip=config_cn["gans_to_skip"],
)

dataloader_cn = torch.utils.data.DataLoader(
  dataset=celebrity_data_cn,
  shuffle=True,
  batch_size=config_cn["batch_size"],
  num_workers=config_cn["n_workers"],
  drop_last=True,  # drop last batch that may not be the same size as the expected batch for the network
  pin_memory=True,
)

Extracting imagery for 'WGAN-CP'
Extracting imagery for 'PGGAN'
Extracting imagery for 'DCGAN'
Extracting imagery for 'LSGAN'
Extracting imagery for 'CDCGAN'
Extracting imagery for 'WGAN-GP'
Extracting RealFaces imagery


In [17]:
class ClassificationNetwork(nn.Module):
  conv_layer_out = None
  activation_out = None
  global_avg_pool_out = None
  flatten_out = None
  fully_connected_out = None
  softmax_out = None

  def __init__(self, learning_rate: float = 1e-3):
    super().__init__()
    self.conv_layer = nn.Conv2d(in_channels=128, out_channels=2, kernel_size=(3, 3))
    self.activation = nn.ReLU()
    self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
    self.flatten = nn.Flatten()
    self.fully_connected = nn.Linear(2, 2)
    self.softmax = nn.Softmax(dim=1)

    self.loss = nn.BCELoss()

    self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)

  def forward(self, x):
    self.conv_layer_out = self.conv_layer(x)
    self.activation_out = self.activation(self.conv_layer_out)
    self.global_avg_pool_out = self.global_avg_pool(self.activation_out)
    self.flatten_out = self.flatten(self.global_avg_pool_out)
    self.fully_connected_out = self.fully_connected(self.flatten_out)
    self.softmax_out = self.softmax(self.fully_connected_out)

    return self.softmax_out

In [18]:
from tqdm import tqdm

## Load pretrained CFFN model and change to eval mode, since the model is already trained
cffn_model_file = get_latest_model(cffn_models_output_path)
print(f"CFFN model file : '{cffn_model_file}'")
checkpoint = torch.load(str(cffn_model_file), map_location=config_cn["device"])
cffn.load_state_dict(checkpoint["CFFN_state_dict"])
cffn.to(device)
cffn.eval()

cn = ClassificationNetwork(
  learning_rate=config_cn["lr"]
).to(device)

train: bool = False
load_pretrained_cn: bool = True
if train:
  epoch_tqdm = tqdm(total=config_cn["n_epochs"], position=0)
  train_loss = {
      "total": Loss.init(),
  }
  train_acc = {
      "total": Accuracy.from_output_decisions(2, onehotencoding=False),
  }
  for gan_name in celebrity_data_cn.names_of_gans():
    train_loss[gan_name] = Loss.init()
    train_acc[gan_name] = Accuracy.from_output_decisions(2, onehotencoding=False)

  bce_loss = nn.BCELoss()
  print("Starting Training Loop for Classification Network...")
  for epoch in range(config_cn["n_epochs"]):
    for batch_idx, (img, label, gan_selection) in enumerate(dataloader_cn):
      # get indices of images that are from the different GANs or real
      vals = np.unique(gan_selection)
      gan_select_idx = {}
      for val in vals:
        idx = np.where(np.array(gan_selection) == val)
        gan_select_idx[val] = torch.Tensor(idx[0]).to(torch.int64)

      # zero classification network gradients
      cn.optimizer.zero_grad()
      # cast img & label to device we are working on
      img = img.to(config_cn["device"])
      label = label.to(config_cn["device"])
      # get discriminative features from last convolution output of CFFN
      img_cffn, img_discriminative_features = cffn(img)
      # classify output from CFFN using classification network
      out_cn = cn(img_cffn)
      # calculate total loss
      criterion = bce_loss(out_cn, label)
      # calculate gradients
      criterion.backward()
      # backpropagate gradients
      cn.optimizer.step()
      # update loss & accuracy for each 
      train_loss["total"] += criterion.item()
      label = torch.argmax(label, dim=1)  # need to unencode one hot encoded labels for accuracy measurements in binary classification task
      train_acc["total"].compare_batch(targets=label, outputs=out_cn)

      labels_real = label[gan_select_idx["real"]]
      out_cn_real = out_cn[gan_select_idx["real"]]
      gan_select_idx.pop("real")
      for gan_name, idx in gan_select_idx.items():
        labels_ = label[idx]
        labels_ = torch.concat([labels_, labels_real], dim=0)
        out_cn_ = out_cn[idx]
        out_cn_ = torch.concat([out_cn_, out_cn_real], dim=0)
        # criterion = bce_loss(out_cn_, labels_)
        # train_loss[gan_name] += criterion.item()
        train_acc[gan_name].compare_batch(targets=labels_, outputs=out_cn_)

      if (batch_idx + 1) % 50 == 0:
        epoch_write_str = " [%d/%d]\tAcc_Total: %.5f\tLoss_Total: %.8f"
        epoch_write_vars = [batch_idx+1, len(dataloader_cn), train_acc["total"].current_accuracy, train_loss["total"].current_loss]
        # for key in celebrity_data_cn.names_of_gans():
        #   epoch_write_str += f"\tLoss_{key}: %.8f"
        #   epoch_write_vars.append(train_loss[key].current_loss)
        
        epoch_write_vars = tuple(epoch_write_vars)
        # epoch_tqdm.write(epoch_write_str % epoch_write_vars)
        print(epoch_write_str % epoch_write_vars)
      
    epoch_tqdm.update(1)
    for key in list(train_loss.keys()):
      # train_loss[key].update_for_epoch()
      train_acc[key].update_for_epoch()
    if (epoch + 1) % config_cn["chkp_freq"] == 0:
      (
        model_output_path,
        confusion_matrix_path,
        accuracy_plot_path,
        roc_curve_plot_path,
        cum_stats_csv_path,
        loss_plot_path,
       ) = training_plot_paths(config_cn["model_output_path"], epoch+1)
      torch.save(
          {
            "CN_state_dict": cn.state_dict(),
            "CN_optimizer": cn.optimizer.state_dict(),
          },
          str(model_output_path),
      )
      for (key, train_loss_), train_acc_ in zip(train_loss.items(), train_acc.values()):
        loss_plot_path_ = loss_plot_path.parent / key / loss_plot_path.name
        loss_plot_path_.parent.mkdir(exist_ok=True, parents=True)
        accuracy_plot_path_ = accuracy_plot_path.parent / key / accuracy_plot_path.name
        accuracy_plot_path_.parent.mkdir(exist_ok=True, parents=True)
        confusion_matrix_path_ = confusion_matrix_path.parent / key / confusion_matrix_path.name
        confusion_matrix_path_.parent.mkdir(exist_ok=True, parents=True)
        roc_curve_plot_path_ = roc_curve_plot_path.parent / key / roc_curve_plot_path.name
        roc_curve_plot_path_.parent.mkdir(exist_ok=True, parents=True)
        cum_stats_csv_path_ = cum_stats_csv_path.parent / key / cum_stats_csv_path.name
        cum_stats_csv_path_.parent.mkdir(exist_ok=True, parents=True)

        if key in ["total"]:
          save_loss_plot(train_loss_.loss_vals_per_epoch, loss_plot_path_)
        save_accuracy_plot(train_acc_.acc_vals_per_epoch, accuracy_plot_path_)
        train_acc_.save_confusion_matrix(confusion_matrix_path_, "01")
        train_acc_.roc_curve(roc_curve_plot_path_)
        train_acc_.cum_stats_to_csv(cum_stats_csv_path_)
elif load_pretrained_cn:
  cn_model_base_path = config_cn["model_output_path"] / "models"
  cn_model_path = get_latest_model(cn_model_base_path)
  print(f"cn_model_path: '{cn_model_path}'")
  
  checkpoint = torch.load(str(cn_model_path), map_location=config_cn["device"])
  cn.load_state_dict(checkpoint["CN_state_dict"])

  epoch_tqdm = tqdm(total=config_cn["n_test_epochs"], position=0, initial=1)
  test_loss = {
      "total": Loss.init(),
  }
  test_acc = {
      "total": Accuracy.from_output_decisions(2, onehotencoding=False),
  }
  for gan_name in celebrity_data_cn.names_of_gans():
    test_loss[gan_name] = Loss.init()
    test_acc[gan_name] = Accuracy.from_output_decisions(2, onehotencoding=False)
  
  cn.eval()
  bce_loss = nn.BCELoss()
  print("Testing Classification Network + CFFN...")
  for epoch in range(config_cn["n_test_epochs"]):
    for batch_idx, (img, label, gan_selection) in enumerate(dataloader_cn):
      vals = np.unique(gan_selection)
      gan_select_idx = {}
      for val in vals:
        idx = np.where(np.array(gan_selection) == val)
        gan_select_idx[val] = torch.Tensor(idx[0]).to(torch.int64)
      
      img = img.to(config_cn["device"])
      label = label.to(config_cn["device"])

      img_cffn, img_discriminative_features = cffn(img)
      out_cn = cn(img_cffn)
      criterion = bce_loss(out_cn, label)
      test_loss["total"] += criterion.item()
      label = torch.argmax(label, dim=1)
      test_acc["total"].compare_batch(targets=label, outputs=out_cn)

      labels_real = label[gan_select_idx["real"]]
      out_cn_real = out_cn[gan_select_idx["real"]]
      gan_select_idx.pop("real")
      for gan_name, idx in gan_select_idx.items():
        labels_ = label[idx]
        labels_ = torch.concat([labels_, labels_real], dim=0)
        out_cn_ = out_cn[idx]
        out_cn_ = torch.concat([out_cn_, out_cn_real], dim=0)
        # criterion = bce_loss(out_cn_, labels_)
        # test_loss[gan_name] += criterion.item()
        test_acc[gan_name].compare_batch(targets=labels_, outputs=out_cn_)
      
      if (batch_idx + 1) % 50 == 0:
        epoch_write_str = " [%d/%d]\tAcc_Total: %.5f\tLoss_Total: %.8f"
        epoch_write_vars = [batch_idx + 1, len(dataloader_cn), test_acc["total"].current_accuracy, test_loss["total"].current_loss]
        for key in celebrity_data_cn.names_of_gans():
          epoch_write_str += f"\tAcc_{key}: %.5f"
          epoch_write_vars.append(test_acc[key].current_accuracy)

        epoch_write_vars = tuple(epoch_write_vars)
        print(epoch_write_str % epoch_write_vars)
    
    epoch_tqdm.update(1)
    for key in list(test_acc.keys()):
      # test_loss[key].update_for_epoch()
      test_acc[key].update_for_epoch()
    if (epoch + 1) % config_cn["test_chkp_freq"] == 0:
      (
        _,
        confusion_matrix_path,
        accuracy_plot_path,
        roc_curve_plot_path,
        cum_stats_csv_path,
        loss_plot_path,
      ) = training_plot_paths(config_cn["test_output_path"], epoch+1)

      for (key, test_loss_), test_acc_ in zip(test_loss.items(), test_acc.values()):
        loss_plot_path_ = loss_plot_path.parent / key / loss_plot_path.name
        loss_plot_path_.parent.mkdir(exist_ok=True, parents=True)
        accuracy_plot_path_ = accuracy_plot_path.parent / key / accuracy_plot_path.name
        accuracy_plot_path_.parent.mkdir(exist_ok=True, parents=True)
        confusion_matrix_path_ = confusion_matrix_path.parent / key / confusion_matrix_path.name
        confusion_matrix_path_.parent.mkdir(exist_ok=True, parents=True)
        roc_curve_plot_path_ = roc_curve_plot_path.parent / key / roc_curve_plot_path.name
        roc_curve_plot_path_.parent.mkdir(exist_ok=True, parents=True)
        cum_stats_csv_path_ = cum_stats_csv_path.parent / key / cum_stats_csv_path.name
        cum_stats_csv_path_.parent.mkdir(exist_ok=True, parents=True)

        if key in ["real"]:
          save_loss_plot(test_loss_.loss_vals_per_epoch, loss_plot_path_)
        save_accuracy_plot(test_acc_.acc_vals_per_epoch, accuracy_plot_path_)
        test_acc_.save_confusion_matrix(confusion_matrix_path_, "01")
        test_acc_.roc_curve(roc_curve_plot_path_)
        test_acc_.cum_stats_to_csv(cum_stats_csv_path_)

CFFN model file : '/content/drive/MyDrive/ECE 792 - Advance Topics in Machine Learning/Code/DeepFakeImageDetection/CFFN/model/models/CFFN--14.pth'
cn_model_path: '/content/drive/MyDrive/ECE 792 - Advance Topics in Machine Learning/Code/DeepFakeImageDetection/CN/model/models/model-2023-03-18--20-06-24--25.pth'


100%|██████████| 1/1 [00:00<?, ?it/s]

Testing Classification Network + CFFN...
 [50/5029]	Acc_Total: 0.91818	Loss_Total: 0.65482322	Acc_WGAN-CP: 0.99406	Acc_PGGAN: 0.99524	Acc_DCGAN: 0.99399	Acc_LSGAN: 0.99560	Acc_CDCGAN: 0.86596	Acc_WGAN-GP: 0.99037
 [100/5029]	Acc_Total: 0.91295	Loss_Total: 0.73146579	Acc_WGAN-CP: 0.99370	Acc_PGGAN: 0.99446	Acc_DCGAN: 0.99379	Acc_LSGAN: 0.99457	Acc_CDCGAN: 0.85798	Acc_WGAN-GP: 0.99025
 [150/5029]	Acc_Total: 0.91152	Loss_Total: 0.72858827	Acc_WGAN-CP: 0.99435	Acc_PGGAN: 0.99466	Acc_DCGAN: 0.99376	Acc_LSGAN: 0.99452	Acc_CDCGAN: 0.85572	Acc_WGAN-GP: 0.99067
 [200/5029]	Acc_Total: 0.91159	Loss_Total: 0.74166871	Acc_WGAN-CP: 0.99411	Acc_PGGAN: 0.99429	Acc_DCGAN: 0.99322	Acc_LSGAN: 0.99391	Acc_CDCGAN: 0.85627	Acc_WGAN-GP: 0.98987
 [250/5029]	Acc_Total: 0.91309	Loss_Total: 0.72368475	Acc_WGAN-CP: 0.99420	Acc_PGGAN: 0.99434	Acc_DCGAN: 0.99363	Acc_LSGAN: 0.99393	Acc_CDCGAN: 0.86037	Acc_WGAN-GP: 0.99002
 [300/5029]	Acc_Total: 0.90989	Loss_Total: 0.73900251	Acc_WGAN-CP: 0.99400	Acc_PGGAN: 0.99403	A

2it [12:30, 750.67s/it]              