In [76]:
###CNN and MLP primitive

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

##The encoder in the paper takes x in R^(TxHxW) and y in R^(TxE) and maps them to R^(TxL) the dimensionality of the context encoder
##and encoder is set to 5. So this cnn needs to output (32x5) so each (surface in R^(5x5) -> (z in R^5)

##we enhance the dimensionality of the iv surface first by upgrading the number of channels. The reasoning behind this is similar to why 
##we do this in transformer architecture. A larger dimensional space will be able to capture more nuanced information and represent it in number form
##then we compress this to something digestable
class CNN(nn.Module):
    #input_size and output_size represent the number of channels in the input and output data
    #channels is the number of dimensions a single data point will have ie RGB = 3 channelss
    def __init__(self, input_size, output_size):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(input_size, output_size, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(output_size, output_size, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(output_size, output_size, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(output_size * 5 * 5, 5)

    def forward(self, x):
        #print(x.shape)
        x = x.reshape(batch_size, 1, H, W)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        #print(f"x shape before resize and convultion passthrough is {x.shape}")
        #x = x.reshape(batch_size, 1, 5)
        x = torch.reshape(x, (batch_size, -1))
        x = self.fc(x)
        #print(f"after passthrough into convultion layers and fully connected layer ther shape of x is {x.shape}")
        return x

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, y):
        x = F.relu(self.fc1(y))

class TCNN(nn.Module):
    def __init__(self, input_size, output_size, num_surfaces):
        super(TCNN, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 5 * 5 * num_surfaces)
        self.output_size = output_size
        self.num_surfaces = num_surfaces
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = torch.reshape(x, (-1, self.num_surfaces, 5, 5))
        return x

In [77]:
###ENCODER DECODER CONTEXTENCODER

class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super(Encoder, self).__init__()
        self.cnn = CNN(input_size, 5)
        self.mlp = nn.Identity()
        self.lstm = nn.LSTM(5 + 3, hidden_size, num_layers=2, batch_first=True, dropout=0.2)
        self.linear_mu = nn.Linear(hidden_size, latent_size)
        self.linear_sigma = nn.Linear(hidden_size, latent_size)

    def forward(self, x, y):
        x_encoded = self.cnn(x)
        y_encoded = self.mlp(y)
        y_encoded = torch.squeeze(y_encoded, dim=1)
        #print(f"x_encoded shape is {x_encoded.shape}  y_encoded shape is {y_encoded.shape}")
        encoded = torch.cat((x_encoded, y_encoded), dim=-1)
        #print(f"concatenated vector is of size {encoded.shape}")
        _, (hidden, _) = self.lstm(encoded)
        #print("hidden state created")
        hidden = hidden[-1]  # Take the last hidden state
        mu = self.linear_mu(hidden)
        log_var = self.linear_sigma(hidden)
        z = self.reparameterize(mu, log_var)
        #print("encoding successful")
        return z, mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

class ContextEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, context_size):
        super(ContextEncoder, self).__init__()
        self.cnn = CNN(input_size, 5)
        self.mlp = nn.Identity()
        self.lstm = nn.LSTM(5 + 3, hidden_size, num_layers=2, batch_first=True, dropout=0.2)
        self.linear = nn.Linear(hidden_size, context_size)

    def forward(self, x_c, y_c):
        x_encoded = self.cnn(x_c)
        y_encoded = self.mlp(y_c)
        y_encoded = torch.squeeze(y_encoded, dim=1)
        #print(f"x_encoded size = {x_encoded.shape} ||y_encoded size = {y_encoded.shape}")
        encoded = torch.cat((x_encoded, y_encoded), dim=-1)
        _, (hidden, _) = self.lstm(encoded)
        hidden = hidden[-1]  # Take the last hidden state
        zeta = self.linear(hidden)
        #print("context encoding successful")
        return zeta

class Decoder(nn.Module):
    def __init__(self, latent_size, hidden_size, output_size, num_surfaces):
        super(Decoder, self).__init__()
        self.lstm = nn.LSTM(latent_size + 5, hidden_size, num_layers=2, batch_first=True, dropout=0.2) #latent
        self.tcnn = TCNN(hidden_size, output_size, num_surfaces)
        self.mlp = nn.Linear(hidden_size, 1)

    def forward(self, z, zeta):
        # Reshape z and zeta to have shape (1, 5)
        z = torch.reshape(z, (1, -1))
        zeta = torch.reshape(zeta, (1, -1))
        # Concatenate z and zeta along the second dimension to get shape (1, 10)
        z_concat = torch.cat((z, zeta), dim=1)
        
        #print(f"z_concat is of size {z_concat.shape}")
        hidden, _ = self.lstm(z_concat)
        #print('i am here')
        #print(f"hidden state published, hidden state shape {hidden.shape}")
        x_n = self.tcnn(hidden) 
        #print('i am here')
        r_n = self.mlp(hidden)
        #print("decoding successful")
        return torch.squeeze(x_n, dim=0), torch.squeeze(r_n, dim=0)

In [78]:
# Hyperparameters
input_size = 1
hidden_size = 100
latent_size = 5
context_size = 5
output_size = 1
num_surfaces = 1
ttm = 10  # Number of time steps to generate

batch_size = 31

model = CVAE(input_size, hidden_size, latent_size, context_size, input_size, num_surfaces)

# Load the saved model state dictionary
model.load_state_dict(torch.load('best_model.pth'))
print("Model loaded successfully.")

# Set the model to evaluation mode
print(model.eval())

encoder = model.encoder
decoder = model.decoder
context_encoder = model.context_encoder

batch_size = 31

# Create dummy inputs for the ONNX export
x = torch.randn(batch_size, 1, 5, 5)
y = torch.randn(batch_size, 1, 3)

# Export the encoder to ONNX
torch.onnx.export(encoder, (x, y), "encoder.onnx", opset_version=11, input_names=['x', 'y'], output_names=['z', 'mu', 'log_var'])

# Create dummy inputs for the ONNX export
z = torch.randn(1, latent_size)
zeta = torch.randn(1, context_size)

# Export the decoder to ONNX
torch.onnx.export(decoder, (z, zeta), "decoder.onnx", opset_version=11, input_names=['z', 'zeta'], output_names=['x_n', 'r_n'])

# Create dummy inputs for the ONNX export
x_c = torch.randn(batch_size, 5, 5)
y_c = torch.randn(batch_size, 3)

# Export the context encoder to ONNX
torch.onnx.export(context_encoder, (x_c, y_c), "context_encoder.onnx", opset_version=11, input_names=['x_c', 'y_c'], output_names=['zeta'])

Model loaded successfully.
CVAE(
  (encoder): Encoder(
    (cnn): CNN(
      (conv1): Conv2d(1, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv3): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (fc): Linear(in_features=125, out_features=5, bias=True)
    )
    (mlp): Identity()
    (lstm): LSTM(8, 100, num_layers=2, batch_first=True, dropout=0.2)
    (linear_mu): Linear(in_features=100, out_features=5, bias=True)
    (linear_sigma): Linear(in_features=100, out_features=5, bias=True)
  )
  (context_encoder): ContextEncoder(
    (cnn): CNN(
      (conv1): Conv2d(1, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv3): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (fc): Linear(in_features=125, out_features=5, bias=True)
    )
    (mlp): Identity()
    (ls

In [None]:
import torch
from torch.ao.quantization import (
  get_default_qconfig_mapping,
  get_default_qat_qconfig_mapping,
  QConfigMapping,
)
import torch.ao.quantization.quantize_fx as quantize_fx
import copy

example_inputs = (x, y)

##static quantisation
encoder_to_quantise = copy.deepcopy(encoder)
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
encoder_to_quantise.eval()
# prepare
encoder_prepared = quantize_fx.prepare_fx(encoder_to_quantise, qconfig_mapping, example_inputs)
# calibrate (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(encoder_prepared)