## Installs

In [None]:
!pip install numpy
!pip install opencv-python
!pip install scipy
!pip install wandb --quiet
!pip install torch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1
!pip install torchsummary -q

## Kaggle Setup

In [None]:
!pip install --upgrade --force-reinstall --no-deps kaggle==1.5.8 -q
!mkdir /root/.kaggle

with open("/root/.kaggle/kaggle.json", "w+") as f:
    f.write('{"username":"KAGGLE_USERNAME","key":"KAGGLE_KEY"}') # TODO: Put your kaggle username & key here

!chmod 600 /root/.kaggle/kaggle.json

In [None]:
# Download the dataset with dalle and midjourney art as well as human art
!kaggle datasets download -d "superpotato9/dalle-recognition-dataset"

In [None]:
# unzip data
!unzip -q dalle-recognition-dataset.zip

## Imports

In [None]:
import numpy as np
import wandb
from tqdm import tqdm
import cv2
from scipy.ndimage import rotate

import PIL.Image
PIL.Image.MAX_IMAGE_PIXELS = None
import random
import concurrent.futures

import torch
import torch.nn as nn
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader
import torchvision
import os

import torchvision.transforms as transforms
from torchvision.io import read_image
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

## Configs

In [None]:
config = { # TODO: Set your configuration values
    "lr"         : 2e-3,
    "epochs"     : 30,
    "batch_size" : 32,
    "crop_size" : 32, #size of image patches
    "num_crops" : 192, #number of crops made from the image
    "crops_used" : 64, #number of crops used to make the patchwork image to run through the model
    "use_subtract" : True, #decides if we will subtract different texture images before or concat them to put through the model
    "keep_medium_texture" : False,
    "model" : "resnet", # can be 'resnet', 'inception', or 'efficientnet'
    "filters" : ["a", "b", "c", "d", "e", "f", "g"], #can have any combination of a,b,c,d,e,f,g,h. h is the garbor filters
}
config["nrow"] = int(config["crops_used"]**0.5)

config["run_name"] = "Abl-size-{}-num-{}-subs-{}-med-{}-mod-{}-filters-{}".format(
    config["crop_size"],
    config["num_crops"],
    config["use_subtract"],
    config["keep_medium_texture"],
    config["model"],
    config["filters"]
    )

## FILTERS

In [None]:
def apply_filter_a(src):
    src_copy = np.copy(src)
    f1 = np.array([[[ 0,  0,  0,  0,  0],
        [ 0,  1,  0,  0,  0],
        [ 0,  0, -1,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0]],

       [[ 0,  0,  0,  0,  0],
        [ 0,  0,  1,  0,  0],
        [ 0,  0, -1,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0]],

       [[ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  1,  0],
        [ 0,  0, -1,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0]],

       [[ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0, -1,  1,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0]],

       [[ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0, -1,  0,  0],
        [ 0,  0,  0,  1,  0],
        [ 0,  0,  0,  0,  0]],

       [[ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0, -1,  0,  0],
        [ 0,  0,  1,  0,  0],
        [ 0,  0,  0,  0,  0]],

       [[ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0, -1,  0,  0],
        [ 0,  1,  0,  0,  0],
        [ 0,  0,  0,  0,  0]],

       [[ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  1, -1,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0]]])

    img = cv2.filter2D(src=src_copy, kernel=f1[0], ddepth=-1)
    for filter in f1[1:]:
        img = cv2.add(img,cv2.filter2D(src=src_copy, kernel=filter, ddepth=-1))

    return img//8

def apply_filter_b(src):
    src_copy = np.copy(src)
    f2 = np.array([[[ 0,  0,  0,  0,  0],
                    [ 0,  2,  1,  0,  0],
                    [ 0,  1, -3,  0,  0],
                    [ 0,  0,  0,  1,  0],
                    [ 0,  0,  0,  0,  0]],

                    [[ 0,  0, -1,  0,  0],
                    [ 0,  0,  3,  0,  0],
                    [ 0,  0, -3,  0,  0],
                    [ 0,  0,  1,  0,  0],
                    [ 0,  0,  0,  0,  0]],

                    [[ 0,  0,  0,  0,  0],
                    [ 0,  0,  1,  2,  0],
                    [ 0,  0, -3,  1,  0],
                    [ 0,  1,  0,  0,  0],
                    [ 0,  0,  0,  0,  0]],

                    [[ 0,  0,  0,  0,  0],
                    [ 0,  0,  0,  0,  0],
                    [ 0,  1, -3,  3, -1],
                    [ 0,  0,  0,  0,  0],
                    [ 0,  0,  0,  0,  0]],

                    [[ 0,  0,  0,  0,  0],
                    [ 0,  1,  0,  0,  0],
                    [ 0,  0, -3,  1,  0],
                    [ 0,  0,  1,  2,  0],
                    [ 0,  0,  0,  0,  0]],

                    [[ 0,  0,  0,  0,  0],
                    [ 0,  0,  1,  0,  0],
                    [ 0,  0, -3,  0,  0],
                    [ 0,  0,  3,  0,  0],
                    [ 0,  0, -1,  0,  0]],

                    [[ 0,  0,  0,  0,  0],
                    [ 0,  0,  0,  1,  0],
                    [ 0,  1, -3,  0,  0],
                    [ 0,  2,  1,  0,  0],
                    [ 0,  0,  0,  0,  0]],

                    [[ 0,  0,  0,  0,  0],
                    [ 0,  0,  0,  0,  0],
                    [-1,  3, -3,  1,  0],
                    [ 0,  0,  0,  0,  0],
                    [ 0,  0,  0,  0,  0]]])

    img = cv2.filter2D(src=src_copy, kernel=f2[0], ddepth=-1)
    for filter in f2[1:]:
        img = cv2.add(img,cv2.filter2D(src=src_copy, kernel=filter, ddepth=-1))

    return img//8



