In [62]:
#Pytorch
import torch
from torchvision import transforms, datasets
from torch.utils.data import Subset, DataLoader, ConcatDataset
import torchvision.models as models
from sklearn.model_selection import KFold
import torch.optim as optim

Imagesize = 500
data_transform = transforms.Compose([
        transforms.Resize((Imagesize, Imagesize)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

#train_dataset = datasets.ImageFolder(root='trainset_500', transform=data_transform)
#validation_dataset = datasets.ImageFolder(root='testset_500', transform=data_transform)
#dataset = ConcatDataset([train_dataset,validation_dataset])

train_dataset = datasets.ImageFolder(root='trainset_500', transform=data_transform)
validation_dataset = datasets.ImageFolder(root='testset_500', transform=data_transform)
dataset = train_dataset
#validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=32, shuffle=True)
#train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)


In [92]:
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(Imagesize//4 * Imagesize//4 * 64, 512)
        self.fc2 = nn.Linear(512, 2)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, Imagesize//4 * Imagesize//4 * 64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = SimpleCNN()


In [87]:
model = models.resnet18(pretrained=True)

# Replace the final layer to match the number of classes in your dataset
num_features = model.fc.in_features
model.fc = torch.nn.Linear(num_features, 2)



In [93]:


# Assuming 'dataset' is your complete dataset
k_folds = 4
kfold = KFold(n_splits=k_folds, shuffle=True)

# This will store the results for each fold
results = {}

# Enumerate splits
for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):
    # Print
    print(f'FOLD {fold}')
    print('--------------------------------')

    # Sample elements randomly from a given list of ids, no replacement.
    train_subsampler = Subset(dataset, train_ids)
    test_subsampler = Subset(dataset, test_ids)

    # Define data loaders for training and testing data in this fold
    train_loader = DataLoader(train_subsampler, batch_size=60, shuffle=True)
    test_loader = DataLoader(test_subsampler, batch_size=20, shuffle=False)

    #recall
    #weights = torch.ones([Imagesize,Imagesize]) * 3
    #define loss function 
    #criterion = nn.BCEWithLogitsLoss()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(2):
        model.train()
        for data in train_loader:
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        model.eval()

    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()


    print(f'Accuracy on validation set: {100 * correct / total}%')
    #print(f'Accuracy on antagonist detection: {100 * correct_ant / total_ant}%')

# Saving the model
torch.save(model.state_dict(), 'protagonist_antagonist_classifier.pth')


FOLD 0
--------------------------------
Accuracy on validation set: 81.25%
FOLD 1
--------------------------------
Accuracy on validation set: 33.333333333333336%
FOLD 2
--------------------------------
Accuracy on validation set: 86.66666666666667%
FOLD 3
--------------------------------
Accuracy on validation set: 80.0%


In [94]:
validation_loader = DataLoader(validation_dataset, shuffle=False)

total = 0
correct = 0
with torch.no_grad():
    for data in validation_loader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
print(f'Accuracy on validation set: {100 * correct / total}%')
print(validation_dataset)
print(total)
print(correct)

Accuracy on validation set: 65.0%
Dataset ImageFolder
    Number of datapoints: 20
    Root location: testset_500
    StandardTransform
Transform: Compose(
               Resize(size=(500, 500), interpolation=bilinear, max_size=None, antialias=warn)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )
20
13


In [97]:


# Function to load image and transform
def process_image(image_path):
    image = Image.open(image_path)
    # Define transformations
    transform = transforms.Compose([
        transforms.Resize((Imagesize, Imagesize)),
        transforms.ToTensor(),
        # Normalize using the same values you used for training
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0)

# Function to get prediction
def classify_image(image_path):
    image = process_image(image_path)
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)
    return predicted.item()

# Example usage

import glob
files = glob.glob("testset_500/A/*.png")
#print(files)
for img in files:
    image_path = img
    classification = classify_image(image_path)
    print(img.split('/')[-1]+ f' was classified as: {classification}')

sephiroth.png was classified as: 0
wario.png was classified as: 0
wolf.png was classified as: 0
waluigi.png was classified as: 0
The_Heartless_Phantom.png was classified as: 0
