In [1]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm 
from timm import create_model
import pandas as pd

device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ===============================
# Test Dataset
# ===============================
class TestDataset(Dataset):
    def __init__(self, test_dir, csv_path=None):
        self.test_dir = test_dir
        
        if csv_path is not None:
            # Validation mode (with labels)
            self.data = pd.read_csv(csv_path, dtype={'ID': str})
            self.has_labels = True
        else:
            # Submission mode (no labels)
            img_files = [f for f in os.listdir(test_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            self.data = pd.DataFrame({'ID': [os.path.splitext(f)[0] for f in img_files]})
            self.has_labels = False

        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_id = str(self.data.iloc[idx]['ID'])
        img_path = os.path.join(self.test_dir, img_id + ".jpg")
        
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)

        if self.has_labels:
            label = int(self.data.iloc[idx]['label'])
            return image, label, img_id
        else:
            return image, img_id

In [None]:
# ===============================
# Neural NetWork
# ===============================
class CNN(nn.Module):
    def __init__(self, pretrained=True, freeze_backbone=True, dropout=0.3):
        super(CNN, self).__init__()

        # === ViT-B ===
        self.vit  = create_model('vit_large_patch16_224', pretrained=pretrained, num_classes=0)  # 1024-dim

        # Freeze backbones
        if freeze_backbone:
            for param in self.vit.parameters():
                param.requires_grad = False

        self.fusion = nn.Sequential(
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),

            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),

            nn.Linear(128, 2)  # Exactly 2 classes: real vs AIGC
        )

    def forward(self, x):
        feat = self.vit(x)
        out = self.fusion(feat)
        return out

In [4]:
def test_model():
    data_root = "./"
    model_path = "model.pth"

    test_dir = os.path.join(data_root, "test")
    output_csv = "submission.csv" 

    print(f"Loading images from: {test_dir}")
    print(f"Loading model from: {model_path}")

    # Load Dataset
    test_dataset = TestDataset(test_dir, csv_path=None)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0, pin_memory=True)

    # Load model
    model = CNN(pretrained=False, freeze_backbone=True).to(device)
    # Load trained weights from the model
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    print("Model loaded")

    results = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Predicting"):
            if len(batch) == 3:
                imgs, _, img_ids = batch
            else:
                imgs, img_ids = batch

            imgs = imgs.to(device)
            outputs = model(imgs)
            preds = outputs.argmax(dim=1).cpu().numpy()  # 0 or 1

            for img_id, pred in zip(img_ids, preds):
                results.append({"ID": img_id, "label": int(pred)})

    # Save to CSV
    submission_df = pd.DataFrame(results)
    submission_df.to_csv(output_csv, index=False)
    print(f"\nSubmission saved to '{output_csv}'")
    print(f"Total predictions: {len(submission_df)}")
    print(submission_df.head())

# Run it
if __name__ == "__main__":
    test_model()

Loading images from: ./test
Loading model from: model.pth
Model loaded


Predicting: 100%|██████████| 79/79 [00:32<00:00,  2.42it/s]


Submission saved to 'submission.csv'
Total predictions: 2500
              ID  label
0  image_0000001      0
1  image_0000002      1
2  image_0000003      1
3  image_0000004      0
4  image_0000005      0





In [5]:
data_root = "path_to_dataset"
model_path = "saved_model.pth"