
# Convolutional Neural Networks

## Setup

In [None]:
!pip install tensorboardx scipy matplotlib torchinfo torch-summary tensorflow torch


# Eyes and Algorithms

`Visual interpretation is essential to living but had remained elusive to machines.`

Our ability to interpret images is not just a cognitive skill; it's core to
survival. It allows us to navigate our world, recognize emotions, and
communicate across cultures. From detecting danger to understanding complex
social cues, visual interpretation is fundamental to human existence.

In recent years, computers have achieved remarkable success in various domains
by applying mathematical methods developed over centuries. However, tasks
involving images present unique challenges. Unlike problems with well-defined
mathematical descriptions, complex shapes and visual patterns have eluded
precise mathematical characterization.

But with the advent of neural networks, a new door has opened. Convolutional
Neural Networks (CNNs), a special type of neural network, have enabled
computers to process images in ways previously thought impossible. By capturing
the intricate patterns and structures within images, CNNs have revolutionized
the field.

The ability to automate image interpretation has far-reaching impacts across various sectors:

- **Healthcare**: Enhancing diagnostics (e.g., detecting tumors in X-rays) -
  impacting millions of patients.
- **Transportation**: Guiding autonomous vehicles - thousands self-driving cars on
  the road.
- **Security**: Powering facial recognition systems - airport security, organizing
  photo albums..
- **Personalization**: secure access, image search
- Environmental/Climate Monitoring: Analyzing satellite images for climate
  studies - tracking changes in over thousands of square kilometers
- **Agriculture**: Precision farming through satellite imagery - optimizing crop
  yields
- **Disaster** Response: Analyzing aerial images for disaster relief planning
- **Manufacturing**: Quality control through image inspection in manufacturing
  plants.
- **Archaeology**: Discovering historical sites through satellite imagery leading
  to the discovery of over 200 unknown sites.


# What computers sees..

`Human: Look at all the shapes, colors, sounds, and patterns! Computer: So... is it 1, or is it 0?`

For a computer, an image is an array of numbers. It's how computers "read"
pictures.

- **Pixels**: An image is made of tiny parts called pixels. A pixel is a small
  dot in an image.
- **Black & White Images**: In black & white images, pixels are either black
  (0) or white (1).
- **Color Images (RGB)**: In color images, pixels have three numbers for red,
  green, and blue.
- **RGB Channels**: The three parts of a color image are red, green, and blue.
  By changing these, you can make many colors.



In [None]:
# =============================================================================j
#                               PLOT IMAGE MATRICES
# =============================================================================j
import matplotlib.pyplot as plt
import numpy as np


def plot_matrix(
    matrix, title, cmap=None, text_color="red", text_size=12, figsize=(2, 2)
    ,show_text=True
):
    fig, ax = plt.subplots(figsize=figsize)
    ax.imshow(matrix, cmap=cmap)
    if show_text:
        for i in range(matrix.shape[0]):
            for j in range(matrix.shape[1]):
                ax.text(
                    j,
                    i,
                    str(matrix[i, j]),
                    ha="center",
                    va="center",
                    color=text_color,
                    fontsize=text_size,
                )
    ax.set_title(title)
    ax.axis("off")
    plt.tight_layout()
    plt.show()


# Example usage for a black and white image
cross = np.array([[1, 0, 1], [0, 1, 0], [1, 0, 1]])
plot_matrix(cross, "Cross", cmap="gray")

# Example usage for an RGB image
red_channel = np.array([[255, 0, 0], [0, 128, 0], [128, 0, 255]])
green_channel = np.array([[0, 255, 0], [255, 0, 255], [0, 255, 0]])
blue_channel = np.array([[0, 0, 255], [0, 128, 0], [255, 0, 128]])

color_image = np.stack((red_channel, green_channel, blue_channel), axis=-1)

# Plot the individual color channels
plot_matrix(red_channel, "Red Channel", cmap="Reds")
plot_matrix(green_channel, "Green Channel", cmap="Greens")
plot_matrix(blue_channel, "Blue Channel", cmap="Blues")
# Plot the full color image
plot_matrix(color_image, "Color Image", text_size=8, text_color="black", figsize=(3, 3))






# Image processing

Take an `image` with `resolution` of `3x3`. How do we `blur` this image.

