In [1]:
!pip install git+https://github.com/openai/CLIP.git --quiet

  Preparing metadata (setup.py) ... [?25l[?25hdone
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for clip (setup.py) ... [?25l[?25hdone


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.nn.functional as F
from tqdm import tqdm
import os
import random
import numpy as np
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import LambdaLR

from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

import math

In [None]:
def seed_everything(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything()

In [None]:
_tokenizer = _Tokenizer()

# **Attention**

In [None]:
class ContextAttention(nn.Module):
    def __init__(self, dim, num_heads, dropout_rate = 0.3):
        super(ContextAttention, self).__init__()
        self.attention_layer = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first= True)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, image_features, ctx_features):
        """
        Args:
            image_features: Tensor of shape [b, 196, 512] (batch size, seq length, feature size)
            ctx_features: Tensor of shape [b, n_ctx, 512] (context features for 'real' and 'fake')
        Returns:
            Tensor of shape [b, n_ctx, 512] containing attended context features.
        """
        attn_output, attn_weight = self.attention_layer(ctx_features, image_features, image_features)
        attn_output = self.dropout(attn_output)
        return attn_output, attn_weight # attn_output shape is (b, n_ctx, 512)

# **PromptLearner**

In [None]:
class PromptLearner(nn.Module):
    def __init__(self, n_ctx, classnames, clip_model, dropout_rate=0.3):
        super().__init__()
        dtype = torch.float32
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        ctx_dim = clip_model.ln_final.weight.shape[0] # clip_model.ln_final.weight.shape[0] is 512
        self.n_cls = len(classnames)

        ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype, device = device)
        nn.init.normal_(ctx_vectors, std=0.02)
        self.ctx = nn.Parameter(ctx_vectors)  # to be optimized --> shape = [n_ctx, 512]

        self.ctx_attn = ContextAttention(512, 8)

        prompt_prefix = " ".join(["X"] * n_ctx)
        prompts = [prompt_prefix + " " + name + "." for name in classnames]
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(device) # tokenized_prompts.shape is (2, 77)
        with torch.no_grad():
                embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) # embedding.shape = torch.Size([2, 77, 512])
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS --> shape = [2, 1, 512]
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS -->shape = [2, 61, 512]
        self.tokenized_prompts = tokenized_prompts # tokenized_prompts.shape is (2, 77)
        self.dropout = nn.Dropout(dropout_rate)


    def forward(self, image_features):
        image_features = F.normalize(image_features, dim=1)  # shape = [b, 196, 512]
        ctx_features = F.normalize(self.ctx, dim=1)  # shape = [n_ctx, 512]
        # image_features = image_features / image_features.norm(dim=1, keepdim=True) # shape = [b, 196, 512]
        # ctx_features = self.ctx / self.ctx.norm(dim=1, keepdim=True) # shape = [n_ctx, 512]
        ctx_features = ctx_features[None, :, :] # shape = [1, n_ctx, 512]
        ctx_features = ctx_features.repeat(image_features.shape[0], 1, 1) # shape = [b, n_ctx, 512]
        ctx_features_attn_output, ctx_features_attn_weight = self.ctx_attn(image_features, ctx_features) # shape = [b, n_ctx, 512]
        ctx_features_attn_output = ctx_features_attn_output + ctx_features # shape = [b, n_ctx, 512]
        ctx_features_attn_output = ctx_features_attn_output.mean(dim = 0) # shape = [n_ctx, 512]
        ctx_features_attn_output = ctx_features_attn_output.unsqueeze(0) # shape = [1, n_ctx, 512]
        ctx_features_attn_output = ctx_features_attn_output.expand(self.n_cls, -1, -1) #shape = [2, n_ctx, 512]
        ctx_features_attn_output = self.dropout(ctx_features_attn_output)


        prompts = torch.cat([self.token_prefix, ctx_features_attn_output, self.token_suffix], dim=1) # shape = [2, 1 + n_ctx + * , 512]
        # shape of prompts = [2, 77, 512]
        return prompts


# **TextEncoder**

In [None]:
'''
This TextEncoder class is designed to process text using elements from the CLIP model's architecture.
It uses a transformer module to encode a series of tokenized text prompts and extract features from the "end of text" (EOT) token,
which often represents a summary of the input sequence in CLIP's text encoding.
'''

class TextEncoder(nn.Module):
    def __init__(self, clip_model, dropout_rate=0.3):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final #layer normalization module
        self.text_projection = clip_model.text_projection # A linear projection matrix to map transformer outputs to a feature space used by the CLIP model.
        self.dtype = torch.float32
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype) # shape of positional_embedding = torch.Size([77, 512])
        '''
        x.shape = [ batch_size (N), n_ctx (L), transformer.width (D) ]
        '''
        x = F.normalize(x, dim=-1)  # Normalize along the feature dimension
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)
        x = self.dropout(x)
        x = F.normalize(x, dim=-1)  # Normalize before projection

        # x.shape is  torch.Size([2, 77, 512])

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        '''
        x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] selects the feature vector corresponding to the EOT token for each prompt in the batch.
        This vector is then linearly projected with self.text_projection, giving the final encoded feature vector for each text prompt.
        '''
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
        # self.text_projection.shape =  torch.Size([512, 512])
        # x.shape =  torch.Size([2, 512])

        return x

