# Advanced NS-NET: Improving Generalizable AI-Generated Image Detection via Learned Semantic Null-Space Projections
This is an improvement to NS-NET (Yan et al. 2025) (https://www.arxiv.org/abs/2508.01248). Training and Testing is done on DALLE Recognition Dataset available on Kaggle (https://www.kaggle.com/datasets/superpotato9/dalle-recognition-dataset).

Note: We have made use of LLMs for debugging of code only.

## Importing the Dataset

In [None]:
import kagglehub

path = kagglehub.dataset_download("superpotato9/dalle-recognition-dataset")

print("Path to dataset files:", path)

Using Colab cache for faster access to the 'dalle-recognition-dataset' dataset.
Path to dataset files: /kaggle/input/dalle-recognition-dataset


## Importing BLIP for captioning of images

In [None]:
!pip install -q open-clip-torch==2.24.0 loralib

import sys, types, importlib, open_clip

# Patch both cases so any old code works
if not hasattr(open_clip, "get_tokenizer"):
    def _get_tokenizer(model_name):
        from open_clip import tokenizer
        return tokenizer._tokenizer  # fallback; v2.x internal
    open_clip.get_tokenizer = _get_tokenizer

# Create a fake submodule for "open_clip.tokenizer"
if "open_clip.tokenizer" not in sys.modules:
    tok_mod = types.SimpleNamespace(get_tokenizer=open_clip.get_tokenizer)
    sys.modules["open_clip.tokenizer"] = tok_mod

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.4/1.5 MB[0m [31m13.6 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m22.4 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25h



## NS-NET Architecture
Individual components have been labelled accordingly in the code. Key implementations include:
- Semantic Mapper Module for estimating the Null Space of Semantic Information from CLIP Based Text Features
- Projection of CLIP Based Visual Features to the Null Space
- Improvement of the Loss Function based on above modules

In [None]:
%%writefile nsnet_cpu.py

import os, math, json, random, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from transformers import BlipProcessor, BlipForConditionalGeneration
import open_clip, loralib
from sklearn.metrics import accuracy_score, average_precision_score
import os, json
from tqdm import tqdm
from PIL import Image, UnidentifiedImageError
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration

# ============== helpers ==============
def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ============== PATCH SELECTION ==============
def spectral_entropy(patch):
    arr = np.mean(np.array(patch), axis=2)
    mag = np.abs(np.fft.fft2(arr))
    mag = mag.flatten()
    mag = mag / (mag.sum() + 1e-8)
    return float(-(mag * np.log(mag + 1e-8)).sum())

def patch_select(img, patch_size=32, out_size=224):
    img = img.convert("RGB")
    w, h = img.size
    new_w = math.ceil(w / patch_size) * patch_size
    new_h = math.ceil(h / patch_size) * patch_size
    padded = Image.new("RGB", (new_w, new_h))
    padded.paste(img, (0,0))
    patches = []
    for r in range(0, new_h, patch_size):
        for c in range(0, new_w, patch_size):
            patch = padded.crop((c, r, c+patch_size, r+patch_size))
            patches.append(patch)
    ents = [spectral_entropy(p) for p in patches]
    idx = np.argsort(ents)
    topk = idx[-len(idx)//4:]; lowk = idx[:len(idx)//4]
    sel = [patches[i] for i in np.concatenate([topk, lowk])]
    random.shuffle(sel)
    grid = int(out_size / patch_size)
    new = Image.new("RGB", (out_size, out_size))
    for i in range(grid*grid):
        p = sel[i % len(sel)]
        r, c = divmod(i, grid)
        new.paste(p.resize((patch_size, patch_size)), (c*patch_size, r*patch_size))
    return new

# ============== DATASET ==============
class DALLEDataset(Dataset):
    def __init__(self, root, split, captions, transform=None,
                 train_limit_real=1000, train_limit_fake=1000,
                 test_limit_real=100, test_limit_fake=100):

        self.samples = []
        self.transform = transform
        self.captions = captions

        real_folder = os.path.join(root, "real")
        fake_folder = os.path.join(root, "fakeV2", "fake-v2")

        real_imgs, fake_imgs = [], []

        if os.path.exists(real_folder):
            real_imgs = [
                os.path.join(real_folder, f)
                for f in os.listdir(real_folder)
                if f.lower().endswith((".png", ".jpg", ".jpeg"))
            ]

        if os.path.exists(fake_folder):
            fake_imgs = [
                os.path.join(fake_folder, f)
                for f in os.listdir(fake_folder)
                if f.lower().endswith((".png", ".jpg", ".jpeg"))
            ]
        
        random.shuffle(real_imgs)
        random.shuffle(fake_imgs)

        if split == "train":
            real_imgs = real_imgs[:train_limit_real]
            fake_imgs = fake_imgs[:train_limit_fake]
        elif split == "test":
            real_imgs = real_imgs[:test_limit_real]
            fake_imgs = fake_imgs[:test_limit_fake]

        for img in real_imgs:
            self.samples.append((img, 0))  # REAL = 0
        for img in fake_imgs:
            self.samples.append((img, 1))  # FAKE = 1

        print(f"[{split.upper()}] Loaded {len(real_imgs)} REAL and {len(fake_imgs)} FAKE images from '{root}'")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, i):
        path, y = self.samples[i]
        img = Image.open(path).convert("RGB")
        img = patch_select(img)
        if self.transform:
            img = self.transform(img)
        rel_folder = os.path.basename(os.path.dirname(path))
        rel_name = os.path.basename(path)
        unique_key = f"{rel_folder}_{rel_name}"
        cap = self.captions.get(unique_key, "")
        return img, y, cap

# ============== Semantic Mapper ============
class SemanticMapper(nn.Module):
    def __init__(self, dim=768, hidden=1024, out_dim=768, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, out_dim)
        )
        # Normalization for stability
        self.norm = nn.LayerNorm(out_dim)

    def forward(self, v):
        # residual mapping: f(v) + v
        return self.norm(v + self.net(v))

# ============== NULL SPACE ==============
def build_nullspace(text_feats, null_dim=None, tol=1e-6, use_mapper = False, mapper = None, device = "cpu"):
    #use_mapper = true when we want to use NN Network to find out Semantic Features. Else, it simply uses original text features.
    #If we pretrained mapper, we can use that itself by passing it to mapper argument.
    feats_t = torch.from_numpy(text_feats).float().to(device)

    if use_mapper == True:
      if mapper is None:
        mapper = SemanticMapper(dim=feats_t.shape[1]).to(device)
      mapper.eval()
      with torch.no_grad():
        feats_t = mapper(feats_t)
      text_feats = feats_t.cpu().numpy()

    #Now we perform SVD
    U,S,VT = np.linalg.svd(text_feats, full_matrices=False)
    rank = np.sum(S>tol)
    D = VT.shape[1]
    null_dim = null_dim or max(1, D-rank)

    #Basis of Null Space and its projection
    N = VT.T[:, -null_dim:]
    P = N @ N.T

    return torch.from_numpy(P).float(), N


# ============== MODEL ==============
class NSNetHead(nn.Module):
    def __init__(self, dim, proj_dim=512):
        super().__init__()
        self.proj = nn.Sequential(nn.Linear(dim, proj_dim), nn.ReLU(), nn.Linear(proj_dim, proj_dim))
        self.cls = nn.Linear(proj_dim,1)
    def forward(self,x):
        f = self.proj(x)
        logit = self.cls(f).squeeze(-1)
        return F.normalize(f,dim=1), logit

# ============== LOSSES & METRICS ==============
def nt_xent(f,y,T=0.07):
    sim = (f @ f.T)/T
    mask = (y.unsqueeze(1)==y.unsqueeze(0)).float()
    exp_sim = torch.exp(sim)*(1-torch.eye(len(f),device=f.device))
    num = (exp_sim*mask).sum(1)
    denom = exp_sim.sum(1)
    loss = -torch.log((num+1e-8)/(denom+1e-8))
    return loss.mean()

def metrics(y,logit):
    y = np.array(y)
    p = torch.sigmoid(torch.tensor(logit)).numpy()
    pred = (p>=0.5).astype(int)
    acc = accuracy_score(y,pred)

    r_mask = (y==0)
    f_mask = (y==1)
    r_acc = accuracy_score(y[r_mask], pred[r_mask]) if r_mask.sum()>0 else 0
    f_acc = accuracy_score(y[f_mask], pred[f_mask]) if f_mask.sum()>0 else 0


    ap = average_precision_score(y,p)
    return {"acc":acc,"r_acc":r_acc,"f_acc":f_acc,"ap":ap}

# ============== PIPELINE ==============
def gen_captions1(data_root, out_json="captions.json",
                  limit_real=1100, limit_fake=1100):
    device = get_device()
    proc = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

    imgs = []

    real_folder = os.path.join(data_root, "real")
    fake_folder = os.path.join(data_root, "fakeV2", "fake-v2")

    real_imgs = []
    fake_imgs = []
    if os.path.exists(real_folder):
        real_imgs = [
            os.path.join(real_folder, f)
            for f in os.listdir(real_folder)
            if f.lower().endswith((".jpg", ".jpeg", ".png"))
        ][:limit_real]

    if os.path.exists(fake_folder):
        fake_imgs = [
            os.path.join(fake_folder, f)
            for f in os.listdir(fake_folder)
            if f.lower().endswith((".jpg", ".jpeg", ".png"))
        ][:limit_fake]

    imgs.extend(real_imgs)
    imgs.extend(fake_imgs)
    print(f"Total selected images for captioning: {len(imgs)} "
          f"({len(real_imgs)} real + {len(fake_imgs)} fake)")

    caps = {}
    skipped = []
    for p in tqdm(imgs, desc="BLIP captioning"):
        rel_folder = os.path.basename(os.path.dirname(p))
        base_name = os.path.basename(p)
        unique_key = f"{rel_folder}_{base_name}"

        try:
            img = Image.open(p).convert("RGB")
            img.thumbnail((512, 512))
            inputs = proc(images=img, return_tensors="pt").to(device)
            with torch.no_grad():
                out = model.generate(**inputs, max_new_tokens=40)
            cap = proc.decode(out[0], skip_special_tokens=True)
            caps[unique_key] = cap

            # Periodic autosave
            if len(caps) % 100 == 0:
                json.dump(caps, open(out_json, "w"))

        except UnidentifiedImageError:
            print(f"⚠️ Skipped unreadable image: {p}")
            skipped.append(p)
            continue
        except Exception as e:
            print(f"⚠️ Error on {p}: {e}")
            skipped.append(p)
            continue

    json.dump(caps, open(out_json, "w"))
    print(f"✅ Saved {len(caps)} captions to {out_json}")
    print(f"⚠️ Skipped {len(skipped)} problematic images.")
    if skipped:
        with open("skipped_images.txt", "w") as f:
            f.write("\n".join(skipped))
        print("Skipped image list saved to skipped_images.txt")

    return caps

def build_null(captions):
    device = get_device()

    model, _, _ = open_clip.create_model_and_transforms("ViT-L-14", pretrained="openai")
    tokenizer = open_clip.get_tokenizer("ViT-L-14")

    model.to(device).eval()

    texts = [str(t) for t in captions.values()]
    print(f"Encoding {len(texts)} captions for NULL-space...")

    feats = []
    for i in tqdm(range(0, len(texts), 32), desc="Text enc"):
        batch = texts[i:i+32]
        tokens = tokenizer(batch).to(device)

        with torch.no_grad():
            f = model.encode_text(tokens)
        feats.append(f.cpu())

    feats = torch.cat(feats).numpy()

    # Build NULL-space (returns torch tensor P)
    #Incase we have pretrained Mapper
    #mapper = SemanticMapper(dim=text_feats.shape[1])
    #mapper.load_state_dict(torch.load("semantic_mapper.pth"))
    #P, _ = build_nullspace(feats, use_mapper = True, mapper=mapper, device = "cuda")

    P, _ = build_nullspace(feats, use_mapper = True) # Pass feats as the first argument
    np.savez("nullspace.npz", P=P.cpu().numpy())  # Save on CPU
    print(f"✅ NULL-space saved. Shape: {tuple(P.shape)}")

    return P.to(device)

def train_nsnet(data_root="/kaggle/input/dalle-recognition-dataset",
                train_limit_real=1000, train_limit_fake=1000,
                test_limit_real=100, test_limit_fake=100):
    device = get_device()
    print("Using device:", device)
    captions = json.load(open("captions.json"))

    model_clip,_,_ = open_clip.create_model_and_transforms("ViT-L-14", pretrained="openai")
    model_clip.to(device).eval()
    mapper = SemanticMapper(dim=768, hidden=1024, out_dim=768).to(device)
    head = NSNetHead(dim=768).to(device)

    opt = torch.optim.Adam(list(head.parameters())+list(mapper.parameters()), lr=2e-4)

    tfm = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.48145466,0.4578275,0.40821073),
                             std=(0.26862954,0.26130258,0.27577711))
    ])

    train_ds = DALLEDataset(
        data_root, "train", captions, tfm,
        train_limit_real=train_limit_real, train_limit_fake=train_limit_fake,
        test_limit_real=test_limit_real, test_limit_fake=test_limit_fake
    )
    test_ds = DALLEDataset(
        data_root, "test", captions, tfm,
        train_limit_real=train_limit_real, train_limit_fake=train_limit_fake,
        test_limit_real=test_limit_real, test_limit_fake=test_limit_fake
    )

    train_dl = DataLoader(train_ds,batch_size=8,shuffle=True)

    best_ap = 0
    for ep in range(1, 3):
        head.train()
        mapper.train()
        tot = 0
        for x, y, caps in tqdm(train_dl, desc=f"Epoch {ep}"):
            x, y = x.to(device), y.to(device)

            #Encode the input image
            with torch.no_grad():
              f_img = model_clip.encode_image(x)

            #Build/ train the Null Space from Scratch instead of directly using P via our Semantic Mapper Module
            tokens = open_clip.get_tokenizer("ViT-L-14")(list(caps)).to(device)
            with torch.no_grad():
              f_text = model_clip.encode_text(tokens)

            #Pass the text features through the Mapper
            f_text_mapped = mapper(f_text)

            #Build Per Sample Null Space Projection: P_i = I - vv.T/||v||^2
            B = x.shape[0]
            D = f_text_mapped.shape[1]
            eye = torch.eye(D, device=device)
            f_null = []

            for i in range(B):
                v = f_text_mapped[i].unsqueeze(0)  # (1, D)
                v_norm = (v @ v.T).clamp(min=1e-8)  # scalar norm^2
                P_i = eye - (v.T @ v) / v_norm     # (D, D)
                f_proj = f_img[i].unsqueeze(0) @ P_i  # (1, D)
                f_null.append(f_proj.squeeze(0))

            f_null = torch.stack(f_null)  # (B, D)
            f, logit = head(f_null)

            #Losses
            cls_loss = F.binary_cross_entropy_with_logits(logit, y.float())
            align_loss = F.mse_loss(f_text_mapped, f_img.detach())  #alignment with semantic image features
            contrast_loss = nt_xent(f, y)

            loss = contrast_loss + 0.2 * cls_loss + 0.3 * align_loss


            opt.zero_grad()
            loss.backward()
            opt.step()
            tot += loss.item()
        print(f"Epoch {ep} loss {tot/len(train_dl):.4f}")

        # Eval
        head.eval()
        mapper.eval()
        logits, ys = [], []

        for x, y, caps in DataLoader(test_ds, batch_size=8, drop_last=False):
          x,y = x.to(device), y.to(device)
          tokens = open_clip.get_tokenizer("ViT-L-14")(list(caps)).to(device)

          with torch.no_grad():
            f_img = model_clip.encode_image(x)
            f_text = model_clip.encode_text(tokens)

            f_text_mapped = mapper(f_text)

            # apply per-sample projection
            B = x.shape[0] # Get the actual batch size
            f_img, f_text_mapped = f_img[:B], f_text_mapped[:B]

            eye = torch.eye(f_text_mapped.shape[1], device=device)
            f_null = []

            for i in range(B):
                v = f_text_mapped[i].unsqueeze(0)  # (1, D)
                v_norm = (v @ v.T).clamp(min=1e-8)
                P_i = eye - (v.T @ v) / v_norm
                f_proj = f_img[i].unsqueeze(0) @ P_i
                f_null.append(f_proj.squeeze(0))

            f_null = torch.stack(f_null)
            _, log = head(f_null)

          logits += log.cpu().tolist()
          ys += y.tolist()

        m = metrics(ys, logits)
        print("Val metrics", m)
        if m["ap"] > best_ap:
            best_ap = m["ap"]
            torch.save(head.state_dict(), "best_head.pth")
            torch.save(mapper.state_dict(), "best_mapper.pth")
    print("Training done. Best AP:", best_ap)