def apply_filter_c(src):
    src_copy=np.copy(src)
    f3 = np.array([[[ 0,  0,  0,  0,  0],
                    [ 0,  0,  1,  0,  0],
                    [ 0,  0, -2,  0,  0],
                    [ 0,  0,  1,  0,  0],
                    [ 0,  0,  0,  0,  0]],

                    [[ 0,  0,  0,  0,  0],
                    [ 0,  0,  0,  0,  0],
                    [ 0,  1, -2,  1,  0],
                    [ 0,  0,  0,  0,  0],
                    [ 0,  0,  0,  0,  0]],

                    [[ 0,  0,  0,  0,  0],
                    [ 0,  1,  0,  0,  0],
                    [ 0,  0, -2,  0,  0],
                    [ 0,  0,  0,  1,  0],
                    [ 0,  0,  0,  0,  0]],

                    [[ 0,  0,  0,  0,  0],
                    [ 0,  0,  0,  1,  0],
                    [ 0,  0, -2,  0,  0],
                    [ 0,  1,  0,  0,  0],
                    [ 0,  0,  0,  0,  0]]])

    img = cv2.filter2D(src=src_copy, kernel=f3[0], ddepth=-1)
    for filter in f3[1:]:
        img = cv2.add(img,cv2.filter2D(src=src_copy, kernel=filter, ddepth=-1))

    return img//4


def apply_filter_d(src):
    src_copy=np.copy(src)
    f4 = np.array([[[ 0,  0,  0,  0,  0],
                    [ 0, -1,  2, -1,  0],
                    [ 0,  2, -4,  2,  0],
                    [ 0,  0,  0,  0,  0],
                    [ 0,  0,  0,  0,  0]],

                    [[ 0,  0,  0,  0,  0],
                    [ 0, -1,  2,  0,  0],
                    [ 0,  2, -4,  0,  0],
                    [ 0, -1,  2,  0,  0],
                    [ 0,  0,  0,  0,  0]],

                    [[ 0,  0,  0,  0,  0],
                    [ 0,  0,  0,  0,  0],
                    [ 0,  2, -4,  2,  0],
                    [ 0, -1,  2, -1,  0],
                    [ 0,  0,  0,  0,  0]],

                    [[ 0,  0,  0,  0,  0],
                    [ 0,  0,  2, -1,  0],
                    [ 0,  0, -4,  2,  0],
                    [ 0,  0,  2, -1,  0],
                    [ 0,  0,  0,  0,  0]]])

    img = cv2.filter2D(src=src_copy, kernel=f4[0], ddepth=-1)
    for filter in f4[1:]:
        img = cv2.add(img,cv2.filter2D(src=src_copy, kernel=filter, ddepth=-1))

    return img//4

def apply_filter_e(src):
    src_copy=np.copy(src)
    f5 = np.array([[[  1,   2,  -2,   2,   1],
                    [  2,  -6,   8,  -6,   2],
                    [ -2,   8, -12,   8,  -2],
                    [  0,   0,   0,   0,   0],
                    [  0,   0,   0,   0,   0]],

                [[  1,   2,  -2,   0,   0],
                    [  2,  -6,   8,   0,   0],
                    [ -2,   8, -12,   0,   0],
                    [  2,  -6,   8,   0,   0],
                    [  1,   2,  -2,   0,   0]],

                [[  0,   0,   0,   0,   0],
                    [  0,   0,   0,   0,   0],
                    [ -2,   8, -12,   8,  -2],
                    [  2,  -6,   8,  -6,   2],
                    [  1,   2,  -2,   2,   1]],

                [[  0,   0,  -2,   2,   1],
                    [  0,   0,   8,  -6,   2],
                    [  0,   0, -12,   8,  -2],
                    [  0,   0,   8,  -6,   2],
                    [  0,   0,  -2,   2,   1]]])

    img = cv2.filter2D(src=src_copy, kernel=f5[0], ddepth=-1)
    for filter in f5[1:]:
        img=cv2.add(img,cv2.filter2D(src=src_copy, kernel=filter, ddepth=-1))

    return img//4