# **MultiscaleAdapter**

In [None]:
class MultiscaleAdapter(nn.Module):
    def __init__(self, input_dim, dropout_rate=0.3):
        super().__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_dim, input_dim, kernel_size=1),
            nn.BatchNorm2d(input_dim),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(input_dim, input_dim, kernel_size=1),
            nn.BatchNorm2d(input_dim),
            nn.ReLU()
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(input_dim, input_dim, kernel_size=1),
            nn.BatchNorm2d(input_dim),
            nn.ReLU()
        )

        # Modify branches to avoid errors
        self.branch_f = nn.Sequential(
            nn.Conv2d(input_dim, input_dim, kernel_size=1, padding=0),
            nn.BatchNorm2d(input_dim),
            nn.ReLU(),
            nn.Conv2d(input_dim, input_dim, kernel_size=1, padding=0),
            nn.BatchNorm2d(input_dim),
            nn.ReLU()
        )
        self.branch_g = nn.Sequential(
            nn.Conv2d(input_dim, input_dim, kernel_size=1, padding=0),
            nn.BatchNorm2d(input_dim),
            nn.ReLU(),
            nn.Conv2d(input_dim, input_dim, kernel_size=3, padding=1, dilation=1),
            nn.BatchNorm2d(input_dim),
            nn.ReLU()
        )
        self.branch_h = nn.Sequential(
            nn.Conv2d(input_dim, input_dim, kernel_size=1, padding=0),
            nn.BatchNorm2d(input_dim),
            nn.ReLU(),
            nn.Conv2d(input_dim, input_dim, kernel_size=3, padding=1, dilation=1),
            nn.BatchNorm2d(input_dim),
            nn.ReLU()
        )

        # Combine features and project back to original dimension
        self.combine = nn.Sequential(
            nn.Conv2d(input_dim * 3, input_dim, kernel_size=1),
            nn.BatchNorm2d(input_dim),
            nn.ReLU()
        )
        self.final_projection = nn.Sequential(
            nn.Conv2d(input_dim, input_dim, kernel_size=1),
            nn.BatchNorm2d(input_dim),
            nn.ReLU()
        )
       # Initialize weights
        self._init_weights()

    def _init_weights(self):
      # Kaiming initialization for convolutional layers with ReLU activation
      for m in self.modules():
          if isinstance(m, nn.Conv2d):
              nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
              if m.bias is not None:
                  nn.init.zeros_(m.bias)

    def forward(self, x):
        # x.shape = (32, 768, 196)
        x = x.unsqueeze(2)
        # Now x.shape =  torch.Size([32, 768, 1, 196])
        x = x.reshape(x.shape[0], x.shape[1], int(math.sqrt(x.shape[3])), int(math.sqrt(x.shape[3])))
        # x.shape =  torch.Size([32, 768, 14, 14])
        x_fg_input = self.conv1(x)
        x_h_input = self.conv2(x)
        # Branch processing
        x_f = F.normalize(self.branch_f(x_fg_input), dim=1)  # Normalize after branch processing
        x_g = F.normalize(self.branch_g(x_fg_input), dim=1)
        x_h = F.normalize(self.branch_h(x_h_input), dim=1)

        # Concatenating along channel dimension and projecting back
        x_out = torch.cat([x_f, x_g, x_h], dim=1)
        x_out = self.combine(x_out)

        # Adding the processed multi-scale features to the original input
        x_out = self.final_projection(x_out) + F.normalize(self.conv3(x), dim=1)
        x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], -1)
        x_out = self.dropout(x_out)

        return x_out

# **CustomCLIP**

