## Set Up Library and Config

In [None]:
import os
import glob
import random
from tqdm import tqdm
import cv2
import sklearn.metrics
import pandas as pd
import re
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, GroupShuffleSplit
from transformers import get_linear_schedule_with_warmup

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import timm

import gc

import albumentations as A

import warnings
warnings.filterwarnings("ignore")

In [None]:
eff_names = [
    'efficientnet_b1',
    'efficientnet_b3',
    'efficientnet_b5',
]

vit_names = [
    'vit_base_patch16_224'
    'swin_base_patch4_window7_224',
    'swin_large_patch4_window7_224',
]


class CFG:
    model_name = 'vit_base_patch16_224'
    input_size = 224
    batch_size = 32
    num_epochs = 20
    lr = 5e-5
    seed = 42

In [None]:
def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True


seed_everything(CFG.seed)

## Dataset and Dataloader

In [None]:
meta = glob.glob('../input/deepfake-detection-faces-*/*.csv')
meta.sort(key=lambda f: int(re.sub('\D', '', f)))

dfs = []
for path in meta:
    df = pd.read_csv(path)
    df['path'] = ''
    path = path.split("/")[:-1]
    path = path[0] + '/' + path[1] + '/' + path[2] + '/'
    for i in range(len(df)):
        df.at[i, 'path'] = f'{path}{df.loc[i]["filename"][:-4]}'
    dfs.append(df)

train_df = pd.concat(dfs)
train_df = train_df.reset_index(drop=True)
len(train_df)

In [None]:
part = 16
for j in range((39-16)+1):
    if part+j != 17:
        meta = pd.read_csv(f'../input/dfdc-part-{part+j}/images/metadata{part+j}.csv')
    else:
        meta = pd.read_csv(f'../input/dfdc-part-{part+j}/images/metadata{part+j}.json', index_col=0)
    meta['path'] = ''
    del_idxs = []
    for i in range(len(meta)):
        if os.path.isdir(f'../input/dfdc-part-{part+j}/images/{meta.loc[i]["filename"][:-4]}'):
            if len(os.listdir(f'../input/dfdc-part-{part+j}/images/{meta.loc[i]["filename"][:-4]}')) < 10:
                del_idxs.append(i)
            else:
                meta.at[i, 'path'] = f'../input/dfdc-part-{part+j}/images/{meta.loc[i]["filename"][:-4]}'
        else:
            del_idxs.append(i)
    for idx in del_idxs:
        meta = meta.drop(idx)
    train_df = pd.concat([train_df,meta])
    train_df = train_df.reset_index(drop=True)
len(train_df)

In [None]:
dfs = []
part = 40
for j in range(10):
    if part+j != 17:
        meta = pd.read_csv(f'../input/dfdc-part-{part+j}/images/metadata{part+j}.csv')
    else:
        meta = pd.read_csv(f'../input/dfdc-part-{part+j}/images/metadata{part+j}.json', index_col=0)
    meta['path'] = ''
    del_idxs = []
    for i in range(len(meta)):
        if os.path.isdir(f'../input/dfdc-part-{part+j}/images/{meta.loc[i]["filename"][:-4]}'):
            if len(os.listdir(f'../input/dfdc-part-{part+j}/images/{meta.loc[i]["filename"][:-4]}')) < 10:
                del_idxs.append(i)
            else:
                meta.at[i, 'path'] = f'../input/dfdc-part-{part+j}/images/{meta.loc[i]["filename"][:-4]}'
        else:
            del_idxs.append(i)
    for idx in del_idxs:
        meta = meta.drop(idx)
    dfs.append(meta)
val_df = pd.concat(dfs)
val_df = val_df.reset_index(drop=True)
print(len(val_df))

In [None]:
tr_df = train_df
te_df = val_df

In [None]:
train_f = tr_df.loc[tr_df['label']=='FAKE']
train_r = tr_df.loc[tr_df['label']=='REAL']

In [None]:
class ImageDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        video = self.df.iloc[idx]
        
        imgs = glob.glob(f'{video["path"]}/*')
        if len(imgs) < 1:
            print(video["path"])
        
        bad = []
        for im in imgs:
            if len(im.split('_')) > 1:
               bad.append(im)
        for im in bad:
            imgs.remove(im)
        
        img_path = random.sample(imgs, 1)[0]
        img = cv2.cvtColor(cv2.imread(img_path),cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (CFG.input_size, CFG.input_size))

        if self.transform is not None:
          res = self.transform(image=img)
          img = res['image']
        
        img = np.rollaxis(img, -1, 0)
        
        label = video['label']
        labels = 1
        if label == 'FAKE':
            labels = 1
        else:
            labels = 0
        labels = np.array(labels).astype(np.float32)
        return [img, labels]

## Model

In [None]:
class Classifier(nn.Module):
  def __init__(self, in_f, out_f):
    super(Classifier, self).__init__()
    
    self.linear1 = nn.Linear(in_f, 512)
    self.relu = nn.ReLU()
    self.linear2 = nn.Linear(512, out_f)

  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.linear2(x)
    return x

In [None]:
model = timm.create_model(
    CFG.model_name,
    pretrained=True,
    num_classes=1
)

if CFG.model_name in eff_names:
    model.classifier = Classifier(model.classifier.in_features, 1)

print(model)

## Training

In [None]:
def criterion(pred, targets):
    pred = pred.view(-1, )
    loss = F.binary_cross_entropy(F.sigmoid(pred), targets)
    return loss

