In [1]:
import os
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms

### Preparing and Preprocessing the Dataset

In [2]:
class ArmorPlateDataset(Dataset):
    """For loading training images and their labels from a directory"""
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.data = []
        
        for file in os.listdir(folder_path):
            if file.endswith('.png'):
                txt_file = file.replace('.png', '.txt')
                img_path = os.path.join(folder_path, file)
                label_path = os.path.join(folder_path, txt_file)
                
                with open(label_path, 'r') as f:
                    label = int(f.read().strip())
                self.data.append((img_path, label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label, dtype=torch.float32)

# Apply basic transformations
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # Resize for lightweight processing
    transforms.ToTensor()
])

# Initialize dataset and dataloader
folder_path = 'path/to/train/data' # Replace with path to your dataset
dataset = ArmorPlateDataset(folder_path=folder_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

### Defining the Convolution Neural Network (CNN)

In [3]:
class SimpleCNN(nn.Module):
    """For defining CNN architecture"""
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)  # 16 filters
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)  # 32 filters
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 16 * 16, 64)  # Fully connected layer
        self.fc2 = nn.Linear(64, 1)  # Output layer for binary classification

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 16 * 16)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))  # Sigmoid for binary output
        return x

### Training the Model

In [4]:
# Initialize the model, loss function, and optimizer
model = SimpleCNN()
criterion = nn.BCELoss()  # Binary Cross-Entropy Loss for binary classification
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device).unsqueeze(1)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward 
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}')

print('Training complete.')

Epoch [1/10], Loss: 0.4444
Epoch [2/10], Loss: 0.2524
Epoch [3/10], Loss: 0.1854
Epoch [4/10], Loss: 0.1561
Epoch [5/10], Loss: 0.1272
Epoch [6/10], Loss: 0.1193
Epoch [7/10], Loss: 0.1063
Epoch [8/10], Loss: 0.0982
Epoch [9/10], Loss: 0.0778
Epoch [10/10], Loss: 0.0826
Training complete.


### Testing the Model

In [None]:
def predict(model, image_path):
    """Uses model to classify input image and returns the prediction as 0 or 1"""
    model.eval()
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(image)
        prediction = torch.round(output).item()  # 0 or 1
        return int(prediction)

test_image = 'path/to/test/image.png' # Replace with path to your image
print(f'Armor present: {predict(model, test_image)}')

### Saving and Loading the Model

In [5]:
# Save the model
torch.save(model.state_dict(), 'armor_plate_classifier.pth')

In [None]:
# Load the model
model = SimpleCNN()
model.load_state_dict(torch.load('armor_plate_classifier.pth'))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)