In [None]:
class CustomCLIP(nn.Module):
    def __init__(self, n_ctx, classnames, clip_model, dropout_rate = 0.3):
        super().__init__()
        self.prompt_learner = PromptLearner(n_ctx, classnames, clip_model)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts # tokenized_prompts.shape is (2, 77)

        self.image_encoder = clip_model.visual
        self.patch_embedding = self.image_encoder.conv1
        self.class_embedding = self.image_encoder.class_embedding
        self.positional_embedding = self.image_encoder.positional_embedding
        self.ln_pre = self.image_encoder.ln_pre
        for param in self.image_encoder.parameters():
            param.requires_grad = False

        self.text_encoder = TextEncoder(clip_model)
        for param in self.text_encoder.parameters():
            param.requires_grad = False

        embed_dim = self.image_encoder.ln_post.weight.shape[0]
        self.multiscale_adapters = nn.ModuleList(
                                    [MultiscaleAdapter(input_dim=embed_dim)
                                    for _ in range(len(self.image_encoder.transformer.resblocks))]
                                  )
        self.image_features_proj = nn.Linear(embed_dim, 512)
        self._initialize_weights()
        self.dtype = torch.float32
        self.dropout = nn.Dropout(dropout_rate)
        self.temperature_scale = nn.Parameter(torch.tensor(10.0))
        self.logit_scale = clip_model.logit_scale

    def _initialize_weights(self):
        # Custom initialization for image_features_proj
        nn.init.xavier_uniform_(self.image_features_proj.weight)
        nn.init.constant_(self.image_features_proj.bias, 0)


    def forward(self, images):
        x = self.patch_embedding(images.type(torch.float32))  # (32, 768, 14, 14)
        x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)  # (batch, num_patches, embed_dim) i.e; (32, 196, 768)
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
        # Now x.shape is [batch_size, num_patches + 1, 768] i.e; torch.Size([32, 197, 768])
        x = x + self.positional_embedding.to(x.dtype)
        x = self.ln_pre(x) # shape = [32, 197, 768]
        # Now x.shape is torch.Size([32, 197, 768])
        #x = x[:,1:,:] # shape = [32, 196, 768]

        intermediate_outputs = []
        for i, block in enumerate(self.image_encoder.transformer.resblocks):
            x = block(x) # shape = [32, 197, 768]
            x_cls = x[:,0,:] # shape = [32, 768]
            x_cls = x_cls.unsqueeze(1) # shape = [32, 1, 768]
            intermediate_outputs.append(x)
            x = x[:,1:,:] # shape = [32, 196, 768]
            x = x.permute(0, 2, 1) # shape = [32, 768, 196]
            x = self.multiscale_adapters[i](x) + x # output of MSA 's shape = [32, 768, 196]
            x = x.permute(0, 2, 1) # shape = [32, 196, 768]
            x = torch.cat([x_cls, x], dim=1) # shape = [32, 197, 768]
            x = F.normalize(x, p=2, dim=-1)
        image_features = self.image_features_proj(x) # shape = [32, 197, 512]
        image_features_cls = image_features[:,0,:] # shape = [32, 512]
        image_features_cls = image_features_cls.unsqueeze(1) # shape = [32, 1, 512]
        image_features = image_features[:,1:,:] # shape = [32, 196, 512]
        image_features = self.dropout(image_features)
        with torch.no_grad():
          promptlearner_outputs = self.prompt_learner(image_features) # shape = [2, 77, 512]
          text_features = self.text_encoder(promptlearner_outputs, self.tokenized_prompts) # shape = [2, 512]
        text_features = text_features.unsqueeze(0) # shape = [1, 2, 512]

        # Expand text_features to match image_features' batch size
        text_features_expanded = text_features.expand(image_features.size(0), -1, -1)  # Shape: (32, 2, 512)
        #image_features_expanded = image_features_cls.unsqueeze(2)  # Shape: (32, 1, 1, 512)
        text_features_expanded = text_features_expanded.unsqueeze(1)  # Shape: (32, 1, 2, 512)

        #image_features_expanded = F.normalize(image_features_expanded, p=2, dim=-1)
        text_features_expanded = F.normalize(text_features_expanded, p=2, dim=-1)
        # cosine_similarities = F.cosine_similarity(image_features_expanded, text_features_expanded, dim=-1)
        # Cosine similarities shape: torch.Size([32, 1, 2])

        image_features_cls = F.normalize(image_features_cls, p=2, dim=-1)


        # cosine_similarities = cosine_similarities.squeeze(1) # Shape: (32, 2)
        logits = torch.matmul(image_features_cls, text_features_expanded.squeeze(1).transpose(1, 2))
        logits = logits.squeeze(1) # Shape: (32, 2)


        #return cosine_similarities * self.temperature_scale
        logit_scale = self.logit_scale.exp()
        return logits * logit_scale


