In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2

import numpy as np
from collections import defaultdict

from PIL import Image, ImageFilter
import io
import re
import random
import numpy.random as npr
from skimage import data
from scipy.ndimage import rotate
from kernels import *
import torchvision
import os
from torchvision.transforms.functional import to_pil_image
from torch.utils.data import Dataset, DataLoader, Subset

import torchvision.transforms as transforms
 
from collections import defaultdict
import my_utils as ut
from transformers import Swinv2ForImageClassification, SwinConfig
from torch.optim import AdamW
from torchvision import transforms, datasets



## Load dataset

In [10]:

class DatasetAI(Dataset):
    def __init__(self, root_dir, transform=None, split='train'):
        self.root_dir = root_dir
        self.transform = transform
        self.split = split  # Can be 'train', 'val', or 'test'
        self.samples = []

        for model_name in sorted(os.listdir(root_dir)):
            model_path = os.path.join(root_dir, model_name)
            if os.path.isdir(model_path):
                # Construct the imagenet directory path
                imagenet_dir = f'imagenet_{model_name}'
                data_dir = os.path.join(model_path, imagenet_dir, split)
                if os.path.isdir(data_dir):
                    for class_label in ['ai', 'nature']:
                        class_path = os.path.join(data_dir, class_label)
                        if os.path.exists(class_path):
                            for img_name in os.listdir(class_path):
                                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                                    img_path = os.path.join(class_path, img_name)
                                    self.samples.append((img_path, class_label, model_name))

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, class_label, model_name = self.samples[idx]
        image = Image.open(img_path).convert('RGB')

        rich, poor = ut.smash_n_reconstruct(image) 
        if self.transform:
            rich = self.transform(rich)
            poor = self.transform(poor)
   
        label = 0 if class_label == 'ai' else 1
        # Return the model name along with the other data
        return rich, poor, label, model_name



