### Classification model based on VAE-compressed images 

**Step 1**: Load data and get encoder running

In [1]:
# Tensorflow 
import tensorflow_datasets as tfds 
import tensorflow as tf 

# Additional
from matplotlib import pyplot as plt 
import numpy as np
from sklearn.metrics import mean_squared_error
import random

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
d1, d2, d3 = tfds.load('patch_camelyon', split=[f'train[:100%]',f'test[:100%]',f'validation[:100%]'],
                          data_dir='./Data/PCAM/', 
                          download=False, 
                          shuffle_files=True, 
                          read_config=tfds.ReadConfig(shuffle_seed=42))

# Training data
d1_a, d1_b = d1.take(len(d1)//2), d1.skip(len(d1)//2)

# Test data 
d2_a, d2_b = d2.take(len(d2)//2), d2.skip(len(d2)//2)

# Valditation data 
d3_a, d3_b = d3.take(len(d3)//2), d3.skip(len(d3)//2)

In [12]:
def convert_sample(sample):
    image, label = sample['image'], sample['label']
    image = tf.image.convert_image_dtype(image, tf.float32)
    label = tf.one_hot(label, 2, dtype=tf.float32)
    return image, label

# a = autoencoder, b = classfier
d1_a, d1_b = d1_a.map(lambda x: convert_sample(x)[0]).batch(128), d1_b.map(convert_sample).batch(128)
d2_a, d2_b = d2_a.map(lambda x: convert_sample(x)[0]).batch(128), d2_b.map(convert_sample).batch(128)
d3_a, d3_b = d3_a.map(lambda x: convert_sample(x)[0]).batch(128), d3_b.map(convert_sample).batch(128)

In [13]:
# Load encoder and freeze weights 

# set latent dimesion to 18 later on 
latent_dim = 16

encoder = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(64, kernel_size= 3, strides = 1, padding='same', activation='relu', input_shape=(96, 96, 3)),
    tf.keras.layers.Conv2D(128, kernel_size=3, strides= 2, padding='same', activation='relu'),
    tf.keras.layers.Conv2D(256, kernel_size=3, strides= 2, padding='same', activation='relu'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(2 * latent_dim), # 2 since we encode mean and standard deviation
])