In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from torchvision import transforms, models
from torchvision.transforms.v2 import JPEG, RandomApply,GaussianBlur
from torchvision.transforms import InterpolationMode

from PIL import Image
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import timm

import os
import random

from sklearn.metrics import classification_report, precision_recall_curve, average_precision_score
from sklearn.metrics import precision_score, recall_score

### Set up seed for reproducibility

In [None]:
DEFAULT_RANDOM_SEED = 2003

def seedBasic(seed=DEFAULT_RANDOM_SEED):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    
def seedTorch(seed=DEFAULT_RANDOM_SEED):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seedBasic(DEFAULT_RANDOM_SEED)
seedTorch(DEFAULT_RANDOM_SEED)

### Set up configurations for trainin and/or evaluation

In [None]:
SYNCLR_PATH = "" # if path is valid - use SynClr, othervise - DinoV2

PRETRAINED_CPT = "foundation_backbone_dinov2_sd14.pth" # init weights will be used if pretrained_checkpoint is invalid
TEST = True # Should model be tested? if pretrained_checkpoint is invalid and TRAIN is False - init weights will be used
TRAIN = True # Should model be trained? Disable if you are interested only in evaluation

# Set the number of test images for each class, but if the actual number of images is less than TEST_SIZE, actual number will be used.
TEST_SIZE = 6_000

# Enable or disable training augmentations
TRAIN_AUG = True
# JPEG compression uniform distribution range
JPEG_INTERVAL_TRAIN = (90, 100)
# Gaussian blur uniform distribution range
BLUR_INTERVAL_TRAIN = (0.1, 1)
# Probability for applying each training augmentation
AUG_PROB = 0.1

# Use 100% JPEG compression for the test set or not
TEST_AUG_JPEG = False
# Use 100% Gaussian blur for the test set or not
TEST_AUG_BLUR = False
# JPEG compression quality for the test set; applied only if TEST_AUG_JPEG = True
JPEG_LEVEL_TEST = 90
# Gaussian blur sigma level for the test set; applied only if TEST_AUG_BLUR = True
BLUR_LEVEL_TEST = 1
# Whether precision-recall curves should be plotted and saved in the /curves folder
VISUALISE_CURVES = True

# training setup
BATCH_SIZE = 64
EPOCH_NUM = 1
LEARNING_RATE = 0.0001

# Dataset directory, with all generators sets in GenImage format. Example:
# - dataset
#
# -- imagenet-ai-0419-sdv4
#
# --- val
# ---- ai
# ---- nature
# --- train
# ---- ai
# ---- nature
#
# -- imagen3
# --- val
# ---- ai
# ---- nature
DATASET_DIR = "dataset"

# folder with training set
TRAIN_DATASET = "imagenet-ai-0419-sdv4"

# set up a dictionary where the keys are the folder names in DATASET_DIR, and the values are True or False (indicating whether the folder should be evaluated).
TEST_DATASETS = {
    "imagen3": True,
    "SDXL1": True,
    "FLUX1-dev": True,
    "PixArt-XL-2-1024-MS": True,
    "imagenet-midjourney": True,
    "imagenet-ai-0419-sdv4": True,
    "imagenet-ai-0424-sdv5": True,
    "adm-genimage-test": True,
    "imagenet-glide": True,
    "wukong-dataset-test": True,
    "vqdm-test-dataset": True,
    "bigger-dataset-test": True
}

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Datasets

