In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from PIL import Image
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
import numpy as np
import gc
import pandas as pd


In [14]:
mode="rgb"

In [15]:
class CustomImageDataset(Dataset):
    def __init__(self, directory):
        self.directory = directory
        self.image_paths = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith((".png", ".jpg", ".jpeg"))]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert("RGB")

        transform = transforms.ToTensor()
        
        img= transform(img)
        # Extract label from the filename (fake -> 1, real -> 0)
        label = 1 if "fake" in os.path.basename(img_path).lower() else 0

        return img, label, img_path

In [None]:
# Load datasets using CustomImageDataset
train_data = CustomImageDataset(directory=f'data/processed/64/train/{mode}')
test_data = CustomImageDataset(directory=f'data/processed/64/valid/{mode}')

train_loader = DataLoader(train_data, batch_size=500, shuffle=True)
test_loader = DataLoader(test_data, batch_size=500, shuffle=False)

train_data.__getitem__(0)[0].shape

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Classifier32(nn.Module):
    def __init__(self, IMAGE_SIZE):
        super(Classifier32, self).__init__()
        self.input_size = IMAGE_SIZE
        # First convolutional block
        self.conv1 = nn.Conv2d(3, 16, kernel_size=4, padding=2, stride=2)
        self.bn1 = nn.BatchNorm2d(16)

        # Second convolutional block
        self.conv2 = nn.Conv2d(16, 32, kernel_size=4, padding=2, stride=2)
        self.bn2 = nn.BatchNorm2d(32)

        # Third convolutional block
        self.conv3 = nn.Conv2d(32, 64, kernel_size=4, padding=2, stride=2) 
        self.bn3 = nn.BatchNorm2d(64)

        # Fourth convolutional block
        self.conv4 = nn.Conv2d(64, 128, kernel_size=4, padding=2, stride=2)  
        self.bn4 = nn.BatchNorm2d(128)

        # Calculate the flattened size after convolutions
        self._flattened_size = self._compute_flattened_size(self.input_size)

        # Fully connected layers
        self.fc1 = nn.Linear(self._flattened_size, 512)
        self.fc2 = nn.Linear(512, 2)

    def _compute_flattened_size(self, input_size):
        """Compute the size of the tensor after all convolutional and pooling layers."""
        x = torch.zeros(1, 3, *input_size)
        # print(f"Initial size: {x.size()}")
        x = F.relu(self.bn1(self.conv1(x)))
        # print(f"After pool1: {x.size()}")
        x = F.relu(self.bn2(self.conv2(x)))
        # print(f"After pool2: {x.size()}")
        x = F.relu(self.bn3(self.conv3(x)))
        # print(f"After pool3: {x.size()}")
        x = F.relu(self.bn4(self.conv4(x)))
        # print(f"After pool3: {x.size()}")
        return x.numel()

    def forward(self, x):
        # Convolutional layers
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        # Flatten the output for the fully connected layers
        x = x.view(x.size(0), -1)

        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x


In [None]:
torch.cuda.empty_cache()

In [None]:
total_bytes = sum(torch.cuda.memory_stats().values())
total_gbs = total_bytes / (1024 ** 3)
print(f"Total size: {total_gbs:.2f} GB")


In [None]:
print(torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE=(64,64)

def train(train_loader):
    # Initialize the model, loss function, and optimizer
    model = Classifier32(IMAGE_SIZE).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.02)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)

    #early stopping variables
    prev_loss=float('inf')
    worse_loss_counter=0
    
    # Train the model
    num_epochs = 10
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        batches=0
        for inputs, labels, paths in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            print(f"Batches: {batches}", end="\r")
            batches+=1
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")
        prev_loss=running_loss
        if (running_loss>prev_loss):
            worse_loss_counter+=1
            if (worse_loss_counter>3):
                print("Early stopping, results are not improving fast enough")
                break;

    return model

In [None]:
# Test the model


def predict (model, test_loader):
    results={}
    
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels, paths in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            for path, true_label, pred_label in zip(paths, labels.cpu().numpy(), predicted.cpu().numpy()):
                results[path]=[true_label, pred_label]
    
    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy}%")
    return results


In [None]:
model=train(train_loader)

In [None]:
result= predict(model, test_loader)

In [None]:
df_result=pd.DataFrame(result).transpose()
df_result.reset_index(inplace=True)

In [None]:
df_result.to_csv(f"result_64_{mode}.csv", index=False)
