In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTForImageClassification, ViTImageProcessor
from torchvision import transforms
from PIL import Image

# =========================
# DEVICE
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =========================
# CNN ARCHITECTURE
# =========================
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()


        self.conv_layers = nn.Sequential(

            nn.Conv2d(3,32,kernel_size=3,padding=1),   # 1st layer
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32,64,kernel_size=3,padding=1),   #2nd layer
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64,128,kernel_size=3,padding=1),   #3rd layer
            nn.ReLU(),
            nn.MaxPool2d(2)

        )


        self.fv_layer = nn.Sequential(

            nn.Flatten(),
            nn.Linear(128*28*28,512),  
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512,1)   # inpopu tvalue 512 and ouput value 1
        )


    def forward(self,x):

        x = self.conv_layers(x)
        x = self.fv_layer(x)    # input = (batchsize,features)

        return x

# =========================
# LOAD CNN (WEIGHTS)
# =========================
cnn_model = SimpleCNN()
cnn_model.load_state_dict(torch.load(
    r"E:\pytorch\simple_cnn.pth", map_location=device
))
cnn_model.to(device)
cnn_model.eval()

cnn_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# =========================
# LOAD ViT (FINE-TUNED)
# =========================
vit_path = r"E:\pytorch\content\ai_vs_real_vit_model"

vit_processor = ViTImageProcessor.from_pretrained(vit_path)
vit_model = ViTForImageClassification.from_pretrained(vit_path)
vit_model.to(device)
vit_model.eval()

# =========================
# CNN PREDICTION
# =========================
def predict_cnn(image):
    x = cnn_transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        probs = F.softmax(cnn_model(x), dim=1)
    return probs

# =========================
# ViT PREDICTION
# =========================
def predict_vit(image):
    inputs = vit_processor(image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        probs = F.softmax(vit_model(**inputs).logits, dim=1)
    return probs

def ensemble_predict(image_path):
    image = Image.open(image_path).convert("RGB")

    # Get predictions
    cnn_probs = predict_cnn(image)
    vit_probs = predict_vit(image)

    # Check if ViT says REAL with at least 20% probability
    if vit_probs[0, 0] >= 0.2:  # index 0 = REAL
        final_pred = 0  # REAL
        final_probs = vit_probs  # you can also keep weighted average if you want
    else:
        # Weighted average: 30% CNN, 70% ViT
        final_probs = 0.3 * cnn_probs + 0.7 * vit_probs
        final_pred = final_probs.argmax(dim=1).item()

    return {
        "CNN": cnn_probs.squeeze().tolist(),
        "ViT": vit_probs.squeeze().tolist(),
        "Final": final_probs.squeeze().tolist(),
        "Prediction": "REAL" if final_pred == 0 else "FAKE"
    }


# =========================
# TEST
# =========================
result = ensemble_predict(r"E:\pytorch\images.jpg")

print("CNN probs   :", result["CNN"])
print("ViT probs   :", result["ViT"])
print("Final probs :", result["Final"])
print("PREDICTION  :", result["Prediction"])




  cnn_model.load_state_dict(torch.load(


CNN probs   : 1.0
ViT probs   : [0.03786914423108101, 0.9621308445930481]
Final probs : [0.32650840282440186, 0.9734916090965271]
PREDICTION  : FAKE
