In [1]:
from typing import List
import os

import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.io import read_image, ImageReadMode
from torchvision.transforms import transforms

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


# Create a dataloader

In [2]:
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((128,128)),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

class ChalklingLoader(Dataset):

    def __init__(self, data_folder_path: str):
        self.data: List = []
        self.labels: List[str] = []
        self.load_data(data_folder_path)

    def load_data(self, data_folder_path: str):
        # ["class", "file", "attack", "defense", "speed", "life", "aesthetic"]
        # All folder that are also y labels
        y_folders = os.listdir(data_folder_path)
        for y_folder in y_folders:
            self.labels.append(y_folder)
            label_idx = self.labels.index(y_folder)

            # All images within the folder
            image_folder_path = os.path.join(data_folder_path, y_folder)
            files = os.listdir(image_folder_path)
            for image_file in files:
                file_path = os.path.join(image_folder_path, image_file)
                img = Image.open(file_path)
                attack = float(img.info['attack'])
                defense = float(img.info['defense'])
                speed = float(img.info['speed'])
                life = float(img.info['life'])
                aesthetic = float(img.info['aesthetic'])
                image = transform(img)
                self.data.append({
                    "label": torch.tensor(label_idx),
                    # Tensor[image_channels, image_height, image_width]
                    "image": image,
                    "data": torch.tensor([attack, defense, speed, life, aesthetic])
                })

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)


# Load data

In [3]:
from torch.utils.data import DataLoader

train_data = ChalklingLoader(os.path.join("..", "data"))
dataloader = DataLoader(train_data, batch_size=10, shuffle=True)
data = next(iter(dataloader))
print(f"Feature Image batch shape: {data['image'].size()}")
print(f"Feature Data batch shape: {data['data'].size()}")
print(f"Labels batch shape: {data['label'].size()}")

Feature Image batch shape: torch.Size([10, 1, 128, 128])
Feature Data batch shape: torch.Size([10, 5])
Labels batch shape: torch.Size([10])


# Create the classification model

In [7]:
from torch import nn, optim


class Model(nn.Module):
    def __init__(self, input_size: int, output_size: int):
        super(Model, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 6, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(16 * 32 * 32, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, output_size)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(-1, 16 * 32 * 32)
        x = self.fc(x)
        return x

# Train Model

In [8]:
model = Model(128*128, len(train_data.labels)).to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Train the model
for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(dataloader, 0):
        inputs = data['image']
        labels = data['label']
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print('[Epoch %d] loss: %.3f' % (epoch + 1, running_loss / len(dataloader)))

print('Finished Training')

[Epoch 1] loss: 0.695
[Epoch 2] loss: 0.693
[Epoch 3] loss: 0.691
[Epoch 4] loss: 0.689
[Epoch 5] loss: 0.688
[Epoch 6] loss: 0.684
[Epoch 7] loss: 0.681
[Epoch 8] loss: 0.678
[Epoch 9] loss: 0.678
[Epoch 10] loss: 0.672
Finished Training