def apply_filter_f(src):
    src_copy=np.copy(src)
    f5 = np.asarray([[ 0,  0,  0,  0,  0],
                    [ 0,  -1,  2, -1,  0],
                    [ 0,  2,  -4,  2,  0],
                    [ 0,  -1,  2, -1,  0],
                    [ 0,  0,  0,  0,  0]])

    img = cv2.filter2D(src=src_copy, kernel=f5, ddepth=-1)
    return img


def apply_filter_g(src):
    src_copy=np.copy(src)
    f5 = np.asarray([[ -1,   2,  -2,   2,  -1],
                    [  2,  -6,   8,  -6,   2],
                    [ -2,   8, -12,   8,  -2],
                    [  2,  -6,   8,  -6,   2],
                    [ -1,   2,  -2,   2,  -1]])

    img = cv2.filter2D(src=src_copy, kernel=f5, ddepth=-1)
    return img

def build_gabor_filters(ksize=31, sigma=4.0, lambd=10.0, gamma=0.5, psi=0, num_theta=4):
    filters = []
    for theta in np.linspace(0, np.pi, num_theta):  # Varying orientation from 0 to 180 degrees
        kern = cv2.getGaborKernel((ksize, ksize), sigma, theta, lambd, gamma, psi, ktype=cv2.CV_32F)
        filters.append(kern)
    return filters

def apply_gabor_filters(src, filters):
    accum = np.zeros_like(src)
    for kern in filters:
        fimg = cv2.filter2D(src, cv2.CV_8UC3, kern)
        np.maximum(accum, fimg, accum)  # Accumulate the maximum response
    return accum

