In [None]:
%cd /content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/ImagebindDeepfakeDetection

/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/ImagebindDeepfakeDetection


In [None]:
!pip install -r requirements.txt --quiet
!pip install pytorch_lightning --quiet

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m86.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m152.6/152.6 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m890.1/890.1 MB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import os
import logging
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import data
from models import imagebind_model
from models import lora as LoRA
from models.imagebind_model import ModalityType, load_module, save_module
import torchvision
from torchvision import transforms

import pytorch_lightning as L
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.loggers import CSVLogger



In [None]:
seed = 0
def seed_everything_func(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything_func(seed)
seed_everything(seed, workers=True)

INFO:lightning_fabric.utilities.seed:Seed set to 0


0

In [None]:
logging.basicConfig(level=logging.INFO, force=True)

num_workers = os.cpu_count()
print(f"Number of CPU cores: {num_workers}")
self_contrast = False
batch_size = 32
lora_modality_names_123 = ["vision", "text"]
LOG_ON_STEP = False
LOG_ON_EPOCH = True
lora= False
full_model_checkpointing = False
full_model_checkpoint_dir="./.checkpoints/full"
lora_checkpoint_dir="./.checkpoints/lora"
device_name="cuda:0" if torch.cuda.is_available() else "cpu"
max_epochs = 5
gradient_clip_val=1.0
loggers = None
linear_probing = True

Number of CPU cores: 12


In [None]:
class ImageBindTrain(L.LightningModule):
    def __init__(self, lr=5e-4, weight_decay=1e-4, max_epochs=500, batch_size=32, num_workers=4, seed=0,
                 self_contrast=False, temperature=0.07,  momentum_betas=(0.9, 0.95),
                 lora=False, lora_rank=4, lora_checkpoint_dir="./.checkpoints/lora",
                 lora_layer_idxs=None, lora_modality_names=None,
                 linear_probing=False, real = None, fake = None
                 ):
        super().__init__()
        assert not (linear_probing and lora), \
            "Linear probing is a subset of LoRA training procedure for ImageBind. " \
            "Cannot set both linear_probing=True and lora=True. " \
            "Linear probing stores params in lora_checkpoint_dir"
        self.save_hyperparameters()

        self.real_label = real
        self.fake_label = fake

        # Load full pretrained ImageBind model
        self.model = imagebind_model.imagebind_huge(pretrained=True)

        # Freeze pre-trained layers
        if lora:
            for modality_preprocessor in self.model.modality_preprocessors.children():
                modality_preprocessor.requires_grad_(False)
            for modality_trunk in self.model.modality_trunks.children():
                modality_trunk.requires_grad_(False)

            self.model.modality_trunks.update(LoRA.apply_lora_modality_trunks(self.model.modality_trunks, rank=lora_rank,
                                                                              layer_idxs=lora_layer_idxs,
                                                                              modality_names=lora_modality_names))
            LoRA.load_lora_modality_trunks(self.model.modality_trunks, checkpoint_dir=lora_checkpoint_dir)

            # Load postprocessors & heads
            load_module(self.model.modality_postprocessors, module_name="postprocessors",
                        checkpoint_dir=lora_checkpoint_dir)
            load_module(self.model.modality_heads, module_name="heads",
                        checkpoint_dir=lora_checkpoint_dir)
        elif linear_probing:
            for modality_preprocessor in self.model.modality_preprocessors.children():
                modality_preprocessor.requires_grad_(False)
            for modality_trunk in self.model.modality_trunks.children():
                modality_trunk.requires_grad_(False)
            for modality_postprocessor in self.model.modality_postprocessors.children():
                modality_postprocessor.requires_grad_(False)

            load_module(self.model.modality_heads, module_name="heads",
                        checkpoint_dir=lora_checkpoint_dir)
            for modality_head in self.model.modality_heads.children():
                modality_head.requires_grad_(False)
                final_layer = list(modality_head.children())[-1]
                final_layer.requires_grad_(True)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay,
                                betas=self.hparams.momentum_betas)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.hparams.max_epochs, eta_min=self.hparams.lr / 50
        )
        return [optimizer], [lr_scheduler]

    def info_nce_loss(self, batch, mode="train"):
        data_a, class_a, labels, class_b = batch
        real = self.model({ModalityType.TEXT:self.real_label})[ModalityType.TEXT]
        fake = self.model({ModalityType.TEXT:self.fake_label})[ModalityType.TEXT]

        # class_a is always "vision" according to ImageBind
        feats_a_tensor = self.model({ModalityType.VISION: data_a})[ModalityType.VISION]
        final_nll = False
        for feats_idx, feats_tensor in enumerate(feats_a_tensor):
            if labels[feats_idx] == 'real':
                pos_sim = F.cosine_similarity(feats_tensor,real)/self.hparams.temperature
                neg_sim = F.cosine_similarity(feats_tensor,fake)/self.hparams.temperature
            elif labels[feats_idx] == 'fake':
                pos_sim = F.cosine_similarity(feats_tensor,fake)/self.hparams.temperature
                neg_sim = F.cosine_similarity(feats_tensor,real)/self.hparams.temperature

            numerator = torch.exp(pos_sim)

            denominator = torch.exp(pos_sim) + torch.exp(neg_sim)

            nll = -torch.log(numerator / denominator)

            self.log(mode + "_loss", nll, prog_bar=True,
                    on_step=LOG_ON_STEP, on_epoch=LOG_ON_EPOCH, batch_size=self.hparams.batch_size)

            if not final_nll:
              final_nll = nll
            else:
              final_nll += nll
              final_nll /= 2

        self.log(mode + "_loss", final_nll, prog_bar=True,
                on_step=LOG_ON_STEP, on_epoch=LOG_ON_EPOCH, batch_size=self.hparams.batch_size)
        return final_nll



    def training_step(self, batch, batch_idx):
        return self.info_nce_loss(batch, mode="train")

    def validation_step(self, batch, batch_idx):
        self.info_nce_loss(batch, mode="val")

    def on_validation_epoch_end(self):
        if self.hparams.lora:
            # Save LoRA checkpoint
            LoRA.save_lora_modality_trunks(self.model.modality_trunks, checkpoint_dir=self.hparams.lora_checkpoint_dir)
            # Save postprocessors & heads
            save_module(self.model.modality_postprocessors, module_name="postprocessors",
                        checkpoint_dir=self.hparams.lora_checkpoint_dir)
            save_module(self.model.modality_heads, module_name="heads",
                        checkpoint_dir=self.hparams.lora_checkpoint_dir)
        elif self.hparams.linear_probing:
            # Save postprocessors & heads
            save_module(self.model.modality_heads, module_name="heads",
                        checkpoint_dir=self.hparams.lora_checkpoint_dir)

