In [1]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from torchvision.models import MobileNet_V3_Small_Weights
from tqdm import tqdm

In [2]:
device = torch.device("mps" if torch.backends.mps.is_built() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [71]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

In [72]:
train_dataset = datasets.ImageFolder(root='../../dataset/ai_art_classification/train', transform=data_transforms['train'])
class_to_idx = train_dataset.class_to_idx
print("Class-to-ID mapping:", class_to_idx)
with open('class_to_idx.json', 'w') as f:
    json.dump(class_to_idx, f)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)

Class-to-ID mapping: {'AI_GENERATED': 0, 'NON_AI_GENERATED': 1}


In [83]:
# model = models.efficientnet_b0(pretrained=True)
model = models.mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT)

for param in model.features.parameters():
    param.requires_grad = False
# num_features = model.classifier[1].in_features
num_features = model.classifier[3].in_features
# model.classifier[1] = nn.Linear(num_features, 2)  # Adjusting the final layer to 2 classes
model.classifier[3] = nn.Linear(num_features, 2)  # Adjusting the final layer to 2 classes

model = model.to(device)

In [84]:
criterion = nn.CrossEntropyLoss()
# optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4, weight_decay=1e-4)

In [85]:
def train_model(model, dataloader, criterion, optimizer, num_epochs=10):
    print(device)
    model.to(device)
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        print(f"Epoch {epoch + 1}/{num_epochs}")
        
        # Training loop
        for inputs, labels in tqdm(dataloader):
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward + Backward + Optimize
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += torch.sum(preds == labels).item()
            total += labels.size(0)

        epoch_loss = running_loss / len(dataloader.dataset)
        epoch_acc = correct / total
        print(f"Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

In [86]:
train_model(model, train_loader, criterion, optimizer, num_epochs=11)

mps
Epoch 1/11


100%|██████████| 582/582 [00:41<00:00, 13.92it/s]


Loss: 0.3621 Acc: 0.8432
Epoch 2/11


100%|██████████| 582/582 [00:41<00:00, 13.91it/s]


Loss: 0.3080 Acc: 0.8717
Epoch 3/11


100%|██████████| 582/582 [00:41<00:00, 14.02it/s]


Loss: 0.2969 Acc: 0.8746
Epoch 4/11


100%|██████████| 582/582 [00:41<00:00, 13.98it/s]


Loss: 0.2895 Acc: 0.8810
Epoch 5/11


100%|██████████| 582/582 [00:41<00:00, 14.18it/s]


Loss: 0.2790 Acc: 0.8806
Epoch 6/11


100%|██████████| 582/582 [00:41<00:00, 14.18it/s]


Loss: 0.2794 Acc: 0.8802
Epoch 7/11


100%|██████████| 582/582 [00:41<00:00, 13.92it/s]


Loss: 0.2708 Acc: 0.8850
Epoch 8/11


100%|██████████| 582/582 [00:41<00:00, 14.04it/s]


Loss: 0.2645 Acc: 0.8864
Epoch 9/11


100%|██████████| 582/582 [00:41<00:00, 13.96it/s]


Loss: 0.2604 Acc: 0.8906
Epoch 10/11


100%|██████████| 582/582 [00:40<00:00, 14.24it/s]


Loss: 0.2579 Acc: 0.8942
Epoch 11/11


100%|██████████| 582/582 [00:40<00:00, 14.26it/s]

Loss: 0.2519 Acc: 0.8961





In [87]:
torch.save(model.state_dict(), 'model.pth')

# Inference

In [3]:
# load model

# model = models.efficientnet_b0(pretrained=False)
model = models.mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT)
num_features = model.classifier[3].in_features
model.classifier[3] = nn.Linear(num_features, 2)

model.load_state_dict(torch.load('model.pth'))
model.eval()

# print training stats
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())


  model.load_state_dict(torch.load('model.pth'))