$$
\text{Image} = \begin{bmatrix}
0 & 1 & 0 \\
1 & 0 & 1 \\
0 & 1 & 0 \\
\end{bmatrix}
$$

## Kernel

Since, number in the pixel reprsents the intensity of the color, `blurring` operation would simply mean `averaging` the intensity of the pixel with its neighbors. To represent this operation with matrix, we can use a `kernel` or `filter` matrix: a `3x3` matrix with all values equal to `1` since we are averaging the intensity with 8 neighbors and itself. (Note: you can normalize the kernel by dividing it with the number of elements in the kernel, in this case 9)

$$
\text{Kernel/Filter} = \begin{bmatrix}
1 & 1 & 1 \\
1 & 1 & 1 \\
1 & 1 & 1 \\
\end{bmatrix}
$$

## Padding

But the pixes in the corner do not have neighbors on all sides. So, we can
`pad` the image with zeros to make it a `5x5` matrix. The padded image can be
represented as a 5x5 matrix with zero padding:

$$
\text{Padded Image} = \begin{bmatrix}
0 & 0 & 0 & 0 & 0 \\
0 & 0 & 1 & 0 & 0 \\
0 & 1 & 0 & 1 & 0 \\
0 & 0 & 1 & 0 & 0 \\
0 & 0 & 0 & 0 & 0 \\
\end{bmatrix}
$$

## Convolution Operation:

Now we replace each pixel in the image with the average of its neighbors and
itself by multiplying the `kernel` with the padded image at each position and
summing the result. We `move/stride` the kernel over the image and perform the same
operation at each position.

This operation is important in many tasks in image processing and beyond. So,
it has a special name: `convolution` and can be represented as following in context of our example:

$$
\text{Blurred Image}[i, j] = \sum_{m=-1}^{1} \sum_{n=-1}^{1} \text{Padded Image}[i+m, j+n] \times \text{Kernel}[m, n]
$$

where $\text{Blurred Image}[i, j]$ represents the value of the pixel at position \((i, j)\) in the resulting blurred image.

After performing the convolution operation, we get the following resulting blurred image:

$$
\text{Blurred Image} = \begin{bmatrix}
2 & 3 & 2 \\
3 & 4 & 3 \\
2 & 3 & 2 \\
\end{bmatrix}
$$

### TLDR: Convolution Operation

<!-- ![cnn_animation](imgs/cnn_animation.gif) -->

<p align="center">
  <img src="imgs/cnn_animation.gif" />
</p>

[source](https://en.wikipedia.org/wiki/File:2D_Convolution_Animation.gif)

## More Kernel Examples

$$
\begin{align*}
\text{Edge Detection} & = \begin{bmatrix} -1 & -1 & -1 \\ -1 & 8 & -1 \\ -1 & -1 & -1 \end{bmatrix}, &
\text{Horizontal Edge} & = \begin{bmatrix} 1 & 2 & 1 \\ 0 & 0 & 0 \\ -1 & -2 & -1 \end{bmatrix}, &
\text{Vertical Edge} & = \begin{bmatrix} 1 & 0 & -1 \\ 2 & 0 & -2 \\ 1 & 0 & -1 \end{bmatrix}.
\end{align*}
$$





In [None]:


# =============================================================================j
#                               VISUALIZE FILTERS
# =============================================================================j

def convolution(image, kernel):
    pad_size = kernel.shape[0] // 2
    padded_image = np.pad(image, pad_size, mode="constant", constant_values=0)
    result_image = np.zeros_like(image)
    for i in range(pad_size, padded_image.shape[0] - pad_size):
        for j in range(pad_size, padded_image.shape[1] - pad_size):
            result_image[i - pad_size, j - pad_size] = np.sum(
                padded_image[
                    i - pad_size : i + pad_size + 1, j - pad_size : j + pad_size + 1
                ]
                * kernel
            )
    return result_image


def visualize_convolution(image, kernel,show_text=True):
    # Original Image
    plot_matrix(image, "Original Image", cmap="gray", text_size=16,show_text=show_text)

    # Kernel in Matrix Form
    # plot_matrix(kernel, "Kernel in Matrix Form", cmap="gray", text_size=16)

    # Blurred Image (Result of Convolution)
    post_conv_image = convolution(image, kernel)
    plot_matrix(
        post_conv_image,
        "Image result of Convolution",
        cmap="gray",
        text_size=16,
        show_text=show_text,
    )

In [None]:

# Example usage
# original_image = np.array([[0, 1, 0], [1, 0, 1], [0, 1, 0]])
blur_kernel = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
original_image = np.random.randint(0, 255, size=(10, 10))
visualize_convolution(original_image, blur_kernel,show_text=False)

In [None]:
# =============================================================================j
#                              VISUALIZE COMPLEX FILTERS
# =============================================================================j


import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
import torch


def get_filters():
    return {
        "Blur": np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]]) / 9,
        "Edge Detection": np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]),
        "Horizontal Edge": np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]),
        "Vertical Edge": np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]]),
    }