In [None]:
class ImageTextDataset(Dataset):
    def __init__(self, root_dir, split='train', random_seed=0, device='cpu'):
        self.root_dir = root_dir
        self.device = device

        self.classes = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

        self.image_paths = []
        for cls in self.classes:
            cls_image_dir = os.path.join(root_dir, cls)
            for filename in os.listdir(cls_image_dir):
                if filename.endswith(".png"):
                    img_path = os.path.join(cls_image_dir, filename)
                    self.image_paths.append((img_path, cls))


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        img_path, class_text = self.image_paths[index]

        # Load and transform image
        images = data.load_and_transform_vision_data([img_path], self.device, to_tensor=True)


        return images, ModalityType.VISION, class_text, ModalityType.TEXT

In [None]:
train_datasets = []
val_datasets = []
test_datasets = []

In [None]:
train_datasets.append(ImageTextDataset(
            root_dir=os.getcwd()+"/dataset/CELEB/train",))

In [None]:
val_datasets.append(ImageTextDataset(
            root_dir=os.getcwd()+"/dataset/CELEB/val",))

In [None]:
# test_datasets.append(ImageTextDataset(
#             root_dir=os.getcwd()+"/dataset/CELEB/test",
#             transform=ContrastiveTransformations(contrast_transforms_test,
#                                                  n_views=2 if self_contrast else 1)))
# test_datasets.append(ImageTextDataset(
#             root_dir=os.getcwd()+"/dataset/CELEB-M/test",
#             transform=ContrastiveTransformations(contrast_transforms_test,
#                                                  n_views=2 if self_contrast else 1)))
# test_datasets.append(ImageTextDataset(
#             root_dir=os.getcwd()+"/dataset/FS/test",
#             transform=ContrastiveTransformations(contrast_transforms_test,
#                                                  n_views=2 if self_contrast else 1)))
# test_datasets.append(ImageTextDataset(
#             root_dir=os.getcwd()+"/dataset/NT/test",
#             transform=ContrastiveTransformations(contrast_transforms_test,
#                                                  n_views=2 if self_contrast else 1)))
# test_datasets.append(ImageTextDataset(
#             root_dir=os.getcwd()+"/dataset/DF/test",
#             transform=ContrastiveTransformations(contrast_transforms_test,
#                                                  n_views=2 if self_contrast else 1)))
# test_datasets.append(ImageTextDataset(
#             root_dir=os.getcwd()+"/dataset/DFD/test",
#             transform=ContrastiveTransformations(contrast_transforms_test,
#                                                  n_views=2 if self_contrast else 1)))

