In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image

In [2]:
class RealFakeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        
        for model in os.listdir(root_dir):
            model_path = os.path.join(root_dir, model)
            if os.path.isdir(model_path):
                for label in ['0_real', '1_fake']:
                    label_path = os.path.join(model_path, label)
                    if os.path.isdir(label_path):
                        for img_name in os.listdir(label_path):
                            img_path = os.path.join(label_path, img_name)
                            self.image_paths.append(img_path)
                            self.labels.append(0 if '0_real' in label else 1)
                    else:
                        for obj in os.listdir(model_path):
                            obj_path = os.path.join(model_path, obj)
                            label_path = os.path.join(obj_path, label)
                            if os.path.isdir(label_path):
                                for img_name in os.listdir(label_path):
                                    img_path = os.path.join(label_path, img_name)
                                    self.image_paths.append(img_path)
                                    self.labels.append(0 if '0_real' in label else 1)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])


full_dataset = RealFakeDataset(root_dir="D:/progan_train/progan_train/", transform=transform)

train_size = int(0.9 * len(full_dataset))
test_size = len(full_dataset) - train_size

train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True, pin_memory=True)

In [3]:
len(train_dataset)


645016

In [4]:
mytest_dataset =  RealFakeDataset(root_dir="C:/Users/Danila/VSU/vsu_common_rep/vsu_common_rep/2year/2term/project/image_classification/content/CNN_synth/test_set/", transform=transform)
mytest_loader = DataLoader(mytest_dataset, batch_size=32, shuffle=True)

In [5]:
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(64*62*62, 128)
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

model = SimpleCNN()

In [6]:
def get_accuracy(loader):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print(f'Accuracy: {100 * correct / total}%')

In [7]:
import torch
import torch.optim as optim
from tqdm import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

num_epochs = 1

In [8]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    n = 0
    s = 2026
    c = s
    # Wrap the dataloader with tqdm for progress bar
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch'):
        model.train()
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        n+=1
        if(n == c):
            get_accuracy(mytest_loader)
            get_accuracy(test_loader)
            c+=s
    print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')
    

print('Finished Training')

Epoch 1/1:  10%|█████▉                                                     | 2026/20157 [41:04<168:22:53, 33.43s/batch]

Accuracy: 50.0%


Epoch 1/1:  20%|███████████▍                                             | 4052/20157 [1:23:02<157:08:32, 35.13s/batch]

Accuracy: 50.0%


Epoch 1/1:  30%|█████████████████▏                                       | 6078/20157 [2:03:41<133:06:03, 34.03s/batch]

Accuracy: 50.0%


Epoch 1/1:  40%|██████████████████████▉                                  | 8104/20157 [2:43:10<111:37:05, 33.34s/batch]

Accuracy: 50.0%


Epoch 1/1:  50%|████████████████████████████▋                            | 10130/20157 [3:21:36<91:14:20, 32.76s/batch]

Accuracy: 50.0%


Epoch 1/1:  60%|██████████████████████████████████▎                      | 12156/20157 [3:59:00<73:01:27, 32.86s/batch]

Accuracy: 50.0%


Epoch 1/1:  70%|████████████████████████████████████████                 | 14182/20157 [4:35:13<54:34:47, 32.88s/batch]

Accuracy: 50.0%


Epoch 1/1:  80%|█████████████████████████████████████████████▊           | 16208/20157 [5:10:31<35:42:38, 32.55s/batch]

Accuracy: 50.0%


Epoch 1/1:  90%|███████████████████████████████████████████████████▌     | 18234/20157 [5:45:16<17:34:55, 32.91s/batch]

Accuracy: 50.0%


Epoch 1/1: 100%|████████████████████████████████████████████████████████████| 20157/20157 [6:15:59<00:00,  1.12s/batch]

Epoch 1, Loss: 0.6940584687709501
Finished Training





In [9]:
print(len(mytest_dataset))

1398


In [14]:
# model saving
torch.save(model.state_dict(), 'CNN_simple_classifier_76acc_.04loss')

In [16]:
# model loading
model = SimpleCNN()
model.load_state_dict(torch.load('CNN_classifier_8_epoch'))
model.eval()

SimpleCNN(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=246016, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=2, bias=True)
)

In [9]:
get_accuracy(test_loader)

Accuracy: 50.2365039277791%
