### Universal Fake Detection

In [1]:
%load_ext autoreload
%autoreload 2

### Testing OpenCLIP

In [2]:
import torch
from PIL import Image
import open_clip
import sys
import os

sys.path.append(os.path.abspath(".."))
from models.clip_models import CLIPModel

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
name = "ViT-L/14"
pretrained = "dfn2b"
model, _, preprocess = open_clip.create_model_and_transforms(name, pretrained=pretrained)
model.eval()  # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
tokenizer = open_clip.get_tokenizer('ViT-L-14')

  return self.fget.__get__(instance, owner)()


In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np

# Dummy dataset of random floats
class DummyDataset(Dataset):
    def __init__(self, size=2):
        # random images
        self.data = torch.rand(size, 3, 224, 224)
        # binary labels
        self.labels = torch.randint(0, 2, (size, 1)).float()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]



# Provided validate function
def validate(model, loader, find_thres=False, test=False):
    with torch.no_grad():
        y_true, y_pred = [], []
        print("Length of dataset: %d" % (len(loader)))
        # for img, label in loader.dataset:
        for batch in loader.dataset:
            if test:
                img, label = batch
            else:
                img, label = batch["image"], batch["label"]
            in_tens = img.cpu()
            y_pred.extend(model(in_tens).sigmoid().flatten().tolist())
            y_true.extend(label.flatten().tolist())
            break
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    print("True labels:", y_true)
    print("Predicted scores:", y_pred)

In [4]:
CHANNELS = {
    "RN50" : 1024,
    "ViT-L/14" : 768,
    "ViT-H/14" : 1024,
    "ViT-g/14" : 1024,
}

class CLIPModel(nn.Module):
    def __init__(self, name, pretrained=None, num_classes=1):
        super(CLIPModel, self).__init__()
        self.name = name
        # self.preprecess will not be used during training, which is handled in Dataset class
        if pretrained:
            self.model, _, self.preprocess = open_clip.create_model_and_transforms(name, 
                                                                            pretrained=pretrained,
                                                                            device="cpu")
        else:
            self.model, self.preprocess = clip.load(name, device="cpu")

        # add a linear layer to the model (hard-coded for ViT)
        self.project = nn.Linear(1024, 768)
        self.fc = nn.Linear(768, num_classes)
 

    def forward(self, x, return_feature=False):
        features = self.model.encode_image(x)
        if CHANNELS.get(self.name) == 1024: 
            features = self.project(features)
        if return_feature:
            return features
        return self.fc(features)


In [6]:
from src.dataset import DF40
import yaml

# load the config file
with open("../configs/df40/test_config.yaml", 'r') as f:
    config = yaml.safe_load(f)

config['test_dataset'] = "MidJourney"
dataset = DF40(config=config, mode='test')
loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=128,
    shuffle=False,
    num_workers=4,
    collate_fn=dataset.collate_fn,
)

In [8]:
dataset = DummyDataset()
loader = DataLoader(dataset, batch_size=1)

In [7]:
len(loader)

15

In [78]:
for batch in loader:
    print(batch["image"].shape, batch["label"].shape)
    break

Error loading image at index (9, '/scratch-shared/scur0555/datasets/df40/test/MidJourney/fake/13592447708_Clean_East_Asian_male_face_only_face_shown_close-up_3ad609fd-9d87-4242-9d5b-d2776a3c2c7e.png'): cannot identify image file '/scratch-shared/scur0555/datasets/df40/test/MidJourney/fake/13592447708_Clean_East_Asian_male_face_only_face_shown_close-up_3ad609fd-9d87-4242-9d5b-d2776a3c2c7e.png'
torch.Size([2, 3, 224, 224]) torch.Size([2])


In [75]:
for batch in loader:
    print(batch[1].shape)
    break

torch.Size([2, 1])


In [11]:
name = "ViT-L/14"

# Load model
model = CLIPModel(name)
state_dict = torch.load("../pretrained_weights/fc_weights.pth", map_location='cpu')
model.fc.load_state_dict(state_dict)
model = model.cpu()

# Run test
validate(model, loader, test=True)

Length of dataset: 2


RuntimeError: The size of tensor a (1024) must match the size of tensor b (16) at non-singleton dimension 2

In [21]:
name = "ViT-H/14"
pretrained = "laion2b_s32b_b79k"

# Create data loader and model
dataset = DummyDataset()
loader = DataLoader(dataset, batch_size=1)
# Load model
model = CLIPModel(name, pretrained)
state_dict = torch.load("../pretrained_weights/fc_weights.pth", map_location='cpu')
model.fc.load_state_dict(state_dict)
model = model.cpu()

# Run test
validate(model, loader)

Length of dataset: 2
True labels: [0. 1.]
Predicted scores: [0.36211863 0.43808025]


In [None]:
# OpenCLIP test (optional)
image = preprocess(Image.open("docs/CLIP.png")).unsqueeze(0)
text = tokenizer(["a diagram", "a dog", "a cat"])

with torch.no_grad(), torch.autocast("cuda"):
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

print("Label probs:", text_probs)  # prints: [[1., 0., 0.]]

In [3]:
# import open_clip
open_clip.list_pretrained()

[('RN50', 'openai'),
 ('RN50', 'yfcc15m'),
 ('RN50', 'cc12m'),
 ('RN101', 'openai'),
 ('RN101', 'yfcc15m'),
 ('RN50x4', 'openai'),
 ('RN50x16', 'openai'),
 ('RN50x64', 'openai'),
 ('ViT-B-32', 'openai'),
 ('ViT-B-32', 'laion400m_e31'),
 ('ViT-B-32', 'laion400m_e32'),
 ('ViT-B-32', 'laion2b_e16'),
 ('ViT-B-32', 'laion2b_s34b_b79k'),
 ('ViT-B-32', 'datacomp_xl_s13b_b90k'),
 ('ViT-B-32', 'datacomp_m_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_clip_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_laion_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_image_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_text_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_basic_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_s128m_b4k'),
 ('ViT-B-32', 'datacomp_s_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_clip_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_laion_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_image_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_text_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_basic_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_s13m_b4k'),
 ('ViT-

In [8]:
dic = {"a": 1}
dic.get("a", 768)

1