In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.models import resnet50
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, precision_recall_curve, average_precision_score, roc_curve, roc_auc_score



import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, precision_recall_curve, average_precision_score, roc_curve, roc_auc_score
import itertools
import itertools

In [2]:

# Set the directory paths
train_dir = '/content/drive/MyDrive/Pothole data/organized_data/train'  # Update with your train directory path
val_dir = '/content/drive/MyDrive/Pothole data/organized_data/val'  # Update with your validation directory path
test_dir = '/content/drive/MyDrive/Pothole data/organized_data/test'  # Update with your test directory path
model_dir = '/content/drive/MyDrive/Pothole data/models from updated neural network'  # Update with your model save directory path

In [3]:
class PotholeClassifier:
    def __init__(self, train_dir, val_dir, test_dir, model_dir, num_epochs=10, batch_size=32, learning_rate=0.001, patience=3):
        self.train_dir = train_dir
        self.val_dir = val_dir
        self.test_dir = test_dir
        self.model_dir = model_dir
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.patience = patience
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        self.train_dataset = datasets.ImageFolder(self.train_dir, transform=self.transform)
        self.val_dataset = datasets.ImageFolder(self.val_dir, transform=self.transform)
        self.test_dataset = datasets.ImageFolder(self.test_dir, transform=self.transform)

        self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
        self.val_loader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)
        self.test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)

        self.model = self._initialize_model()
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.best_val_loss = float('inf')
        self.epochs_without_improvement = 0
        self.val_losses = []  # List to store validation losses

    def _initialize_model(self):
        model = resnet50(pretrained=True)

        for param in model.parameters():
            param.requires_grad = False

        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 2)  # Assuming 2 classes: pothole and normal

        model = model.to(self.device)

        return model

    def train(self):
        for epoch in range(self.num_epochs):
            self.model.train()
            for images, labels in self.train_loader:
                images = images.to(self.device)
                labels = labels.to(self.device)

                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()

            # Evaluate the model on the validation set
            self.model.eval()
            val_loss = 0
            correct = 0
            total = 0
            with torch.no_grad():
                for images, labels in self.val_loader:
                    images = images.to(self.device)
                    labels = labels.to(self.device)

                    outputs = self.model(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
                    val_loss += self.criterion(outputs, labels).item()

            # Calculate average validation loss and accuracy
            val_loss /= len(self.val_loader)
            accuracy = 100 * correct / total

            # Append validation loss to the list
            self.val_losses.append(val_loss)

            # Check if validation loss has improved
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.epochs_without_improvement = 0
                # Save the model if it's the best so far
                self.save_model('best_model.pth')
            else:
                self.epochs_without_improvement += 1

            # Print training progress and validation metrics
            print(f'Epoch {epoch+1}/{self.num_epochs}, Loss: {loss.item():.4f}, '
                  f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%')

            # Check if training should be stopped based on patience value
            if self.epochs_without_improvement >= self.patience:
                print(f'Early stopping. No improvement in validation loss for {self.patience} epochs.')
                break

        # Plot and save the validation loss graph
        self.plot_validation_loss()

        # Perform testing and save metrics
        self.test()

    def plot_validation_loss(self):
        plt.plot(self.val_losses, label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Validation Loss over Epochs')
        plt.legend()
        save_path = os.path.join(self.model_dir, 'validation_loss_graph.png')
        plt.savefig(save_path)
        plt.close()

    def test(self):
        self.model.eval()
        test_loss = 0
        correct = 0
        total = 0
        predictions = []
        true_labels = []

        with torch.no_grad():
            for images, labels in self.test_loader:
                images = images.to(self.device)
                labels = labels.to(self.device)

                outputs = self.model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                test_loss += self.criterion(outputs, labels).item()

                predictions.extend(predicted.cpu().numpy())
                true_labels.extend(labels.cpu().numpy())

        # Calculate average test loss and accuracy
        test_loss /= len(self.test_loader)
        accuracy = 100 * correct / total

        # Compute other metrics
        confusion = confusion_matrix(true_labels, predictions)
        specificity = self.calculate_specificity(confusion)
        precision, recall, thresholds = precision_recall_curve(true_labels, predictions)
        average_precision = average_precision_score(true_labels, predictions)
        fpr, tpr, roc_thresholds = roc_curve(true_labels, predictions)
        auc_roc = roc_auc_score(true_labels, predictions)

        # Save metrics
        self.save_metric('accuracy.txt', accuracy)
        self.save_metric('specificity.txt', specificity)
        self.save_metric('average_precision.txt', average_precision)
        self.save_metric('auc_roc.txt', auc_roc)
        self.save_confusion_matrix(confusion)
        self.plot_precision_recall_curve(precision, recall)
        self.plot_roc_curve(fpr, tpr)

        # Print test metrics
        print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {accuracy:.2f}%')

    def save_metric(self, filename, metric_value):
        save_path = os.path.join(self.model_dir, filename)
        with open(save_path, 'w') as f:
            f.write(str(metric_value))

    def save_confusion_matrix(self, cm):
        plt.figure()
        classes = ['Normal', 'Pothole']
        plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
        plt.title('Confusion Matrix')
        plt.colorbar()
        tick_marks = np.arange(len(classes))
        plt.xticks(tick_marks, classes)
        plt.yticks(tick_marks, classes)

        thresh = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, cm[i, j], horizontalalignment='center', color='white' if cm[i, j] > thresh else 'black')

        plt.tight_layout()
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        save_path = os.path.join(self.model_dir, 'confusion_matrix.png')
        plt.savefig(save_path)
        plt.close()

    def plot_precision_recall_curve(self, precision, recall):
        plt.plot(recall, precision, marker='.')
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.title('Precision-Recall Curve')
        save_path = os.path.join(self.model_dir, 'precision_recall_curve.png')
        plt.savefig(save_path)
        plt.close()

    def plot_roc_curve(self, fpr, tpr):
        plt.plot(fpr, tpr, marker='.')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('ROC Curve')
        save_path = os.path.join(self.model_dir, 'roc_curve.png')
        plt.savefig(save_path)
        plt.close()

    def calculate_specificity(self, confusion):
        tn = confusion[0, 0]
        fp = confusion[0, 1]
        specificity = tn / (tn + fp)
        return specificity


    def save_model(self, filepath):
      save_path = os.path.join(self.model_dir, filepath)
      torch.save(self.model.state_dict(), save_path)
      print(f"Model saved at '{save_path}'")


In [None]:
# Create an instance of the PotholeClassifier
classifier = PotholeClassifier(train_dir, val_dir, test_dir, model_dir)

# Train the classifier
classifier.train()

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:04<00:00, 24.6MB/s]
