# 1.0 Libaries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms, datasets
from torch.autograd import Variable

import numpy as np

import os

import matplotlib.pyplot as plt

# 2.0 Generall

In [None]:
print('Cuda is available:',torch.cuda.is_available(),'\n')

# 3.0 Hyperparameters

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

hidden_dim = 128

batch_size = 32

num_epochs = 1
learn_rate = 0.0002


# 4.0 Load Data

In [None]:
# create folder structure if it does not exist
folder_path = "../../data/"
if not os.path.exists(folder_path):
    os.makedirs(folder_path)


#MNIST Dataset
    #Create Folder, transorm to Tensor
train_dataset = torchvision.datasets.MNIST(
    root=folder_path, train=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.01307),(0.3081))]), download=True
)

test_dataset = torchvision.datasets.MNIST(
    root=folder_path, train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.01307),(0.3081))])
)

#Data Loader
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=batch_size, shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=batch_size, shuffle=True
)

#show MNIST Dataset

figure= plt.figure(figsize=(10,8))
spalten, zeilen = 5,5
for i in range(1,spalten*zeilen+1):
    sample_idx = torch.randint(len(train_dataset),size=(1,)).item()
    image , label = train_dataset[sample_idx]
    figure.add_subplot(spalten,zeilen,i)
    plt.title(label, weight='bold')
    plt.imshow(image.squeeze(), cmap='gray')
plt.subplots_adjust(wspace=0.3,hspace=0.6)
plt.show()


# 5.0 Classes

## 5.1 Class Auto Encoder

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
            torch.nn.Linear(28 * 28, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 9)
        )
        
        self.decoder = nn.Sequential(
            torch.nn.Linear(9, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 28 * 28),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

### 5.1.1 Load AE Model

In [None]:
FILE_ModelAE = 'ae2.pth'
ModelAE = torch.load(FILE_ModelAE)

ModelAE = ModelAE.to(device)

ModelAE.eval()

trained_encoder = ModelAE.encoder


## 5.2 Class Neural Network

In [None]:
class MNIST_De_Class(nn.Module):
    def __init__(self):
        super(MNIST_De_Class, self).__init__()
        # self.input = nn.Linear(28 * 28, 20)  # input
        # self.hidden1 = nn.Linear(20, 20)  # hidden1
        # self.hidden2 = nn.Linear(20, 20)  # hidden2
        # self.out = nn.Linear(20, 10)  # output
      
        self.network = nn.Sequential(
            nn.Linear(9, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 10),
            nn.Sigmoid(),
        )

    def forward(self, x):
        # x = self.input(x)
        # x = nn.ReLU(x)
        # x = self.hidden1(x)
        # x = nn.ReLU(x)
        # x = self.hidden2(x)
        # x = nn.ReLU(x)
        # x = self.out(x)
        # x = nn.Softmax(x)
        # print(x.size())
        return (self.network(x))

#  6.0 Function: View Predictions

In [None]:
def view_classifications(img,pred):
    pred = pred.cpu().data.numpy().squeeze()

    fig, (ax1, ax2) = plt.subplots(figsize=(6,9), ncols=2)
    ax1.imshow(img.resize_(1, 28, 28).cpu().numpy().squeeze())
    bars = ax2.barh(np.arange(10), pred)
    
    ax2.set_aspect(0.1)
    ax2.set_yticks(np.arange(10))
    ax2.set_yticklabels(np.arange(10))
    ax2.set_title('Class Probability')
    ax2.set_xlim(0, 1.1)

    for bar, val in zip(bars, pred):
        if not (val < 0.3):
            ax2.text(val, bar.get_y() + bar.get_height() / 2, round(val, 2), va='center')

    plt.tight_layout()
    plt.show()
    plt.close()

# 7.0 Training

In [None]:
model1 = MNIST_De_Class().to(device)

#Load Model / Continue Training
continue_training = False

FILE_Model1 = "model1.pth"

if continue_training:
    model1 = torch.load(FILE_Model1).to(device)

#criterion / lf
criterion = nn.CrossEntropyLoss()

#optimizer
optimizer = optim.Adam(params=model1.parameters(),lr=learn_rate)

for epoch in range(num_epochs):
    for batch_id, (Bild,Label) in enumerate(train_loader):
        Bild = Bild.reshape(-1,28*28).to(device)  #[100, 1,28,28] > [100,734] ?
        Label = Label.to(device)

        outputAE = trained_encoder(Bild)
        outputAE = outputAE.to(device)
        
        output = model1(outputAE) #was macht out?
        loss = criterion(output,Label) #hier out?!

        if batch_id %100 == 0:
            print('Batch Nr. in Epoche',batch_id+1, '// Loss: %.5f ' % loss)
        # view_classifications(Bild[0].cpu().view(1,28,28),output[0])
        

        optimizer.zero_grad()
        loss.backward()
        
        optimizer.step()
        
    print('Epoche:',epoch+1, 'Loss: %.5f ' % loss)
    
    if epoch %2 == 0:
        view_classifications(Bild[0].view(1,28,28),output[0])

# 8.0 Save Model Neural Network

In [None]:
# torch.save(model1,FILE_Model1)

# Confusuion Matrix

In [None]:
# label_list = torch.zeros(batch_size)
# pred_list = torch.zeros(batch_size)

# print(label_list, pred_list)

# with torch.no_grad():
#     for data,target in test_dataset:
#         data = data.reshape(-1,28,28).to(device)
#         target = target.to(device)

#         outconf= model1(data)

        

# Evaluation

In [None]:
# model1 = torch.load(FILE)

# with torch.no_grad():
#     for batch_id, (data,target) in enumerate(test_loader):
#         data = data.reshape(-1,28*28).to(device)  #[100, 1,28,28] > [100,734] ?
#         target = target.to(device)
        
#         output = model1(data) #was macht out?
#         loss = criterion(output,target) #hier out?!

#         view_classifications(data[0].view(1,28,28),output[0])
#         print('Batch in Epoche ',batch_id+1, 'Loss: %.5f ' % loss)

#     print('Epoche:',epoch+1, 'Loss: %.5f ' % loss)
