In [1]:
# Installing extra dependencies into drive
!pip install lightning



In [2]:
import sys
import os

def is_colab_env():
    return "google.colab" in sys.modules

def mount_google_drive(drive_dir="/content/drive/", repo_dir="MyDrive/repositories/deepfake-detection"):
    # mount google drive
    from google.colab import drive
    drive.mount(drive_dir)

    # change to correct working directory
    import os
    os.chdir(f"{drive_dir}{repo_dir}")
    print(os.listdir()) # verify content

def resolve_path(levels_deep=3):
    if is_colab_env():
        mount_google_drive()
    else:
        # Get the directory of the current script
        current_dir = os.path.dirname(os.path.abspath('__file__'))

        # Construct the path to the parent directory
        for i in range(levels_deep):
            current_dir = os.path.dirname(current_dir)

        # Add the parent directory to sys.path
        sys.path.append(current_dir)
        print(sys.path)

resolve_path()

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
['README.md', 'src', 'environment.yml', '.git', '.gitignore', 'config.py', 'playground', 'reports', '__pycache__', 'lightning_logs']


In [3]:
# import local config
import config

In [4]:
# pytorch
import torch
import pytorch_lightning as L

In [5]:
from torchvision import transforms

In [6]:
# import local dependencies
from src.adapters.datasets.sida import SidADataModule
# from src.transforms.ela import ela
from src.models.encoder_decoder import DisentanglementDeepfakeDetector

In [7]:
# Initialize model
model = DisentanglementDeepfakeDetector(
    input_channels=3,
    hidden_dim=512,
    learning_rate=1e-3,
    loss_weights={'alpha': 1.0, 'beta': 0.1, 'gamma': 0.1}
)
image_size = 224
DEFAULT_DATA_TRANSFORMS = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(image_size),
        transforms.RandomHorizontalFlip(),
        # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        # transforms.Normalize(RESNET_INPUT_MEAN, RESNET_INPUT_SD) #mean and std dev values for each channel from ImageNet (pretrain data)
    ]),
    'val': transforms.Compose([
        transforms.Resize(int(image_size * 1.14)),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        # transforms.Normalize(RESNET_INPUT_MEAN, RESNET_INPUT_SD) #mean and std dev values for each channel from ImageNet (pretrain data)
    ]),
    'test': transforms.Compose([
        transforms.Resize(int(image_size * 1.14)),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        # transforms.Normalize(RESNET_INPUT_MEAN, RESNET_INPUT_SD) #mean and std dev values for each channel from ImageNet (pretrain data)
    ]),
}

# Initialize data module (you'll need to implement this)
# sida_data_module = SidADataModule(seed=42, transforms=DEFAULT_DATA_TRANSFORMS)

# Initialize trainer
trainer = L.Trainer(
    max_epochs=50,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    log_every_n_steps=10,
    check_val_every_n_epoch=1
)

INFO:pytorch_lightning.utilities.rank_zero:ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [8]:
from src.adapters.datasets.wilddeepfake import load_streaming_dataset, create_data_loaders

dataset_name = "xingjunm/WildDeepfake"
max_samples = 50000  # For quick development, remove for full dataset
batch_size = 16
num_workers = 2
seed = config.SEED

datasets = load_streaming_dataset(
    dataset_name,
    max_samples=max_samples,
    seed=seed
)
train_loader, val_loader, test_loader = create_data_loaders(
    datasets,
    batch_size=batch_size,
    num_workers=num_workers,
    transforms=DEFAULT_DATA_TRANSFORMS
)

Loading streaming dataset: xingjunm/WildDeepfake


Resolving data files:   0%|          | 0/963 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/157 [00:00<?, ?it/s]

In [10]:
trainer

<pytorch_lightning.trainer.trainer.Trainer at 0x79a7879c9250>

In [9]:
# Train the model
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type                   | Params | Mode 
-----------------------------------------------------------------
0 | encoder       | DisentanglementEncoder | 40.4 M | train
1 | recombination | RecombinationModule    | 526 K  | train
2 | criterion     | DisentanglementLoss    | 0      | train
-----------------------------------------------------------------
40.9 M    Trainable params
0         Non-trainable params
40.9 M    Total params
163.514   Total estimated model params size (MB)
43        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.12/dist-packages/pytorch_lightning/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:
Detected KeyboardInterrupt, attempting graceful shutdown ...
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/pytorch_lightning/trainer/call.py", line 49, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/pytorch_lightning/trainer/trainer.py", line 598, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/usr/local/lib/python3.12/dist-packages/pytorch_lightning/trainer/trainer.py", line 1011, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/pytorch_lightning/trainer/trainer.py", line 1055, in _run_stage
    self.fit_loop.run()
  File "/usr/local/lib/python3.12/dist-packages/pytorch_lightning/loops/fit_loop.py", line 216, in run
    self.advance()
  File "/usr/local/lib/python3.12/dist-packages/pytorch_lightning/loops/fit_loop.py", line 458, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/usr/local/lib/python3.12

TypeError: object of type 'NoneType' has no len()

In [None]:
# Test the model
trainer.test(model, test_loader)

print("Disentanglement learning model setup complete!")
print("Key components:")
print("- Encoder: Disentangles features into content, specific, and common")
print("- Recombination: Combines features with attention for classification")
print("- Custom loss: Balances classification, orthogonality, and mutual information")