In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import datasets

import torchvision
import torchvision.transforms as transforms

import scipy.io as mlio
import numpy as np

import os
import argparse
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

## Try the vertical noises first

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

In [4]:
def split_train_val(train, val_split):
    train_len = int(len(train) * (1.0-val_split))
    train, val = torch.utils.data.random_split(
        train,
        (train_len, len(train) - train_len),
        generator=torch.Generator().manual_seed(42),
    )
    return train, val

In [204]:
res1 = 2048
res2 = 128
res = 256
class AddWaveTransform:
    def __call__(self, image):
        magnitude = 3
        frequency = 2.99433
        waves0 = magnitude * torch.rand(1) * torch.cos(frequency*torch.arange(res2)).unsqueeze(0)
        waves1 = magnitude * torch.rand(1) * torch.cos(frequency*torch.arange(res2)).unsqueeze(0)
        waves2 = magnitude * torch.rand(1) * torch.cos(frequency*torch.arange(res2)).unsqueeze(0)
        image[0, :, :] += waves0
        image[1, :, :] += waves1
        image[2, :, :] += waves2
        return image

In [None]:
# Define the path to store the CelebA dataset
data_dir = './data/CelebA'

# Define transformations for the dataset
transform = transforms.Compose([
    transforms.CenterCrop(178),  # Center crop to 178x178
    transforms.Resize((res1,res2)),      # Resize to 128x128
    transforms.ToTensor(),       # Convert to tensor
    AddWaveTransform(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize
    transforms.Lambda(lambda x: x.view(3, res1*res2)),
])

# Load the CelebA dataset
dataset = datasets.CelebA(root=data_dir, split='train', transform=transform, download=False)

# Create DataLoader for the dataset
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# Check the dataset
print(f'Number of samples: {len(dataset)}')

In [206]:
# Function to denormalize and plot images
from PIL import Image
def imshow(img):
    img = img / 2 + 0.5  # Unnormalize
    img = img.reshape(img.shape[:-1] + (res1, res2))
    img = torch.clamp(img, 0, 1)
    npimg = np.transpose(img.detach().numpy(), (1,2,0))
    print(npimg.shape)
    rescaled_image = Image.fromarray((npimg*225).astype(np.uint8)).resize((256, 256))
    rescaled_image.show()

In [None]:
image, label = dataset[2]
print(image.shape)
imshow(image)

In [185]:
import math
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange, repeat

from src.models.nn import DropoutNd

class S4DKernel_simple(nn.Module):
    """Generate convolution kernel from diagonal SSM parameters."""

    def __init__(self, d_model, N=1, lr=0.0001):
        super().__init__()
        H = d_model
        log_dt = torch.rand(H) * (
            math.log(1e-3) - math.log(1e-3)
        ) + math.log(1e-3)

        C = torch.randn(H, N, dtype=torch.cfloat)
        self.C = nn.Parameter(torch.view_as_real(C))
        self.register("log_dt", log_dt, 0)

        log_A_real = torch.log(0.5 * torch.ones(H, N))
        A_imag = math.pi * repeat(torch.arange(N) - N // 2, 'n -> h n', h=H) * 10
        self.register("log_A_real", log_A_real, lr)
        self.register("A_imag", A_imag, lr)

    def forward(self, L):
        """
        returns: (..., c, L) where c is number of channels (default 1)
        """

        # Materialize parameters
        dt = torch.exp(self.log_dt) # (H)
        C = torch.view_as_complex(self.C) # (H N)
        A = -torch.exp(self.log_A_real) + 1j * self.A_imag # (H N)

        # Vandermonde multiplication
        dtA = A * dt.unsqueeze(-1)  # (H N)
        K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device) # (H N L)
        C = C * (torch.exp(dtA)-1.) / A
        K = 2 * torch.einsum('hn, hnl -> hl', C, torch.exp(K)).real

        return K

    def register(self, name, tensor, lr=None):
        """Register a tensor with a configurable learning rate and 0 weight decay"""

        if lr == 0.0:
            self.register_buffer(name, tensor)
        else:
            self.register_parameter(name, nn.Parameter(tensor))

            optim = {"weight_decay": 0.0}
            if lr is not None: optim["lr"] = lr
            setattr(getattr(self, name), "_optim", optim)

L_image = res1*res2
class S4D_simple(nn.Module):
    def __init__(self, d_state = 512, L = L_image, d_output = 3, dropout=0.0, transposed=True, **kernel_args):
        super().__init__()

        self.n = d_state
        self.d_output = d_output
        self.d_model = 4
        self.transposed = transposed
        self.D = nn.Parameter(torch.randn(1))
        self.encoder = nn.Linear(3, self.d_model)
        self.decoder = nn.Linear(self.d_model, d_output)

        # SSM Kernel
        self.kernel = S4DKernel_simple(self.d_model, N=self.n, **kernel_args)

        # Pointwise
        self.activation = nn.GELU()
        # dropout_fn = nn.Dropout2d # NOTE: bugged in PyTorch 1.11
        dropout_fn = DropoutNd
        self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()

        # position-wise output transform to mix features
        self.output_linear = nn.Sequential(
            # nn.Conv1d(self.d_model, 2*self.d_model, kernel_size=1),
            nn.GELU(),
        )

    def forward(self, u, **kwargs): # absorbs return_output and transformer src mask
        """ Input and output shape (B, H, L) """
        if not self.transposed: u = u.transpose(-1, -2)
        L = u.size(-1)
        u = u.transpose(-1,-2) # (B L 3)
        u = self.encoder(u) # (B L H)
        u = u.transpose(-1,-2) # (B H L)

        # Compute SSM Kernel
        k = self.kernel(L=L) # (H L)

        # Convolution
        k = nn.functional.pad(k,(0,L),'constant',0)
        u = nn.functional.pad(u,(0,L),'constant',0)
        k_f = torch.fft.fft(k) # (H L)
        u_f = torch.fft.fft(u) # (B H L)
        
        y = torch.fft.ifft(u_f*k_f).real # (B H L)
        y = y[...,0:L]
        u = u[...,0:L]
        
        y = y.transpose(-1,-2) # (B L H)
        y = self.decoder(y) # (B L d_output)
        y = y.transpose(-1,-2) # (B d_output L)

        return y, None # Return a dummy state to satisfy this repo's interface, but this can be modified

In [187]:
model = S4D_simple()
model = model.to('cuda')

In [188]:
checkpoint = torch.load('./checkpoint/ckpt.pth')
model.load_state_dict(checkpoint['model'])
model = model.to('cpu')

In [None]:
@torch.no_grad()
def plot_pics():
    pbar = tqdm(enumerate(dataloader))
    for batch_idx, (inputs, targets) in pbar:
        outputs, _ = model(inputs)
        for i in range(32):
            print('Blurred Images')
            imshow(inputs[i,:,:])
            print('Output of the SSM')
            imshow(outputs[i,:,:])
        break

plot_pics()

## Try the horizontal noises

In [134]:
res1 = 2048
res2 = 128
res = 256
class AddWaveTransform:
    def __call__(self, image):
        magnitude = 5
        frequency = 2.99433
        waves0 = magnitude * torch.rand(1) * torch.cos(frequency*torch.arange(res1)).unsqueeze(1)
        waves1 = magnitude * torch.rand(1) * torch.cos(frequency*torch.arange(res1)).unsqueeze(1)
        waves2 = magnitude * torch.rand(1) * torch.cos(frequency*torch.arange(res1)).unsqueeze(1)
        image[0, :, :] += waves0
        image[1, :, :] += waves1
        image[2, :, :] += waves2
        return image

In [None]:
# Define the path to store the CelebA dataset
data_dir = './data/CelebA'

# Define transformations for the dataset
transform = transforms.Compose([
    transforms.CenterCrop(178),  # Center crop to 178x178
    transforms.Resize((res1,res2)),      # Resize to 128x128
    transforms.ToTensor(),       # Convert to tensor
    AddWaveTransform(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize
    transforms.Lambda(lambda x: x.view(3, res1*res2)),
])

# Load the CelebA dataset
dataset = datasets.CelebA(root=data_dir, split='train', transform=transform, download=False)

# Create DataLoader for the dataset
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# Check the dataset
print(f'Number of samples: {len(dataset)}')

In [None]:
@torch.no_grad()
def plot_pics():
    pbar = tqdm(enumerate(dataloader))
    for batch_idx, (inputs, targets) in pbar:
        outputs, _ = model(inputs)
        for i in range(32):
            print('Blurred Images')
            imshow(inputs[i,:,:])
            print('Output of the SSM')
            imshow(outputs[i,:,:])
        break

plot_pics()

# For some rigirous numbers

In [None]:
from torch.utils.data import Dataset

class AddWaveTransform:
    def __call__(self, image):
        waves0 = 0.5 * torch.rand(1) * torch.cos(2.99433*torch.arange(image.size(2))).unsqueeze(0)
        waves1 = 0.5 * torch.rand(1) * torch.cos(2.99433*torch.arange(image.size(2))).unsqueeze(0)
        waves2 = 0.5 * torch.rand(1) * torch.cos(2.99433*torch.arange(image.size(2))).unsqueeze(0)
        image[0, :, :] += waves0
        image[1, :, :] += waves1
        image[2, :, :] += waves2
        return image

class AddWaveTransform_Horizontal:
    def __call__(self, image):
        waves0 = 0.5 * torch.rand(1) * torch.cos(2.99433*torch.arange(image.size(1))).unsqueeze(1)
        waves1 = 0.5 * torch.rand(1) * torch.cos(2.99433*torch.arange(image.size(1))).unsqueeze(1)
        waves2 = 0.5 * torch.rand(1) * torch.cos(2.99433*torch.arange(image.size(1))).unsqueeze(1)
        image[0, :, :] += waves0
        image[1, :, :] += waves1
        image[2, :, :] += waves2
        return image

# Define the path to store the CelebA dataset
data_dir = './data/CelebA'

res1 = 2048
res2 = 128

# Define transformations for the dataset
transform = transforms.Compose([
    transforms.CenterCrop(178),  # Center crop to 178x178
    transforms.Resize((res1,res2)),      # Resize to 128x128
    transforms.ToTensor(),       # Convert to tensor
    #AddWaveTransform(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize
    transforms.Lambda(lambda x: x.view(3, res1*res2)),
])

# Define the two transforms
transform1 = transforms.Compose([
    transforms.CenterCrop(178),  # Center crop to 178x178
    transforms.Resize((res1,res2)),      # Resize to 128x128
    transforms.ToTensor(),       # Convert to tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize
    transforms.Lambda(lambda x: x.view(3, res1*res2)),
])

transform2 = transforms.Compose([
    transforms.CenterCrop(178),  # Center crop to 178x178
    transforms.Resize((res1,res2)),      # Resize to 128x128
    transforms.ToTensor(),       # Convert to tensor
    AddWaveTransform(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize
    transforms.Lambda(lambda x: x.view(3, res1*res2)),
])

transform3 = transforms.Compose([
    transforms.CenterCrop(178),  # Center crop to 178x178
    transforms.Resize((res1,res2)),      # Resize to 128x128
    transforms.ToTensor(),       # Convert to tensor
    AddWaveTransform_Horizontal(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize
    transforms.Lambda(lambda x: x.view(3, res1*res2)),
])

class DualTransformDataset(Dataset):
    def __init__(self, dataset, transform1=None, transform2=None, transform3=None):
        self.dataset = dataset
        self.transform1 = transform1
        self.transform2 = transform2
        self.transform3 = transform3

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        input_image_vertical = self.transform2(image)
        input_image_horizontal = self.transform3(image)
        true_image = self.transform1(image)
        return input_image_vertical, input_image_horizontal, true_image, label

# Load the CelebA dataset
dataset = datasets.CelebA(root=data_dir, split='train', download=False)

# Create the dual transform dataset
dual_transform_dataset = DualTransformDataset(dataset, transform1=transform1, transform2=transform2, transform3=transform3)

# Create a DataLoader
dataloader = DataLoader(dual_transform_dataset, batch_size=32, shuffle=True, num_workers=4)

# Check the dataset
print(f'Number of samples: {len(dataset)}')

In [None]:
total = 0
total_vertical = 0
total_horizontal = 0
total_noises_vertical = 0
total_noises_horizontal = 0
criterion = nn.MSELoss()
model = model.to('cuda')
pbar = tqdm(enumerate(dataloader))
for batch_idx, (input_image_vertical, input_image_horizontal, true_image, targets) in pbar:
    input_image_vertical, input_image_horizontal, true_image = input_image_vertical.to('cuda'), input_image_horizontal.to('cuda'), true_image.to('cuda')
    
    foutputs, _ = model(input_image_vertical)
    loss = criterion(foutputs, true_image)
    total_vertical += loss.item()

    foutputs, _ = model(input_image_horizontal)
    loss = criterion(foutputs, true_image)
    total_horizontal += loss.item()
    
    total_noises_vertical += criterion(input_image_vertical, true_image).item()
    total_noises_horizontal += criterion(input_image_horizontal, true_image).item()
    
    total += targets.size(0)
    
    pbar.set_description(
        'Batch Idx: (%d/%d) | Loss_V: %.4f | Noise_V: %.4f | Loss_H: %.4f | Noise_H: %.4f' %
        (batch_idx, len(dataloader), total_vertical/(batch_idx+1), total_noises_vertical/(batch_idx+1), total_horizontal/(batch_idx+1), total_noises_horizontal/(batch_idx+1))
    )

print('noises level = ' + str(total_noises / (batch_idx + 1)) + 'loss = ' + str(total_loss / (batch_idx + 1)))