def apply_all_filters(src):
    gab_fil = build_gabor_filters(ksize=20, sigma=4, lambd=10, gamma=0.5, psi=2, num_theta=2)

    src_copy = np.copy(src)
    filDict = {}
    filDict["a"] = apply_filter_a(src_copy)
    filDict["b"] = apply_filter_b(src_copy)
    filDict["c"] = apply_filter_c(src_copy)
    filDict["d"] = apply_filter_d(src_copy)
    filDict["e"] = apply_filter_e(src_copy)
    filDict["f"] = apply_filter_f(src_copy)
    filDict["g"] = apply_filter_g(src_copy)
    filDict["h"] = apply_gabor_filters(src_copy, gab_fil)
    if len(config["filters"]) == 0:
      filters = src_copy
      return (np.array(cv2.cvtColor(filters, cv2.COLOR_RGB2GRAY))).astype(np.float64).reshape((1, filters.shape[0], filters.shape[1]))
    f = filDict["a"]+filDict["b"]
    filters = filDict[config["filters"][0]]
    for i in range(1,len(config["filters"])):
      filters += filDict[config["filters"][i]]

    num_fil = len(config["filters"])
    filters = filters.reshape((filters.shape[1], filters.shape[2], -1))
    return (np.array(cv2.cvtColor(filters, cv2.COLOR_RGB2GRAY))//num_fil).astype(np.float64).reshape((1, filters.shape[0], filters.shape[1]))

## Patch Gen

In [None]:
def calculate_l1(patch):
    # Assuming patch size of (3, 8, 8)
    patch = patch.sum(dim=0)  # Sum over channels

    # Calculate differences between adjacent pixels
    diff_hor = torch.abs(patch[:, 1:] - patch[:, :-1]).sum()
    diff_ver = torch.abs(patch[1:, :] - patch[:-1, :]).sum()

    # Diagonal differences
    diff_diag1 = torch.abs(patch[1:, 1:] - patch[:-1, :-1]).sum()
    diff_diag2 = torch.abs(patch[:-1, 1:] - patch[1:, :-1]).sum()

    # Total L1 norm
    total_l1 = diff_hor + diff_ver + diff_diag1 + diff_diag2
    return total_l1

In [None]:
def return_rich_poor_images(image, crop_size, num_crops, crops_used, include_mediums):
# Define the transform
    transform = transforms.RandomCrop(crop_size)  # Define the crop size

    crops = [transform(image) for _ in range(num_crops)]
    diversities = [calculate_l1(crop) for crop in crops]
    paired_crops_diversities = list(zip(crops, diversities))
    paired_crops_diversities.sort(key=lambda x: x[1])

    sorted_crops, sorted_diversities = zip(*paired_crops_diversities)

    sorted_crops = list(sorted_crops)
    sorted_diversities = list(sorted_diversities)

    poor_pixels = sorted_crops[:crops_used]
    rich_pixels = sorted_crops[-crops_used:]

    poor_tensor = torch.stack(poor_pixels)
    poor_grid = make_grid(poor_tensor, nrow=config["nrow"], padding=0)
    poor_grid = apply_all_filters(poor_grid)

    rich_tensor = torch.stack(rich_pixels)
    rich_grid = make_grid(rich_tensor, nrow=config["nrow"], padding=0)
    rich_grid = apply_all_filters(rich_grid)

    if include_mediums:
      halfway = num_crops // 2
      medium_pixels = sorted_crops[halfway-(crops_used//2):halfway+(crops_used//2)]
      medium_tensor = torch.stack(poor_pixels)
      medium_grid = make_grid(poor_tensor, nrow=config["nrow"], padding=0)
      medium_grid = apply_all_filters(medium_grid)
      return poor_grid, rich_grid, medium_grid

    return poor_grid, rich_grid

## Data and Dataloaders

In [None]:
class MakeDataset(torch.utils.data.Dataset):
  def __init__(self, img_paths, labels):
    self.labels = []
    self.imgs = []

    for i in range(len(img_paths)):
      image = read_image(img_paths[i])
      if image.shape[0] == 3:
        self.imgs.append(image)
        self.labels.append(labels[i])

    self.length = len(self.labels)

  def __len__(self):
    return self.length
  def __getitem__(self, ind):
    image = self.imgs[ind]
    texture_regions = return_rich_poor_images(image, config['crop_size'], config['num_crops'], config['crops_used'], config['keep_medium_texture'])
    return texture_regions, torch.tensor(self.labels[ind])

In [None]:
train_dir_ai = '/content/fakeV2/fake-v2'
train_dir_real = '/content/real'

ai_imgs = []
for img in os.listdir(train_dir_ai):
  if img[-3:] == "jpg":
    ai_imgs.append(os.path.join(train_dir_ai,img))

real_imgs = []
for img in os.listdir(train_dir_real):
  if img[-3:] == "jpg":
    real_imgs.append(os.path.join(train_dir_real,img))

ai_label = [1 for i in range(len(ai_imgs))]
real_label = [0 for i in range(len(real_imgs))]

shorterLen = len(ai_imgs)
if len(real_imgs) < shorterLen:
  shorterLen = len(real_imgs)

# shorterLen = 100
train_split = list(np.arange(shorterLen))
random.shuffle(train_split)
X_train = [ai_imgs[i] for i in train_split[:int(shorterLen*0.7)]] + [real_imgs[i] for i in train_split[:int(shorterLen*0.7)]]
y_train = [ai_label[i] for i in train_split[:int(shorterLen*0.7)]] + [real_label[i] for i in train_split[:int(shorterLen*0.7)]]
X_validate = [ai_imgs[i] for i in train_split[int(shorterLen*0.7):]] + [real_imgs[i] for i in train_split[int(shorterLen*0.7):]]
y_validate = [ai_label[i] for i in train_split[int(shorterLen*0.7):]] + [real_label[i] for i in train_split[int(shorterLen*0.7):]]
print(len(X_train))

In [None]:
train_data = MakeDataset(X_train, y_train)
val_data = MakeDataset(X_validate, y_validate)

In [None]:
train_loader = torch.utils.data.DataLoader(
    dataset     = train_data,
    num_workers = 4,
    batch_size  = config['batch_size'],
    pin_memory  = True,
    shuffle     = True
)

val_loader = torch.utils.data.DataLoader(
    dataset     = val_data,
    num_workers = 2,
    batch_size  = config['batch_size'],
    pin_memory  = True,
    shuffle     = False
)

In [None]:
# sanity check
for data in train_loader:
    x, y = data
    print(len(x), y.shape)
    print(x[0].shape)
    break

## MODEL

### Fingerprinting Stem

In [None]:
class SplitConv(nn.Module):
    def __init__(self, use_subtract, keep_medium_texture, in_channels=3):
        super(SplitConv, self).__init__()

        self.use_subtract = use_subtract
        self.keep_medium_texture = keep_medium_texture
        out = 32
        inChan = in_channels
        self.block = torch.nn.Sequential(
            torch.nn.Conv2d(inChan, out, 3, 1, 0),
            torch.nn.BatchNorm2d(out),
            torch.nn.Hardtanh()
        )

    def forward(self, x):
        rich = self.block(x[0].to(device=DEVICE, dtype=torch.float))
        poor = self.block(x[1].to(device=DEVICE, dtype=torch.float))
        rich_poor = torch.subtract(rich,poor) if self.use_subtract else torch.concat((rich, poor), dim=1)
        if self.keep_medium_texture:
          medium = self.block(x[2].to(device=DEVICE, dtype=torch.float))
          return torch.concat((rich_poor, medium), dim=1)
        else:
          return rich_poor

class Stem(nn.Module):
    def __init__(self, use_subtract=config['use_subtract'], keep_medium_texture=config['keep_medium_texture'], in_channels=3, stem_output_channels=32):
        super(Stem, self).__init__()
        self.backbone = torch.nn.Sequential(
                          SplitConv(use_subtract, keep_medium_texture, in_channels=in_channels))
        self.temp = torch.nn.Sequential(
            torch.nn.Conv2d(stem_output_channels, stem_output_channels, 3, 1, 0),
        )
        self.block = torch.nn.Sequential(

                          torch.nn.BatchNorm2d(stem_output_channels),
                          torch.nn.ReLU(),
                          torch.nn.Conv2d(stem_output_channels, stem_output_channels, 3, 1, 0),
                          torch.nn.BatchNorm2d(stem_output_channels),
                          torch.nn.ReLU(),
                          torch.nn.Conv2d(stem_output_channels, stem_output_channels, 3, 1, 0),
                          torch.nn.BatchNorm2d(stem_output_channels),
                          torch.nn.ReLU(),
                          torch.nn.Conv2d(stem_output_channels, stem_output_channels, 3, 1, 0),
                          torch.nn.BatchNorm2d(stem_output_channels),
                          torch.nn.ReLU())

    def forward(self, x):
      x = self.backbone(x)
      x = self.temp(x)
      return self.block(x)

### Inceptionv1

In [None]:
class ConvBlock(torch.nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, padding, stride=1):
        super().__init__()

        self.layers = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU()
            )

    def forward(self, x, return_feats=False):

        return self.layers(x)

In [None]:
class InceptionBlock(torch.nn.Module):

  def __init__(self, in_channels, block_1x1, block_3x3_in, block_3x3_out, block_5x5_in, block_5x5_out, block_pool_out):
    super().__init__()

    self.block_1x1 = ConvBlock(in_channels, block_1x1, kernel_size=1, padding=0, stride=1)

    self.block_3x3 = torch.nn.Sequential(
        ConvBlock(in_channels, block_3x3_in, kernel_size=1, padding=0, stride=1),
        ConvBlock(block_3x3_in, block_3x3_out, kernel_size=3, padding=1, stride=1),
        )

    self.block_5x5 = torch.nn.Sequential(
        ConvBlock(in_channels, block_5x5_in, kernel_size=1, padding=0, stride=1),
        ConvBlock(block_5x5_in, block_5x5_out, kernel_size=5, padding=2, stride=1),
        )

    self.block_pool = torch.nn.Sequential(
        torch.nn.MaxPool2d(kernel_size=3, padding=1, stride=1),
        ConvBlock(in_channels, block_pool_out, kernel_size=1, padding=0, stride=1),
    )
  def forward(self, x):
    block_1x1_out = self.block_1x1(x)
    block_3x3_out = self.block_3x3(x)
    block_5x5_out = self.block_5x5(x)
    block_pool_out = self.block_pool(x)
    return torch.cat([block_1x1_out, block_3x3_out, block_5x5_out, block_pool_out], 1)

In [None]:
class InceptionModelv1(torch.nn.Module):

  def __init__(self, in_channels=32, num_classes=2):
    super().__init__()

    self.backbone = torch.nn.Sequential(
        # section 1
        ConvBlock(in_channels, 64, kernel_size=7, padding=3, stride=2),
        torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        # section 2
        ConvBlock(64, 192, kernel_size=3, padding=1, stride=1),
        torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        # section 3
        InceptionBlock(192, 64, 96, 128, 16, 32, 32),
        InceptionBlock(256, 128, 128, 192, 32, 96, 64),
        torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        # Section 4
        InceptionBlock(480, 192, 96, 208, 16, 48, 64),
        InceptionBlock(512, 160, 112, 224, 24, 64, 64),
        InceptionBlock(512, 128, 128, 256, 24, 64, 64),
        InceptionBlock(512, 112, 144, 288, 32, 64, 64),
        InceptionBlock(528, 256, 160, 320, 32, 128, 128),
        torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        # Section 5
        InceptionBlock(832, 256, 160, 320, 32, 128, 128),
        InceptionBlock(832, 384, 192, 384, 48, 128, 128),
        torch.nn.AvgPool2d(kernel_size=8, stride=1),
        torch.nn.Dropout(p=0.4),
        torch.nn.Flatten(),
        )

    self.cls_layer = torch.nn.Linear(1024, num_classes)

  def forward(self, x, return_feats=False):
    feats = self.backbone(x)
    out = self.cls_layer(feats)

    if return_feats:
        return feats
    else:
        return out

### EfficientNet

Inspired by

https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/CNN_architectures/pytorch_efficientnet.py

#### Conv Block
packages conv with batchnorm and relu

In [None]:
class EnetConvBlock(torch.nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, padding, stride=1, depthwise=False):
        super().__init__()

        groups = in_channels if depthwise else 1

        self.layers = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.SiLU()
            )

    def forward(self, x):

        return self.layers(x)

#### Squeeze Excitation
Determine how important each layer is

In [None]:
class SqueezeExcitation(torch.nn.Module):

    def __init__(self, upsized, squeezed):
        super().__init__()

        self.layers = torch.nn.Sequential(
            torch.nn.AdaptiveAvgPool2d((1, 1)), # pool across layer
            torch.nn.Conv2d(upsized, squeezed, 1), # downsize number of layers
            torch.nn.SiLU(), # activation
            torch.nn.Conv2d(squeezed, upsized, 1), # upsize to input number
            torch.nn.Sigmoid() # get importance of each layer as (0,1)
            )

    def forward(self, x):
        return x * self.layers(x)


#### MBB Block

In [None]:
class MBBlock(torch.nn.Module):

  def __init__(self, in_channels, out_channels, kernel_size, stride, expand_ratio, reduction):
    super().__init__()
    self.keep_prob = 0.8 # odds of not skipping the block
    self.use_residual = (in_channels == out_channels) and (stride == 1)
    self.stride = stride
    upsized = int(in_channels * expand_ratio) # increased for inverted bottlenck
    self.expand_ratio = expand_ratio
    squeezed = int(in_channels / reduction) # squeezed size in squeezeExcitation

    self.expand = EnetConvBlock(in_channels, upsized, 1, 0, 1) # upsize the input

    self.bottleneck = torch.nn.Sequential(
        EnetConvBlock(upsized, upsized, kernel_size, kernel_size//2, stride, True), # do the depthwise conv
        SqueezeExcitation(upsized, squeezed),
        torch.nn.Conv2d(upsized, out_channels, 1, bias=False),
        torch.nn.BatchNorm2d(out_channels)
    )

  def randomlyDrop(self, x):
    mask = torch.rand(x.shape[0], 1, 1, 1, device=x.device) < self.keep_prob
    return (x/self.keep_prob) * mask

  def forward(self, x):
    # expand if necessary
    if self.expand_ratio != 1:
      expanded = self.expand(x)
    else:
      expanded = x

    # do the bottleneck forward pass
    output = self.bottleneck(expanded)
    if not self.use_residual:
      return output

    # stochastic depth
    if self.training:
      output = self.randomlyDrop(output)
    # return with residual
    return x + output

#### EfficientNet Network Itself

In [None]:
from math import ceil
class EfficientNet(torch.nn.Module):

  def __init__(self, in_channels=32, num_classes=2):
    super().__init__()
    # get the base model values from the paper
    base_model = [
      # expand_ratio, channels, num_blocks, stride, kernel_size
      (1, 16, 1, 1, 3),
      (6, 24, 2, 2, 3),
      (6, 40, 2, 2, 5),
      (6, 80, 3, 2, 3),
      (6, 112, 3, 1, 5),
      (6, 192, 4, 2, 5),
      (6, 320, 1, 1, 3),
    ]

    # calculate hyperpamaters for network
    phi = 1.08 # change to scale network
    depth_factor = 1.2 ** phi
    width_factor = 1.1 ** phi
    dropout = 0.5
    reduction = 4 # how much to squeeze in squeezeExcite

    embedding_size = ceil(1280 * width_factor)

    # add the initial conv layer layers
    layers = []
    inital_size = in_channels
    in_channels = int(32 * width_factor)
    layers.append(EnetConvBlock(inital_size, in_channels, 3, 1, 2)) # initial image processing

    # add all the bottleneck layers
    for (expand_ratio, channels, num_blocks, stride, kernel_size) in base_model:
      out_channels = reduction * ceil(int(channels * width_factor) / reduction) # make number div by 4 for squeezeExcitation
      for block_num in range(ceil(num_blocks * depth_factor)):
        if block_num == 0:
          layers.append(MBBlock(in_channels, out_channels, kernel_size, stride, expand_ratio, reduction)) # downsampling block
        else:
          layers.append(MBBlock(in_channels, out_channels, kernel_size, 1, expand_ratio, reduction)) # just regular conv block
        in_channels = out_channels

    # add the classifier stuff
    layers.append(EnetConvBlock(in_channels, embedding_size, 1, 0, 1)) # 1x1 conv to final size
    layers.append(torch.nn.AdaptiveAvgPool2d((1,1)))
    layers.append(torch.nn.Flatten())
    layers.append(torch.nn.Dropout(p=dropout))
    self.backbone = torch.nn.Sequential(*layers)

    self.cls_layer = torch.nn.Linear(embedding_size, num_classes)

  def forward(self, x, return_feats=False):
    feats = self.backbone(x)
    out = self.cls_layer(feats)

    if return_feats:
        return feats
    else:
        return out

### Resnet

In [None]:
class basicConvBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, activation_function=True):
        super(basicConvBlock, self).__init__()
        layers = []
        layers.append(torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                            stride=stride, padding=padding, bias=False))
        layers.append(torch.nn.BatchNorm2d(out_channels))
        if activation_function:
          layers.append(torch.nn.SiLU())
        self.model = torch.nn.Sequential(*layers)
    def forward(self, x):
        return self.model(x)

