In [None]:
!git clone https://github.com/Morteza-24/SelfBlendedImages.git
!cp SelfBlendedImages/src/* ./ -r
%pip install efficientnet-pytorch retinaface-pytorch
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from model import Detector
import torch

class SBIFeatureExtractor(torch.nn.Module):
    def __init__(self):
        super(SBIFeatureExtractor, self).__init__()
        sbi_model = Detector()
        sbi_model = sbi_model.to('cuda')
        cnn_sd = torch.load("drive/MyDrive/FFraw.tar")["model"]
        sbi_model.load_state_dict(cnn_sd)
        self.features = sbi_model.net.extract_features

    def forward(self, x):
        return self.features(x)

In [None]:
import torch

class ConvClassifier(torch.nn.Module):
    def __init__(self, input_channels, num_classes=1):
        super(ConvClassifier, self).__init__()
        self.classifier = torch.nn.Sequential(
            torch.nn.Conv2d(input_channels, 512, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.AdaptiveAvgPool2d((1, 1)),
            torch.nn.Flatten(),
            torch.nn.Dropout(0.5),
            torch.nn.Linear(512, num_classes),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        return self.classifier(x)

In [None]:
import cv2
import torch
from inference.preprocess import extract_face
from retinaface.pre_trained_models import get_model

def get_image_features(img):
  model = SBIFeatureExtractor()
  model.eval()

  frame = cv2.imread(img)
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

  face_detector = get_model("resnet50_2020-07-20", max_size=max(frame.shape),device=torch.device('cuda'))
  face_detector.eval()

  face_list = extract_face(frame,face_detector)

  with torch.no_grad():
      img = torch.tensor(face_list).to(torch.device('cuda')).float()/255
      feat = model(img)
  return feat

In [None]:
model = ConvClassifier(input_channels=1792).to('cuda')
features = get_image_features("a.png")
output = model(features)
print(output)

In [None]:
from utils.funcs import IoUfrom2bboxes, crop_face, RandomDownScale
from torchmetrics.classification import BinaryAUROC
from torch.utils.data import Dataset
from torch.nn import functional as F
from datetime import datetime
import albumentations as alb
from model import Detector
from PIL import Image
import numpy as np
import torch
import json
import os


class SBIFeatureExtractor(torch.nn.Module):
    def __init__(self):
        super(SBIFeatureExtractor, self).__init__()
        sbi_model = Detector()
        sbi_model = sbi_model.to('cuda')
        cnn_sd = torch.load("drive/MyDrive/FFraw.tar")["model"]
        sbi_model.load_state_dict(cnn_sd)
        self.features = sbi_model.net.extract_features

    def forward(self, x):
        return self.features(x)


class FeatureDataset(Dataset):
    def __init__(self, images, labels, is_val=False):
        self.images = images
        self.labels = labels
        self.extract_features = SBIFeatureExtractor()
        self.extract_features.eval()
        self.transforms = self.get_transforms()
        self.source_transforms = self.get_source_transforms()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
      flag=True
      while flag:
        try:
          filename = self.images[idx]
          img = np.array(Image.open(filename))
          landmark = np.load(filename.replace('.png', '.npy').replace('/frames/', self.path_lm))[0]
          bbox_lm = np.array([landmark[:,0].min(), landmark[:,1].min(), landmark[:,0].max(), landmark[:,1].max()])
          bboxes = np.load(filename.replace('.png', '.npy').replace('/frames/', '/retina/'))[:2]
          iou_max = -1
          for i in range(len(bboxes)):
            iou = IoUfrom2bboxes(bbox_lm,bboxes[i].flatten())
            if iou_max < iou:
              bbox = bboxes[i]
              iou_max = iou
          landmark = self.reorder_landmark(landmark)
          if not self.is_val:
            if np.random.rand() < 0.5:
              img, _, landmark, bbox = self.hflip(img, None, landmark,bbox)
          img,landmark,bbox,__ = crop_face(img, landmark, bbox, margin=True, crop_by_bbox=False)
          img_r, mask = self.self_blending(img.copy(), landmark.copy())
          if not self.is_val:
            img_r = self.transforms(image=img_r.astype('uint8'))
          img_f,_,__,___,y0_new,y1_new,x0_new,x1_new=crop_face(img_r,landmark,bbox,margin=False,crop_by_bbox=True,abs_coord=True,phase="val" if self.is_val else "train")
          img_r=img_r[y0_new:y1_new,x0_new:x1_new]
          img_r=cv2.resize(img_r,self.image_size,interpolation=cv2.INTER_LINEAR).astype('float32')/255
          img_r=img_r.transpose((2,0,1))
          flag=False
        except Exception as e:
          print(e)
          idx=torch.randint(low=0,high=len(self),size=(1,)).item()
      with torch.no_grad():
        features = self.extract_features(img_r)
      return features, self.labels[idx]

    def reorder_landmark(self,landmark):
      landmark_add = np.zeros((13,2))
      for idx,idx_l in enumerate([77,75,76,68,69,70,71,80,72,73,79,74,78]):
        landmark_add[idx] = landmark[idx_l]
      landmark[68:] = landmark_add
      return landmark

    def hflip(self,img,mask=None,landmark=None,bbox=None):
      H,W = img.shape[:2]
      landmark = landmark.copy()
      bbox=bbox.copy()
      if landmark is not None:
        landmark_new = np.zeros_like(landmark)
        landmark_new[:17]=landmark[:17][::-1]
        landmark_new[17:27]=landmark[17:27][::-1]
        landmark_new[27:31]=landmark[27:31]
        landmark_new[31:36]=landmark[31:36][::-1]
        landmark_new[36:40]=landmark[42:46][::-1]
        landmark_new[40:42]=landmark[46:48][::-1]
        landmark_new[42:46]=landmark[36:40][::-1]
        landmark_new[46:48]=landmark[40:42][::-1]
        landmark_new[48:55]=landmark[48:55][::-1]
        landmark_new[55:60]=landmark[55:60][::-1]
        landmark_new[60:65]=landmark[60:65][::-1]
        landmark_new[65:68]=landmark[65:68][::-1]
        if len(landmark)==68:
          pass
        elif len(landmark)==81:
          landmark_new[68:81]=landmark[68:81][::-1]
        else:
          raise NotImplementedError
        landmark_new[:,0]=W-landmark_new[:,0]
      else:
        landmark_new=None
      if bbox is not None:
        bbox_new=np.zeros_like(bbox)
        bbox_new[0,0]=bbox[1,0]
        bbox_new[1,0]=bbox[0,0]
        bbox_new[:,0]=W-bbox_new[:,0]
        bbox_new[:,1]=bbox[:,1].copy()
        if len(bbox)>2:
          bbox_new[2,0]=W-bbox[3,0]
          bbox_new[2,1]=bbox[3,1]
          bbox_new[3,0]=W-bbox[2,0]
          bbox_new[3,1]=bbox[2,1]
          bbox_new[4,0]=W-bbox[4,0]
          bbox_new[4,1]=bbox[4,1]
          bbox_new[5,0]=W-bbox[6,0]
          bbox_new[5,1]=bbox[6,1]
          bbox_new[6,0]=W-bbox[5,0]
          bbox_new[6,1]=bbox[5,1]
      else:
        bbox_new=None
      if mask is not None:
        mask=mask[:,::-1]
      else:
        mask=None
      img=img[:,::-1].copy()
      return img,mask,landmark_new,bbox_new

    def self_blending(self,img,landmark):
      H, W = len(img), len(img[0])
      if np.random.rand() < 0.25:
        landmark = landmark[:68]
      if np.random.rand() < 0.5:
        img = self.source_transforms(image=img.astype(np.uint8))['image']
      img = img.astype(np.uint8)
      return img

    def get_transforms(self):
      return alb.Compose([
        alb.RGBShift((-20,20),(-20,20),(-20,20),p=0.3),
        alb.HueSaturationValue(hue_shift_limit=(-0.3,0.3), sat_shift_limit=(-0.3,0.3), val_shift_limit=(-0.3,0.3), p=0.3),
        alb.RandomBrightnessContrast(brightness_limit=(-0.3,0.3), contrast_limit=(-0.3,0.3), p=0.3),
        alb.ImageCompression(quality_lower=40,quality_upper=100,p=0.5),
      ], p=1.)

    def get_source_transforms(self):
      return alb.Compose([
          alb.Compose([
              alb.RGBShift((-20,20),(-20,20),(-20,20),p=0.3),
              alb.HueSaturationValue(hue_shift_limit=(-0.3,0.3), sat_shift_limit=(-0.3,0.3), val_shift_limit=(-0.3,0.3), p=1),
              alb.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1,0.1), p=1),
            ],p=1),
          alb.OneOf([
            RandomDownScale(p=1),
            alb.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=1),
          ],p=1),
        ], p=1.)


class ConvClassifier(torch.nn.Module):
    def __init__(self, input_channels, num_classes=1):
        super(ConvClassifier, self).__init__()
        self.classifier = torch.nn.Sequential(
            torch.nn.Conv2d(input_channels, 512, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.AdaptiveAvgPool2d((1, 1)),
            torch.nn.Flatten(),
            torch.nn.Dropout(0.5),
            torch.nn.Linear(512, num_classes),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        return self.classifier(x)


with open("configs/sbi/base.json", "rt") as f:
  cfg = json.load(f)
train_dataset = FeatureDataset(train_images, train_labels)
val_dataset = FeatureDataset(val_images, val_labels, is_val=True)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=cfg['batch_size'] // 2,
                                           shuffle=True,
                                           num_workers=8,
                                           pin_memory=True,
                                           drop_last=True,
                                           persistent_workers=True
)
val_loader=torch.utils.data.DataLoader(val_dataset,
                                       batch_size=cfg['batch_size'],
                                       shuffle=False,
                                       num_workers=8,
                                       pin_memory=True,
                                       persistent_workers=True
)
model = ConvClassifier(input_channels=1792).to('cuda')
criterion = torch.nn.BCEWithLogitsLoss()  # Combines sigmoid + binary cross-entropy
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.75)
train_auc_metric = BinaryAUROC(thresholds=None).to('cuda')
val_auc_metric = BinaryAUROC(thresholds=None).to('cuda')
now = datetime.now()
save_path = 'output/{}_'.format("SBI_FE_")+'base'+'_'+now.strftime("%m_%d_%H_%M_%S")+'/'
weight_dict = {}

for epoch in range(cfg["epoch"]):
    print(f"Epoch {epoch + 1}/{cfg["epoch"]}")
    print("-" * 30)
    model.train()
    running_loss = 0.0
    correct_preds = 0
    total_preds = 0

    for features, labels in train_loader:
        features, labels = features.to('cuda'), labels.to('cuda')

        # Forward pass
        outputs = model(features)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Track loss and accuracy
        running_loss += loss.item() * features.size(0)  # Accumulate loss
        preds = (outputs > 0.5).float()  # Threshold for binary classification
        correct_preds += (preds == labels).sum().item()
        total_preds += labels.size(0)
        train_auc_metric.update(outputs.sigmoid(), labels.int())

    train_loss = running_loss / len(train_loader.dataset)
    train_accuracy = correct_preds / total_preds
    train_auc = train_auc_metric.compute().item()
    train_auc_metric.reset()
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Train AUC: {train_auc:.4f}")

    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct_preds = 0
    val_total_preds = 0

    with torch.no_grad():
        for features, labels in val_loader:
            features, labels = features.to(device), labels.to(device)

            # Forward pass
            outputs = model(features)
            loss = criterion(outputs, labels)

            # Track loss and accuracy
            val_loss += loss.item() * features.size(0)  # Accumulate loss
            preds = (outputs > 0.5).float()
            val_correct_preds += (preds == labels).sum().item()
            val_total_preds += labels.size(0)
            val_auc_metric.update(outputs.sigmoid(), labels.int())

    val_loss = val_loss / len(val_loader.dataset)
    val_accuracy = val_correct_preds / val_total_preds
    val_auc = val_auc_metric.compute().item()
    val_auc_metric.reset()
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}, Validation AUC: {val_auc:.4f}")

    scheduler.step()

    if len(val_auc_list) < 6:
      save_model_path=os.path.join(save_path+'weights/',"{}_{:.4f}_val.tar".format(epoch+1, val_auc))
      weight_dict[save_model_path] = val_auc
      torch.save({
          "model":model.state_dict(),
          "optimizer":model.optimizer.state_dict(),
          "epoch":epoch
      },save_model_path)
      last_val_auc = min([weight_dict[k] for k in weight_dict])
    elif val_auc >= last_val_auc:
      save_model_path = os.path.join(save_path+'weights/',"{}_{:.4f}_val.tar".format(epoch+1,val_auc))
      for k in weight_dict:
        if weight_dict[k] == last_val_auc:
          del weight_dict[k]
          os.remove(k)
          weight_dict[save_model_path] = val_auc
          break
      torch.save({
          "model":model.state_dict(),
          "optimizer":model.optimizer.state_dict(),
          "epoch":epoch
      },save_model_path)
      last_val_auc = min([weight_dict[k] for k in weight_dict])
    print()

print("Training Complete!")