In [None]:
!pip install git+https://github.com/openai/CLIP.git --quiet

  Preparing metadata (setup.py) ... [?25l[?25hdone


In [None]:
# !pip install tensorboard --quiet

In [None]:
# from torch.utils.tensorboard import SummaryWriter

# # Initialize TensorBoard writer
# writer = SummaryWriter(log_dir='runs/my_model')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.nn.functional as F
from tqdm import tqdm
import os
import random
import numpy as np
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import LambdaLR
from torchvision.transforms import InterpolationMode
from scipy.stats import ks_2samp

from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

import math

In [None]:
def seed_everything(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything()

In [None]:
_tokenizer = _Tokenizer()

# **Attention**

In [None]:
class ContextAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super(ContextAttention, self).__init__()
        self.attention_layer = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first= True)

    def forward(self, image_features, ctx_features):
        """
        Args:
            image_features: Tensor of shape [b, 196, 512] (batch size, seq length, feature size)
            ctx_features: Tensor of shape [b, n_ctx, 512] (context features for 'real' and 'fake')
        Returns:
            Tensor of shape [b, n_ctx, 512] containing attended context features.
        """
        attn_output, attn_weight = self.attention_layer(ctx_features, image_features, image_features)
        return attn_output, attn_weight # attn_output shape is (b, n_ctx, 512)

# **PromptLearner**

In [None]:
class PromptLearner(nn.Module):
    def __init__(self, n_ctx, classnames, clip_model):
        super().__init__()
        dtype = clip_model.dtype
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        ctx_dim = clip_model.ln_final.weight.shape[0] # clip_model.ln_final.weight.shape[0] is 512
        self.n_cls = len(classnames)

        ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype, device = device)

        nn.init.normal_(ctx_vectors)

        self.ctx = nn.Parameter(ctx_vectors)  # to be optimized --> shape = [n_ctx, 512]

        self.ctx_attn = ContextAttention(512, 8)

        prompt_prefix = " ".join(["X"] * n_ctx)
        prompts = [prompt_prefix + " " + name + "." for name in classnames]
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(device) # tokenized_prompts.shape is (2, 77)
        with torch.no_grad():
                embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) # embedding.shape = torch.Size([2, 77, 512])
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS --> shape = [2, 1, 512]
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS -->shape = [2, 61, 512]
        self.tokenized_prompts = tokenized_prompts # tokenized_prompts.shape is (2, 77)

        self.img_ln = nn.LayerNorm(512)
        self.ctx_ln = nn.LayerNorm(512)
        self.ctx_attn_ln_pre = nn.LayerNorm(ctx_dim)  # ctx_dim = 512
        self.ctx_attn_ln_post = nn.LayerNorm(ctx_dim)


    def forward(self, image_features, global_step=None):
        # image_features.shape is [b, 196, 512]
        ctx_features = self.ctx.unsqueeze(0) # shape = [1, n_ctx, 512]
        ctx_features = ctx_features.repeat(image_features.shape[0], 1, 1) # shape = [b, n_ctx, 512]

        # Standardize image_features
        image_features_standardized = self.img_ln(image_features)
        # writer.add_histogram('image_features_standardized', image_features_standardized, global_step)

        # Standardize ctx_features
        ctx_features_standardized = self.ctx_ln(ctx_features)
        # writer.add_histogram('ctx_features_standardized', ctx_features_standardized, global_step)

        # Patch enhanced attention calculation
        ctx_features_attn_output, ctx_features_attn_weight = self.ctx_attn(image_features_standardized, ctx_features_standardized) # shape = [b, n_ctx, 512]
        # writer.add_histogram('ctx_features_attn_output', ctx_features_attn_output, global_step)

        # Normalizing ctx_features_attention_output
        ctx_features_attn_output_standardized = self.ctx_attn_ln_pre(ctx_features_attn_output)
        # writer.add_histogram('ctx_features_attn_output_standardized', ctx_features_attn_output_standardized, global_step)

        ctx_features_attn_output = ctx_features_attn_output_standardized + ctx_features_standardized # shape = [b, n_ctx, 512]
        # writer.add_histogram('ctx_features_attn_output_standardized + ctx_features_standardized', ctx_features_attn_output, global_step)

        # Normalizing ctx_features_attention_output
        ctx_features_attn_output = self.ctx_attn_ln_post(ctx_features_attn_output)
        # writer.add_histogram('ctx_features_attn_output_final', ctx_features_attn_output, global_step)

        ctx_features_attn_output = ctx_features_attn_output.mean(dim = 0) # shape = [n_ctx, 512]

        ctx_features_attn_output = ctx_features_attn_output.unsqueeze(0) # shape = [1, n_ctx, 512]
        ctx_features_attn_output = ctx_features_attn_output.expand(self.n_cls, -1, -1) #shape = [2, n_ctx, 512]
        # writer.add_histogram('before_norm_token_prefix', self.token_prefix, global_step)
        # writer.add_histogram('before_norm_token_suffix', self.token_suffix, global_step)

        self.token_prefix = nn.functional.normalize(self.token_prefix, dim=-1)
        self.token_suffix = nn.functional.normalize(self.token_suffix, dim=-1)

        # writer.add_histogram('after_norm_token_prefix', self.token_prefix, global_step)
        # writer.add_histogram('after_norm_token_suffix', self.token_suffix, global_step)

        # writer.add_histogram('final_ctx_features_attn_output', ctx_features_attn_output, global_step)


        prompts = torch.cat([self.token_prefix, ctx_features_attn_output, self.token_suffix], dim=1) # shape = [2, 1 + n_ctx + * , 512]
        # shape of prompts = [2, 77, 512]

        prompts_mean = prompts.mean(dim=1, keepdim=True)
        prompts_std = prompts.std(dim=1, keepdim=True)
        prompts = (prompts - prompts_mean) / prompts_std # shape of prompts = [2, 77, 512]

        return prompts