In [None]:
class ResNetBasicBlock(torch.nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, downsample_block=None):
    super(ResNetBasicBlock, self).__init__()
    self.downsample_block = downsample_block
    self.model = torch.nn.Sequential(
        basicConvBlock(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
        basicConvBlock(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size,
                       stride=1, padding=padding, activation_function=None)
    )
    self.activation_function = torch.nn.SiLU()
  def forward(self, x):
    y = self.model(x)
    if self.downsample_block:
      y += self.downsample_block(x)
    else:
      y += x
    return self.activation_function(y)

In [None]:
class ResNetModel(torch.nn.Module):
  def __init__(self, in_channels=32, num_classes=2):
    super(ResNetModel, self).__init__()
    self.in_channels = in_channels*2
    layers = []
    layers.append(basicConvBlock(in_channels=in_channels, out_channels=self.in_channels,
                                 kernel_size=7, stride=2, padding=3,
                                 activation_function=True))
    layers.append(torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

    configs = [
        [self.in_channels, 2, 1],
        [self.in_channels*2, 2, 2],
        [self.in_channels*4, 2, 2],
        [self.in_channels*8, 2, 2]
    ]

    self.cls_layer = torch.nn.Linear(self.in_channels*8, num_classes)
    for out_channels, number, stride in configs:
      self.build_blocks(out_channels, number, stride, layers=layers)
    layers.append(torch.nn.AdaptiveAvgPool2d((1,1)))
    layers.append(torch.nn.Flatten())
    self.backbone = torch.nn.Sequential(*layers)

  def build_blocks(self, out_channels, number, stride, layers):
    downsample_block = None
    if stride != 1:
      downsample_block = basicConvBlock(in_channels=self.in_channels,
                                        out_channels=out_channels,
                                        kernel_size=1, stride=stride,
                                        padding=0,
                                        activation_function=False)
    layers.append(ResNetBasicBlock(in_channels=self.in_channels,
                                   out_channels=out_channels, stride=stride,
                                   downsample_block=downsample_block))
    self.in_channels = out_channels
    for _ in range(1, number):
      layers.append(ResNetBasicBlock(in_channels=self.in_channels,
                                    out_channels=out_channels, stride=1))
    return
  def forward(self, x, return_feats=False):
    feats = self.backbone(x)
    out = self.cls_layer(feats)
    if return_feats:
        return feats
    else:
      return out

### Model Init

In [None]:
# set classifier
match config['model']:
  case 'resnet':
    classifier = ResNetModel
  case 'inception':
    classifier = InceptionModelv1
  case 'efficientnet':
    classifier = EfficientNet

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

filter_output_channels = 1
stem_output_channels = 32
class_in_channels = stem_output_channels
if not config['use_subtract']:
  class_in_channels += stem_output_channels
if config['keep_medium_texture']:
  class_in_channels += stem_output_channels

In [None]:
model = torch.nn.Sequential(Stem(in_channels=filter_output_channels, stem_output_channels=class_in_channels), classifier(in_channels=class_in_channels)).to(DEVICE)
summary(model, (32,1, 256, 256))

## Loss, Optimizer, and Scheduler

In [None]:
# Defining Loss function
criterion = torch.nn.CrossEntropyLoss()

# Defining Optimizer
optimizer = torch.optim.Adam(model.parameters())

# Defining Scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config['epochs'])
scaler = torch.cuda.amp.GradScaler()

## Train and Validate

### Train

In [None]:
def train(model, dataloader, optimizer, criterion):

    model.train()

    # Progress Bar
    batch_bar   = tqdm(total=len(dataloader), dynamic_ncols=True, leave=False, position=0, desc='Train', ncols=5)

    num_correct = 0
    total_loss  = 0
    for i, (images, labels) in enumerate(dataloader):

        optimizer.zero_grad() # Zero gradients

        labels = labels.to(DEVICE)

        #with torch.cuda.amp.autocast(): # This implements mixed precision. Thats it!
        outputs = model(images)
        loss    = criterion(outputs, labels)

        # Update no. of correct predictions & loss as we iterate
        num_correct     += int((torch.argmax(outputs, axis=1) == labels).sum())
        total_loss      += float(loss.item())

        # tqdm lets you add some details so you can monitor training as you train.
        batch_bar.set_postfix(
            acc         = "{:.04f}%".format(100 * num_correct / (config['batch_size']*(i + 1))),
            loss        = "{:.04f}".format(float(total_loss / (i + 1))),
            num_correct = num_correct,
            lr          = "{:.04f}".format(float(optimizer.param_groups[0]['lr']))
        )

        scaler.scale(loss).backward() # This is a replacement for loss.backward()
        scaler.step(optimizer) # This is a replacement for optimizer.step()
        scaler.update()

        batch_bar.update() # Update tqdm bar

    batch_bar.close() # You need this to close the tqdm bar

    acc         = 100 * num_correct / (config['batch_size']* len(dataloader))
    total_loss  = float(total_loss / len(dataloader))

    return acc, total_loss

### Validation

In [None]:
def validate(model, dataloader, criterion):

    model.eval()
    batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, position=0, leave=False, desc='Val', ncols=5)

    num_correct = 0.0
    total_loss = 0.0

    for i, (images, labels) in enumerate(dataloader):

        # Move images to device
        labels = labels.to(DEVICE)

        # Get model outputs
        with torch.inference_mode():
            outputs = model(images)
            loss = criterion(outputs, labels)

        num_correct += int((torch.argmax(outputs, axis=1) == labels).sum())
        total_loss += float(loss.item())

        batch_bar.set_postfix(
            acc="{:.04f}%".format(100 * num_correct / (config['batch_size']*(i + 1))),
            loss="{:.04f}".format(float(total_loss / (i + 1))),
            num_correct=num_correct)

        batch_bar.update()

    batch_bar.close()
    acc = 100 * num_correct / (config['batch_size']* len(dataloader))
    total_loss = float(total_loss / len(dataloader))
    return acc, total_loss

## WandB

In [None]:
use_wandb = False # TODO: set True/False depending on if you are logging results to wandb
if use_wandb:
  wandb.login(key="YOUR_WANDB_KEY") # TODO: Add your wandb key to log runs

In [None]:
run_name = "Abl-size-{}-num-{}-subs-{}-med-{}-mod-{}".format(
    config["crop_size"],
    config["num_crops"],
    config["use_subtract"],
    config["keep_medium_texture"],
    config["model"])

if use_wandb:
  run = wandb.init(
      name = run_name, ## Wandb creates random run names if you skip this field
      reinit = True, ### Allows reinitalizing runs when you re-run this cell
      # id = "n37w794d", ### Insert specific run id here if you want to resume a previous run
      # resume = "must", ### You need this to resume previous runs, but comment out reinit = True when using this
      project = "idl-project", ### Project should be created in your wandb account
      config = config ### Wandb Config for your run
  )

In [None]:
if use_wandb:
  model_arch  = str(model)

  ### Save it in a txt file
  arch_file   = open("model_arch.txt", "w")
  file_write  = arch_file.write(model_arch)
  arch_file.close()

  ### log it in your wandb run with wandb.save()
  wandb.save('model_arch.txt')

  ### Save parameters in a txt file
  config_file   = open("config.txt", "w")
  file_write  = config_file.write(str(config))
  config_file.close()

  ### log it in your wandb run with wandb.save()
  wandb.save('config.txt')

## Run

In [None]:
best_val_acc        = 0.0

In [None]:
for epoch in range(config['epochs']):
    curr_lr = float(optimizer.param_groups[0]['lr'])

    train_acc, train_loss = train(model, train_loader, optimizer, criterion)

    print("\nEpoch {}/{}: \nTrain Acc {:.04f}%\t Train Loss {:.04f}\t Learning Rate {:.04f}".format(
        epoch + 1, config['epochs'], train_acc, train_loss, curr_lr))

    val_acc, val_loss = validate(model, val_loader, criterion)
    print("Val Acc {:.04f}%\t Val Loss {:.04f}".format(val_acc, val_loss))

    if use_wandb:
      wandb.log({"train_acc": train_acc,
                "train_loss":train_loss,
                "val_acc": val_acc,
                "val_loss": val_loss,
                "learning_rate": curr_lr})

    if val_acc >= best_val_acc:
        best_val_acc = val_acc
        torch.save({'model_state_dict':model.state_dict(),
                    'optimizer_state_dict':optimizer.state_dict(),
                    'scheduler_state_dict':scheduler.state_dict(),
                    'val_acc': val_acc,
                    'epoch': epoch}, './checkpoint_classification.pth')
        if use_wandb:
          wandb.save('checkpoint.pth')
        print("Saved best model")

    scheduler.step()

In [None]:
run.finish()