In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2

import numpy as np

from PIL import Image, ImageFilter
import io
from preprocessing import *
import random
import numpy.random as npr
from skimage import data
from scipy.ndimage import rotate
from kernels import *
import torchvision
import os
from torchvision.transforms.functional import to_pil_image
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from preprocessing import *
from transformers import Swinv2ForImageClassification, SwinConfig
from torch.optim import AdamW
from torchvision import transforms, datasets



## Load dataset

In [None]:

def preprocess_image(image):    
    rich, poor = smash_n_reconstruct(image)
    rich = apply_high_pass_filter(rich)
    poor = apply_high_pass_filter(poor)
    return rich, poor


class DatasetAI(Dataset):
    def __init__(self, root_dir, transform, split='train'):
        self.root_dir = root_dir
        self.transform = transform
        self.split = split  # This can be 'train', 'val', or 'test'
        self.samples = []
        self.label_count = {'ai': 0, 'nature': 0}

        for model in sorted(os.listdir(root_dir)):
            model_path = os.path.join(root_dir, model)
            if os.path.isdir(model_path):
                # Depending on the split, choose the appropriate subdirectory
                split_folder = 'train' if split == 'train' else 'val'
                data_dir = os.path.join(model_path, f'imagenet_{model.split("_")[0]}', split_folder)
                for class_label in ['ai', 'nature']:
                    class_path = os.path.join(data_dir, class_label)
                    if os.path.exists(class_path):
                        for img_name in os.listdir(class_path):
                            if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                                img_path = os.path.join(class_path, img_name)
                                self.samples.append((img_path, class_label))
                                self.label_count[class_label] += 1

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, class_label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        rich, poor = smash_n_reconstruct(image)  # Assume this function is defined elsewhere
        if self.transform:
            rich = self.transform(rich)
            poor = self.transform(poor)
   
        label = 0 if class_label == 'ai' else 1
        return rich, poor, label

def split_val_test_train(dataset_test_valid, dataset_train, train_size, val_size, test_size, seed=42):
    rng = npr.default_rng(seed)
    total_size_test_valid = len(dataset_test_valid)
    total_size_train = len(dataset_train)

    indices_test_valid = np.arange(total_size_test_valid)
    indices_train = np.arange(total_size_train)

    rng.shuffle(indices_test_valid)
    rng.shuffle(indices_train)

    if val_size + test_size > total_size_test_valid:
        raise ValueError("Requested sizes for validation and test exceed available data")
    if train_size > total_size_train:
        raise ValueError("Requested size for train exceeds available data")

    val_indices = indices_test_valid[:val_size]
    test_indices = indices_test_valid[val_size:val_size + test_size]
    train_indices = indices_train[:train_size]

    val_subset = Subset(dataset_test_valid, val_indices)
    test_subset = Subset(dataset_test_valid, test_indices)
    train_subset = Subset(dataset_train, train_indices)

    return train_subset, val_subset, test_subset



transform = transforms.Compose([
     transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
   
])

# Create dataset instances
train_dataset = DatasetAI(root_dir='/mnt/d/GenImage', transform=transform, split='train')
val_test_dataset = DatasetAI(root_dir='/mnt/d/GenImage', transform=transform, split='val')


val_dataset, test_dataset ,train_dataset = split_val_test_train(val_test_dataset, train_dataset, 10, 2, 20)

# Create DataLoader for each dataset
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)



In [None]:


class HighPassFilters(nn.Module):
    def __init__(self, kernels):
        super(HighPassFilters, self).__init__()
        # Kernels are a parameter but not trained
        self.kernels = nn.Parameter(kernels, requires_grad=False)

    def forward(self, x):
        # Apply convolution with padding to maintain output size equal to input size
        return F.conv2d(x, self.kernels, padding =2)  # Padding set to 2 to maintain output size
    
kernels = apply_high_pass_filter()    
print("Kernel shape:", kernels.shape)  
model = HighPassFilters(kernels)

# Initialize the transformation
transform = transforms.Compose([
    transforms.ToTensor(),  # Converts the image to a tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Optional: Normalize
])

# Load and transform the image
image_path = '/home/kosta/code/School/SentryAI/sample_images/img1.jpeg'
image = Image.open(image_path).convert('RGB')  # Convert image to RGB
image_tensor = transform(image).unsqueeze(0)
output = model(image_tensor)
print("Output shape:", output.shape)  # Should be [1, 30, height, width]



## CNN

In [None]:

class CNNBlock(nn.Module):
   def __init__(self, kernals):
       super(CNNBlock, self).__init__()
       self.conv = nn.Conv2d(30, 3, kernel_size=1,padding=0)
       self.filters = HighPassFilters(kernals)
       self.bn = nn.BatchNorm2d(3)
       self.htanh = nn.Hardtanh()
   def forward(self, x):
       x = self.filters(x)
       x = self.conv(x)
       x = self.bn(x)
       x = self.htanh(x)
       return x

