In [None]:
import sys
import os

def is_colab_env():
    return "google.colab" in sys.modules

def mount_google_drive(drive_dir="/content/drive/", repo_dir="MyDrive/repositories/deepfake-detection"):
    # mount google drive
    from google.colab import drive
    drive.mount(drive_dir)

    # change to correct working directory
    import os
    os.chdir(f"{drive_dir}{repo_dir}")
    print(os.listdir()) # verify content

def resolve_path(levels_deep=3):
    if is_colab_env():
        mount_google_drive()
    else:
        # Get the directory of the current script
        current_dir = os.path.dirname(os.path.abspath('__file__'))

        # Construct the path to the parent directory
        for i in range(levels_deep):
            current_dir = os.path.dirname(current_dir)

        # Add the parent directory to sys.path
        sys.path.append(current_dir)
        print(sys.path)

resolve_path()

In [None]:
# import local config
import config

In [None]:
# import library dependencies
import numpy as np

In [None]:
# pytorch
import torch
import pytorch_lightning as L

In [None]:
# import local dependencies
# from src.adapters.datasets.wilddeepfake import WildDeepfakeDataModule
from src.adapters.datasets.sida import SidADataModule
from src.models.freqnet import LitFreqNet

In [None]:
model_id = "frequency_freqnet"
model_checkpoint_dir = f"{config.CHECKPOINTS_DIR}/{model_id}"

In [None]:
from torchvision import transforms

# --- common normalization (ImageNet) ---
imagenet_normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

# --- training transform ---
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),        # resize frame
    transforms.Lambda(lambda img: img.convert("RGB")),  # force RGB
    transforms.RandomHorizontalFlip(),    # flip for augmentation
    transforms.ColorJitter(               # optional: color variation
        brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1
    ),
    transforms.ToTensor(),
    imagenet_normalize,
])

# --- validation transform ---
val_transform = transforms.Compose([
    transforms.Resize((256, 256)),  # deterministic resize
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.ToTensor(),
    imagenet_normalize,
])

# --- test transform (usually same as val) ---
test_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.ToTensor(),
    imagenet_normalize,
])

transforms = {
    "train": train_transform,
    "val": val_transform,
    "test": test_transform
}

In [None]:
# Set seeds for reproducibility
seed = config.SEED

torch.manual_seed(seed)
np.random.seed(seed)

# Determine device (GPU or CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
dataset_name = "xingjunm/WildDeepfake"
max_samples = 994_000  # For quick development, remove for full dataset
batch_size = 16
num_workers = 0
max_epochs = 30

# define datamodule
# wilddeepfake_data_module = WildDeepfakeDataModule(
#     dataset_name=dataset_name,
#     batch_size=batch_size,
#     num_workers=num_workers,
#     max_samples=max_samples,
#     seed=seed,
#     transforms=DEFAULT_DATA_TRANSFORMS,
# )
from src.adapters.datasets.wilddeepfake import load_streaming_dataset, create_data_loaders
datasets = load_streaming_dataset(
    dataset_name,
    max_samples=max_samples,
    seed=seed
)
train_loader, val_loader, test_loader = create_data_loaders(
    datasets,
    batch_size=batch_size,
    num_workers=num_workers,
    transforms=transforms,
)

In [None]:
# define early stopper
early_stop_callback = L.callbacks.EarlyStopping(
    monitor="val_loss",       # metric to track
    patience=3,               # epochs to wait for improvement
    mode="min",               # "min" because we want val_loss to decrease
    verbose=True
)

In [None]:
# define ligntning checkpoint
best_loss_checkpoint = L.callbacks.ModelCheckpoint(
    monitor="val_loss",
    mode="min",
    save_top_k=1,
)

In [None]:
# define model
deepfake_detector = LitFreqNet()

In [None]:
trainer = L.Trainer(
    devices=1,
    callbacks=[early_stop_callback, best_loss_checkpoint],
    default_root_dir=model_checkpoint_dir,
    log_every_n_steps=10,
    profiler="simple", # track time taken
    max_steps= max_epochs * max_samples / batch_size, #(desired_epochs × dataset_size) / batch_size
    # limit_train_batches=1000,   # how many batches per "epoch"
    # limit_val_batches=200,      # how many val batches per "epoch"
  )

In [None]:
# train model
# trainer.fit(deepfake_detector, datamodule=sida_data_module)
trainer.fit(deepfake_detector, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [None]:
# test dataset on unseen samples
# trainer.test(deepfake_detector, datamodule=sida_data_module)
trainer.test(deepfake_detector, test_loader)

In [None]:
# view metrics from previous runs
%reload_ext tensorboard
%tensorboard --logdir=$model_checkpoint_dir

In [None]:

import torch
print('Torch version:', torch.__version__)
print('CUDA available:', torch.cuda.is_available())
print('CUDA version from torch:', torch.version.cuda)
print('Device count:', torch.cuda.device_count())