def subset_train(dataset, desired_size, seed=42):
    rng = np.random.default_rng(seed)
    # Organize indices by both model and class
    model_class_indices = defaultdict(list)
    for idx, (_, class_label, model_name) in enumerate(dataset.samples):
        model_class_indices[(model_name, class_label)].append(idx)
    
    # Determine the minimum size across all model-class combinations to ensure balance
    min_group_size = min(len(indices) for indices in model_class_indices.values())
    # Calculate the number of samples to select per model-class combination
    samples_per_group = min(min_group_size, desired_size // len(model_class_indices))
    
    balanced_indices = []
    for indices in model_class_indices.values():
        selected_indices = rng.choice(indices, samples_per_group, replace=False)
        balanced_indices.extend(selected_indices)
    
    # Shuffle the indices to ensure the dataset order does not introduce bias
    rng.shuffle(balanced_indices)
    
    # Create the balanced training set
    balanced_train_set = Subset(dataset, balanced_indices)
    return balanced_train_set




def subset_val_test(dataset, val_size, test_size, seed=42):
    rng = random.Random(seed)
    model_class_indices = defaultdict(list)
    for idx, (_, class_label, model_name) in enumerate(dataset.samples):
        model_class_indices[(model_name, class_label)].append(idx)
        

    min_group_size = min(len(indices) for indices in model_class_indices.values())
    val_samples_per_group = min(min_group_size, val_size // len(model_class_indices))

    val_indices = []
    test_indices = []

    for indices in model_class_indices.values():
        val_indices.extend(rng.sample(indices, val_samples_per_group))

    all_val_indices = set(val_indices)

    remaining_indices = [idx for idx in range(len(dataset)) if idx not in all_val_indices]

    test_group_size = test_size // len(model_class_indices)

    for indices in model_class_indices.values():
        test_indices.extend(rng.sample([idx for idx in indices if idx in remaining_indices], test_group_size))

    val_subset = Subset(dataset, val_indices)
    test_subset = Subset(dataset, test_indices)

    return val_subset, test_subset



transform = transforms.Compose([
     transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
   
])

# Create dataset instances
train_dataset = DatasetAI(root_dir='/mnt/e/GenImage', transform=transform, split='train')
val_test_dataset = DatasetAI(root_dir='/mnt/e/GenImage', transform=transform, split='val')
test_dalle = DatasetAI(root_dir='/mnt/e/GenImage', transform=transform, split='val')
# Balance the training dataset
val_subset, test_subset = subset_val_test(val_test_dataset, 10000,10000)

train_subset = subset_train(train_dataset, 30000)

# Create DataLoader for each dataset
train_loader = DataLoader(train_subset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_subset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_subset, batch_size=32, shuffle=False, num_workers=4)

 

train_set = set(train_subset.indices)
val_set = set(val_subset.indices)
test_set = set(test_subset.indices)


# Find the intersection of these three sets
intersection_trian_test = train_set.intersection(train_set, test_set)
intersection_valid_train = train_set.intersection(train_set, val_set)
intersection_val_test = val_set.intersection(val_set, test_set)


print(len(intersection_trian_test))
print(len(intersection_valid_train))
print(len(intersection_val_test))

206
247
0


In [3]:


class HighPassFilters(nn.Module):
    def __init__(self, kernels):
        super(HighPassFilters, self).__init__()
        # Kernels are a parameter but not trained
        self.kernels = nn.Parameter(kernels, requires_grad=False)

    def forward(self, x):
        # Apply convolution with padding to maintain output size equal to input size
        return F.conv2d(x, self.kernels, padding =2)  # Padding set to 2 to maintain output size




## CNN

In [4]:

class CNNBlock(nn.Module):
   def __init__(self, kernals):
       super(CNNBlock, self).__init__()
       self.conv = nn.Conv2d(30, 3, kernel_size=1,padding=0)
       self.filters = HighPassFilters(kernals)
       self.bn = nn.BatchNorm2d(3)
       self.htanh = nn.Hardtanh()
   def forward(self, x):
       x = self.filters(x)
       x = self.conv(x)
       x = self.bn(x)
       x = self.htanh(x)
       return x
  

## Model

In [5]:
class ImageClassificationModel(nn.Module):
    def __init__(self,kernels):
        super(ImageClassificationModel, self).__init__()
        self.feature_combiner = CNNBlock(kernels)
        self.feature_combiner2 = CNNBlock(kernels)
        config = SwinConfig.from_pretrained('microsoft/swinv2-tiny-patch4-window8-256',num_classes=2)
        self.transformer = Swinv2ForImageClassification.from_pretrained(
            "microsoft/swinv2-tiny-patch4-window8-256",
            config=config
        )
        
        self.transformer.classifier = nn.Linear(config.hidden_size, 2) 

 
    def forward(self, rich, poor):
       
        x = self.feature_combiner(rich)
        y = self.feature_combiner2(poor)   
        feature_difference = x - y
        outputs = self.transformer(feature_difference)

        return outputs.logits


## Train & Validation

In [8]:


kernels = ut.apply_high_pass_filter()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImageClassificationModel(kernels).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = AdamW([
    {'params': model.feature_combiner.parameters(), 'lr': 1e-4,},
    {'params': model.feature_combiner2.parameters(), 'lr': 1e-4,},
    {'params': model.transformer.parameters(), 'lr': 1e-4,}
])
# #freeze the transformer
# for param in model.transformer.parameters():
#     param.requires_grad = False
# #unfreeze classifier
# for param in model.transformer.classifier.parameters():
#     param.requires_grad = True
    
best_val_accuracy = 0.0
best_model_path = '/home/kosta/code/School/SentryAI/pth/best_model_newPatching_Crazy.pth'

#
# Try to load previous best model and its best validation accuracy
try:
    checkpoint = torch.load(best_model_path)
    best_val_accuracy = checkpoint['best_val_accuracy']
    print("Loaded previous best model with accuracy:", best_val_accuracy)
except FileNotFoundError:
    best_val_accuracy = float('-inf')
    print("No saved model found. Starting fresh!")
from collections import defaultdict

def train_and_validate(model, train_loader, valid_loader, optimizer, device, num_epochs, best_val_accuracy):
    
    best_val_accuracy_general = best_val_accuracy  # Use this to track the overall best accuracy

    for epoch in range(num_epochs):
        # # Training Phase
        model.train()
        total_train_loss, total_train, correct_train = 0, 0, 0
        for batch in train_loader:
            rich, poor, labels, model_names = batch  # Unpack model_names as well
            rich = rich.to(device)
            poor = poor.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(rich, poor)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item() * labels.size(0)
            _, predicted = torch.max(outputs, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        train_loss = total_train_loss / total_train
        train_accuracy = correct_train / total_train

        # Validation Phase
        model.eval()
        val_accuracy_per_model = defaultdict(lambda: {'correct': 0, 'total': 0})
        total_val_loss, total_val, correct_val = 0, 0, 0
        with torch.no_grad():
            for batch in valid_loader:
                rich, poor, labels, model_names = batch  # Unpack model_names as well
                rich = rich.to(device)
                poor = poor.to(device)
                labels = labels.to(device)

                outputs = model(rich, poor)
                loss = criterion(outputs, labels)

                total_val_loss += loss.item() * labels.size(0)
                _, predicted = torch.max(outputs, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

                # Collect stats per model
                for model_name, pred, true in zip(model_names, predicted, labels):
                    val_accuracy_per_model[model_name]['total'] += 1
                    if pred == true:
                        val_accuracy_per_model[model_name]['correct'] += 1

        val_loss = total_val_loss / total_val
        val_accuracy_general = correct_val / total_val

        # Print overall validation accuracy
        print(f'Epoch {epoch+1}/{num_epochs}\n,'
              f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy_general:.4f}\n')
        



        # Check if general accuracy is the best and save
        if val_accuracy_general > best_val_accuracy_general:
            best_val_accuracy_general = val_accuracy_general
            torch.save({'model_state': model.state_dict(),
                        'best_val_accuracy': best_val_accuracy_general},
                       best_model_path)
            print(f"Saved new best general model with accuracy: {best_val_accuracy_general:.4f}")

train_and_validate(model, train_loader, val_loader, optimizer, device, num_epochs=10, best_val_accuracy=best_val_accuracy)


NameError: name 'apply_high_pass_filter' is not defined

## Test

In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImageClassificationModel(kernels).to(device)

def test(model, test_loader, device):
    # Load the best model
    checkpoint = torch.load("/home/kosta/code/School/SentryAI/pth/best_model_newPatching.pth")
    model.load_state_dict(checkpoint['model_state'])
    
    model.eval()
    total_test, correct_test = 0, 0
    test_accuracy_per_model = defaultdict(lambda: {'correct': 0, 'total': 0})

    with torch.no_grad():
        for batch in test_loader:
            rich, poor, labels, model_names = batch  # Assuming you have model_names
            rich = rich.to(device)
            poor = poor.to(device)
            labels = labels.to(device)

            outputs = model(rich, poor)
            _, predicted = torch.max(outputs, 1)
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()

            # Collect stats per model just like validation phase
            for model_name, pred, true in zip(model_names, predicted, labels):
                test_accuracy_per_model[model_name]['total'] += 1
                if pred == true:
                    test_accuracy_per_model[model_name]['correct'] += 1

    test_accuracy = correct_test / total_test
    print(f'Test Accuracy: {test_accuracy:.4f}')

    # Print per model accuracy
    print("-------------------------------------------------------------------------")
    print("Test Accuracy per model:")
    for model_name, stats in test_accuracy_per_model.items():
        model_accuracy = stats['correct'] / stats['total']
        print(f"Test Accuracy for model {model_name}: {model_accuracy:.4f}")

test(model, test_loader, device)
from PIL import Image
import torchvision.transforms as transforms
import torch

# Load the model
checkpoint = torch.load("/home/kosta/code/School/SentryAI/pth/best_model_newPatching.pth")
model.load_state_dict(checkpoint['model_state'])
model.to(device)
model.eval()

# Define the image path
img_path = '/mnt/c/Users/kosta/Downloads/Screenshot 2024-04-28 002208.png'

# Load the image
rich, poor  = ut.smash_n_reconstruct(Image.open(img_path).convert('RGB'))

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
rich_tensor = transform(rich)
poor_tensor = transform(poor)

# Predict
with torch.no_grad():
    output = model(rich_tensor.unsqueeze(0).to(device), poor_tensor.unsqueeze(0).to(device))
    _, predicted = torch.max(output, 1)

# Print the predicted class
print("Predicted class:", predicted.item())