In [None]:
train_dataset = train_datasets[0]
val_dataset = val_datasets[0]
# test_dataset_celeb = test_datasets[0]
# test_dataset_celeb_m = test_datasets[1]
# test_dataset_fs = test_datasets[2]
# test_dataset_nt = test_datasets[3]
# test_dataset_df = test_datasets[4]
# test_dataset_dfd = test_datasets[5]

In [None]:
train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=num_workers,
    )
val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=num_workers,
    )

In [None]:
lora_layer_idxs = {}
lora_modality_names = []
modalities = ["vision", "text", "audio", "thermal", "depth", "imu"]
for modality_name in lora_modality_names_123:
    if modality_name in modalities:
        modality_type = getattr(ModalityType, modality_name.upper())
        #lora_layer_idxs[modality_type] = getattr(args, f'lora_layer_idxs_{modality_name}', None)
        # if not lora_layer_idxs[modality_type]:
        #     lora_layer_idxs[modality_type] = None
        lora_layer_idxs[modality_type] = None
        lora_modality_names.append(modality_type)
    else:
        raise ValueError(f"Unknown modality name: {modality_name}")

In [None]:
real = data.load_and_transform_text(["real"], device_name)
fake = data.load_and_transform_text(["fake"], device_name)

In [None]:
real.shape, fake.shape

(torch.Size([1, 77]), torch.Size([1, 77]))

In [None]:
model = ImageBindTrain(
                        max_epochs=max_epochs, batch_size=batch_size,
                        num_workers=num_workers, self_contrast=self_contrast,
                        lora=lora, lora_checkpoint_dir=lora_checkpoint_dir,
                        lora_layer_idxs=lora_layer_idxs if lora_layer_idxs else None,
                        lora_modality_names=lora_modality_names if lora_modality_names else None,
                        linear_probing=linear_probing, real = real, fake = fake
                      )

INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmpc8jve96b
INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmpc8jve96b/_remote_module_non_scriptable.py


In [None]:
if full_model_checkpointing:
        checkpointing = {"enable_checkpointing": full_model_checkpointing,
                         "callbacks": [ModelCheckpoint(monitor="val_loss", dirpath=full_model_checkpoint_dir,
                                                        filename="imagebind-{epoch:02d}-{val_loss:.2f}",
                                                        save_last=True, mode="min")]}
else:
        checkpointing = {"enable_checkpointing": full_model_checkpointing,}

In [None]:
# Initialize the CSV logger
csv_logger = CSVLogger("logs", name="my_model")
trainer = Trainer(accelerator="gpu" if "cuda" in device_name else "cpu",
                      devices=1 if ":" not in device_name else [int(device_name.split(":")[1])], deterministic=True,
                      max_epochs=max_epochs, gradient_clip_val=gradient_clip_val,
                      logger=csv_logger, **checkpointing)

trainer.fit(model, train_loader, val_loader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type           | Params
-----------------------------------------
0 | model | ImageBindModel | 1.2 B 
--------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

INFO:root:Saved parameters for module heads to ./.checkpoints/lora.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:root:Saved parameters for module heads to ./.checkpoints/lora.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:root:Saved parameters for module heads to ./.checkpoints/lora.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:root:Saved parameters for module heads to ./.checkpoints/lora.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:root:Saved parameters for module heads to ./.checkpoints/lora.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:root:Saved parameters for module heads to ./.checkpoints/lora.
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.
