In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
device = torch.device("cuda")

In [2]:
class ArtifactDataset(Dataset):
    def __init__(self, root_dir, mode='train', transform=None):
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.classes = []
        for label_type in ['real', 'fake']:
            for cls in ['human_faces', 'animals', 'vehicles']:
                dir_path = os.path.join(root_dir, label_type, cls)
                if not os.path.exists(dir_path):
                    continue
                for img_name in os.listdir(dir_path):
                    self.image_paths.append(os.path.join(dir_path, img_name))
                    self.labels.append(1 if label_type == 'real' else 0)
                    self.classes.append(cls)
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        class_name = self.classes[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        class_label = ['human_faces', 'animals', 'vehicles'].index(class_name)
        return image, torch.tensor(label, dtype=torch.float32), torch.tensor(class_label)

In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

In [4]:
train_dataset = ArtifactDataset("ArtiFact_240K/train", transform=transform)
val_dataset = ArtifactDataset("ArtiFact_240K/validation", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

In [5]:
import torch
import torch.nn as nn
import torch.fft
import timm

class KANVisionLSTM_FFT(nn.Module):
    def __init__(self, backbone='xception', lstm_hidden=128, lstm_layers=1, num_classes=3):
        super(KANVisionLSTM_FFT, self).__init__()

        # Combine RGB and FFT channels
        self.rgb_fft_proj = nn.Conv2d(6, 3, kernel_size=1)

        # Load the specified backbone model
        self.backbone_name = backbone
        self.backbone = timm.create_model(backbone, pretrained=True, num_classes=0)

        # Determine if the backbone is a Vision Transformer
        self.is_vit = 'vit' in backbone.lower()

        # Set embedding dimension based on the backbone
        if self.is_vit:
            self.embed_dim = self.backbone.embed_dim
        else:
            # For CNN backbones like XceptionNet
            dummy_input = torch.randn(1, 3, 299, 299)
            with torch.no_grad():
                features = self.backbone.forward_features(dummy_input)
            self.embed_dim = features.shape[1]

        # LSTM over backbone features
        self.lstm = nn.LSTM(
            input_size=self.embed_dim,
            hidden_size=lstm_hidden,
            num_layers=lstm_layers,
            batch_first=True,
            bidirectional=True
        )

        # Attention mechanism
        self.attn = nn.Sequential(
            nn.Linear(lstm_hidden * 2, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )

        # KAN-style transformation
        self.kernel_net = nn.Sequential(
            nn.Linear(lstm_hidden * 2, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )

        # Output heads
        self.real_fake_head = nn.Linear(64, 1)
        self.class_head = nn.Linear(64, num_classes)

    def extract_fft(self, x):
        # Compute FFT magnitude
        fft = torch.fft.fft2(x)
        return torch.abs(fft)

    def forward(self, x):
        # x shape: [B, 3, H, W]
        fft = self.extract_fft(x)
        combined = torch.cat([x, fft], dim=1)           # [B, 6, H, W]
        projected = self.rgb_fft_proj(combined)         # [B, 3, H, W]

        # Extract features using the backbone
        if self.is_vit:
            tokens = self.backbone.forward_features(projected)  # [B, N, embed_dim]
        else:
            features = self.backbone.forward_features(projected)  # [B, embed_dim, H', W']
            tokens = features.flatten(2).transpose(1, 2)          # [B, N, embed_dim]

        lstm_out, _ = self.lstm(tokens)                 # [B, N, 2*lstm_hidden]

        attn_scores = self.attn(lstm_out)               # [B, N, 1]
        attn_weights = torch.softmax(attn_scores, dim=1)
        pooled = (attn_weights * lstm_out).sum(dim=1)   # [B, 2*lstm_hidden]

        fused = self.kernel_net(pooled)                 # [B, 64]

        real_fake = torch.sigmoid(self.real_fake_head(fused))  # [B]
        cls_pred = self.class_head(fused)                      # [B, num_classes]

        return real_fake.squeeze(1), cls_pred

In [6]:
model = KANVisionLSTM_FFT()

In [7]:
model

KANVisionLSTM_FFT(
  (rgb_fft_proj): Conv2d(6, 3, kernel_size=(1, 1), stride=(1, 1))
  (backbone): Xception(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act2): ReLU(inplace=True)
    (block1): Block(
      (skip): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (skipbn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (rep): Sequential(
        (0): SeparableConv2d(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
          (pointwise): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (1): BatchNorm2d(128, eps=1e-05, mom

In [8]:
def train_model(model, train_loader, val_loader, epochs=3, lr=1e-4):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion_real_fake = nn.BCELoss()
    criterion_class = nn.CrossEntropyLoss()
    model.to(device)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for images, labels, class_labels in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
            images, labels, class_labels = images.to(device), labels.to(device), class_labels.to(device)
            optimizer.zero_grad()
            out_real_fake, out_class = model(images)
            loss1 = criterion_real_fake(out_real_fake, labels)
            loss2 = criterion_class(out_class, class_labels)
            loss = loss1 + loss2
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1} Loss: {total_loss / len(train_loader):.4f}")

In [9]:
train_model(model, train_loader, val_loader, epochs=3)

Epoch 1 Training: 100%|██████████| 10500/10500 [1:33:45<00:00,  1.87it/s]


Epoch 1 Loss: 0.5944


Epoch 2 Training: 100%|██████████| 10500/10500 [1:33:23<00:00,  1.87it/s] 


Epoch 2 Loss: 0.4880


Epoch 3 Training: 100%|██████████| 10500/10500 [1:32:46<00:00,  1.84it/s]  

Epoch 3 Loss: 0.3690


In [10]:
def evaluate(model, dataloader):
    model.eval()
    preds_rf, trues_rf = [], []
    preds_cls, trues_cls = [], []
    with torch.no_grad():
        for images, labels, class_labels in tqdm(dataloader, desc="Evaluating"):
            images = images.to(device)
            labels = labels.to(device)
            class_labels = class_labels.to(device)
            out_real_fake, out_class = model(images)
            preds_rf += (out_real_fake > 0.5).cpu().numpy().tolist()
            trues_rf += labels.cpu().numpy().tolist()
            preds_cls += torch.argmax(out_class, dim=1).cpu().numpy().tolist()
            trues_cls += class_labels.cpu().numpy().tolist()
    acc_rf = accuracy_score(trues_rf, preds_rf)
    acc_cls = accuracy_score(trues_cls, preds_cls)
    print(f"Real/Fake Accuracy: {acc_rf:.4f}, Class Accuracy: {acc_cls:.4f}")

In [11]:
evaluate(model, val_loader)

Evaluating: 100%|██████████| 3750/3750 [13:44<00:00,  4.55it/s]

Real/Fake Accuracy: 0.9440, Class Accuracy: 0.9961





In [12]:
def predict_test(model, test_dir, output_csv="test.csv"):
    model.eval()
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])

    test_images = [img for img in os.listdir(test_dir) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
    results = []

    for img_name in tqdm(test_images, desc="Predicting Test Images"):
        img_path = os.path.join(test_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        image = transform_test(image).unsqueeze(0).to(device)

        with torch.no_grad():
            out_real_fake, out_class = model(image)
            label = 1 if out_real_fake.item() > 0.5 else 0
            cls_idx = torch.argmax(out_class, dim=1).item()
            cls_name = ['human_faces', 'animals', 'vehicles'][cls_idx]
            results.append([img_name, label, cls_name])

    df = pd.DataFrame(results, columns=['image', 'label', 'class'])
    df.to_csv(output_csv, index=False)
    print(f"Saved predictions to {output_csv}")

In [13]:
predict_test(model, "ArtiFact_240K/test", output_csv="test.csv")

Predicting Test Images: 100%|██████████| 12002/12002 [9:27<00:00,  14.55it/s]

Saved predictions to test.csv



