In [3]:
# import packages
import os
import torch 
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
 
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from skimage.metrics import structural_similarity as ssim 

Initialize use of device, and store images


In [None]:
# utility functions
def get_device():
    if torch.cuda.is_available():
        device = 'cuda:0'
    else:
        device = 'cpu'
    return device
def make_dir():
    image_dir = 'MNIST_Images'
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)
def save_decoded_image(img, epoch):
    img = img.view(img.size(0), 1, 28, 28)
    save_image(img, './MNIST_Images/linear_ae_image{}.png'.format(epoch))

In [None]:
# constants
NUM_EPOCHS = 50
LEARNING_RATE = 1e-3
BATCH_SIZE = 128
# image transformations
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [None]:
trainset = datasets.MNIST(
    root='./data',
    train=True, 
    download=True,
    transform=transform
)
testset = datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transform
)
trainloader = DataLoader(
    trainset, 
    batch_size=BATCH_SIZE,
    shuffle=True
)
testloader = DataLoader(
    testset, 
    batch_size=BATCH_SIZE, 
    shuffle=True
)

In [2]:
# utility functions
def get_device():
    if torch.cuda.is_available():
        device = 'cuda:0'
    else:
        device = 'cpu'
    return device
def make_dir():
    image_dir = 'MNIST_Images'
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)
def save_decoded_image(img, epoch):
    img = img.view(img.size(0), 1, 28, 28)
    save_image(img, './MNIST_Images/linear_ae_image{}.png'.format(epoch))

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        # encoder
        self.enc1 = nn.Linear(in_features=784, out_features=256)
        self.enc2 = nn.Linear(in_features=256, out_features=128)
        self.enc3 = nn.Linear(in_features=128, out_features=64)
        self.enc4 = nn.Linear(in_features=64, out_features=32)
        self.enc5 = nn.Linear(in_features=32, out_features=16)

    def forward(self, x):
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))
        x = F.relu(self.enc4(x))
        x = F.relu(self.enc5(x))
       
        return x


## As we see, input and output dimensions are 16 for the channel


In [1]:
input_dim= 16
output_dim = 16

In [None]:
class Channel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Channel, self).__init__()
        self.serialize = nn.Linear(input_dim, output_dim)  
        self.deserialize = nn.Linear(output_dim, input_dim)

    def forward(self, x):
        # Serialization
        x = self.serialize(x)
        # Here I will input the noise n
        # Deserialization
        x = self.deserialize(x)
        
        return x


In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
         # decoder 
        self.dec1 = nn.Linear(in_features=16, out_features=32)
        self.dec2 = nn.Linear(in_features=32, out_features=64)
        self.dec3 = nn.Linear(in_features=64, out_features=128)
        self.dec4 = nn.Linear(in_features=128, out_features=256)
        self.dec5 = nn.Linear(in_features=256, out_features=784)

    def forward (self,x):
        x = F.relu(self.dec1(x))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        x = F.relu(self.dec4(x))
        x = F.relu(self.dec5(x))

        return x

Define an instance for the model combining these 3 components

In [None]:
model = nn.Sequential(Encoder, Channel, Decoder)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

Define train and test functions

In [None]:
def train(model, trainloader, NUM_EPOCHS):
    train_loss = []
    for epoch in range(NUM_EPOCHS):
        running_loss = 0.0
        for data in trainloader:
            img, _ = data
            img = img.to(device)
            img = img.view(img.size(0), -1)
            optimizer.zero_grad()
            outputs = model(img)
            loss = criterion(outputs, img)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        loss = running_loss / len(trainloader)
        train_loss.append(loss)
        print('Epoch {} of {}, Train Loss: {:.3f}'.format(epoch+1, NUM_EPOCHS, loss))
        if epoch % 5 == 0:
            save_decoded_image(outputs.cpu().data, epoch)
    return train_loss


In [None]:
def test_image_reconstruction(model, testloader, device):
    for batch in testloader:
        img, _ = batch
        img = img.to(device)
        img_flat = img.view(img.size(0), -1)
        outputs = model(img_flat)
        outputs = outputs.view(outputs.size(0), 1, 28, 28)
        outputs = outputs.cpu().data
        
        img_np = img.cpu().numpy()
        outputs_np = outputs.numpy()
        ssim_values = []
        for i in range(img_np.shape[0]):
            original_img = img_np[i, 0]
            reconstructed_img = outputs_np[i, 0]
            ssim_value = ssim(original_img, reconstructed_img, data_range=original_img.max() - original_img.min())
            ssim_values.append(ssim_value)
        
        fig, axes = plt.subplots(2, img_np.shape[0], figsize=(15, 4))
        for i in range(img_np.shape[0]):
            axes[0, i].imshow(img_np[i, 0], cmap='gray')
            axes[0, i].set_title('Original')
            axes[0, i].axis('off')
            
            axes[1, i].imshow(outputs_np[i, 0], cmap='gray')
            axes[1, i].set_title(f'Reconstructed\nSSIM: {ssim_values[i]:.4f}')
            axes[1, i].axis('off')
        
        plt.tight_layout()
        plt.show()

        save_image(outputs, 'mnist_reconstruction.png')
        
        break

# Example usage (assuming `model`, `testloader`, and `device` are defined):
# test_image_reconstruction(model, testloader, device)


## Training and testing the Model

In [None]:
# get the computation device
device = get_device()
print(device)
# load the neural network onto the device
model.to(device)

make_dir()

# train the network
train_loss = train(model, trainloader, NUM_EPOCHS)
plt.figure()
plt.plot(train_loss)
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.savefig('deep_ae_mnist_loss.png')

# test the network
test_image_reconstruction(model, testloader)

## Plot the spectrorgram

## Apply filters on evaluation

## Apply filters during training and see αν η συμπεριφορά του μοντέλου αλλάζε, αν βελτιώνεται και μπορεί να "καταλάβει" την εισαγωγή θορύβου στο κανάλι. Αυτό ουσιαστικά θέλουμε να είναι το αποτέλεσμα