# **TextEncoder**

In [None]:
'''
This TextEncoder class is designed to process text using elements from the CLIP model's architecture.
It uses a transformer module to encode a series of tokenized text prompts and extract features from the "end of text" (EOT) token,
which often represents a summary of the input sequence in CLIP's text encoding.
'''

class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final #layer normalization module
        self.text_projection = clip_model.text_projection # A linear projection matrix to map transformer outputs to a feature space used by the CLIP model.
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype) # shape of positional_embedding = torch.Size([77, 512])
        '''
        x.shape = [ batch_size (N), n_ctx (L), transformer.width (D) ]
        '''
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape is  torch.Size([2, 77, 512])

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        '''
        x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] selects the feature vector corresponding to the EOT token for each prompt in the batch.
        This vector is then linearly projected with self.text_projection, giving the final encoded feature vector for each text prompt.
        '''
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
        # self.text_projection.shape =  torch.Size([512, 512])
        # x.shape =  torch.Size([2, 512])

        return x

# **MultiscaleAdapter**

In [None]:
class MultiscaleAdapter(nn.Module):
    def __init__(self, input_dim):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(input_dim, input_dim, kernel_size=1),
            nn.BatchNorm2d(input_dim),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(input_dim, input_dim, kernel_size=1),
            nn.BatchNorm2d(input_dim),
            nn.ReLU()
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(input_dim, input_dim, kernel_size=1),
            nn.BatchNorm2d(input_dim),
            nn.ReLU()
        )

        # Modify branches to avoid errors
        self.branch_f = nn.Sequential(
            nn.Conv2d(input_dim, input_dim, kernel_size=1, padding=0),
            nn.BatchNorm2d(input_dim),
            nn.ReLU(),
            nn.Conv2d(input_dim, input_dim, kernel_size=1, padding=0),
            nn.BatchNorm2d(input_dim),
            nn.ReLU()
        )
        self.branch_g = nn.Sequential(
            nn.Conv2d(input_dim, input_dim, kernel_size=1, padding=0),
            nn.BatchNorm2d(input_dim),
            nn.ReLU(),
            nn.Conv2d(input_dim, input_dim, kernel_size=3, padding=1, dilation=1),
            nn.BatchNorm2d(input_dim),
            nn.ReLU()
        )
        self.branch_h = nn.Sequential(
            nn.Conv2d(input_dim, input_dim, kernel_size=1, padding=0),
            nn.BatchNorm2d(input_dim),
            nn.ReLU(),
            nn.Conv2d(input_dim, input_dim, kernel_size=3, padding=1, dilation=1),
            nn.BatchNorm2d(input_dim),
            nn.ReLU()
        )

        # Combine features and project back to original dimension
        self.combine = nn.Sequential(
            nn.Conv2d(input_dim * 3, input_dim, kernel_size=1),
            nn.BatchNorm2d(input_dim),
            nn.ReLU()
        )
        self.final_projection = nn.Sequential(
            nn.Conv2d(input_dim, input_dim, kernel_size=1),
            nn.BatchNorm2d(input_dim),
            nn.ReLU()
        )


    def forward(self, x):
        # x.shape = (32, 768, 196)
        x = x.unsqueeze(2)
        # Now x.shape =  torch.Size([32, 768, 1, 196])
        x = x.reshape(x.shape[0], x.shape[1], int(math.sqrt(x.shape[3])), int(math.sqrt(x.shape[3])))
        # x.shape =  torch.Size([32, 768, 14, 14])
        x_fg_input = self.conv1(x)
        x_h_input = self.conv2(x)

        # Branch processing
        x_f = self.branch_f(x_fg_input)
        x_g = self.branch_g(x_fg_input)
        x_h = self.branch_h(x_h_input)

        # Concatenating along channel dimension and projecting back
        x_out = torch.cat([x_f, x_g, x_h], dim=1)
        x_out = self.combine(x_out)

        # Adding the processed multi-scale features to the original input
        x_out = self.final_projection(x_out) + self.conv3(x)
        x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], -1)

        return x_out