Overwriting nsnet_cpu.py


## Loading the dataset and captioning images

In [None]:
data_root = "/kaggle/input/dalle-recognition-dataset"

caps = nsnet_cpu.gen_captions1(
    data_root,
    out_json="captions.json",
    limit_real=1100,
    limit_fake=1100
)

caps = [f"A photo of {c}" for c in caps]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


✅ open_clip patching is active.
✅ open_clip patching is active.


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/506 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

Total selected images for captioning: 2200 (1100 real + 1100 fake)



BLIP captioning:   0%|          | 0/2200 [00:00<?, ?it/s][A
BLIP captioning:   0%|          | 1/2200 [00:13<8:19:29, 13.63s/it][A
BLIP captioning:   0%|          | 2/2200 [00:17<4:57:26,  8.12s/it][A
BLIP captioning:   0%|          | 3/2200 [00:23<4:20:48,  7.12s/it][A
BLIP captioning:   0%|          | 4/2200 [00:29<4:00:27,  6.57s/it][A
BLIP captioning:   0%|          | 5/2200 [00:33<3:28:18,  5.69s/it][A
BLIP captioning:   0%|          | 6/2200 [00:44<4:33:40,  7.48s/it][A
BLIP captioning:   0%|          | 7/2200 [00:47<3:38:19,  5.97s/it][A
BLIP captioning:   0%|          | 8/2200 [00:50<3:05:55,  5.09s/it][A
BLIP captioning:   0%|          | 9/2200 [00:53<2:43:45,  4.48s/it][A
BLIP captioning:   0%|          | 10/2200 [00:56<2:24:58,  3.97s/it][A
BLIP captioning:   0%|          | 11/2200 [00:59<2:10:22,  3.57s/it][A
BLIP captioning:   1%|          | 12/2200 [01:02<2:03:06,  3.38s/it][A
BLIP captioning:   1%|          | 13/2200 [01:05<2:06:14,  3.46s/it][A
BLIP captio

✅ Saved 2200 captions to captions.json
⚠️ Skipped 0 problematic images.





In [None]:
!mkdir -p "/content/drive/MyDrive/DALLE_NSNet"
!cp /content/captions.json "/content/drive/MyDrive/DALLE_NSNet/captions.json"
print("✅ Captions saved to Drive at: /content/drive/MyDrive/DALLE_NSNet/captions.json")
import json

with open("/content/drive/MyDrive/DALLE_NSNet/captions.json", "r") as f:
    caps = json.load(f)

for k, v in list(caps.items()):
    if isinstance(v, list):
        caps[k] = v[0] if len(v) > 0 else ""
    elif not isinstance(v, str):
        caps[k] = str(v)

json.dump(caps, open("captions.json", "w"))
print(f"✅ Fixed captions: {len(caps)} entries saved to captions_fixed.json")

✅ Captions saved to Drive at: /content/drive/MyDrive/DALLE_NSNet/captions.json
✅ Fixed captions: 2200 entries saved to captions_fixed.json


## Training the model
Due to limited computation resources we used 2000 Real and 2000 Fake images for training with 2 epochs.

Note: Unlike NS-NET, we don't precompute the Null Space Matrix (P). This ensures the semantic removal is tailored to each sample individually and hence improves performance of the model.

In [None]:
data_root = "/kaggle/input/dalle-recognition-dataset"
nsnet_cpu.train_nsnet(
    data_root=data_root,
    train_limit_real=1000,
    train_limit_fake=1000,
    test_limit_real=100,
    test_limit_fake=100
)

Using device: cpu
[TRAIN] Loaded 1000 REAL and 1000 FAKE images from '/kaggle/input/dalle-recognition-dataset'
[TEST] Loaded 100 REAL and 100 FAKE images from '/kaggle/input/dalle-recognition-dataset'


Epoch 1: 100%|██████████| 250/250 [1:54:22<00:00, 27.45s/it]


Epoch 1 loss 0.8458
Val metrics {'acc': 0.91, 'r_acc': 0.9, 'f_acc': 0.92, 'ap': np.float64(0.9664097875239941)}


Epoch 2: 100%|██████████| 250/250 [1:54:16<00:00, 27.43s/it]


Epoch 2 loss 0.7164
Val metrics {'acc': 0.87, 'r_acc': 0.77, 'f_acc': 0.97, 'ap': np.float64(0.9753267424044924)}
Training done. Best AP: 0.9753267424044924


We get significantly good results compared to NS-NET implementation even on limited training.

# Testing our Model

In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.metrics import accuracy_score, average_precision_score, confusion_matrix
from tqdm import tqdm

with open("captions.json", "r") as f:
    caps = json.load(f)


tokenizer = open_clip.get_tokenizer("ViT-L-14")
device = nsnet_cpu.get_device()

model_clip, _, _ = open_clip.create_model_and_transforms("ViT-L-14", pretrained="openai")
model_clip.to(device).eval()

_dummy_tokens = tokenizer(["hello world"]).to(device)
with torch.no_grad():
    _dummy_text_feat = model_clip.encode_text(_dummy_tokens)
text_dim = _dummy_text_feat.shape[1]
print(f"Detected text feature dim = {text_dim}")

head = nsnet_cpu.NSNetHead(dim=text_dim).to(device)
head_path = "best_head.pth"
try:
    state = torch.load(head_path, map_location=device)
    head.load_state_dict(state)
    head.eval()
    print("✅ NSNetHead loaded successfully from", head_path)
except Exception as e:
    print("⚠️ Failed to load NSNetHead:", repr(e))
    raise

print("⚙️ Building nullspace dynamically from captions...")
texts = [str(t) for t in caps.values()]

tfm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                         std=(0.26862954, 0.26130258, 0.27577711))
])

test_ds = nsnet_cpu.DALLEDataset(
    root=data_root,
    split="test",
    captions=caps,
    transform=tfm,
    train_limit_real=1000,
    train_limit_fake=1000,
    test_limit_real=100,
    test_limit_fake=100
)
test_dl = DataLoader(test_ds, batch_size=8, shuffle=False)

print(f"🧾 Number of test images loaded: {len(test_ds)}")

logits = []
ys = []

for x, y, caps_batch in tqdm(test_dl, desc="Testing"):
    x = x.to(device)
    y = y.to(device)

    tokens = tokenizer(list(caps_batch)).to(device)
    with torch.no_grad():
        f_img = model_clip.encode_image(x)            # (B_img, D_img)
        f_text = model_clip.encode_text(tokens)       # (B_cap, text_dim)
        if mapper is not None:
            f_text_mapped = mapper(f_text)            # (B_cap, text_dim)
        else:
            f_text_mapped = f_text                    # use raw text features

        B_img = f_img.shape[0]
        B_txt = f_text_mapped.shape[0]
        B = min(B_img, B_txt)
        if B < 1:
            continue

        f_img = f_img[:B]
        f_text_mapped = f_text_mapped[:B]

        D = f_text_mapped.shape[1]
        eye = torch.eye(D, device=device)

        f_null_list = []
        for i in range(B):
            v = f_text_mapped[i].unsqueeze(0)      # (1, D)
            v_norm = (v @ v.T).clamp(min=1e-8)     # scalar
            P_i = eye - (v.T @ v) / v_norm         # (D, D)
            f_proj = f_img[i].unsqueeze(0) @ P_i   # (1, D)
            f_null_list.append(f_proj.squeeze(0))

        f_null = torch.stack(f_null_list)          # (B, D)
        _, batch_logits = head(f_null)

    logits.extend(batch_logits.cpu().tolist())
    ys.extend(y[:B].cpu().tolist())

if len(ys) == 0:
    raise RuntimeError("No test outputs were produced (ys empty). Check the dataset and caption tokenization.")

y_true = np.array(ys)
probs = torch.sigmoid(torch.tensor(logits)).numpy()
y_pred = (probs >= 0.5).astype(int)

acc = accuracy_score(y_true, y_pred)
r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0])
f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1])
ap = average_precision_score(y_true, probs)
cm = confusion_matrix(y_true, y_pred)

print("Done. Acc:", acc, "AP:", ap)
print("\n✅ Evaluation Results on DALLE Recognition Test Set:")
print(f"Overall Accuracy : {acc:.4f}")
print(f"Real Accuracy    : {r_acc:.4f}")
print(f"Fake Accuracy    : {f_acc:.4f}")
print(f"Average Precision: {ap:.4f}")

print("\n📊 Confusion Matrix:")
print("        Pred Real | Pred Fake")
print(f"Real | {cm[0][0]:5d}       | {cm[0][1]:5d}")
print(f"Fake | {cm[1][0]:5d}       | {cm[1][1]:5d}")

print("\n✅ Testing completed successfully.")

Detected text feature dim = 768
✅ NSNetHead loaded successfully from best_head.pth
✅ SemanticMapper loaded successfully from best_mapper.pth
⚙️ Building nullspace dynamically from captions...
[TEST] Loaded 100 REAL and 100 FAKE images from '/kaggle/input/dalle-recognition-dataset'
🧾 Number of test images loaded: 200


Testing: 100%|██████████| 25/25 [11:44<00:00, 28.17s/it]

Done. Acc: 0.93 AP: 0.9940644541893723

✅ Evaluation Results on DALLE Recognition Test Set:
Overall Accuracy : 0.9300
Real Accuracy    : 0.8600
Fake Accuracy    : 1.0000
Average Precision: 0.9941

📊 Confusion Matrix:
        Pred Real | Pred Fake
Real |    86       |    14
Fake |     0       |   100

✅ Testing completed successfully.





We get significant improvement to the original NS-NET implementation even with limited training. Key highlight is that the fake accuracy is 1.