In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np
from PIL import Image, ImageFilter
import io
from preprocessing import *
import random
from skimage import data
from scipy.ndimage import rotate
from kernels import *
import torchvision
import os
from torchvision.transforms.functional import to_pil_image
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from preprocessing import *
from transformers import Swinv2ForImageClassification, SwinConfig
from torch.optim import AdamW
from torchvision import transforms, datasets



## Load dataset

In [22]:

def preprocess_image(image):    
    rich, poor = smash_n_reconstruct(image)
    rich = apply_high_pass_filter(rich)
    poor = apply_high_pass_filter(poor)
    return rich, poor

class ProcessedPairDataset(Dataset):
    def __init__(self, root_dir, transform=None, split='train'):
        self.root_dir = root_dir
        self.transform = transform
        self.split = split  # 'train' or 'val'
        self.models = sorted(entry.name for entry in os.scandir(root_dir) if entry.is_dir())
        self.samples = []
        
        for model in self.models:
            model_dir = os.path.join(root_dir, model, f'imagenet_{model}', split)
            for class_label in ['ai', 'nature']:
                class_dir = os.path.join(model_dir, class_label)
                for img_name in os.listdir(class_dir):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                        img_path = os.path.join(class_dir, img_name)
                        self.samples.append((img_path, class_label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, class_label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')

        # Apply preprocessing
        rich, poor = preprocess_image(image)

        # Apply transformations if any
        if self.transform:
            rich = self.transform(rich)
            poor = self.transform(poor)

        label = 0 if class_label == 'ai' else 1
        return rich, poor, label

def sample_data(dataset, num_samples_per_class_per_model):
    random.seed(42)  # For reproducibility
    sampled_indices = {}
    
    # Initialize the dictionary to hold indices for each model and class
    for model in dataset.models:
        sampled_indices[model] = {'ai': [], 'nature': []}

    # Collect indices for each class and model
    for idx, (img_path, class_label) in enumerate(dataset.samples):
        for model in dataset.models:
            if model in img_path:
                sampled_indices[model][class_label].append(idx)
                break

    # Sample indices for each model and class
    final_indices = []
    for model in sampled_indices:
        for class_label in ['ai', 'nature']:
            final_indices.extend(random.sample(sampled_indices[model][class_label], num_samples_per_class_per_model))

    # Shuffle the indices to mix classes and models
    random.shuffle(final_indices)
    return Subset(dataset, final_indices)

# Define transformations
transform = transforms.Compose([
  
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create balanced subsets for both training and validation sets
train_dataset = ProcessedPairDataset(root_dir='/mnt/d/GenImage', transform=transform, split='train')
val_dataset = ProcessedPairDataset(root_dir='/mnt/d/GenImage', transform=transform, split='val')


# and 625 images per class per model for validation
train_subset = sample_data(train_dataset, num_samples_per_class_per_model=1875 // 2)  # Dividing by number of classes (2)
val_subset = sample_data(val_dataset, num_samples_per_class_per_model=625 // 2)

# Create data loaders
train_loader = DataLoader(train_subset, batch_size=32, shuffle=True, num_workers=4)
valid_loader = DataLoader(val_subset, batch_size=32, shuffle=False, num_workers=4)


## CNN

In [23]:

class CNNBlock(nn.Module):
   def __init__(self, num_input_channels):
       super(CNNBlock, self).__init__()
       self.conv = nn.Conv2d(num_input_channels, 3, kernel_size=3, padding=1)
       self.bn = nn.BatchNorm2d(3)
       self.relu = nn.ReLU()
   def forward(self, x):
       x = self.conv(x)
       x = self.bn(x)
       x = self.relu(x)
      
       return x

## Model

In [24]:
class ImageClassificationModel(nn.Module):
    def __init__(self):
        super(ImageClassificationModel, self).__init__()
        self.feature_combiner = CNNBlock(num_input_channels=3)
        self.feature_combiner2 = CNNBlock(num_input_channels=3)
        self.transformer = Swinv2ForImageClassification.from_pretrained(
            "microsoft/swinv2-tiny-patch4-window8-256",
            config=SwinConfig.from_pretrained('microsoft/swinv2-tiny-patch4-window8-256', num_classes=2)
        )

    def forward(self, rich, poor):
       
        x = self.feature_combiner(rich)
        y = self.feature_combiner2(poor)   
        feature_difference = x - y
        outputs = self.transformer(feature_difference)

        return outputs.logits


## Train & Validation

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ImageClassificationModel().to(device)
criterion = nn.CrossEntropyLoss()

optimizer = AdamW([
    {'params': model.feature_combiner.parameters(), 'lr': 1e-3},
    {'params': model.feature_combiner2.parameters(), 'lr': 1e-3},
    {'params': model.transformer.parameters(), 'lr': 1e-5}  # Lower lr for fine-tuning
])

#unfreeze the transformer

for param in model.transformer.parameters():
    param.requires_grad = True
    
for param in model.transformer.parameters():
    param.requires_grad = True
def train(model, train_loader, optimizer, device, num_epochs):
    model.train()  # Set model to training mode
    for epoch in range(num_epochs):
        total_loss, total, correct = 0, 0, 0
        for batch in train_loader:
            rich, poor, labels = batch[0], batch[1], batch[2]
            rich = rich.to(device)
            poor = poor.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            
            # Assuming your model takes two inputs, rich and poor
            outputs = model(rich, poor)  
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * labels.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = total_loss / total
        epoch_acc = correct / total
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')


def validate(model, batches, device):
    model.eval()
    total_loss, total, correct = 0, 0, 0
    with torch.no_grad():
        for rich, poor, labels in batches:
            rich, poor, labels = rich.to(device), poor.to(device), labels.to(device)
            
            outputs = model(rich, poor)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item() * labels.size(0)  # Correct loss calculation
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)  # Total number of labels processed
            correct += (predicted == labels).sum().item()  # Correct predictions
        
    avg_loss = total_loss / total  # Correct average loss
    avg_acc = correct / total  # Correct accuracy
    print(f'Validation Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}')



# Assuming the datasets and loaders are set up correctly
num_epochs = 10
train(model, train_loader, optimizer, device, num_epochs)
validate(model, valid_loader, device)


You are using a model of type swinv2 to instantiate a model of type swin. This is not supported for all configurations of models and can yield errors.


Epoch 1/10, Loss: 0.7570, Accuracy: 0.6690
Epoch 2/10, Loss: 0.4858, Accuracy: 0.7621
Epoch 3/10, Loss: 0.4080, Accuracy: 0.8122
Epoch 4/10, Loss: 0.3439, Accuracy: 0.8452
Epoch 5/10, Loss: 0.2983, Accuracy: 0.8693
Epoch 6/10, Loss: 0.2651, Accuracy: 0.8856
Epoch 7/10, Loss: 0.2343, Accuracy: 0.8991
Epoch 8/10, Loss: 0.2141, Accuracy: 0.9110
Epoch 9/10, Loss: 0.1928, Accuracy: 0.9200
Epoch 10/10, Loss: 0.1732, Accuracy: 0.9311


TypeError: ImageClassificationModel.forward() missing 1 required positional argument: 'poor'