## Model

In [None]:
class ImageClassificationModel(nn.Module):
    def __init__(self,kernels):
        super(ImageClassificationModel, self).__init__()
        self.feature_combiner = CNNBlock(kernels)
        self.feature_combiner2 = CNNBlock(kernels)
        config = SwinConfig.from_pretrained('microsoft/swinv2-tiny-patch4-window8-256',num_classes=2)
        self.transformer = Swinv2ForImageClassification.from_pretrained(
            "microsoft/swinv2-tiny-patch4-window8-256",
            config=config
        )
        
        self.transformer.classifier = nn.Linear(config.hidden_size, 2) 

 
    def forward(self, rich, poor):
       
        x = self.feature_combiner(rich)
        y = self.feature_combiner2(poor)   
        feature_difference = x - y
        outputs = self.transformer(feature_difference)

        return outputs.logits


## Train & Validation

In [11]:


kernels = apply_high_pass_filter()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImageClassificationModel(kernels).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = AdamW([
    {'params': model.feature_combiner.parameters(), 'lr': 1e-5,},
    {'params': model.feature_combiner2.parameters(), 'lr': 1e-5,},
    {'params': model.transformer.parameters(), 'lr': 1e-6,}
])
#freeze the transformer

# Initialize the best_val_accuracy variable
best_val_accuracy = 0.0
best_model_path = 'best_model2.pth'

# Try to load previous best model and its best validation accuracy
try:
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state'])
    best_val_accuracy = checkpoint['best_val_accuracy']
    print("Loaded previous best model with accuracy:", best_val_accuracy)
except FileNotFoundError:
    best_val_accuracy = float('-inf')
    print("No saved model found. Starting fresh!")

def train_and_validate(model, train_loader, valid_loader, optimizer, device, num_epochs, best_val_accuracy):
    for epoch in range(num_epochs):
        # Training Phase
        model.train()
        total_train_loss, total_train, correct_train = 0, 0, 0
        for batch in train_loader:
            rich, poor, labels = batch
            rich = rich.to(device)
            poor = poor.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(rich, poor)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item() * labels.size(0)
            _, predicted = torch.max(outputs, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        train_loss = total_train_loss / total_train
        train_accuracy = correct_train / total_train

        # Validation Phase
        model.eval()
        total_val_loss, total_val, correct_val = 0, 0, 0
        with torch.no_grad():
            for rich, poor, labels in valid_loader:
                rich = rich.to(device)
                poor = poor.to(device)
                labels = labels.to(device)

                outputs = model(rich, poor)
                loss = criterion(outputs, labels)

                total_val_loss += loss.item() * labels.size(0)
                _, predicted = torch.max(outputs, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        val_loss = total_val_loss / total_val
        val_accuracy = correct_val / total_val

        print(f'Epoch {epoch+1}/{num_epochs}, '
              f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, '
              f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')

        # Update the best model if current model is better
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save({'model_state': model.state_dict(),
                        'best_val_accuracy': best_val_accuracy},
                       best_model_path)
            print(f"Saved new best model with accuracy: {best_val_accuracy:.4f}")
        

# Assuming the datasets and loaders are correctly set up
train_and_validate(model, train_loader, val_loader, optimizer, device, num_epochs=10, best_val_accuracy=best_val_accuracy)


You are using a model of type swinv2 to instantiate a model of type swin. This is not supported for all configurations of models and can yield errors.


Final stacked kernels shape: torch.Size([30, 3, 5, 5])
ImageClassificationModel(
  (feature_combiner): CNNBlock(
    (conv): Conv2d(30, 3, kernel_size=(1, 1), stride=(1, 1))
    (filters): HighPassFilters()
    (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (htanh): Hardtanh(min_val=-1.0, max_val=1.0)
  )
  (feature_combiner2): CNNBlock(
    (conv): Conv2d(30, 3, kernel_size=(1, 1), stride=(1, 1))
    (filters): HighPassFilters()
    (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (htanh): Hardtanh(min_val=-1.0, max_val=1.0)
  )
  (transformer): Swinv2ForImageClassification(
    (swinv2): Swinv2Model(
      (embeddings): Swinv2Embeddings(
        (patch_embeddings): Swinv2PatchEmbeddings(
          (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
        )
        (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (enc

## Test

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImageClassificationModel().to(device)
def test(model, test_loader, device):
    #load the best model
    checkpoint = torch.load("best_model.pth")
    model.load_state_dict(checkpoint['model_state'])
    
    model.eval()
    total_test, correct_test = 0, 0
    with torch.no_grad():
        for rich, poor, labels in test_loader:
            rich = rich.to(device)
            poor = poor.to(device)
            labels = labels.to(device)

            outputs = model(rich, poor)
            _, predicted = torch.max(outputs, 1)
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()

    test_accuracy = correct_test / total_test
    print(f'Test Accuracy: {test_accuracy:.4f}')
test(model, test_loader, device)