In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights, resnet18
from PIL import Image
from tqdm import tqdm
import re
import json
import torch.nn.functional as F
import shutil
import pathlib
from pathlib import Path

In [2]:
weights = "./DLA_corrected_weights.pth" 
testpath = "./testdataset"
device = "cuda"

In [3]:
def get_sorted_image_paths(folder_path):
   
    try:
        files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

        def extract_number(filename):
            match = re.search(r'\d+', filename)
            return int(match.group()) if match else float('inf')

        sorted_files = sorted(files, key=extract_number)

        sorted_paths = [os.path.join(folder_path, file) for file in sorted_files]

        return np.array(sorted_paths)

    except Exception as e:
        print(f"An error occurred: {e}")
        return np.array([])

In [4]:
sorted_image_paths = get_sorted_image_paths(testpath)

In [5]:
transform_test = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

In [6]:
class ImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        
        self.image_paths = image_paths  
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        image = np.array(image)
        if self.transform:
            augmented = self.transform(image=image) 
            image = augmented["image"]

        img_name = os.path.basename(img_path)
        return image, img_name

In [7]:
dataset = ImageDataset(image_paths = sorted_image_paths, transform=transform_test)
dataloader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0)

In [8]:
"""
This module implements a Deep Layer Aggregation (DLA) model with components like `BasicBlock`, `Root`, `Tree`, and `DLA`.

Classes:
--------
BasicBlock:
    A basic residual block with optional shortcut connections.

    Methods:
        __init__(in_planes, planes, stride=1): Initializes the block.
        forward(x): Performs the forward pass.

Root:
    Combines feature maps via concatenation, applies convolution, batch normalization, and ReLU activation.

    Methods:
        __init__(in_channels, out_channels, kernel_size=1): Initializes the Root module.
        forward(xs): Combines and processes feature maps.

Tree:
    A recursive structure for hierarchical feature aggregation.

    Methods:
        __init__(block, in_channels, out_channels, level=1, stride=1): Initializes the Tree.
        forward(x): Aggregates and processes features.

DLA:
    The main Deep Layer Aggregation model for classification tasks.

    Methods:
        __init__(block=BasicBlock, num_classes=2): Initializes the DLA model.
        forward(x): Passes input through DLA layers for classification.

Usage:
------
- `BasicBlock`: Core block for residual operations.
- `Root`: Combines hierarchical features.
- `Tree`: Builds hierarchical aggregation.
- `DLA`: Full model for feature extraction and classification.
"""

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class Root(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1):
        super(Root, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size,
            stride=1, padding=(kernel_size - 1) // 2, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, xs):
        x = torch.cat(xs, 1)
        out = F.relu(self.bn(self.conv(x)))
        return out

class Tree(nn.Module):
    def __init__(self, block, in_channels, out_channels, level=1, stride=1):
        super(Tree, self).__init__()
        self.level = level
        if level == 1:
            self.root = Root(2*out_channels, out_channels)
            self.left_node = block(in_channels, out_channels, stride=stride)
            self.right_node = block(out_channels, out_channels, stride=1)
        else:
            self.root = Root((level+2)*out_channels, out_channels)
            for i in reversed(range(1, level)):
                subtree = Tree(block, in_channels, out_channels,
                               level=i, stride=stride)
                self.__setattr__('level_%d' % i, subtree)
            self.prev_root = block(in_channels, out_channels, stride=stride)
            self.left_node = block(out_channels, out_channels, stride=1)
            self.right_node = block(out_channels, out_channels, stride=1)

    def forward(self, x):
        xs = [self.prev_root(x)] if self.level > 1 else []
        for i in reversed(range(1, self.level)):
            level_i = self.__getattr__('level_%d' % i)
            x = level_i(x)
            xs.append(x)
        x = self.left_node(x)
        xs.append(x)
        x = self.right_node(x)
        xs.append(x)
        out = self.root(xs)
        return out

class DLA(nn.Module):
    def __init__(self, block=BasicBlock, num_classes=2):
        super(DLA, self).__init__()
        self.base = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(True)
        )

        self.layer1 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(True)
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True)
        )

        self.layer3 = Tree(block,  32,  64, level=1, stride=1)
        self.layer4 = Tree(block,  64, 128, level=2, stride=2)
        self.layer5 = Tree(block, 128, 256, level=2, stride=2)
        self.layer6 = Tree(block, 256, 512, level=1, stride=2)
        self.linear = nn.Linear(512, num_classes)

    def forward(self, x):
        out = self.base(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.layer6(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [9]:
model = DLA()

In [10]:
 # PUT WEIGHTS HERE
checkpoint = torch.load(weights)

  checkpoint = torch.load(weights)


In [11]:
if "net" in checkpoint:
        state_dict = checkpoint["net"]
        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
        print("net")
elif "state_dict" in checkpoint:
        state_dict = checkpoint["state_dict"]
        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
else:
        state_dict = checkpoint

net


In [12]:
model.load_state_dict(state_dict, strict=False)
model.to(device)

DLA(
  (base): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (layer1): Sequential(
    (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (layer2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (layer3): Tree(
    (root): Root(
      (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (left_node): BasicBlock(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1

In [13]:
root_dir = Path(os.getcwd())
Fake_dir = os.path.join(root_dir, "Fake_Folder")
os.makedirs(Fake_dir, exist_ok=True)

In [28]:
results1 = []

In [29]:
index = 1
with torch.no_grad():
    for images, img_paths in dataloader:
        images = images.to(device)
        outputs = model(images)  
        _, preds = torch.max(outputs, 1) 
        
        for img_path, pred in zip(img_paths, preds):
            prediction_label = "fake" if pred.item() == 0 else "real"
            
            results1.append({
                "index":img_path[:-4] ,
                "prediction": prediction_label
            })
            index += 1
            if pred.item() == 0:
                shutil.copy(os.path.join(testpath, img_path), Fake_dir)

In [31]:
json_filename = "./DLA.json"  
with open(json_filename, "w") as json_file:
    json.dump(results1, json_file, indent=4)