In [None]:
class TrainDataset(Dataset):
    """
    Dataset object for training set
    """
    def __init__(self, model_name):
        self.paths = []
        self.model_name = model_name
        self.get_paths()
        
        self.image_transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomCrop((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        self.jpeg_augmentation = JPEG(quality=JPEG_INTERVAL_TRAIN)
        self.blur = GaussianBlur(kernel_size=5, sigma=BLUR_INTERVAL_TRAIN)
        
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        path, image_label = self.paths[index]
        
        image = Image.open(path).convert('RGB')

        if TRAIN_AUG:
            if random.random() < AUG_PROB:
                image = self.jpeg_augmentation(image)
            if random.random() < AUG_PROB:
                image = self.blur(image)

        if self.image_transform:
            image = self.image_transform(image)
        
        return image, image_label

    def get_paths(self):
        """
        Extract paths from the generator's training set
        """
        ai_paths_test = os.listdir(os.path.join(DATASET_DIR, self.model_name, "train/ai"))
        ai_paths_test = [(os.path.join(DATASET_DIR, self.model_name, "train/ai", path), 1) for path in ai_paths_test]
        
        real_paths_test = os.listdir(os.path.join(DATASET_DIR, self.model_name, "train/nature"))
        real_paths_test = [(os.path.join(DATASET_DIR, self.model_name, "train/nature", path), 0) for path in real_paths_test]

        self.paths = ai_paths_test + real_paths_test

In [None]:
class TestDataset(Dataset):
    """
    Dataset object for test set
    """
    def __init__(self, model_name):
        self.paths = []
        self.model_name = model_name
        self.get_paths()
        
        self.image_transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomCrop((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        self.jpeg_augmentation = JPEG(quality=JPEG_LEVEL_TEST)
        self.blur = GaussianBlur(kernel_size=5, sigma=BLUR_LEVEL_TEST)
        
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        path, image_label = self.paths[index]
        
        image = Image.open(path).convert('RGB')

        if TEST_AUG_JPEG:
            image = self.jpeg_augmentation(image)
        if TEST_AUG_BLUR:
            image = self.blur(image)

        if self.image_transform:
            image = self.image_transform(image)
        
        return image, image_label

    def get_paths(self):
        """
        Extract paths from the generator's test set
        """
        ai_paths_test = os.listdir(os.path.join(DATASET_DIR, self.model_name, "val/ai"))
        ai_paths_test = [(os.path.join(DATASET_DIR, self.model_name, "val/ai", path), 1) for path in ai_paths_test]
        
        real_paths_test = os.listdir(os.path.join(DATASET_DIR, self.model_name, "val/nature"))
        real_paths_test = [(os.path.join(DATASET_DIR, self.model_name, "val/nature", path), 0) for path in real_paths_test]

        self.paths = ai_paths_test[:TEST_SIZE] + real_paths_test[:TEST_SIZE]

## Create datasets and loaders

In [None]:
train_dataset = TrainDataset(TRAIN_DATASET)
test_datasets = []
for key, value in TEST_DATASETS.items():
    if value:
        test_datasets.append((TestDataset(key), key))

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
test_loaders = [(DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4), name) for test_dataset, name in test_datasets]

## Set up models

In [None]:
def load_dinoV2():
    """
    Load pre-trained DINOv2 model's weights based on vit-l14 architecture from torch.hub
    """
    backbone_model_real = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14", pretrained=True)
    
    for param in backbone_model_real.parameters():
        param.requires_grad = False
    
    backbone_model_real.to(DEVICE)
    print("DinoV2 loaded")

    return backbone_model_real

In [None]:
def load_synclr():
    """
    Load pre-trained SynCLR model's weights based on vit-l14 architecture from local file
    """
    backbone_model_syn = timm.create_model('vit_large_patch14_224', pretrained=False, num_classes=0)

    checkpoint = torch.load(SYNCLR_PATH, map_location=DEVICE)
    state_dict = checkpoint.get('model', checkpoint)
    
    def remove_module_prefix(state_dict):
        return {k.replace("module.visual.", ""): v for k, v in state_dict.items()}
    
    state_dict = remove_module_prefix(state_dict)
    backbone_model_syn.load_state_dict(state_dict, strict=False)
    
    for param in backbone_model_syn.parameters():
        param.requires_grad = False
    
    backbone_model_syn.to(DEVICE)
    
    print("Synclr loaded")

    return backbone_model_syn

### Load foundation model based on config setup

In [None]:
import os

use_synclr = True
if os.path.isfile(SYNCLR_PATH):
    try:
        print("Use SynClr")
        backbone_foundation_model = load_synclr()
    except Exception as e:
        print("While loading SynClr something went wrong")
        print("Use DinoV2")
        backbone_foundation_model = load_dinoV2()
        use_synclr = False
else:
    print("Use DinoV2")
    backbone_foundation_model = load_dinoV2()
    use_synclr = False


### Define standart implementation for MLP head with GELU activation

In [None]:
class Mlp(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.bn1 = nn.BatchNorm1d(hidden_features)
        self.fc2 = nn.Sequential(
            nn.Linear(hidden_features, out_features),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

### Define one-branch architecture

In [None]:
class FoundationModelClassifier(nn.Module):

    def __init__(self, backbone_foundation_model):
        super().__init__()
        self.backbone_foundation_model = backbone_foundation_model
        self.mlp = Mlp(1024, 512, 1)

    def forward(self, x):
        features = self.backbone_foundation_model(x)
        out = self.mlp(features)
        return out

In [None]:
def save(model, name):
    """
    Save weights, ignore foundation model's weights
    """
    state_dict = model.state_dict()
    
    filtered_state_dict = {
        k: v for k, v in state_dict.items()
        if not k.startswith("backbone_foundation_model")
    }
    
    torch.save({
        'backbone_foundation_model_classifier': filtered_state_dict
    }, f"foundation_backbone_{name}.pth")

In [None]:
def validate(model, test_loaders, visualise_curves=True):
    """
    Validate model on test sets, plot precision-recall curves if needed
    """
    model.eval()
    for test, set_name in test_loaders:
        correct = 0
        total = 0

        true_labels = []
        predicted_labels = []
        probabilities = []

        with torch.no_grad():
            for images, labels in test:
                images, labels = images.to(DEVICE), labels.float().to(DEVICE)
                labels = labels.view(-1, 1)
                outputs = model(images)
                probs = outputs
                predicted = (probs >= 0.5).float()

                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                true_labels.append(labels.cpu())
                predicted_labels.append(predicted.cpu())
                probabilities.append(probs.cpu())

        accuracy = 100 * correct / total
        print(f"Validation on {set_name}: Test Accuracy: {accuracy:.2f}%")

        true_labels = torch.cat(true_labels).numpy()
        predicted_labels = torch.cat(predicted_labels).numpy()
        probabilities = torch.cat(probabilities).numpy()

        print(classification_report(true_labels, predicted_labels))

        precision, recall, thresholds = precision_recall_curve(true_labels, probabilities)
        avg_precision = average_precision_score(true_labels, probabilities)
        
        if visualise_curves:
            threshold_point = 0.5
            preds_at_05 = (probabilities >= threshold_point).astype(float)
            p_05 = precision_score(true_labels, preds_at_05)
            r_05 = recall_score(true_labels, preds_at_05)
            
            plt.figure()
            plt.plot(recall, precision, label=None)
            
            plt.axvline(x=r_05, color='red', linestyle='--')
            
            plt.scatter(r_05, p_05, color='red', zorder=5)
            
            plt.text(0.02, 0.02, f'AP = {avg_precision:.2f}', transform=plt.gca().transAxes,
                     fontsize=18, verticalalignment='bottom', horizontalalignment='left',
                     bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7))
            
            plt.xlabel('Recall', fontsize=16)
            plt.ylabel('Precision', fontsize=16)
            plt.title(f'Precision-Recall curve ({set_name})', fontsize=18)
            plt.xticks(fontsize=14)
            plt.yticks(fontsize=14)
            plt.legend()
            plt.grid(True)
            
            os.makedirs('curves', exist_ok=True)
            filename = f'curves/pr_curve_{set_name}.png'
            plt.savefig(filename)
            plt.close()

In [None]:
def train(model, train_loader, optimizer, criterion, scheduler):
    """
    Training code
    """
    for epoch in range(EPOCH_NUM):
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0
    
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCH_NUM}", leave=True)
        
        for batch_idx, (images, labels) in enumerate(progress_bar):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            labels = labels.float().view(-1, 1)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            predictions = (outputs >= 0.5).float()
    
            correct_train += (predictions == labels).sum().item()
            total_train += labels.size(0)
            
            progress_bar.set_postfix(loss=f"{running_loss / (total_train // labels.size(0)):.4f}", 
                                     acc=f"{100 * correct_train / total_train:.2f}%")
    
            scheduler.step()

        model.eval()
        save(model, f"synclr_epoch_{epoch}" if use_synclr else f"dinov2_epoch_{epoch}")
        
        print(f"Epoch [{epoch+1}/{EPOCH_NUM}], Loss: {running_loss/len(train_loader):.4f}")
    
    print("Training Complete!")

In [None]:
def load_cpt(model):
    """
    Load pre-trained weights for proposed architecture
    """
    if os.path.isfile(PRETRAINED_CPT):
        try:
            print("Use pretrained checkpoint")
            checkpoint = torch.load(PRETRAINED_CPT, map_location=DEVICE)
            model.load_state_dict(checkpoint["backbone_foundation_model_classifier"], strict=False)
            
        except Exception as e:
            print("While loading pretrained checkpoint something went wrong")
    return model

### Run train and/or evaluation process

In [None]:
model = FoundationModelClassifier(backbone_foundation_model)
model = load_cpt(model)
model.to(DEVICE)

if TRAIN:
    criterion = nn.BCELoss()
    
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    scheduler = lr_scheduler.StepLR(optimizer, step_size=len(train_dataset) // BATCH_SIZE, gamma=0.5)
    train(model, train_loader, optimizer, criterion, scheduler)

if TEST:
    model.eval()
    validate(model, test_loaders, visualise_curves=VISUALISE_CURVES)
