<a href="https://colab.research.google.com/github/Wjh70301/AGB/blob/master/csc421_mnist_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import gzip
import sys
import struct
import urllib.request
import numpy as np

def read_image(fi):
    magic, n, rows, columns = struct.unpack(">IIII", fi.read(16))
    assert magic == 0x00000803
    assert rows == 28
    assert columns == 28
    rawbuffer = fi.read()
    assert len(rawbuffer) == n * rows * columns
    rawdata = np.frombuffer(rawbuffer, dtype='>u1', count=n*rows*columns)
    return rawdata.reshape(n, rows, columns).astype(np.float32) / 255.0

def read_label(fi):
    magic, n = struct.unpack(">II", fi.read(8))
    assert magic == 0x00000801
    rawbuffer = fi.read()
    assert len(rawbuffer) == n
    return np.frombuffer(rawbuffer, dtype='>u1', count=n)

def openurl_gzip(url):
    request = urllib.request.Request(
        url,
        headers={
            "Accept-Encoding": "gzip",
            "User-Agent": "Mozilla/5.0 (X11; U; Linux i686) Gecko/20071127 Firefox/2.0.0.11", 
        })
    response = urllib.request.urlopen(request)
    return gzip.GzipFile(fileobj=response, mode='rb')

In [0]:
import math
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

def plot_weights(weights):
    w = weights
    w = w.reshape(w.shape[0], math.ceil(math.sqrt(w.shape[1])), 
                              math.ceil(math.sqrt(w.shape[1])))

    # Get the lowest and highest values for the weights.
    # This is used to correct the colour intensity across
    # the images so they can be compared with each other.
    w_min = np.min(w)
    w_max = np.max(w)

    # Number of filters used in the fully connected layer.
    num_filters = w.shape[0]

    # Number of grids to plot.
    # Rounded-up, square-root of the number of filters.
    num_grids = math.ceil(math.sqrt(num_filters))
    
    # Create figure with a grid of sub-plots.
    fig, axes = plt.subplots(num_grids, num_grids)

    # Plot all the filter-weights.
    for i, ax in enumerate(axes.flat):
        # Only plot the valid filter-weights.
        if i<num_filters:
            # Get the weights for the i'th filter of the input.
            img = w[i, :, :]

            # Plot image.
            #ax.imshow(img, vmin=w_min, vmax=w_max,
            #          interpolation='nearest', cmap='gray')
            ax.imshow(img, vmin=np.min(img), vmax=np.max(img),
                      interpolation='nearest', cmap='gray')
        
        # Remove ticks from the plot.
        ax.set_xticks([])
        ax.set_yticks([])
    
    # Ensure the plot is shown correctly with multiple plots
    # in a single Notebook cell.
    plt.show()
    
def plot_single_training_example(i):
    data = np.load('mnist.npz')
    image = data['train_x'][i]
    label = data['train_y'][i]

    print(label)
    f, ax = plt.subplots(figsize=(12, 12))
    sns.heatmap(image, annot=True, fmt='.1f', square=True, cmap="YlGnBu")
    plt.show()