In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTForImageClassification, ViTImageProcessor
from torchvision import transforms
from PIL import Image
import cv2
from collections import Counter

# =========================
# DEVICE
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =========================
# CNN ARCHITECTURE
# =========================
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()


        self.conv_layers = nn.Sequential(

            nn.Conv2d(3,32,kernel_size=3,padding=1),   # 1st layer
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32,64,kernel_size=3,padding=1),   #2nd layer
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64,128,kernel_size=3,padding=1),   #3rd layer
            nn.ReLU(),
            nn.MaxPool2d(2)

        )


        self.fv_layer = nn.Sequential(

            nn.Flatten(),
            nn.Linear(128*28*28,512),  
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512,1)   # inpopu tvalue 512 and ouput value 1
        )


    def forward(self,x):

        x = self.conv_layers(x)
        x = self.fv_layer(x)    # input = (batchsize,features)

        return x

# =========================
# LOAD CNN (WEIGHTS)
# =========================
cnn_model = SimpleCNN()
cnn_model.load_state_dict(torch.load(r"E:\pytorch\simple_cnn.pth", map_location=device))
cnn_model.to(device)
cnn_model.eval()

cnn_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# =========================
# LOAD ViT (FINE-TUNED)
# =========================
vit_path = r"E:\pytorch\content\ai_vs_real_vit_model"
vit_processor = ViTImageProcessor.from_pretrained(vit_path)
vit_model = ViTForImageClassification.from_pretrained(vit_path)
vit_model.to(device)
vit_model.eval()

# =========================
# CNN PREDICTION
# =========================
def predict_cnn(image):
    x = cnn_transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        probs = F.softmax(cnn_model(x), dim=1)
    return probs

# =========================
# ViT PREDICTION
# =========================
def predict_vit(image):
    inputs = vit_processor(image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        probs = F.softmax(vit_model(**inputs).logits, dim=1)
    return probs

# =========================
# ENSEMBLE PREDICTION FOR SINGLE IMAGE
# =========================
def ensemble_predict(image):
    if isinstance(image, str):
        image = Image.open(image).convert("RGB")

    cnn_probs = predict_cnn(image)
    vit_probs = predict_vit(image)

    # ViT ≥ 20% REAL → REAL
    if vit_probs[0, 0] >= 0.2:
        final_pred = 0
        final_probs = vit_probs
    else:
        final_probs = 0.3 * cnn_probs + 0.7 * vit_probs
        final_pred = final_probs.argmax(dim=1).item()

    return {
        "CNN": cnn_probs.squeeze().tolist(),
        "ViT": vit_probs.squeeze().tolist(),
        "Final": final_probs.squeeze().tolist(),
        "Prediction": "REAL" if final_pred == 0 else "FAKE"
    }

# =========================
# VIDEO PREDICTION
# =========================
def predict_video(video_path, frame_interval=10):
    cap = cv2.VideoCapture(video_path)
    frame_count = 0
    frame_preds = []

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        if frame_count % frame_interval == 0:
            frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            result = ensemble_predict(frame_pil)
            frame_preds.append(result["Prediction"])

        frame_count += 1

    cap.release()

    # If any frame is REAL, video is REAL (based on your 20% ViT rule)
    if "REAL" in frame_preds:
        final_video_pred = "REAL"
    else:
        final_video_pred = "FAKE"

    return {
        "Frame_predictions": frame_preds,
        "Final_prediction": final_video_pred
    }


# Video
video_result = predict_video(r"E:\pytorch\mountain-top-temple-clouds-smoke-4k-live-wallpaper.mp4", frame_interval=10)
print("Video prediction:", video_result)


  from .autonotebook import tqdm as notebook_tqdm
  cnn_model.load_state_dict(torch.load(r"E:\pytorch\simple_cnn.pth", map_location=device))


Image prediction: {'CNN': 1.0, 'ViT': [0.03786914423108101, 0.9621308445930481], 'Final': [0.32650840282440186, 0.9734916090965271], 'Prediction': 'FAKE'}
Video prediction: {'Frame_predictions': ['FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE', 'FAKE'], 'Final_prediction': 'FAKE'}
