# Train A Shape Classifier Model



In [None]:
import json
import os

train_data_root = "../datasets/train"
test_data_root = "../datasets/test"

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transformations (including resizing and normalization)
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale (black and white images)
    transforms.Resize((64, 64)),  # Resize images to 64x64 pixels
    transforms.ToTensor(),  # Convert the image to a tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize the images (mean=0.5, std=0.5 for grayscale)
])

# Load the dataset
train_dataset = datasets.ImageFolder(root=train_data_root, transform=transform)
test_dataset = datasets.ImageFolder(root=test_data_root, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Check class names (optional)
print(f'Classes: {train_dataset.classes}')

# 2. Define a simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 16 * 16, 128)
        self.fc2 = nn.Linear(128, 3)  # 3 classes: circle, triangle, rectangle
        
    def forward(self, x):
        x = F.relu(self.conv1(x))   # First Conv Layer
        x = F.max_pool2d(x, 2)      # Max Pooling
        x = F.relu(self.conv2(x))   # Second Conv Layer
        x = F.max_pool2d(x, 2)      # Max Pooling
        x = x.view(x.size(0), -1)   # Flatten
        x = F.relu(self.fc1(x))     # Fully Connected Layer 1
        x = self.fc2(x)             # Fully Connected Layer 2 (output)
        return x

model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_model(model, train_loader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        accuracy = 100.0 * correct / total
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {accuracy:.2f}%')


# Train

In [None]:
train_model(model, train_loader, criterion, optimizer, epochs=15)

# Test 

In [None]:
def test(model, test_loader):
    """Print the Precision, Recall and F1-score for the trained model
    """
    # Precision = TP / (TP + FP)
    # Recall = TP / (TP + FN)
    # F1-score = 2 * (Precision * Recall) / (Precision + Recall)
    
    num_classes = 3
    all_preds = []
    all_labels = []
    precisions = []
    recalls = []
    f1s = []
    
    
    for image, label in test_loader:
        image = image.to(device)
        label = label.to(device)
        outputs = model(image) # inference on the test image
        # max function returns (value ,index), outputs will return a tensor with 3 values
        _, predicted = torch.max(outputs, 1)
        
        # append predictions and labels to lists
        all_preds.append(predicted)
        all_labels.append(label)
        
# concatenate all batches since all_preds and all_labels are lists of tensors
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    
    for c in range(num_classes):
        TP = ((all_preds == c) & (all_labels == c)).sum().item()
        FP = ((all_preds == c) & (all_labels != c)).sum().item()
        FN = ((all_preds != c) & (all_labels == c)).sum().item()
        
        
        precision = TP / (TP + FP) if (TP + FP) > 0 else 0
        recall = TP / (TP + FN) if (TP + FN) > 0 else 0
        f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        precisions.append(precision)
        recalls.append(recall)
        f1s.append(f1_score)
        
    # Calculate average prediction, recall and, f1
    p = sum(precisions) / num_classes
    r = sum(recalls) / num_classes
    f1 = sum(f1s) / num_classes


    print(f'Precision: {p:.4f}, Recall: {r:.4f}, F1-score: {f1:.4f}')


test(model, test_loader)

# Show Predictions


In [None]:
import torchvision
import matplotlib.pyplot as plt
def show_prediction(model, image):
    """Pass the image to the model and overlay the predicted shape and confidence on the input
    image and display it
    """
    # Define transformations (including resizing and normalization)
    transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale (black and white images)
    transforms.Resize((64, 64)),  # Resize images to 64x64 pixels
    transforms.Normalize((0.5,), (0.5,))  # Normalize the images (mean=0.5, std=0.5 for grayscale)
])
    
    image = torchvision.io.read_image(image)  # shape [C, H, W], values in [0, 255]
    image = image.float()                 
    image = transform(image)                       # apply preprocessing
    input_tensor = image.unsqueeze(0).to(device)   # add batch dimension because model expects batch input
    
    # Model inference
    with torch.no_grad():
        outputs = model(input_tensor)
        probabilities = F.softmax(outputs, dim=1) # apply softmax to get probabilities
        confidence, predicted_class = torch.max(probabilities, 1)

    # Overlay the prediction and confidence on the image
    plt.imshow(image.squeeze(0).cpu(), cmap='gray')
    plt.title(f'Predicted: {train_dataset.classes[predicted_class.item()]}, Confidence: {confidence.item():.4f}')
    plt.axis('off')
    plt.show()