class DatasetLoader:
    def __init__(self, dataset_name="MNIST", transform=None):
        self.dataset_name = dataset_name
        self.dataset = self.load_dataset(transform)

    def load_dataset(self, transform):
        if self.dataset_name == "MNIST":
            return torchvision.datasets.MNIST(
                root="./data", train=True, download=True, transform=transform
            )
        # You can add other datasets here, e.g.
        elif self.dataset_name == "CIFAR10":
            return torchvision.datasets.CIFAR10(
                root="./data", train=True, download=True, transform=transform
            )
        else:
            raise ValueError(f"Unknown dataset: {self.dataset_name}")

    def get_sample_image(self, label=None, num_samples=1, seed=None, grayscale=False):
        if label is None:  # If no specific label is provided, choose one randomly
            if seed is not None:
                np.random.seed(seed)
            label = np.random.choice(self.get_unique_labels())

        images = [image for image, label_ in self.dataset if label_ == label]
        if seed is not None:
            np.random.seed(seed)
        np.random.shuffle(images)
        images = images[:num_samples]
        if grayscale:
            to_grayscale_transform = transforms.Compose(
                [transforms.ToPILImage(), transforms.Grayscale(), transforms.ToTensor()]
            )
            images = [to_grayscale_transform(image) for image in images]
        return images

    def get_unique_labels(self):
        labels = [label for _, label in self.dataset]
        return np.unique(labels)

    def get_dataloader(self, batch_size=64, shuffle=True, fraction=1.0):
        if fraction < 0.0 or fraction > 1.0:
            raise ValueError("Fraction must be between 0.0 and 1.0.")

        dataset_size = len(self.dataset)
        subset_size = int(fraction * dataset_size)

        indices = torch.randperm(dataset_size)[:subset_size]


        subset = torch.utils.data.dataset.Subset(self.dataset, indices)
        return torch.utils.data.DataLoader(
            subset, batch_size=batch_size, shuffle=shuffle
        )


def apply_filters(image, filters):
    plt.figure(figsize=(15, 5))
    plt.subplot(1, len(filters) + 1, 1)
    plt.imshow(image, cmap="gray")
    plt.title("Original Image")
    plt.axis("off")
    for idx, (name, kernel) in enumerate(filters.items()):
        filtered_image = convolution(image, kernel)
        plt.subplot(1, len(filters) + 1, idx + 2)
        plt.imshow(filtered_image, cmap="gray")
        plt.title(name)
        plt.axis("off")
    plt.show()


def demo_filter(dataset_name, count=2):
    for _ in range(count):
        transform = transforms.Compose([transforms.ToTensor()])
        dataset_loader = DatasetLoader(dataset_name=dataset_name, transform=transform)
        sample_image_tensor = dataset_loader.get_sample_image(
            num_samples=1, grayscale=True
        )[0]
        sample_image = sample_image_tensor.squeeze().numpy()
        apply_filters(sample_image, get_filters())


demo_filter("MNIST")
demo_filter("CIFAR10")

# Complex Problems, Simple Filters

`Real world problems are too complicated for simple filters to be effective:`

<p align="center">
  <img src="imgs/obj_detection.jpg" alt="obj_detection" width="300"/>
  <img src="imgs/handwriting_meme.jpg" alt="handwriting_meme" width="300"/>
  <!-- <img src="imgs/brain_scan.jpg" alt="brain_scan" width="200"/> -->
</p>

