In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import logging

In [2]:
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class DateDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ['Red', 'White']
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.images = self._load_images()

    def _load_images(self):
        images = []
        for cls in self.classes:
            class_dir = os.path.join(self.root_dir, cls)
            if not os.path.exists(class_dir):
                logger.warning(f"Directory not found: {class_dir}")
                continue
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    img_path = os.path.join(class_dir, img_name)
                    images.append((img_path, self.class_to_idx[cls]))
        if not images:
            logger.warning(f"No images found in {self.root_dir}")
        return images

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path, label = self.images[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label, img_path
        except Exception as e:
            logger.error(f"Error loading image {img_path}: {str(e)}")
            placeholder_image = torch.zeros((3, 224, 224))
            return placeholder_image, label, img_path

In [3]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

In [4]:
# Set up data transforms
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = DateDataset(root_dir='/Users/williamrae/Desktop/Lego-Scanner/Dataset/Train', transform=data_transforms)
test_dataset = DateDataset(root_dir='/Users/williamrae/Desktop/Lego-Scanner/Dataset/Test', transform=data_transforms)

# Check if datasets are empty
if len(train_dataset) == 0:
    raise ValueError("Training dataset is empty. Please add images to the 'Train' directory.")

if len(test_dataset) == 0:
    logger.warning("Test dataset is empty. Evaluation will be skipped.")

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0) if len(test_dataset) > 0 else None

# Set up the device
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set up the model
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)  # 2 classes: Red and White
model = model.to(device)

# Set up loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)  # Reduced learning rate


Using device: mps


In [11]:
# Training loop
num_epochs = 50
early_stopping = EarlyStopping(patience=5, min_delta=0.001)

try:
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels, _ in train_loader:  # Ignore img_path during training
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        
        epoch_loss = running_loss / len(train_dataset)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
        
        # Check for early stopping
        early_stopping(epoch_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break

except KeyboardInterrupt:
    print("Training interrupted by user")

finally:
    # Save the model
    torch.save(model.state_dict(), "date_classification_model.pth")
    print("Model saved successfully.")

Epoch 1/50, Loss: 0.0000
Epoch 2/50, Loss: 0.0000
Epoch 3/50, Loss: 0.0000
Epoch 4/50, Loss: 0.0000
Epoch 5/50, Loss: 0.0000
Epoch 6/50, Loss: 0.0000
Early stopping triggered
Model saved successfully.


In [12]:
# Evaluation (if test data is available)
if test_loader is not None:
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels, img_paths in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Log incorrectly classified images
            incorrect = predicted != labels
            for img_path in [img_paths[i] for i in range(len(img_paths)) if incorrect[i]]:
                logger.info(f"Incorrectly classified image: {img_path}")

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
else:
    print("Skipping evaluation due to empty test dataset.")

INFO:__main__:Incorrectly classified image: /Users/williamrae/Desktop/Lego-Scanner/Dataset/Test/Red/IMG_1779.png
INFO:__main__:Incorrectly classified image: /Users/williamrae/Desktop/Lego-Scanner/Dataset/Test/Red/IMG_1786.png
INFO:__main__:Incorrectly classified image: /Users/williamrae/Desktop/Lego-Scanner/Dataset/Test/Red/IMG_1787.png
INFO:__main__:Incorrectly classified image: /Users/williamrae/Desktop/Lego-Scanner/Dataset/Test/Red/IMG_1778.png
INFO:__main__:Incorrectly classified image: /Users/williamrae/Desktop/Lego-Scanner/Dataset/Test/Red/IMG_1785.png
INFO:__main__:Incorrectly classified image: /Users/williamrae/Desktop/Lego-Scanner/Dataset/Test/Red/IMG_1790.png
INFO:__main__:Incorrectly classified image: /Users/williamrae/Desktop/Lego-Scanner/Dataset/Test/Red/IMG_1784.png
INFO:__main__:Incorrectly classified image: /Users/williamrae/Desktop/Lego-Scanner/Dataset/Test/Red/IMG_1780.png
INFO:__main__:Incorrectly classified image: /Users/williamrae/Desktop/Lego-Scanner/Dataset/Test/

Test Accuracy: 54.29%


In [10]:
import os

# Path to the saved model
model_path = "date_classification_model.pth"

# Check if the model file exists and delete it
if os.path.exists(model_path):
    os.remove(model_path)
    print(f"Deleted the model file: {model_path}")
else:
    print(f"No model file found at: {model_path}")

Deleted the model file: date_classification_model.pth
