In [3]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image

In [4]:
class RealFakeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        
        for model in os.listdir(root_dir):
            model_path = os.path.join(root_dir, model)
            if os.path.isdir(model_path):
                for label in ['0_real', '1_fake']:
                    label_path = os.path.join(model_path, label)
                    if os.path.isdir(label_path):
                        for img_name in os.listdir(label_path):
                            img_path = os.path.join(label_path, img_name)
                            self.image_paths.append(img_path)
                            self.labels.append(0 if '0_real' in label else 1)
                    else:
                        for obj in os.listdir(model_path):
                            obj_path = os.path.join(model_path, obj)
                            label_path = os.path.join(obj_path, label)
                            if os.path.isdir(label_path):
                                for img_name in os.listdir(label_path):
                                    img_path = os.path.join(label_path, img_name)
                                    self.image_paths.append(img_path)
                                    self.labels.append(0 if '0_real' in label else 1)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])


full_dataset = RealFakeDataset(root_dir=r"C:\Users\Danila\VSU\vsu_common_rep\vsu_common_rep\2year\2term\project\image_classification\content\CNN_synth\train_set", transform=transform)

train_size = int(0.9 * len(full_dataset))
test_size = len(full_dataset) - train_size

train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [6]:
len(train_dataset)

81296

In [7]:
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(64*62*62, 128)
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

model = SimpleCNN()

In [12]:
import torch
import torch.optim as optim
from tqdm import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

num_epochs = 10

cuda


In [13]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    # Wrap the dataloader with tqdm for progress bar
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch'):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')
    get_accuracy()

print('Finished Training')

Epoch 1/10: 100%|███████████████████████████████████████████████████████████████| 2541/2541 [14:19<00:00,  2.96batch/s]


Epoch 1, Loss: 0.48347113623669746
Accuracy: 75.01383814900919%


Epoch 2/10: 100%|███████████████████████████████████████████████████████████████| 2541/2541 [14:33<00:00,  2.91batch/s]


Epoch 2, Loss: 0.3899474657341891
Accuracy: 74.75921620724012%


Epoch 3/10: 100%|███████████████████████████████████████████████████████████████| 2541/2541 [14:32<00:00,  2.91batch/s]


Epoch 3, Loss: 0.26029671798542653
Accuracy: 76.75190966456327%


Epoch 4/10: 100%|███████████████████████████████████████████████████████████████| 2541/2541 [14:11<00:00,  2.98batch/s]


Epoch 4, Loss: 0.15695234667473143
Accuracy: 76.89582641425883%


Epoch 5/10: 100%|███████████████████████████████████████████████████████████████| 2541/2541 [14:04<00:00,  3.01batch/s]


Epoch 5, Loss: 0.0989006727675305
Accuracy: 76.22052474261042%


Epoch 6/10: 100%|███████████████████████████████████████████████████████████████| 2541/2541 [14:15<00:00,  2.97batch/s]


Epoch 6, Loss: 0.0756470137702749
Accuracy: 76.25373630023248%


Epoch 7/10: 100%|███████████████████████████████████████████████████████████████| 2541/2541 [14:21<00:00,  2.95batch/s]


Epoch 7, Loss: 0.05787833225146407
Accuracy: 76.53049928041625%


Epoch 8/10: 100%|███████████████████████████████████████████████████████████████| 2541/2541 [14:30<00:00,  2.92batch/s]


Epoch 8, Loss: 0.04646418478339635
Accuracy: 76.03232591608547%


Epoch 9/10:  71%|████████████████████████████████████████████▍                  | 1794/2541 [10:21<04:18,  2.88batch/s]


KeyboardInterrupt: 

In [10]:
def get_accuracy():
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print(f'Accuracy: {100 * correct / total}%')

In [18]:
print(len(val_dataset.image_paths))

0


In [14]:
# model saving
torch.save(model.state_dict(), 'CNN_simple_classifier_76acc_.04loss')

In [16]:
# model loading
model = SimpleCNN()
model.load_state_dict(torch.load('CNN_classifier_8_epoch'))
model.eval()

SimpleCNN(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=246016, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=2, bias=True)
)