In [27]:
import torch
from torchvision import transforms

import beacon

from torchinfo import summary

import matplotlib.pyplot as plt

from tqdm import tqdm

import ParabolaGen
import NoiseGen

In [28]:
x_dim = 92
y_dim = 120
time_dimension = 1000
num_of_files = 1
origin = None    #Set to none to have random origins, or input coordinates (x,y, ToF)

train_data = ParabolaGen.generate_parabola(x_dim, y_dim, time_dimension, 128, origin).unsqueeze(1)
test_data = ParabolaGen.generate_parabola(x_dim, y_dim, time_dimension, 32, origin).unsqueeze(1)

train_data_noisy = torch.clamp((NoiseGen.generate_gaussian_noise(128, 120, 92, mean=0, std=0.1) + NoiseGen.generate_binary_noise(128, 120, 92, magnitude=1, p=0.3)).unsqueeze(1) + train_data, 0, 1)
test_data_noisy = torch.clamp((NoiseGen.generate_gaussian_noise(32, 120, 92, mean=0, std=0.1) + NoiseGen.generate_binary_noise(32, 120, 92, magnitude=1, p=0.3)).unsqueeze(1) + test_data, 0, 1)

In [29]:
train_data.shape

torch.Size([128, 1, 120, 92])

In [None]:
fig, ax = plt.subplots(1, 2)
ax[0].imshow(train_data[0][0])
ax[1].imshow(train_data_noisy[0][0])

In [None]:
class LinearAutoencoder(beacon.Module):
    def __init__(self, input_features):
        super().__init__()
        # Encoder
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(input_features, 128),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(128, 16),
            torch.nn.ReLU()
        )
        # Decoder
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(16, 128),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(128, input_features),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [24]:
class ConvolutionAutoencoder(beacon.Module):
    def __init__(self):
        super().__init__()
        
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(1, 64, 3, stride=1),  # Output: (16, 60, 45)
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 16, 3, stride=1), # Output: (32, 30, 23)
            torch.nn.BatchNorm2d(16),
            torch.nn.ReLU(),
            torch.nn.Conv2d(16, 8, 3, stride=1), # Output: (64, 15, 12)
            torch.nn.BatchNorm2d(8),
            torch.nn.ReLU(),
            # torch.nn.Flatten(1),
            # torch.nn.Linear(64*114*84, 128),
            # torch.nn.ReLU(),
        )
        
        self.decoder = torch.nn.Sequential(
            # torch.nn.Linear(128, 64*114*84),
            # torch.nn.ReLU(),
            # torch.nn.Unflatten(1, (64, 114, 84)),
            torch.nn.ConvTranspose2d(8, 16, 3, stride=1), # Output: (32, 30, 23)
            torch.nn.BatchNorm2d(16),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(16, 64, 3, stride=1), # Output: (16, 60, 45)
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(64, 1, 3, stride=1), # Output: (1, 120, 90)
            torch.nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [26]:
summary(ConvolutionAutoencoder(), input_size=(1, 1, 120, 92))

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Conv2d: 2]

In [None]:
autoencoder = ConvolutionAutoencoder().to("mps")
loss_function = torch.nn.MSELoss()
optimiser = torch.optim.Adam(autoencoder.parameters(), lr=0.003)

In [None]:
for epoch in tqdm(range(100)):
    optimiser.zero_grad()
    outputs = autoencoder(train_data_noisy.to("mps"))
    loss = loss_function(outputs, train_data.to("mps"))
    loss.backward()
    optimiser.step()

In [None]:
autoencoder.eval()

with torch.no_grad():
    pred = autoencoder.to("cpu")(test_data_noisy[0].unsqueeze(0).to("cpu"))

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].imshow(test_data_noisy[0][0].to("cpu"))
ax[1].imshow(test_data[0][0].to("cpu"))
ax[2].imshow(pred[0][0])

ax[0].set_title("Noisy Input")
ax[1].set_title("Original Image")
ax[2].set_title("Reconstructed Image")