In [None]:
import torch 
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import matplotlib.pylab as plt
import numpy as np

# Multi Layer Perzeptron (MLP)

Dieses Notebook zeigt die Implementierung und Visualisierung eines Multi Layer Perzeptrons (MLP) zur Klassifikation des MNIST-Datensatzes mit PyTorch.

In [None]:
# Create and print the training dataset
train_dataset = dsets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
print("Print the training dtatset:\n", train_dataset)

# Create and print the validating dataset
validation_dataset = dsets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
print("Print the validating dataset:\n ", validation_dataset)

In [None]:
# Plot multiple training data
h = plt.figure(figsize=(10,8))
cols, rows = 5, 5
for i in range(1, cols*rows+1):
    random_idx = torch.randint(len(train_dataset), size=(1,)).item() 
    img, label = train_dataset[random_idx]
    h.add_subplot(rows, cols, i)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap='gray')
plt.show()

In [None]:
# Define the Multi Layer Perceptron (MLP) class
class MLP(nn.Module):
    # Contructor
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.hidden = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU()
        )
        self.out = nn.Linear(hidden_size, output_size)

    # Prediction
    def forward(self, x):
        x = self.hidden(x)
        pred = self.out(x)
        return pred

input_dim = 28*28
hidden_dim = 32 # 130 512
output_dim = 10

In [None]:
# Set GPU if device is possible
device  = "cuda" if torch.cuda.is_available() else "cpu"
if device !="cuda":
    device = "mps" if torch.backends.mps.is_available() else "cpu"

# Create Model
model = MLP(input_dim, hidden_dim, output_dim).to(device)
print('The model: \n', model)

In [None]:
# Helper Function to plot weights
def PlotParameters(model, hiddenDim): 
    W = model.state_dict()['hidden.0.weight'].data.cpu()
    b = model.state_dict()['hidden.0.bias'].data.cpu()
    w_min = W.min().item()
    w_max = W.max().item()
    fig, axes = plt.subplots(int(np.ceil(hiddenDim/10.0)), 10, figsize=(20,int(hiddenDim/10)*3))
    fig.subplots_adjust(hspace=0.01, wspace=0.1)
    for i, ax in enumerate(axes.flat):
        if i < hiddenDim:
            ax.set_xlabel(f"neuron: {i}")
            Img = W[i, :].view(28, 28)
            ax.imshow(Img, vmin=-1, vmax=1, cmap='seismic')
            ax.set_xticks([])
            ax.set_yticks([])
    plt.show()

PlotParameters(model=model, hiddenDim = hidden_dim)

In [None]:
# Define the learning_rate, optimizer, criterion and data loader
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#learning_rate = 0.1
#optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
loss_func = nn.CrossEntropyLoss()
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True, num_workers=1)
validation_loader = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=100, shuffle=True, num_workers=1)

In [None]:
# Training function
def train(dataloader, model, loss_func, optimizer):
    size = len(dataloader.dataset)
    model.train()
    train_loss, correct = 0, 0
    for batchNr, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        pred = model(X.view(-1, input_dim))
        loss = loss_func(pred, y)
        loss.backward()
        optimizer.step()
        _, yhat = torch.max(pred.data, 1)
        correct += (yhat == y).sum().item()
        train_loss += loss
        if (batchNr+1) %100 == 0:
            loss, current = loss.item(), (batchNr+1)*len(y)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
    accuracy = correct / size
    train_accuracy_list.append(accuracy)
    loss_list.append(train_loss.item()/size)

# Validation function
def validate(dataloader, model, loss_func):
    num_batches = len(dataloader)
    size = len(dataloader.dataset)
    model.eval()
    val_loss, correct = 0, 0
    with torch.no_grad():
        for x_test, y_test in dataloader:
            x_test, y_test = x_test.to(device), y_test.to(device)
            pred = model(x_test.view(-1, input_dim))
            val_loss += loss_func(pred, y_test)
            _, yhat = torch.max(pred.data, 1)
            correct += (yhat == y_test).sum().item()
    val_loss /= num_batches  
    accuracy = correct / size
    val_accuracy_list.append(accuracy)
    val_loss_list.append(val_loss.item())
    print(f"Validation Error: \n Accuracy: {(100*accuracy):>0.1f} Avg loss: {(val_loss.item()):>8f}")

In [None]:
# Training loop
n_epochs = 15
loss_list = []
train_accuracy_list = []
val_loss_list = []
val_accuracy_list = []
N_train = len(train_dataset)
N_test = len(validation_dataset)
print(N_test)
print(len(train_loader.dataset))

