In [1]:
# Installing extra dependencies into drive
!pip install lightning

Collecting lightning
  Downloading lightning-2.5.5-py3-none-any.whl.metadata (39 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Collecting torchmetrics<3.0,>0.7.0 (from lightning)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.5.5-py3-none-any.whl.metadata (20 kB)
Downloading lightning-2.5.5-py3-none-any.whl (828 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m828.5/828.5 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m30.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pytorch_lightning-2.5.5-py3-none-any.whl (832 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import sys
import os

def is_colab_env():
    return "google.colab" in sys.modules

def mount_google_drive(drive_dir="/content/drive/", repo_dir="MyDrive/repositories/deepfake-detection"):
    # mount google drive
    from google.colab import drive
    drive.mount(drive_dir)

    # change to correct working directory
    import os
    os.chdir(f"{drive_dir}{repo_dir}")
    print(os.listdir()) # verify content

def resolve_path(levels_deep=3):
    if is_colab_env():
        mount_google_drive()
    else:
        # Get the directory of the current script
        current_dir = os.path.dirname(os.path.abspath('__file__'))

        # Construct the path to the parent directory
        for i in range(levels_deep):
            current_dir = os.path.dirname(current_dir)

        # Add the parent directory to sys.path
        sys.path.append(current_dir)
        print(sys.path)

resolve_path()

Mounted at /content/drive/
['src', '.git', 'playground', 'setup', '.gitignore', 'README.md', 'config.py', 'train.py', 'run_jupyter.sh', 'sfiad_sanity_check.sh', '__pycache__', 'freqnet_image.ipynb', 'reports', 'environment.yml', 'environment-updated.yml']


In [3]:
# import local config
import config

In [4]:
# import library dependencies
import numpy as np

In [5]:
# pytorch
import torch
import pytorch_lightning as L

In [None]:
# import local dependencies
from src.adapters.datasets.wilddeepfake import WildDeepfakeDataModule
from src.transforms.frequency import fft_magnitude
from src.models.resnet import ResNetClassifier, DEFAULT_DATA_TRANSFORMS

In [None]:
model_id = "fft_magnitude_resnet18"
model_checkpoint_dir = f"{config.CHECKPOINTS_DIR}/{model_id}"

In [None]:
max_samples = 500_000  # For quick development, remove for full dataset
batch_size = 16
num_workers = 2
seed = config.SEED

# Set seeds for reproducibility
torch.manual_seed(seed)
np.random.seed(seed)

# Determine device (GPU or CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# define datamodule
wilddeepfake_data_module = WildDeepfakeDataModule(
    batch_size=batch_size,
    num_workers=num_workers,
    max_samples=max_samples,
    seed=seed,
    transforms=DEFAULT_DATA_TRANSFORMS,
    additional_transforms=fft_magnitude
)
# from src.adapters.datasets.wilddeepfake import load_streaming_dataset, create_data_loaders
# dataset_name = "xingjunm/WildDeepfake"
# datasets = load_streaming_dataset(
#     dataset_name,
#     max_samples=max_samples,
#     seed=seed
# )
# train_loader, val_loader, test_loader = create_data_loaders(
#     datasets,
#     batch_size=batch_size,
#     num_workers=num_workers,
#     transforms=DEFAULT_DATA_TRANSFORMS,
#     # additional_transforms=fft
# )

Using device: cuda
Loading streaming dataset: xingjunm/WildDeepfake


README.md:   0%|          | 0.00/31.0 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/963 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/157 [00:00<?, ?it/s]

In [9]:
# define early stopper
early_stop_callback = L.callbacks.EarlyStopping(
    monitor="val_loss",       # metric to track
    patience=3,               # epochs to wait for improvement
    mode="min",               # "min" because we want val_loss to decrease
    verbose=True
)

In [10]:
# define ligntning checkpoint
best_loss_checkpoint = L.callbacks.ModelCheckpoint(
    monitor="val_loss",
    mode="min",
    save_top_k=1,
)

In [None]:
# define model
deepfake_detector = ResNetClassifier(in_channels=1, freeze_features=False)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 200MB/s]


In [12]:
# define trainer
# max_steps ~ (num_samples // batch_size) * num_epochs
max_epochs = 20
trainer = L.Trainer(
    devices=1,
    callbacks=[early_stop_callback, best_loss_checkpoint],
    default_root_dir=model_checkpoint_dir,
    log_every_n_steps=10,
    profiler="simple", # track time taken
    # max_steps=31_250,
    max_steps= max_epochs * max_samples / batch_size
    # limit_train_batches=1000,   # how many batches per "epoch"
    # limit_val_batches=200,      # how many val batches per "epoch"
  )

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
# train model
trainer.fit(deepfake_detector, datamodule=wilddeepfake_data_module)
# trainer.fit(deepfake_detector, train_dataloaders=train_loader, val_dataloaders=val_loader)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
   | Name       | Type              | Params | Mode 
----------------------------------------------------------
0  | model      | ResNet            | 11.2 M | train
1  | criterion  | BCEWithLogitsLoss | 0      | train
2  | train_acc  | BinaryAccuracy    | 0      | train
3  | val_acc    | BinaryAccuracy    | 0      | train
4  | test_acc   | BinaryAccuracy    | 0      | train
5  | train_prec | BinaryPrecision   | 0      | train
6  | val_prec   | BinaryPrecision   | 0      | train
7  | test_prec  | BinaryPrecision   | 0      | train
8  | train_rec  | BinaryRecall      | 0      | train
9  | val_rec    | BinaryRecall      | 0      | train
10 | test_rec   | BinaryRecall      | 0      | train
11 | train_f1   | BinaryF1Score     | 0      | train
12 | val_f1     | BinaryF1Score     | 0      | train
13 | test_f1    | BinaryF1Score     | 0      | train
14 | train_auc 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.12/dist-packages/pytorch_lightning/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.


Training: |          | 0/? [00:00<?, ?it/s]

'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: d90a7719-34aa-4a6f-8ab6-3633b62f8d4b)')' thrown while requesting GET https://huggingface.co/datasets/xingjunm/WildDeepfake/resolve/f3835aaf281dd9f8d79b51c4e02f050d3f7af0b4/deepfake_in_the_wild/fake_train/65.tar.gz
Retrying in 1s [Retry 1/5].


In [None]:
# test dataset on unseen samples
trainer.test(deepfake_detector, datamodule=wilddeepfake_data_module)
# trainer.test(deepfake_detector, test_loader)

In [None]:
# view metrics from previous runs
%reload_ext tensorboard
%tensorboard --logdir=$model_checkpoint_dir