# **CustomCLIP**

In [None]:
class CustomCLIP(nn.Module):
    def __init__(self, n_ctx, classnames, clip_model):
        super().__init__()
        self.prompt_learner = PromptLearner(n_ctx, classnames, clip_model)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts # tokenized_prompts.shape is (2, 77)

        self.image_encoder = clip_model.visual
        self.patch_embedding = self.image_encoder.conv1
        self.class_embedding = self.image_encoder.class_embedding
        self.positional_embedding = self.image_encoder.positional_embedding

        for param in self.patch_embedding.parameters():
            param.requires_grad = False

        self.class_embedding.requires_grad = True
        self.positional_embedding.requires_grad = False

        self.ln_pre = self.image_encoder.ln_pre
        for param in self.image_encoder.parameters():
            param.requires_grad = False

        self.text_encoder = TextEncoder(clip_model)
        for param in self.text_encoder.parameters():
            param.requires_grad = False

        embed_dim = self.image_encoder.ln_post.weight.shape[0]
        self.multiscale_adapters = nn.ModuleList(
                                    [MultiscaleAdapter(input_dim=embed_dim)
                                    for _ in range(len(self.image_encoder.transformer.resblocks))]
                                  )
        self.image_features_proj = nn.Linear(embed_dim, 512)

        self.logit_scale = clip_model.logit_scale
        self.logit_scale.requires_grad = True

        self.dtype = clip_model.dtype

        self.dropout = nn.Dropout(p=0.4)

        self.multihead_attn = nn.MultiheadAttention(embed_dim=512, num_heads=8)


    def forward(self, images, global_step=None):
        x = self.patch_embedding(images.type(torch.float32))  # (32, 768, 14, 14)
        x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)  # (batch, num_patches, embed_dim) i.e; (32, 196, 768)
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
        # Now x.shape is [batch_size, num_patches + 1, 768] i.e; torch.Size([32, 197, 768])
        # writer.add_histogram('before ln_pre', x, global_step)
        x = x + self.positional_embedding.to(x.dtype)
        x = self.ln_pre(x) # shape = [32, 197, 768]
        # writer.add_histogram('after_ln_pre', x, global_step)
        # Now x.shape is torch.Size([32, 197, 768])
        #x = x[:,1:,:] # shape = [32, 196, 768]
        x = self.dropout(x) #>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

        for i, block in enumerate(self.image_encoder.transformer.resblocks):
            x = block(x) # shape = [32, 197, 768]
            # writer.add_histogram(f'after_resblock_{i}', x, global_step)
            x_cls = x[:,0,:] # shape = [32, 768]
            x_cls = x_cls.unsqueeze(1) # shape = [32, 1, 768]
            x = x[:,1:,:] # shape = [32, 196, 768]
            x = x.permute(0, 2, 1) # shape = [32, 768, 196]
            msa = self.multiscale_adapters[i](x)
            # writer.add_histogram(f'before norm msa_{i}', msa, global_step)
            msa = nn.functional.layer_norm(msa, [msa.shape[-1]])  # Normalize each channel feature individually
            msa = self.dropout(msa) #>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
            # writer.add_histogram(f'after_norm msa_{i}', msa, global_step)
            x = msa + x # output of MSA 's shape = [32, 768, 196]
            # writer.add_histogram(f'after_msa_+_x_{i}', x, global_step)
            x = x.permute(0, 2, 1) # shape = [32, 196, 768]
            x = nn.functional.layer_norm(x, [x.shape[-1]])  # Normalize each channel feature individually
            # writer.add_histogram(f'after_layer_norm_msa_+_x_{i}', x, global_step)
            x = torch.cat([x_cls, x], dim=1) # shape = [32, 197, 768]

        image_features = self.image_features_proj(x) # shape = [32, 197, 512]
        image_features_cls = image_features[:,0,:] # shape = [32, 512]
        image_features_cls = image_features_cls.unsqueeze(1) # shape = [32, 1, 512]
        image_features_cls = self.dropout(image_features_cls) #>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        image_features = image_features[:,1:,:] # shape = [32, 196, 512]

        promptlearner_outputs = self.prompt_learner(image_features, global_step) # shape = [2, 77, 512]
        text_features = self.text_encoder(promptlearner_outputs, self.tokenized_prompts) # shape = [2, 512]


        # Expand text_features to match image_features' batch size
        text_features = text_features.unsqueeze(0) # shape = [1, 2, 512]
        text_features_expanded = text_features.expand(image_features.size(0), -1, -1)  # Shape: (32, 2, 512)

        # Image patches mean
        image_features = image_features.mean(dim = 1) # shape = [32, 512]
        image_features = image_features.unsqueeze(1) # shape = [32, 1, 512]

        # Normalization
        image_features_cls = nn.functional.normalize(image_features_cls, dim=-1)
        image_features = nn.functional.normalize(image_features, dim=-1)
        text_features_expanded = nn.functional.normalize(text_features_expanded, dim=-1)


        image_features_cls = image_features_cls.to(device)  # Shape: [32, 1, 512]
        text_features_expanded = text_features_expanded.to(device)  # Shape: [32, 2, 512]
        combined_features = torch.cat((image_features_cls, text_features_expanded), dim=1)  # Shape: [32, 3, 512]
        # Transpose to match PyTorch's expected shape: [sequence_length, batch_size, embed_dim]
        combined_features_t = combined_features.transpose(0, 1).to(device)  # Shape: [3, 32, 512]
        # Compute attention
        attn_output, attn_weights = self.multihead_attn(
            query=combined_features_t,
            key=combined_features_t,
            value=combined_features_t
        )
        # Transpose back to [batch_size, sequence_length, embed_dim]
        attn_output = attn_output.transpose(0, 1)  # Shape: [32, 3, 512]
        updated_image_features_cls = attn_output[:, 0, :]        # Shape: [32, 512]
        updated_image_features_cls = updated_image_features_cls.unsqueeze(1)  # Shape: [32, 1, 512]
        updated_text_features_expanded = attn_output[:, 1:, :]   # Shape: [32, 2, 512]



        #logits = torch.matmul(image_features_cls, text_features_expanded.transpose(1, 2))
        #logits = torch.matmul(image_features, text_features_expanded.transpose(1, 2))
        logits = torch.matmul(updated_image_features_cls, updated_text_features_expanded.transpose(1, 2))
        logits = logits.squeeze(1) # Shape: (32, 2)

        logit_scale = self.logit_scale.exp()
        return logits * logit_scale


