In [1]:
import sys
import os

def is_colab_env():
    return "google.colab" in sys.modules

def mount_google_drive(drive_dir="/content/drive/", repo_dir="MyDrive/repositories/deepfake-detection"):
    # mount google drive
    from google.colab import drive
    drive.mount(drive_dir)

    # change to correct working directory
    import os
    os.chdir(f"{drive_dir}{repo_dir}")
    print(os.listdir()) # verify content

def resolve_path(levels_deep=3):
    if is_colab_env():
        mount_google_drive()
    else:
        # Get the directory of the current script
        current_dir = os.path.dirname(os.path.abspath('__file__'))

        # Construct the path to the parent directory
        for i in range(levels_deep):
            current_dir = os.path.dirname(current_dir)

        # Add the parent directory to sys.path
        sys.path.append(current_dir)
        print(sys.path)

resolve_path()

['c:\\Users\\jinxy\\anaconda3\\envs\\df-env\\python310.zip', 'c:\\Users\\jinxy\\anaconda3\\envs\\df-env\\DLLs', 'c:\\Users\\jinxy\\anaconda3\\envs\\df-env\\lib', 'c:\\Users\\jinxy\\anaconda3\\envs\\df-env', '', 'c:\\Users\\jinxy\\anaconda3\\envs\\df-env\\lib\\site-packages', 'c:\\Users\\jinxy\\anaconda3\\envs\\df-env\\lib\\site-packages\\win32', 'c:\\Users\\jinxy\\anaconda3\\envs\\df-env\\lib\\site-packages\\win32\\lib', 'c:\\Users\\jinxy\\anaconda3\\envs\\df-env\\lib\\site-packages\\Pythonwin', 'g:\\']


In [2]:
# import local config
import config

In [None]:
# import library dependencies
import numpy as np

In [4]:
# pytorch
import torch
import pytorch_lightning as L

In [5]:
# import local dependencies
# from src.adapters.datasets.wilddeepfake import WildDeepfakeDataModule
from src.adapters.datasets.sida import SidADataModule
from src.models.freqnet import LitFreqNet

In [6]:
model_id = "frequency_freqnet"
model_checkpoint_dir = f"{config.CHECKPOINTS_DIR}/{model_id}"

In [7]:
from torchvision import transforms

# --- common normalization (ImageNet) ---
imagenet_normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

# --- training transform ---
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),        # resize frame
    transforms.Lambda(lambda img: img.convert("RGB")),  # force RGB
    transforms.RandomHorizontalFlip(),    # flip for augmentation
    transforms.ColorJitter(               # optional: color variation
        brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1
    ),
    transforms.ToTensor(),
    imagenet_normalize,
])

# --- validation transform ---
val_transform = transforms.Compose([
    transforms.Resize((256, 256)),  # deterministic resize
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.ToTensor(),
    imagenet_normalize,
])

# --- test transform (usually same as val) ---
test_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.ToTensor(),
    imagenet_normalize,
])

transforms = {
    "train": train_transform,
    "val": val_transform,
    "test": test_transform
}

In [8]:
# Set seeds for reproducibility
seed = config.SEED

torch.manual_seed(seed)
np.random.seed(seed)

# Determine device (GPU or CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [15]:
dataset_name = "xingjunm/WildDeepfake"
max_samples = 994_000  # For quick development, remove for full dataset
batch_size = 16
num_workers = 0
max_epochs = 30

# define datamodule
# wilddeepfake_data_module = WildDeepfakeDataModule(
#     dataset_name=dataset_name,
#     batch_size=batch_size,
#     num_workers=num_workers,
#     max_samples=max_samples,
#     seed=seed,
#     transforms=DEFAULT_DATA_TRANSFORMS,
#     additional_transforms=ela
# )
from src.adapters.datasets.wilddeepfake import load_streaming_dataset, create_data_loaders
datasets = load_streaming_dataset(
    dataset_name,
    max_samples=max_samples,
    seed=seed
)
train_loader, val_loader, test_loader = create_data_loaders(
    datasets,
    batch_size=batch_size,
    num_workers=num_workers,
    transforms=transforms,
    # additional_transforms=ela
)

Loading streaming dataset: xingjunm/WildDeepfake


Resolving data files:   0%|          | 0/963 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/157 [00:00<?, ?it/s]

In [10]:
# define early stopper
early_stop_callback = L.callbacks.EarlyStopping(
    monitor="val_loss",       # metric to track
    patience=3,               # epochs to wait for improvement
    mode="min",               # "min" because we want val_loss to decrease
    verbose=True
)

In [11]:
# define ligntning checkpoint
best_loss_checkpoint = L.callbacks.ModelCheckpoint(
    monitor="val_loss",
    mode="min",
    save_top_k=1,
)