source[1](https://www.echelon.health/wp-content/uploads/2020/05/brain-scan.jpg)
[2](https://www.meme-arsenal.com/memes/8cd3a0c0fbca1d452927c33cd2638030.jpg)
[3](https://en.wikipedia.org/wiki/Object_detection#/media/File:Detected-with-YOLO--Schreibtisch-mit-Objekten.jpg)

- **Object Recognition**: Filters can't identify specific objects like cars or
  dogs, especially in varying shapes and orientations.
- **Face Recognition**: They can't recognize individual faces with different
  expressions or angles.
- **Semantic Segmentation**: Filters can't divide an image into meaningful parts or
  understand relationships between pixels.
- **Handling Occlusions**: They struggle with partially hidden or overlapped
  objects.
- **Multiclass Classification**: Filters can't identify multiple classes within an
  image.
- **Anomaly Detection**: They can't detect subtle abnormalities, e.g., in medical
  imaging.
- **Texture Recognition**: Recognizing complex textures or patterns is challenging
  for simple filters.


# Need for Neural

`NNs are learning-driven, flexible, able to recognize complex patterns, and capable of generalizing across various tasks.`

Traditional methods use fixed filters for tasks like edge detection, requiring
expert knowledge and often being limited in flexibility. Neural networks
surpass traditional methods in several ways:

- **Adaptability**: They learn and adapt from data, accommodating a variety of
  tasks.
- **Hierarchical Learning**: They can recognize simple patterns and combine
  them into more complex structures, enabling the capture of intricate
  relationships in the data.
- **Automatic Feature Learning**: They automatically find important features,
  removing the need for manual engineering.
- **Robustness**: Neural networks can handle variations in input that might
  challenge rigid traditional filters.
- **End-to-End Training**: They provide a direct mapping from raw input to
  desired output, simplifying the process.
- **Generalization**: With proper training, they can effectively work on new,
  unseen data.


# Neural Networks

## Fully Connected and Bloated

`Neural networks considers everything, and that's the problem...`

<p align="center">
<img src="imgs/nn.png" alt="nn" width="300"/>
</p>

Assuming a `12 MP` resolution colored image with dimensions `4000x3000` pixels
and three color channels (Red, Green, Blue), the total number of features for
each image would be `4000x3000x3`. If this image is fed into a fully connected
layer with 1000 hidden units (neurons), the number of weights (parameters) just
for this single layer would be:

$$
\textrm{FC size = } 4000 \times 3000 \times 3 \times 1000 = 36 \text{ billion!!}
$$

for just one layer! If the neural network consists of several fully connected layers, the number of parameters can grow even more significantly.

## Fully Connected and Blind

`Human: Take this 2D image data. FC: Oh a single list of numbers!`

$$
\begin{align*}
\text{What was shown} &=
\begin{bmatrix}
1 & 0 & 1 \\
0 & 1 & 0 \\
1 & 0 & 1 \\
\end{bmatrix}
&
\text{What is seen} &= [1, 0, 1, 0, 1, 0, 1, 0, 1]
\end{align*}
$$

Fully connected layers are somewhat "blind" to certain inherent properties of
images that are essential for effective image analysis. Here's what they tend
to overlook:

- **Pixel Correlation**: Fully connected layers in traditional DNNs ignore the
  spatial relationships between neighboring pixels, treating each pixel
  independently. ( The patterns of colors formed by putting pixels on opposite
  corners of the images may not be useful to determing what is the objecct in
  the image.)

<p align="center">
  <img src="imgs/teddy_img.png" alt="teddy_img" width="300"/>
</p>

[source](https://www.youtube.com/watch?v=HGwBXDKFk9I)

<!-- TODO: add iamges for each invariances -->

- **Translation Invariance**: Unlike CNNs, fully connected layers can't recognize
  the same object in different positions, lacking translation invariance.

- **Scale and Rotational Sensitivity**: Fully connected layers are sensitive to
  variations in the scale and rotation of objects, whereas CNNs can be more


# Convolutional Neural Networks

` Bring back the filters!`

<p align="center">
<img src="imgs/cnn_layout.png" alt="cnn_layout" style="width:80%;" />
</p>

## CNN: learns filters

`Dont worry about which filter to use, let the network learn it.`

In traditional image processing, filters are manually designed based on expert
knowledge for specific tasks, such as edge detection or noise reduction. These
filters remain fixed and perform the same operation independent of problem at hand.

On the other hand, Convolutional Neural Networks (CNNs) automatically learning
filters from the data during the training process. Rather than relying on
pre-defined filters, CNNs initialize their filters randomly and iteratively
update them based on the training data . The learning process is guided by
optimization algorithms like gradient descent, which adjust the filter weights
to minimize the difference between the network's predictions and the actual
targets. The adaptive nature of CNNs allows them to learn filters that are
specifically tailored to the unique characteristics of the data and the
specific task at hand.

## CNN: filters and features

`Filter can go way beyond edge detection.`

In the context of CNNs, a feature refers to a specific visual pattern or
characteristic of an image that the network has learned to recognize. Features
can range from simple attributes like edges or colors to more complex and
abstract representations of objects or shapes. At any stage of more than
multiple filters can be applied to capture sufficient information required for
inference in deeper layers.

For example, the CNN in picture above has `four filters in first layer`. When
these filters slide over the input image, each one generates a corresponding
feature map. These feature maps, when combined, create a new representation of
the of the data coming from the previous layer.

## CNN: depth

`Increase depth to learn complex features with hierarchies.`

Images contains hierarchy of features, with more complex features composed of
simpler ones. eg: Learning to detect face requires learning to detect eyes,
nose, mouth, etc and those features needs to be combined to detect a face. The
learning of face can again be used to learn to detect a person and so on..

When add more convolution layers, the network learns to recognize more complex
patterns by combining the features learned in the previous layers. The first
convolutional layers learn simple features like edges and colors, while deeper
layers combine them to learn more abstract and complex concepts like facial
features or objects.

## CNN: Pooling

`What is in the image has little to do with how big the image is.`

Its often that we do not care how big the image is. Pooling, is a technique
used in CNNs to reduce the spatial dimensions (width and height) of the feature
maps generated by the convolutional layers. The pooling operation aggregates
information from neighboring pixels and outputs a single value, effectively
downsampling the feature map. By reducing the spatial resolution, the network
becomes less sensitive to small spatial translations, making it more robust to
slight changes in the position of features in the image.

Pooling can be done using different techniques, such as `max pooling` (selecting
the maximum value from a small neighborhood) or `average pooling` (taking the
average value).

## CNN: Final layers

`Reason using features and not pixels`

After several convolutional and pooling layers, the high-level reasoning in the
neural network is done via fully connected layers. Neurons in a fully connected
layer have full connections to all activations in the previous layer, as seen
in regular neural networks, allowing the CNN to decide how to combine the
features to make the final decision (e.g., classifying the image).

## CNN: What did we get out of it?

Convolutional Neural Networks (CNNs) were designed with a particular
understanding of the properties of images that traditional Deep Neural Networks
(DNNs) often ignored. Here are the key ideas behind CNNs that make them so
effective for image analysis:

- **Local Connectivity**: Recognizing that pixels in images are related to
  their immediate neighbors, CNNs employ local receptive fields. Each neuron is
  only connected to a small region of the input, which enables the network to
  focus on local features. (Kernels to the rescue!)

- **Parameter Sharing**: Unlike traditional DNNs, where each weight is unique
  to an input, CNNs use parameter sharing. The same filter is applied across
  different parts of the image, making the network translational invariant. It
  can recognize a feature anywhere in the input space.

- **Spatial Hierarchies**: CNNs builds a hierarchy of features from lower to
  higher levels of spatical abstraction, starting with simple features like
  edges and textures, and moving on to more complex features.Early layers
  detect simple patterns like edges and textures, while deeper layers identify
  complex shapes and objects. This hierarchical representation ensures a deeper
  understanding of image content. (This is present in DNN as well but CNNs do
  it better for images)

- **Scale and Rotational Invariance**: CNNs can be designed to be robust to
  changes in scale and rotation. Pooling layers summarize neighboring neuron
  outputs, imparting scale and rotational invariance to some extent, enhancing
  feature detection across different variations.

- **Reduced Computation**: Utilizing local connectivity and parameter sharing,
  CNNs significantly reduce the number of trainable parameters compared to
  fully connected DNNs. This computational efficiency makes training and
  inference more manageable.


MNIST CNN

In [None]:

# =============================================================================j
#                             CNN CODE
# =============================================================================j
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import StepLR
from tensorboardX import SummaryWriter


import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

class CNN_MNIST(nn.Module):
    def __init__(self):
        super(CNN_MNIST, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(7 * 7 * 32, 256)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)  # 10 output classes for MNIST

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, 7 * 7 * 32)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x



def train_model(model, train_loader, criterion, optimizer, scheduler, num_epochs=5):
    # Create a SummaryWriter to log training information
    writer = SummaryWriter()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct_predictions = 0
        total_samples = 0

        for i, (inputs, labels) in enumerate(train_loader, 0):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()
            
            
            if i % 100 == 99:
                print(f"[Epoch {epoch + 1}, Batch {i + 1}] Loss: {running_loss / 100:.3f}")
                running_loss = 0.0
                
            

        # Learning rate scheduler step
        scheduler.step()

        accuracy = 100 * correct_predictions / total_samples
        print(f"Accuracy on epoch {epoch + 1}: {accuracy:.2f}%")

        # Log metrics to TensorBoard
        writer.add_scalar('Loss/train', running_loss, epoch + 1)
        writer.add_scalar('Accuracy/train', accuracy, epoch + 1)
        writer.add_scalar('Learning_rate', optimizer.param_groups[0]['lr'], epoch + 1)

        # Log histogram of model parameters (optional)
        for name, param in model.named_parameters():
            writer.add_histogram(name, param, epoch + 1)

    print("Training finished!")

    # Close the SummaryWriter
    writer.close()
    
    return model

def train_custom_model(model, dataset_name, transform=None, batch_size=64, fraction=1.0, num_epochs=5):
    if not transform:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))  # Normalize inputs to range [-1, 1]
        ])
    # Create an instance of the DatasetLoader class
    dataset_loader = DatasetLoader(dataset_name=dataset_name, transform=transform)

    # Get the data loader for training set with specified fraction of the dataset and batch size
    train_loader = dataset_loader.get_dataloader(batch_size=batch_size, shuffle=True, fraction=fraction)

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Define learning rate scheduler
    scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

    # Train the model
    trained_model= train_model(model, train_loader, criterion, optimizer, scheduler, num_epochs=num_epochs)


    # Capture the correct and incorrect examples after the training is done
    correct_examples = []
    incorrect_examples = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    trained_model.to(device)
    with torch.no_grad():
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = trained_model(inputs)
            _, predicted = torch.max(outputs, 1)
            correct_mask = predicted == labels
            for image, true_label, pred_label, is_correct in zip(inputs, labels, predicted, correct_mask):
                example_image = image.cpu().numpy()
                true_label = true_label.item()
                predicted_label = pred_label.item()
                if is_correct:
                    correct_examples.append((example_image, true_label, predicted_label))
                else:
                    incorrect_examples.append((example_image, true_label, predicted_label))
    
    return trained_model, correct_examples, incorrect_examples