for t in range(n_epochs):
    print(f"n_epoch {t+1}\n-------------------------------")
    train(train_loader, model, loss_func, optimizer)
    validate(validation_loader, model, loss_func)
print("Done!")

In [None]:
# Plot the loss and accuracy
fig, ax1 = plt.subplots()
color = 'tab:red'
ax1.plot(loss_list,color=color)
ax1.set_xlabel('epoch',color=color)
ax1.set_ylabel('total loss',color=color)
ax1.tick_params(axis='y', color=color)
    
ax2 = ax1.twinx()  
color = 'tab:blue'
ax2.set_ylabel('accuracy', color=color)  
ax2.plot( train_accuracy_list, color=color)
ax2.tick_params(axis='y', color=color)
fig.tight_layout()

color = 'tab:orange'
ax2.plot(val_accuracy_list, color=color)

color = 'tab:purple'
ax2.plot(val_loss_list, color=color)

PlotParameters(model, hidden_dim)

In [None]:
# Plot the data
def show_data(data_sample):
    plt.imshow(data_sample[0].numpy().squeeze(), cmap='gray')
    plt.title('y = ' + str(data_sample[1]))

Softmax_fn=nn.Softmax(dim=1)

count = 0
for X, y in validation_dataset:
    X = X.to(device)
    y = torch.from_numpy(np.array([[y]]))
    y = y.to(device)
    z = model(X.reshape(-1, input_dim))
    _, yhat = torch.max(z, 1)
    yhat = int(yhat.item())
    if yhat != y:
        X, y = X.cpu(), y.cpu()
        show_data((X, y.item()))
        plt.show()
        print("yhat:", yhat)
        print("probability of class ", torch.max(Softmax_fn(z)).item())
        count += 1
    if count >= 5:
        break

In [None]:
X, y = validation_dataset[63]
X = X.to(device)
y = torch.from_numpy(np.array([[y]]))
y = y.to(device)
z = model(X.reshape(-1, input_dim))
_, yhat = torch.max(z, 1)
yhat = int(yhat.item())
X, y = X.cpu(), y.cpu()
show_data((X, y.item()))
plt.show()
print("yhat:", yhat)
print("z:", z)
print("probability of class ", torch.max(Softmax_fn(z)).item())
print(X.shape)
X = X.reshape(-1, input_dim)
print(X.shape)
hidden_out = model.hidden(X.to(device))

print(hidden_out)

b = model.state_dict()['hidden.0.bias'].data.cpu()
reconstrWeights = hidden_out.data.cpu()
min, max = reconstrWeights.min(), reconstrWeights.max()
print(f'{min=} {max=}')

W = model.state_dict()['hidden.0.weight'].data.cpu()
Img = torch.zeros((28, 28))
for i in range(hidden_dim):
    s = reconstrWeights[:,i]
    Mat = W[i, :].view(28, 28)
    Mat = s*Mat
    Img += Mat
plt.imshow(Img, cmap='seismic')
plt.show()

In [None]:
classValues = model.out(hidden_out)
print(f'{classValues=}')

Softmax_fn(classValues)

W_out = model.state_dict()['out.weight'].data.cpu()
b_out = model.state_dict()['out.bias'].data.cpu()
print(W_out)
print(W_out[3,:])

def PlotOutputLayer(model, hiddenDim, outDim): 
    W = model.state_dict()['hidden.0.weight'].data.cpu()
    W_out = model.state_dict()['out.weight'].data.cpu()
    fig, axes = plt.subplots(1, 10, figsize=(20, 3))
    fig.subplots_adjust(hspace=0.01, wspace=0.1)
    for i, ax in enumerate(axes.flat):
        if i < outDim:
            ax.set_xlabel(f"neuron: {i}")
            classWeights = W_out[i, :] 
            Img = torch.zeros((28,28))
            for j in range(hiddenDim):
                s = classWeights[j]
                Mat = W[j, :].view(28, 28)
                Mat = s*Mat
                Img += Mat
            ax.imshow(Img, vmin=-1, vmax=1, cmap='seismic')
            ax.set_xticks([])
            ax.set_yticks([])
    plt.show()

PlotParameters(model, hidden_dim)
PlotOutputLayer(model, hidden_dim, output_dim)

hidden_out.data.cpu()

W_out = model.state_dict()['out.weight'].data.cpu()
print(W_out[3,:])
print(W_out[3,:]*hidden_out.data.cpu())
Softmax_fn(torch.matmul(hidden_out.data.cpu(), torch.transpose(W_out,0,1)) +  b_out)