In [12]:
# define model
deepfake_detector = LitFreqNet()

In [13]:
trainer = L.Trainer(
    devices=1,
    callbacks=[early_stop_callback, best_loss_checkpoint],
    default_root_dir=model_checkpoint_dir,
    log_every_n_steps=10,
    profiler="simple", # track time taken
    max_steps= max_epochs * max_samples / batch_size, #(desired_epochs Ã— dataset_size) / batch_size
    # limit_train_batches=1000,   # how many batches per "epoch"
    # limit_val_batches=200,      # how many val batches per "epoch"
  )

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\Users\jinxy\anaconda3\envs\df-env\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [16]:
# train model
# trainer.fit(deepfake_detector, datamodule=sida_data_module)
trainer.fit(deepfake_detector, train_dataloaders=train_loader, val_dataloaders=val_loader)


   | Name       | Type              | Params | Mode 
----------------------------------------------------------
0  | model      | FreqNet           | 1.9 M  | train
1  | loss_fn    | BCEWithLogitsLoss | 0      | train
2  | train_acc  | BinaryAccuracy    | 0      | train
3  | val_acc    | BinaryAccuracy    | 0      | train
4  | test_acc   | BinaryAccuracy    | 0      | train
5  | train_prec | BinaryPrecision   | 0      | train
6  | val_prec   | BinaryPrecision   | 0      | train
7  | test_prec  | BinaryPrecision   | 0      | train
8  | train_rec  | BinaryRecall      | 0      | train
9  | val_rec    | BinaryRecall      | 0      | train
10 | test_rec   | BinaryRecall      | 0      | train
11 | train_f1   | BinaryF1Score     | 0      | train
12 | val_f1     | BinaryF1Score     | 0      | train
13 | test_f1    | BinaryF1Score     | 0      | train
14 | train_auc  | BinaryAUROC       | 0      | train
15 | val_auc    | BinaryAUROC       | 0      | train
16 | test_auc   | BinaryAUROC       | 0

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\jinxy\anaconda3\envs\df-env\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Loading sample 1: {'png': {'path': None, 'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xe0\x00\x00\x00\xe0\x08\x02\x00\x00\x00\x95O\xfd\xb6\x00\x00\xd1"IDATx\x9c|\xfd\xdd\x9a\xe4J\x8e\x04\x88\x19\x00\'#2\xabN\xf7\xac\xf4\xbc\xba\xd0\xb5\x9eu\xb53\xdd\xa7*3H\x07\xa0\x0b\x03\x9c\x8c\xeaY\xc5W]\x9d\'+"H\xba\xc3\xf1c0\x00\xf2\xff\xf9\x7f\xff\xbf\xfe\xaf\xff\xfe\xef\xff\xf3\xff\xfc?\xff\xfb\xbf\xff\xfb\xf5zefD|\xbdN\x00\xaf\xf3\x9c\xd3=rF\xce9Ed\xdb\xb6\xfd\xf1\x04p\x9c\xf3\xeb\xfb\xfb\xdf\xbf\xfe\xfe\xfa\xfe\xf6\x88\x84\x00\x10\x11\x00\nA\xbf"\xdc\xe7\xf4pww\x8f\xcc\x00 \x99\x02@ \x02S\xdd\x14*\xf2\x90T\xd5Ms\x1b\xe3s\x1f\xcfm\xdb\x87=\x86n\xdb\xf6\x1c\xa6j\x9b\xc86\xb6}\xd3}\xd8\x18\xc3\xccTE\x90""\x92\x00\x80\x14\x01\x10\x89\x04\xa0\nU\x1dc\x8c1653\x13@ \x00R "\xaa*"\xa6jj"*b\xda\xbf\xcd~\x04\x15U\x88\x8a\x88"3""\x93\xd7\n\x00c\x0c~\x89\xf6\xab>$2\xe7\xe4\r%\x92K\x9a\xfd\x92D\xb8#3\x13\x19\x01df\xba{D\xf0o\xbe\x7f\x9e\xee\x1e\x91\xe9\x01wwdD\x042"\x1d\x19\t\x8ft\xcf\xe9qL|}}

RuntimeError: Detected the following values in `target`: tensor([-1], dtype=torch.int32) but expected only the following values [0, 1].

In [None]:
# test dataset on unseen samples
# trainer.test(deepfake_detector, datamodule=sida_data_module)
trainer.test(deepfake_detector, test_loader)

In [None]:
# view metrics from previous runs
%reload_ext tensorboard
%tensorboard --logdir=$model_checkpoint_dir

In [None]:

import torch
print('Torch version:', torch.__version__)
print('CUDA available:', torch.cuda.is_available())
print('CUDA version from torch:', torch.version.cuda)
print('Device count:', torch.cuda.device_count())