# **Training**

In [None]:
data_dir = '/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/CLIP_based_deepfake_detection/dataset/CELEB'

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224), interpolation= InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
])

# Load the dataset
train_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'train'), transform=transform)
train_dataset.class_to_idx = {'real': 0, 'fake': 1}
val_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'val'), transform=transform)
val_dataset.class_to_idx = {'real': 0, 'fake': 1}

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers= os.cpu_count())
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers= os.cpu_count())

In [None]:
def get_scheduler(optimizer, num_warmup_steps, num_training_steps):
    # Lambda function for warmup phase
    def lr_lambda(current_step: int):
      if current_step < num_warmup_steps:
          return float(current_step) / float(max(1, num_warmup_steps))
      return max(
          0.0, 0.5 * (1.0 + torch.cos(torch.tensor(torch.pi * (current_step - num_warmup_steps) / (num_training_steps - num_warmup_steps))))
      )
    return LambdaLR(optimizer, lr_lambda)


def train_model(model, train_loader, val_loader, device, num_epochs=10, lr=1e-4, warmup_steps=100, log_interval=100):
    model.train()
    scaler = GradScaler()

    # optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay= 1e-5)
    # optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    optimizer = optim.SGD(
                    filter(lambda p: p.requires_grad, model.parameters()),
                    lr=lr,
                    momentum=0.9,
                    weight_decay = 1e-4,
                    nesterov = True
                    )
    total_steps = len(train_loader) * num_epochs
    scheduler = get_scheduler(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

    criterion = nn.CrossEntropyLoss()
    #criterion = nn.BCEWithLogitsLoss()

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (images, labels) in enumerate(tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}")):
            images, labels = images.to(device), labels.to(device)
            #labels_one_hot = F.one_hot(labels, num_classes=2).float()
            optimizer.zero_grad()

            with autocast():
                logits = model(images, global_step=epoch * len(train_loader) + batch_idx)
                loss = criterion(logits, labels)

            scaler.scale(loss).backward()
            # Apply gradient clipping before optimizer step
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            scheduler.step()  # Update learning rate after every optimizer step

            running_loss += loss.item() * images.size(0)
            # Log loss to TensorBoard every few batches
            # global_step = epoch * len(train_loader) + batch_idx
            # writer.add_scalar('Training Loss', loss.item(), global_step)
            _, predicted = torch.max(logits, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            # Log intermediate loss and accuracy
            if (batch_idx + 1) % log_interval == 0:
                current_loss = running_loss / total
                current_acc = correct / total
                print(f"  Step [{batch_idx + 1}/{len(train_loader)}], "
                      f"Loss: {current_loss:.4f}, Accuracy: {current_acc:.4f}")

        epoch_loss = running_loss / total
        epoch_acc = correct / total

        val_loss, val_acc = validate_model(model, val_loader, criterion, device)

        print(f"Epoch [{epoch + 1}/{num_epochs}], "
              f"Train Loss: {epoch_loss:.4f}, Train Accuracy: {epoch_acc:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.4f}")

def validate_model(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            #labels_one_hot = F.one_hot(labels, num_classes=2).float()

            with autocast():  # Use autocast in validation as well, if mixed precision
                logits = model(images)
                loss = criterion(logits, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(logits, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    val_loss = running_loss / total
    val_acc = correct / total
    return val_loss, val_acc

In [None]:
small_train_loader = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dataset, range(32*10)),
    batch_size=32, shuffle=True
)
small_val_loader = torch.utils.data.DataLoader(
    torch.utils.data.Subset(val_dataset, range(32*10)),
    batch_size=32, shuffle=False
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_model, preprocess = clip.load("ViT-B/16", device=device)
clip_model = clip_model.float()

n_ctx = 15  # Number of context tokens
classnames = ["real", "fake"]

model = CustomCLIP(n_ctx=n_ctx, classnames=classnames, clip_model=clip_model).to(device)
model = torch.nn.DataParallel(model)
train_model(model, train_loader, val_loader, device, num_epochs=10, lr=1e-2, warmup_steps=400, log_interval = 20) #lr = 1e-3, warmup_steps = 250
#train_model(model, small_train_loader, small_val_loader, device, num_epochs=5, lr=1e-2, warmup_steps=125)
# writer.close()

  scaler = GradScaler()
  with autocast():
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Training Epoch 1/10:   4%|▍         | 20/500 [01:26<18:00,  2.25s/it]

  Step [20/500], Loss: 0.6933, Accuracy: 0.5031


Training Epoch 1/10:   8%|▊         | 40/500 [01:35<03:32,  2.16it/s]

  Step [40/500], Loss: 0.6933, Accuracy: 0.5008


Training Epoch 1/10:  12%|█▏        | 60/500 [01:44<03:22,  2.17it/s]

  Step [60/500], Loss: 0.6934, Accuracy: 0.5036


Training Epoch 1/10:  16%|█▌        | 80/500 [01:54<03:12,  2.18it/s]

  Step [80/500], Loss: 0.6933, Accuracy: 0.4980


Training Epoch 1/10:  20%|██        | 100/500 [02:03<03:01,  2.20it/s]

  Step [100/500], Loss: 0.6933, Accuracy: 0.4997


Training Epoch 1/10:  24%|██▍       | 120/500 [02:12<02:53,  2.19it/s]

  Step [120/500], Loss: 0.6934, Accuracy: 0.5018


Training Epoch 1/10:  25%|██▍       | 124/500 [02:14<06:48,  1.09s/it]


KeyboardInterrupt: 

In [None]:
# %load_ext tensorboard

In [None]:
# %tensorboard --logdir=runs

In [None]:
import torch

num_gpus = torch.cuda.device_count()
print(f"Number of available GPUs: {num_gpus}")

# **Testing**

In [None]:
# Define the testing function
def test_model(model, test_loader, device):
    model.eval()  # Set the model to evaluation mode
    running_loss = 0.0
    correct = 0
    total = 0
    criterion = nn.CrossEntropyLoss()  # Loss function for testing

    with torch.no_grad():  # Disable gradient tracking
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            with autocast():  # Use autocast in validation as well, if mixed precision
                logits = model(images)
                loss = criterion(logits, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(logits, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    test_loss = running_loss / total
    test_acc = correct / total
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

# **CELEB-Test**

In [None]:
test_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'test'), transform=transform)
test_dataset.class_to_idx = {'real': 0, 'fake': 1}
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers= os.cpu_count())

In [None]:
# Test the model after training
test_model(model, test_loader, device)

# **CELEB-M**

In [None]:
dir = '/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/CLIP_based_deepfake_detection/dataset/CELEB-M'
test_dataset = datasets.ImageFolder(root=os.path.join(dir, 'test'), transform=transform)
test_dataset.class_to_idx = {'real': 0, 'fake': 1}
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers= os.cpu_count())

In [None]:
# Test the model after training
test_model(model, test_loader, device)

# **FS**

In [None]:
dir = '/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/CLIP_based_deepfake_detection/dataset/FS'
test_dataset = datasets.ImageFolder(root=os.path.join(dir, 'test'), transform=transform)
test_dataset.class_to_idx = {'real': 0, 'fake': 1}
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers= os.cpu_count())

In [None]:
# Test the model after training
test_model(model, test_loader, device)

# **NT**

In [None]:
dir = '/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/CLIP_based_deepfake_detection/dataset/NT'
test_dataset = datasets.ImageFolder(root=os.path.join(dir, 'test'), transform=transform)
test_dataset.class_to_idx = {'real': 0, 'fake': 1}
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers= os.cpu_count())

In [None]:
# Test the model after training
test_model(model, test_loader, device)

# **DF**

In [None]:
dir = '/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/CLIP_based_deepfake_detection/dataset/DF'
test_dataset = datasets.ImageFolder(root=os.path.join(dir, 'test'), transform=transform)
test_dataset.class_to_idx = {'real': 0, 'fake': 1}
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers= os.cpu_count())

In [None]:
# Test the model after training
test_model(model, test_loader, device)

# **DFD**

In [None]:
dir = '/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/CLIP_based_deepfake_detection/dataset/DFD'
test_dataset = datasets.ImageFolder(root=os.path.join(dir, 'test'), transform=transform)
test_dataset.class_to_idx = {'real': 0, 'fake': 1}
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers= os.cpu_count())

In [None]:
# Test the model after training
test_model(model, test_loader, device)

# **F2F**

In [None]:
dir = '/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/CLIP_based_deepfake_detection/dataset/F2F'
test_dataset = datasets.ImageFolder(root=os.path.join(dir, 'test'), transform=transform)
test_dataset.class_to_idx = {'real': 0, 'fake': 1}
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers= os.cpu_count())

In [None]:
# Test the model after training
test_model(model, test_loader, device)