In [1]:
import sys
import os

def is_colab_env():
    return "google.colab" in sys.modules

def mount_google_drive(drive_dir="/content/drive/", repo_dir="MyDrive/repositories/deepfake-detection"):
    # mount google drive
    from google.colab import drive
    drive.mount(drive_dir)

    # change to correct working directory
    import os
    os.chdir(f"{drive_dir}{repo_dir}")
    print(os.listdir()) # verify content

def resolve_path(levels_deep=3):
    if is_colab_env():
        mount_google_drive()
    else:
        # Get the directory of the current script
        current_dir = os.path.dirname(os.path.abspath('__file__'))

        # Construct the path to the parent directory
        for i in range(levels_deep):
            current_dir = os.path.dirname(current_dir)

        # Add the parent directory to sys.path
        sys.path.append(current_dir)
        print(sys.path)

resolve_path()

['/home/FYP/tanj0303/.conda/envs/df-env/lib/python310.zip', '/home/FYP/tanj0303/.conda/envs/df-env/lib/python3.10', '/home/FYP/tanj0303/.conda/envs/df-env/lib/python3.10/lib-dynload', '', '/home/FYP/tanj0303/.conda/envs/df-env/lib/python3.10/site-packages', '/home/FYP/tanj0303']


In [9]:
data_dir = "/scratch-shared/TANJ0303/datasets/"

# import local config
import config
model_id = "fft_magnitude_phase_resnet18"
model_checkpoint_dir = f"{config.CHECKPOINTS_DIR}/{model_id}"
checkpoint = "lightning_logs/version_0/checkpoints/epoch=3-step=202884.ckpt"

In [3]:
# import local dependencies
from src.adapters.datasets.wilddeepfake import WildDeepfakeDataModule
from src.models.resnet import ResNetClassifier

In [12]:
import pytorch_lightning as L

In [16]:
import torch

In [10]:
deepfake_detector = ResNetClassifier.load_from_checkpoint(f"{model_checkpoint_dir}/{checkpoint}")

In [14]:
# define trainer
# max_steps ~ (num_samples // batch_size) * num_epochs
max_epochs = 100
trainer = L.Trainer(
    devices=1,
    # callbacks=[early_stop_callback, best_loss_checkpoint],
    default_root_dir=model_checkpoint_dir,
    log_every_n_steps=10,
    profiler="simple", # track time taken
    # max_steps=31_250,
    # max_steps= max_epochs * max_samples / batch_size,
    max_epochs=max_epochs,
    # limit_train_batches=1000,   # how many batches per "epoch"
    # limit_val_batches=200,      # how many val batches per "epoch"
  )

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/FYP/tanj0303/.conda/envs/df-env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [19]:
from src.transforms.frequency import get_transforms
fft_mag_phase_transforms = get_transforms("fft_ri")

In [20]:
# max_samples = 800_000  # For quick development, remove for full dataset
batch_size = 16
num_workers = 4
seed = config.SEED

# Determine device (GPU or CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# define datamodule
wilddeepfake_data_module = WildDeepfakeDataModule(
    batch_size=batch_size,
    num_workers=num_workers,
    # max_samples=max_samples,
    seed=seed,
    transforms=fft_mag_phase_transforms,
    dataset_cache_dir=data_dir
)

Using device: cuda


In [21]:
trainer.test(deepfake_detector, datamodule=wilddeepfake_data_module)

Resolving data files:   0%|          | 0/963 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/157 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/124 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/20 [00:00<?, ?it/s]

Dataset loaded. Processing samples...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Extracted labels. Generating train/val split...
Train samples: 811549, Val samples: 202888, Test samples: 165662


Testing: |          | 0/? [00:00<?, ?it/s]

TEST Profiler Report

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                   	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                    	|  -              	|  144987         	|  410.61         	|  100 % 

â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
       Test metric             DataLoader 0
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
        test_acc            0.6806449294090271
        test_auc            0.6912497282028198
         test_f1            0.7633298635482788
        test_loss           0.6073537468910217
        test_prec           0.7321085333824158

[{'test_acc': 0.6806449294090271,
  'test_prec': 0.7321085333824158,
  'test_rec': 0.797332763671875,
  'test_f1': 0.7633298635482788,
  'test_auc': 0.6912497282028198,
  'test_loss': 0.6073537468910217}]