### DWT Preprocessing

In [None]:
!pip install pywavelets
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import numpy as np
from PIL import Image
import pywt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.optim as optim

In [None]:
def dwt_transform(image):
    coeffs2 = pywt.dwt2(image, 'haar')  # DWT using Haar wavelet
    cA, (cH, cV, cD) = coeffs2
    return np.stack([cA, cH, cV, cD], axis=1)  # Stacked into a single 4-channel array

# Preprocessing
def preprocess_and_save_image(image_path, output_path):
    image = Image.open(image_path).convert('L')
    image = np.array(image)
    dwt_image = dwt_transform(image)
    np.save(output_path, dwt_image)  # Saved as .npy

original_image_folder = '/content/drive/MyDrive/UBI_images/amplitude_dynamic'
npy_output_folder = '/content/drive/MyDrive/UBI_images/train'

if not os.path.exists(npy_output_folder):
    os.makedirs(npy_output_folder)

for image_file in os.listdir(original_image_folder):
    if image_file.endswith('.jpg') or image_file.endswith('.png'):
        input_image_path = os.path.join(original_image_folder, image_file)
        output_npy_path = os.path.join(npy_output_folder, f"{os.path.splitext(image_file)[0]}.npy")
        preprocess_and_save_image(input_image_path, output_npy_path)

print(f"Processed and saved {len(os.listdir(npy_output_folder))} images as .npy files.")

### Autoencoder Architecture

In [None]:
class DWT_Autoencoder(nn.Module):
    def __init__(self):
        super(DWT_Autoencoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(4, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 4, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

        # Upsample layer to match output dimensions
        self.upsample = nn.Upsample(size=(90, 90), mode = 'bilinear', align_corners=False)

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        decoded = self.upsample(decoded)
        return decoded

dummy_input = torch.randn(16, 4, 90, 90) # Batch of 16, 4 channels, 90x90 images
model = DWT_Autoencoder()

output = model(dummy_input)
print(f"Output size: {output.size()}") # Check

### Training Data Preparation

In [None]:
class DWTAutoencoderDataset(Dataset):
    def __init__(self, image_folder, normalize=True):
        self.image_folder = image_folder
        self.image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.endswith('.npy')]
        self.normalize = normalize

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
      image_path = self.image_paths[idx]
      if not image_path.endswith('.npy'):
        dwt_image = preprocess_and_save_image(image_path)
        np.save(image_path.replace('.png', '.npy'), dwt_image)
      else:
        dwt_image = np.load(image_path, allow_pickle=True)

      if self.normalize:
        dwt_image = dwt_image / np.max(dwt_image)  # Normalize to [0, 1]
      return torch.tensor(dwt_image, dtype=torch.float32)

    def preprocess_image(self, image_path):
        dwt_image = np.load(image_path, allow_pickle=True)
        return dwt_image

image_folder = '/content/drive/MyDrive/UBI_images/train'
train_dataset = DWTAutoencoderDataset(image_folder, normalize=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)

# Check
print(f"Number of images: {len(train_dataset)}")

### Training the Autoencoder

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DWT_Autoencoder().to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
# criterion = nn.MSELoss() # Mean-Squared Loss
criterion = nn.SmoothL1Loss()  # Huber Loss

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for data in train_loader:

        data = data.permute(0, 2, 1, 3)
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, data)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    print(f"Epoch {epoch + 1}, Loss: {epoch_loss / len(train_loader):.4f}")

torch.save(model.state_dict(), "dwt_autoencoder.pth")

### Reconstruction with DWT

In [None]:
import matplotlib.pyplot as plt
import torch
import pywt

def inverse_dwt(coeffs):
    cA, (cH, cV, cD) = coeffs
    return pywt.idwt2((cA, (cH, cV, cD)), 'haar')

model.eval()
with torch.no_grad():
    for data in train_loader:

        print(f"Raw data shape: {data.shape}")
        data = data.permute(0, 2, 1, 3)
        print(f"Permuted data shape: {data.shape}")
        data = torch.clamp(data, min=0, max=1)

        reconstructed_dwt = model(data)
        print(f"Reconstructed DWT Shape: {reconstructed_dwt.shape}") # Checking

        reconstructed_dwt = reconstructed_dwt.cpu().numpy()

        reconstructed_images = []

        for i in range(reconstructed_dwt.shape[0]):
            coeffs = (reconstructed_dwt[i, 0], (reconstructed_dwt[i, 1], reconstructed_dwt[i, 2], reconstructed_dwt[i, 3]))
            reconstructed_image = inverse_dwt(coeffs)
            reconstructed_images.append(reconstructed_image)

        # Convert reconstructed images list to a NumPy array for visualization
        reconstructed_images = np.array(reconstructed_images)

        # Visualization (for the first image in the batch)
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        ax[0].imshow(data[4].cpu().numpy().transpose(1, 2, 0))
        ax[0].set_title("Original Image")
        ax[0].axis('off')

        ax[1].imshow(reconstructed_images[4], cmap='gray')
        ax[1].set_title("Reconstructed Image")
        ax[1].axis('off')

        plt.show()
        break

In [None]:
print("Reconstructed DWT Coefficients Shape: ", reconstructed_dwt.shape)