RuntimeError: Error(s) in loading state_dict for MobileNetV3:
	Missing key(s) in state_dict: "features.0.0.weight", "features.0.1.weight", "features.0.1.bias", "features.0.1.running_mean", "features.0.1.running_var", "features.1.block.0.0.weight", "features.1.block.0.1.weight", "features.1.block.0.1.bias", "features.1.block.0.1.running_mean", "features.1.block.0.1.running_var", "features.1.block.1.fc1.weight", "features.1.block.1.fc1.bias", "features.1.block.1.fc2.weight", "features.1.block.1.fc2.bias", "features.1.block.2.0.weight", "features.1.block.2.1.weight", "features.1.block.2.1.bias", "features.1.block.2.1.running_mean", "features.1.block.2.1.running_var", "features.2.block.0.0.weight", "features.2.block.0.1.weight", "features.2.block.0.1.bias", "features.2.block.0.1.running_mean", "features.2.block.0.1.running_var", "features.2.block.1.0.weight", "features.2.block.1.1.weight", "features.2.block.1.1.bias", "features.2.block.1.1.running_mean", "features.2.block.1.1.running_var", "features.2.block.2.0.weight", "features.2.block.2.1.weight", "features.2.block.2.1.bias", "features.2.block.2.1.running_mean", "features.2.block.2.1.running_var", "features.3.block.0.0.weight", "features.3.block.0.1.weight", "features.3.block.0.1.bias", "features.3.block.0.1.running_mean", "features.3.block.0.1.running_var", "features.3.block.1.0.weight", "features.3.block.1.1.weight", "features.3.block.1.1.bias", "features.3.block.1.1.running_mean", "features.3.block.1.1.running_var", "features.3.block.2.0.weight", "features.3.block.2.1.weight", "features.3.block.2.1.bias", "features.3.block.2.1.running_mean", "features.3.block.2.1.running_var", "features.4.block.0.0.weight", "features.4.block.0.1.weight", "features.4.block.0.1.bias", "features.4.block.0.1.running_mean", "features.4.block.0.1.running_var", "features.4.block.1.0.weight", "features.4.block.1.1.weight", "features.4.block.1.1.bias", "features.4.block.1.1.running_mean", "features.4.block.1.1.running_var", "features.4.block.2.fc1.weight", "features.4.block.2.fc1.bias", "features.4.block.2.fc2.weight", "features.4.block.2.fc2.bias", "features.4.block.3.0.weight", "features.4.block.3.1.weight", "features.4.block.3.1.bias", "features.4.block.3.1.running_mean", "features.4.block.3.1.running_var", "features.5.block.0.0.weight", "features.5.block.0.1.weight", "features.5.block.0.1.bias", "features.5.block.0.1.running_mean", "features.5.block.0.1.running_var", "features.5.block.1.0.weight", "features.5.block.1.1.weight", "features.5.block.1.1.bias", "features.5.block.1.1.running_mean", "features.5.block.1.1.running_var", "features.5.block.2.fc1.weight", "features.5.block.2.fc1.bias", "features.5.block.2.fc2.weight", "features.5.block.2.fc2.bias", "features.5.block.3.0.weight", "features.5.block.3.1.weight", "features.5.block.3.1.bias", "features.5.block.3.1.running_mean", "features.5.block.3.1.running_var", "features.6.block.0.0.weight", "features.6.block.0.1.weight", "features.6.block.0.1.bias", "features.6.block.0.1.running_mean", "features.6.block.0.1.running_var", "features.6.block.1.0.weight", "features.6.block.1.1.weight", "features.6.block.1.1.bias", "features.6.block.1.1.running_mean", "features.6.block.1.1.running_var", "features.6.block.2.fc1.weight", "features.6.block.2.fc1.bias", "features.6.block.2.fc2.weight", "features.6.block.2.fc2.bias", "features.6.block.3.0.weight", "features.6.block.3.1.weight", "features.6.block.3.1.bias", "features.6.block.3.1.running_mean", "features.6.block.3.1.running_var", "features.7.block.0.0.weight", "features.7.block.0.1.weight", "features.7.block.0.1.bias", "features.7.block.0.1.running_mean", "features.7.block.0.1.running_var", "features.7.block.1.0.weight", "features.7.block.1.1.weight", "features.7.block.1.1.bias", "features.7.block.1.1.running_mean", "features.7.block.1.1.running_var", "features.7.block.2.fc1.weight", "features.7.block.2.fc1.bias", "features.7.block.2.fc2.weight", "features.7.block.2.fc2.bias", "features.7.block.3.0.weight", "features.7.block.3.1.weight", "features.7.block.3.1.bias", "features.7.block.3.1.running_mean", "features.7.block.3.1.running_var", "features.8.block.0.0.weight", "features.8.block.0.1.weight", "features.8.block.0.1.bias", "features.8.block.0.1.running_mean", "features.8.block.0.1.running_var", "features.8.block.1.0.weight", "features.8.block.1.1.weight", "features.8.block.1.1.bias", "features.8.block.1.1.running_mean", "features.8.block.1.1.running_var", "features.8.block.2.fc1.weight", "features.8.block.2.fc1.bias", "features.8.block.2.fc2.weight", "features.8.block.2.fc2.bias", "features.8.block.3.0.weight", "features.8.block.3.1.weight", "features.8.block.3.1.bias", "features.8.block.3.1.running_mean", "features.8.block.3.1.running_var", "features.9.block.0.0.weight", "features.9.block.0.1.weight", "features.9.block.0.1.bias", "features.9.block.0.1.running_mean", "features.9.block.0.1.running_var", "features.9.block.1.0.weight", "features.9.block.1.1.weight", "features.9.block.1.1.bias", "features.9.block.1.1.running_mean", "features.9.block.1.1.running_var", "features.9.block.2.fc1.weight", "features.9.block.2.fc1.bias", "features.9.block.2.fc2.weight", "features.9.block.2.fc2.bias", "features.9.block.3.0.weight", "features.9.block.3.1.weight", "features.9.block.3.1.bias", "features.9.block.3.1.running_mean", "features.9.block.3.1.running_var", "features.10.block.0.0.weight", "features.10.block.0.1.weight", "features.10.block.0.1.bias", "features.10.block.0.1.running_mean", "features.10.block.0.1.running_var", "features.10.block.1.0.weight", "features.10.block.1.1.weight", "features.10.block.1.1.bias", "features.10.block.1.1.running_mean", "features.10.block.1.1.running_var", "features.10.block.2.fc1.weight", "features.10.block.2.fc1.bias", "features.10.block.2.fc2.weight", "features.10.block.2.fc2.bias", "features.10.block.3.0.weight", "features.10.block.3.1.weight", "features.10.block.3.1.bias", "features.10.block.3.1.running_mean", "features.10.block.3.1.running_var", "features.11.block.0.0.weight", "features.11.block.0.1.weight", "features.11.block.0.1.bias", "features.11.block.0.1.running_mean", "features.11.block.0.1.running_var", "features.11.block.1.0.weight", "features.11.block.1.1.weight", "features.11.block.1.1.bias", "features.11.block.1.1.running_mean", "features.11.block.1.1.running_var", "features.11.block.2.fc1.weight", "features.11.block.2.fc1.bias", "features.11.block.2.fc2.weight", "features.11.block.2.fc2.bias", "features.11.block.3.0.weight", "features.11.block.3.1.weight", "features.11.block.3.1.bias", "features.11.block.3.1.running_mean", "features.11.block.3.1.running_var", "features.12.0.weight", "features.12.1.weight", "features.12.1.bias", "features.12.1.running_mean", "features.12.1.running_var", "classifier.0.weight", "classifier.0.bias", "classifier.3.weight", "classifier.3.bias". 
	Unexpected key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "bn1.num_batches_tracked", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.bn1.num_batches_tracked", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.bn2.num_batches_tracked", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.bn3.num_batches_tracked", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.0.downsample.1.num_batches_tracked", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.bn1.num_batches_tracked", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.bn2.num_batches_tracked", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.1.bn3.num_batches_tracked", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.bn1.num_batches_tracked", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.bn2.num_batches_tracked", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer1.2.bn3.num_batches_tracked", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.bn1.num_batches_tracked", "layer2.0.conv2.weight", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.bn2.num_batches_tracked", "layer2.0.conv3.weight", "layer2.0.bn3.weight", "layer2.0.bn3.bias", "layer2.0.bn3.running_mean", "layer2.0.bn3.running_var", "layer2.0.bn3.num_batches_tracked", "layer2.0.downsample.0.weight", "layer2.0.downsample.1.weight", "layer2.0.downsample.1.bias", "layer2.0.downsample.1.running_mean", "layer2.0.downsample.1.running_var", "layer2.0.downsample.1.num_batches_tracked", "layer2.1.conv1.weight", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.bn1.num_batches_tracked", "layer2.1.conv2.weight", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer2.1.bn2.num_batches_tracked", "layer2.1.conv3.weight", "layer2.1.bn3.weight", "layer2.1.bn3.bias", "layer2.1.bn3.running_mean", "layer2.1.bn3.running_var", "layer2.1.bn3.num_batches_tracked", "layer2.2.conv1.weight", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.bn1.num_batches_tracked", "layer2.2.conv2.weight", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer2.2.bn2.num_batches_tracked", "layer2.2.conv3.weight", "layer2.2.bn3.weight", "layer2.2.bn3.bias", "layer2.2.bn3.running_mean", "layer2.2.bn3.running_var", "layer2.2.bn3.num_batches_tracked", "layer2.3.conv1.weight", "layer2.3.bn1.weight", "layer2.3.bn1.bias", "layer2.3.bn1.running_mean", "layer2.3.bn1.running_var", "layer2.3.bn1.num_batches_tracked", "layer2.3.conv2.weight", "layer2.3.bn2.weight", "layer2.3.bn2.bias", "layer2.3.bn2.running_mean", "layer2.3.bn2.running_var", "layer2.3.bn2.num_batches_tracked", "layer2.3.conv3.weight", "layer2.3.bn3.weight", "layer2.3.bn3.bias", "layer2.3.bn3.running_mean", "layer2.3.bn3.running_var", "layer2.3.bn3.num_batches_tracked", "layer3.0.conv1.weight", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.bn1.num_batches_tracked", "layer3.0.conv2.weight", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.bn2.num_batches_tracked", "layer3.0.conv3.weight", "layer3.0.bn3.weight", "layer3.0.bn3.bias", "layer3.0.bn3.running_mean", "layer3.0.bn3.running_var", "layer3.0.bn3.num_batches_tracked", "layer3.0.downsample.0.weight", "layer3.0.downsample.1.weight", "layer3.0.downsample.1.bias", "layer3.0.downsample.1.running_mean", "layer3.0.downsample.1.running_var", "layer3.0.downsample.1.num_batches_tracked", "layer3.1.conv1.weight", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.bn1.num_batches_tracked", "layer3.1.conv2.weight", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer3.1.bn2.num_batches_tracked", "layer3.1.conv3.weight", "layer3.1.bn3.weight", "layer3.1.bn3.bias", "layer3.1.bn3.running_mean", "layer3.1.bn3.running_var", "layer3.1.bn3.num_batches_tracked", "layer3.2.conv1.weight", "layer3.2.bn1.weight", "layer3.2.bn1.bias", "layer3.2.bn1.running_mean", "layer3.2.bn1.running_var", "layer3.2.bn1.num_batches_tracked", "layer3.2.conv2.weight", "layer3.2.bn2.weight", "layer3.2.bn2.bias", "layer3.2.bn2.running_mean", "layer3.2.bn2.running_var", "layer3.2.bn2.num_batches_tracked", "layer3.2.conv3.weight", "layer3.2.bn3.weight", "layer3.2.bn3.bias", "layer3.2.bn3.running_mean", "layer3.2.bn3.running_var", "layer3.2.bn3.num_batches_tracked", "layer3.3.conv1.weight", "layer3.3.bn1.weight", "layer3.3.bn1.bias", "layer3.3.bn1.running_mean", "layer3.3.bn1.running_var", "layer3.3.bn1.num_batches_tracked", "layer3.3.conv2.weight", "layer3.3.bn2.weight", "layer3.3.bn2.bias", "layer3.3.bn2.running_mean", "layer3.3.bn2.running_var", "layer3.3.bn2.num_batches_tracked", "layer3.3.conv3.weight", "layer3.3.bn3.weight", "layer3.3.bn3.bias", "layer3.3.bn3.running_mean", "layer3.3.bn3.running_var", "layer3.3.bn3.num_batches_tracked", "layer3.4.conv1.weight", "layer3.4.bn1.weight", "layer3.4.bn1.bias", "layer3.4.bn1.running_mean", "layer3.4.bn1.running_var", "layer3.4.bn1.num_batches_tracked", "layer3.4.conv2.weight", "layer3.4.bn2.weight", "layer3.4.bn2.bias", "layer3.4.bn2.running_mean", "layer3.4.bn2.running_var", "layer3.4.bn2.num_batches_tracked", "layer3.4.conv3.weight", "layer3.4.bn3.weight", "layer3.4.bn3.bias", "layer3.4.bn3.running_mean", "layer3.4.bn3.running_var", "layer3.4.bn3.num_batches_tracked", "layer3.5.conv1.weight", "layer3.5.bn1.weight", "layer3.5.bn1.bias", "layer3.5.bn1.running_mean", "layer3.5.bn1.running_var", "layer3.5.bn1.num_batches_tracked", "layer3.5.conv2.weight", "layer3.5.bn2.weight", "layer3.5.bn2.bias", "layer3.5.bn2.running_mean", "layer3.5.bn2.running_var", "layer3.5.bn2.num_batches_tracked", "layer3.5.conv3.weight", "layer3.5.bn3.weight", "layer3.5.bn3.bias", "layer3.5.bn3.running_mean", "layer3.5.bn3.running_var", "layer3.5.bn3.num_batches_tracked", "layer4.0.conv1.weight", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.bn1.num_batches_tracked", "layer4.0.conv2.weight", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.bn2.num_batches_tracked", "layer4.0.conv3.weight", "layer4.0.bn3.weight", "layer4.0.bn3.bias", "layer4.0.bn3.running_mean", "layer4.0.bn3.running_var", "layer4.0.bn3.num_batches_tracked", "layer4.0.downsample.0.weight", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.0.downsample.1.num_batches_tracked", "layer4.1.conv1.weight", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.bn1.num_batches_tracked", "layer4.1.conv2.weight", "layer4.1.bn2.weight", "layer4.1.bn2.bias", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.1.bn2.num_batches_tracked", "layer4.1.conv3.weight", "layer4.1.bn3.weight", "layer4.1.bn3.bias", "layer4.1.bn3.running_mean", "layer4.1.bn3.running_var", "layer4.1.bn3.num_batches_tracked", "layer4.2.conv1.weight", "layer4.2.bn1.weight", "layer4.2.bn1.bias", "layer4.2.bn1.running_mean", "layer4.2.bn1.running_var", "layer4.2.bn1.num_batches_tracked", "layer4.2.conv2.weight", "layer4.2.bn2.weight", "layer4.2.bn2.bias", "layer4.2.bn2.running_mean", "layer4.2.bn2.running_var", "layer4.2.bn2.num_batches_tracked", "layer4.2.conv3.weight", "layer4.2.bn3.weight", "layer4.2.bn3.bias", "layer4.2.bn3.running_mean", "layer4.2.bn3.running_var", "layer4.2.bn3.num_batches_tracked", "fc.weight", "fc.bias". 