# **Training**

In [None]:
data_dir = '/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/CLIP_based_deepfake_detection/dataset/CELEB'

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
])

# Load the dataset
train_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'train'), transform=transform)
train_dataset.class_to_idx = {'real': 0, 'fake': 1}
val_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'val'), transform=transform)
val_dataset.class_to_idx = {'real': 0, 'fake': 1}

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers= os.cpu_count())
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers= os.cpu_count())

In [None]:
def get_scheduler(optimizer, num_warmup_steps, num_training_steps):
    # Lambda function for warmup phase
    def lr_lambda(current_step: int):
      if current_step < num_warmup_steps:
          return float(current_step) / float(max(1, num_warmup_steps))
      return max(
          0.0, 0.5 * (1.0 + torch.cos(torch.tensor(torch.pi * (current_step - num_warmup_steps) / (num_training_steps - num_warmup_steps))))
      )
    return LambdaLR(optimizer, lr_lambda)


def train_model(model, train_loader, val_loader, device, num_epochs=10, lr=1e-4, warmup_steps=100):
    model.train()
    scaler = GradScaler()

    #optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay= 1e-5)
    optimizer = optim.SGD(
                    filter(lambda p: p.requires_grad, model.parameters()),
                    lr=lr,
                    momentum=0.9,
                    weight_decay = 1e-4
                    )
    total_steps = len(train_loader) * num_epochs
    scheduler = get_scheduler(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

    criterion = nn.CrossEntropyLoss()
    #criterion = nn.BCEWithLogitsLoss()

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)
            #labels_one_hot = F.one_hot(labels, num_classes=2).float()
            optimizer.zero_grad()

            with autocast():
                logits = model(images)
                loss = criterion(logits, labels)

            scaler.scale(loss).backward()
            # Apply gradient clipping before optimizer step
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            scheduler.step()  # Update learning rate after every optimizer step

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(logits, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        epoch_loss = running_loss / total
        epoch_acc = correct / total

        val_loss, val_acc = validate_model(model, val_loader, criterion, device)

        print(f"Epoch [{epoch + 1}/{num_epochs}], "
              f"Train Loss: {epoch_loss:.4f}, Train Accuracy: {epoch_acc:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.4f}")

def validate_model(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            #labels_one_hot = F.one_hot(labels, num_classes=2).float()

            with autocast():  # Use autocast in validation as well, if mixed precision
                logits = model(images)
                loss = criterion(logits, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(logits, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    val_loss = running_loss / total
    val_acc = correct / total
    return val_loss, val_acc

In [None]:
# small_train_loader = torch.utils.data.DataLoader(
#     torch.utils.data.Subset(train_dataset, range(32)),
#     batch_size=32, shuffle=True
# )
# small_val_loader = torch.utils.data.DataLoader(
#     torch.utils.data.Subset(val_dataset, range(32)),
#     batch_size=32, shuffle=False
# )

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_model, preprocess = clip.load("ViT-B/16", device=device)
clip_model = clip_model.float()

n_ctx = 15  # Number of context tokens
classnames = ["real", "fake"]

model = CustomCLIP(n_ctx=n_ctx, classnames=classnames, clip_model=clip_model).to(device)
train_model(model, train_loader, val_loader, device, num_epochs=20, lr=1e-1, warmup_steps=500)
#train_model(model, small_train_loader, small_val_loader, device, num_epochs=5, lr=1e-2, warmup_steps=10)

  scaler = GradScaler()
  with autocast():
Training Epoch 1/20: 100%|██████████| 500/500 [03:10<00:00,  2.62it/s]
  with autocast():  # Use autocast in validation as well, if mixed precision


Epoch [1/20], Train Loss: 1.5025, Train Accuracy: 0.5038, Val Loss: 0.7054, Val Accuracy: 0.5000


Training Epoch 2/20: 100%|██████████| 500/500 [01:46<00:00,  4.70it/s]


Epoch [2/20], Train Loss: 0.6956, Train Accuracy: 0.4994, Val Loss: 0.6917, Val Accuracy: 0.5160


Training Epoch 3/20: 100%|██████████| 500/500 [01:45<00:00,  4.73it/s]


Epoch [3/20], Train Loss: 0.6934, Train Accuracy: 0.4969, Val Loss: 0.6930, Val Accuracy: 0.5240


Training Epoch 4/20: 100%|██████████| 500/500 [01:46<00:00,  4.71it/s]


Epoch [4/20], Train Loss: 0.6934, Train Accuracy: 0.4998, Val Loss: 0.6938, Val Accuracy: 0.4360


Training Epoch 5/20: 100%|██████████| 500/500 [01:45<00:00,  4.74it/s]


Epoch [5/20], Train Loss: 0.6934, Train Accuracy: 0.5001, Val Loss: 0.6930, Val Accuracy: 0.5960


Training Epoch 6/20: 100%|██████████| 500/500 [01:45<00:00,  4.74it/s]


Epoch [6/20], Train Loss: 0.6935, Train Accuracy: 0.4976, Val Loss: 0.6941, Val Accuracy: 0.3960


Training Epoch 7/20: 100%|██████████| 500/500 [01:45<00:00,  4.72it/s]


Epoch [7/20], Train Loss: 0.6934, Train Accuracy: 0.4936, Val Loss: 0.6936, Val Accuracy: 0.4600


Training Epoch 8/20: 100%|██████████| 500/500 [01:45<00:00,  4.72it/s]


Epoch [8/20], Train Loss: 0.6934, Train Accuracy: 0.5011, Val Loss: 0.6931, Val Accuracy: 0.5000


Training Epoch 9/20: 100%|██████████| 500/500 [01:45<00:00,  4.73it/s]


Epoch [9/20], Train Loss: 0.6934, Train Accuracy: 0.5000, Val Loss: 0.6933, Val Accuracy: 0.5000


Training Epoch 10/20: 100%|██████████| 500/500 [01:45<00:00,  4.75it/s]


KeyboardInterrupt: 

# **Testing**

In [None]:
# Define the testing function
def test_model(model, test_loader, device):
    model.eval()  # Set the model to evaluation mode
    running_loss = 0.0
    correct = 0
    total = 0
    criterion = nn.CrossEntropyLoss()  # Loss function for testing

    with torch.no_grad():  # Disable gradient tracking
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            with autocast():  # Use autocast in validation as well, if mixed precision
                logits = model(images)
                loss = criterion(logits, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(logits, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    test_loss = running_loss / total
    test_acc = correct / total
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

# **CELEB-Test**

In [None]:
test_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'test'), transform=transform)
test_dataset.class_to_idx = {'real': 0, 'fake': 1}
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers= os.cpu_count())

In [None]:
# Test the model after training
test_model(model, test_loader, device)

# **CELEB-M**

In [None]:
dir = '/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/CLIP_based_deepfake_detection/dataset/CELEB-M'
test_dataset = datasets.ImageFolder(root=os.path.join(dir, 'test'), transform=transform)
test_dataset.class_to_idx = {'real': 0, 'fake': 1}
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers= os.cpu_count())

In [None]:
# Test the model after training
test_model(model, test_loader, device)

# **FS**

In [None]:
dir = '/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/CLIP_based_deepfake_detection/dataset/FS'
test_dataset = datasets.ImageFolder(root=os.path.join(dir, 'test'), transform=transform)
test_dataset.class_to_idx = {'real': 0, 'fake': 1}
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers= os.cpu_count())

In [None]:
# Test the model after training
test_model(model, test_loader, device)

# **NT**

In [None]:
dir = '/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/CLIP_based_deepfake_detection/dataset/NT'
test_dataset = datasets.ImageFolder(root=os.path.join(dir, 'test'), transform=transform)
test_dataset.class_to_idx = {'real': 0, 'fake': 1}
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers= os.cpu_count())

In [None]:
# Test the model after training
test_model(model, test_loader, device)

# **DF**

In [None]:
dir = '/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/CLIP_based_deepfake_detection/dataset/DF'
test_dataset = datasets.ImageFolder(root=os.path.join(dir, 'test'), transform=transform)
test_dataset.class_to_idx = {'real': 0, 'fake': 1}
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers= os.cpu_count())

In [None]:
# Test the model after training
test_model(model, test_loader, device)

# **DFD**

In [None]:
dir = '/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/CLIP_based_deepfake_detection/dataset/DFD'
test_dataset = datasets.ImageFolder(root=os.path.join(dir, 'test'), transform=transform)
test_dataset.class_to_idx = {'real': 0, 'fake': 1}
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers= os.cpu_count())

In [None]:
# Test the model after training
test_model(model, test_loader, device)

# **F2F**

In [None]:
dir = '/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/CLIP_based_deepfake_detection/dataset/F2F'
test_dataset = datasets.ImageFolder(root=os.path.join(dir, 'test'), transform=transform)
test_dataset.class_to_idx = {'real': 0, 'fake': 1}
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers= os.cpu_count())

In [None]:
# Test the model after training
test_model(model, test_loader, device)