In [None]:
# =============================================================================
#                   Training the model
# =============================================================================

model,correct_examples, incorrect_examples = train_custom_model(CNN_MNIST(), "MNIST",batch_size=1024, fraction=0.3, num_epochs=5)

In [None]:

# =============================================================================
#                  Visualizing the results
# =============================================================================

def visualize_examples(examples, title, example_count=6):
    np.random.shuffle(examples)
    num_examples = len(examples)
    num_examples_to_show = min(num_examples, example_count)
    rows = (num_examples_to_show + 2) // 3
    cols = min(num_examples_to_show, 3)
    plt.figure(figsize=(8, 3 * rows))

    for i, (image, true_label, predicted_label) in enumerate(examples[:num_examples_to_show]):
        plt.subplot(rows, cols, i + 1)
        image = np.transpose(image, (1, 2, 0))  # Convert (C, H, W) to (H, W, C)
        plt.imshow(image.squeeze(), cmap='gray')
        plt.title(f'True: {true_label}\nPredicted: {predicted_label}')
        plt.axis('off')

    plt.suptitle(title,fontsize=8)
    plt.show()

visualize_examples(correct_examples, f"Correctly Classified Examples")
visualize_examples(incorrect_examples, "Incorrectly Classified Examples")

In [None]:

# =============================================================================
#                  Visualizing the trainnig
# =============================================================================

%load_ext tensorboard
%tensorboard --logdir=runs

# =============================================================================