In [88]:
def inference_on_test(model, test_folder_path):
    model.eval()
    test_transforms = data_transforms['test']
    test_images = datasets.ImageFolder(root=test_folder_path, transform=test_transforms)
    test_loader = DataLoader(test_images, batch_size=1, shuffle=False)

    file_names = []
    predictions = []
    with torch.no_grad():
        for i, (inputs, _) in enumerate(test_loader):
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            # Get the file name from the test_images dataset
            file_name = test_images.samples[i][0]
            file_names.append(file_name)
            predictions.append(preds.item())

    return file_names, predictions

In [89]:
with open('class_to_idx.json', 'r') as f:
    class_to_idx = json.load(f)

In [90]:
idx_to_class = {v: k for k, v in class_to_idx.items()}
print("ID-to-Class mapping:", idx_to_class)

ID-to-Class mapping: {0: 'AI_GENERATED', 1: 'NON_AI_GENERATED'}


In [None]:
test_folder_path = 'Test'
files, predictions = inference_on_test(model, test_folder_path)
predicted_labels = [idx_to_class[pred] for pred in predictions]
print("Predictions for test images:", predictions)
print(F"Predicted labels: {predicted_labels}")
print(F"Files: {files}")

Predictions for test images: [0, 1, 0, 1, 0]
Predicted labels: ['AI_GENERATED', 'NON_AI_GENERATED', 'AI_GENERATED', 'NON_AI_GENERATED', 'AI_GENERATED']
Files: ['Test/test/1.jpg', 'Test/test/3.jpg', 'Test/test/DSC_0255.JPG', 'Test/test/Portrait .jpg', 'Test/test/apple.jpeg']
