In [1]:
# data_dir = "/scratch-shared/TANJ0303/datasets/"

In [2]:
# !export HF_DATASETS_CACHE=data_dir

In [3]:
!pip install pytorch_lightning

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.5.5-py3-none-any.whl.metadata (20 kB)
Collecting torchmetrics>0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading pytorch_lightning-2.5.5-py3-none-any.whl (832 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m832.4/832.4 kB[0m [31m26.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m59.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lightning-utilities, torchmetrics, pytorch_lightning
Successfully installed lightning-utilities-0.15.2 pytorch_lightning-2.5.5 torchmetrics-1.8.2


In [4]:
import sys
import os

def is_colab_env():
    return "google.colab" in sys.modules

def mount_google_drive(drive_dir="/content/drive/", repo_dir="MyDrive/repositories/deepfake-detection"):
    # mount google drive
    from google.colab import drive
    drive.mount(drive_dir)

    # change to correct working directory
    import os
    repo_dir = f"{drive_dir}{repo_dir}"
    os.chdir(repo_dir)
    print(os.listdir()) # verify content
    return repo_dir

def resolve_path(levels_deep=3):
    if is_colab_env():
        return mount_google_drive()
    else:
        # Get the directory of the current script
        current_dir = os.path.dirname(os.path.abspath('__file__'))

        # Construct the path to the parent directory
        for i in range(levels_deep):
            current_dir = os.path.dirname(current_dir)

        # Add the parent directory to sys.path
        sys.path.append(current_dir)
        print(sys.path)
        return current_dir

proj_dir = resolve_path()

Mounted at /content/drive/
['src', '.git', 'playground', 'setup', 'README.md', 'hf_wdf.sh', 'freqnet_image.ipynb', 'reports', '__pycache__', 'faceforensics_download.py', 'environment.yml', 'run_jupyter.sh', 'config.py', 'dct_mean_real_fake.png', 'analysis', '.gitignore']


In [5]:
# import local config
import config

In [6]:
# import library dependencies
import numpy as np

In [7]:
# pytorch
import torch
import pytorch_lightning as L

In [8]:
# import local dependencies
from src.adapters.datasets.sida import SidADataModule
from src.models.resnet import ResNetClassifier

In [9]:
model_id = "rgb_resnet18"
model_checkpoint_dir = f"{proj_dir}/{config.CHECKPOINTS_DIR}/{model_id}"

In [10]:
from src.transforms.frequency import get_transforms
rgb_transforms = get_transforms("rgb")

In [11]:
seed = config.SEED

# Set seeds for reproducibility
torch.manual_seed(seed)
np.random.seed(seed)

# Determine device (GPU or CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [12]:
batch_size = 32
num_workers = 2

# define datamodule
sida_data_module = SidADataModule(
    batch_size=batch_size,
    num_workers=num_workers,
    seed=seed,
    transforms=rgb_transforms,
    # dataset_cache_dir=data_dir
)

In [13]:
# define early stopper
early_stop_callback = L.callbacks.EarlyStopping(
    monitor="val_loss",       # metric to track
    patience=3,               # epochs to wait for improvement
    mode="min",               # "min" because we want val_loss to decrease
    verbose=True
)

In [14]:
# define ligntning checkpoint
best_loss_checkpoint = L.callbacks.ModelCheckpoint(
    monitor="val_loss",
    mode="min",
    save_top_k=1,
)

In [15]:
# define model
checkpoint = "lightning_logs/version_0/checkpoints/epoch=10-step=278960.ckpt"
deepfake_detector = ResNetClassifier.load_from_checkpoint(f"{model_checkpoint_dir[1:]}/{checkpoint}")

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 169MB/s]


In [16]:
# define trainer
max_epochs = 100
trainer = L.Trainer(
    devices=1,
    callbacks=[early_stop_callback, best_loss_checkpoint],
    default_root_dir=model_checkpoint_dir,
    log_every_n_steps=100,
    profiler="simple", # track time taken
    max_epochs=max_epochs,
  )

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [17]:
# test dataset on unseen samples
trainer.test(deepfake_detector, datamodule=sida_data_module)

README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/249 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/34 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/249 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/34 [00:00<?, ?it/s]

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.profilers.profiler:TEST Profiler Report

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                         	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                       

[{'test_acc': 0.5619333386421204,
  'test_prec': 0.6647290587425232,
  'test_rec': 0.6918500065803528,
  'test_f1': 0.6780184507369995,
  'test_auc': 0.48592329025268555,
  'test_loss': 0.9904940724372864}]