In [None]:
def train_model(epoch, optimizer, scheduler=None, history=None):
    model.train()
    total_loss = 0
    pred = []
    real = []
    t = tqdm(train_loader)
    for i, (img_batch, y_batch) in enumerate(t):
        img_batch = img_batch.cuda().float()
        y_batch = y_batch.cuda().float()

        optimizer.zero_grad()
        
        out = model(img_batch)
        loss = criterion(out, y_batch)
            
        total_loss += loss
        
        loss.backward()
        optimizer.step()
        
        for prediction in out:
            pred.append(torch.sigmoid(prediction))
            
        for label in y_batch:
            real.append(label.data.cpu())
    
    pred = [p.data.cpu().numpy() for p in pred]
    pred2 = pred
    pred = [np.round(p) for p in pred]
    pred = np.array(pred)
    acc = sklearn.metrics.accuracy_score(real, pred)

    real = [r.item() for r in real]
    pred2 = np.array(pred2).clip(0.01, 0.99)
    kaggle = sklearn.metrics.log_loss(real, pred2)

    total_loss /= len(train_loader)
    
    history.append(total_loss.detach().cpu().numpy())

    print(f'Train loss: %.4f, Accuracy: %.5f, LogLoss: %.6f'%(total_loss, acc, kaggle))

def evaluate_model(epoch, scheduler=None, history=None):
    model.eval()
    total_loss = 0
    pred = []
    real = []
    with torch.no_grad():
        for img_batch, y_batch in tqdm(val_loader):
            img_batch = img_batch.cuda().float()
            y_batch = y_batch.cuda().float()

            out = model(img_batch)
            loss = criterion(out, y_batch)
            total_loss += loss
            
            for prediction in out:
              pred.append(torch.sigmoid(prediction))
            for label in y_batch:
              real.append(label.data.cpu())
    
    pred = [p.data.cpu().numpy() for p in pred]
    pred2 = pred
    pred = [np.round(p) for p in pred]
    pred = np.array(pred)
    acc = sklearn.metrics.accuracy_score(real, pred)

    real = [r.item() for r in real]
    pred2 = np.array(pred2).clip(0.01, 0.99)
    kaggle = sklearn.metrics.log_loss(real, pred2)

    total_loss /= len(val_loader)
    
    history.append(total_loss.cpu().numpy())

    print(f'Valid loss: %.4f, Accuracy: %.5f, LogLoss: %.6f'%(total_loss, acc, kaggle))
    
    return kaggle

## Augmentation

In [None]:
train_transform = A.Compose([
    A.ImageCompression(quality_lower=60, quality_upper=100, p=0.5),
    A.GaussNoise(p=0.1),
    A.GaussianBlur(blur_limit=3, p=0.05),
    A.HorizontalFlip(),
    A.Normalize(always_apply=True)
])

val_transform = A.Compose([
    A.Normalize(always_apply=True)
])
val_dataset = ImageDataset(te_df, transform=val_transform)

In [None]:
fake_batch = train_f.sample(len(train_r)).reset_index(drop=True)
train_ = pd.concat([fake_batch, train_r])
train_ = train_.sample(frac=1).reset_index(drop=True)
print(train_['label'].value_counts())

train_dataset = ImageDataset(train_, transform=train_transform)

nrow, ncol = 3, 5
fig, axes = plt.subplots(nrow, ncol, figsize=(20, 8))
axes = axes.flatten()
for i, ax in enumerate(axes):
    images, label = train_dataset[i]
    image = np.rollaxis(images, 0, 3)
    ax.imshow(image)
    ax.set_title(f'label: {label}')
plt.tight_layout()

In [None]:
# del val_loader, model, train_loader
# import gc
# gc.collect()
# torch.cuda.empty_cache()

## Train

In [None]:
torch.cuda.empty_cache()
gc.collect()

train_history = []
val_history = []

best = 10

val_loader = DataLoader(dataset=val_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=0)

model = model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=CFG.lr)

for epoch in range(CFG.num_epochs):
    torch.cuda.empty_cache()
    gc.collect()
    
    fake_batch = train_f.sample(len(train_r)).reset_index(drop=True)
    train_ = pd.concat([fake_batch, train_r])
    train_ = train_.sample(frac=1).reset_index(drop=True)
    
    train_dataset = ImageDataset(train_, transform=train_transform)
    train_loader = DataLoader(dataset=train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=4)
    
    print('Epoch {}'.format(epoch + 1))
    
    train_model(epoch, optimizer, scheduler=None, history=train_history)
    
    loss = evaluate_model(epoch, scheduler=None, history=val_history)
    
    if loss < best:
      best = loss
      print(f'Saving best model...')
      torch.save(model.state_dict(), f'model_{epoch+1}.pth')

## Plot

In [None]:
import matplotlib.pyplot as plt

data_train = {'epoch': np.arange(CFG.num_epochs) + 1,
              'loss': train_history}

data_valid = {'epoch': np.arange(CFG.num_epochs) + 1,
              'loss': val_history}

df_train = pd.DataFrame(data_train)
df_valid = pd.DataFrame(data_valid)

# Plotting
plt.figure(figsize=(10, 6))

plt.plot(df_train['epoch'], df_train['loss'], label='Train Loss', marker='o')
plt.plot(df_valid['epoch'], df_valid['loss'], label='Validation Loss', marker='o')

plt.title('Training and Validation Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()