In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torchmetrics import Accuracy
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

from PIL import Image, ImageOps
import numpy as np

In [115]:
class ImageDataset(Dataset):
    def __init__(self, image_paths, grid_transform=None, image_transform=None):
        self.image_paths = image_paths.tolist()
        self.image_transform = image_transform
        self.grid_transform = grid_transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Read the image
        image_path = self.image_paths[idx]
        image = ImageOps.exif_transpose(Image.open(image_path).convert('RGB'))
        if self.image_transform:
            image = self.image_transform(image)
        
        # Create grid
        _, height, width= image.shape
        yx_grid = self._create_yx_grid(grid_size=(height, width))

        assert len(yx_grid.shape) == len(image.shape)

        if self.grid_transform is not None:
            return self.grid_transform(yx_grid), image
        else:
            return yx_grid, image

    @staticmethod
    def _create_yx_grid(grid_size):
        """
        Creates mesh grid of normalised pixel coordinates based on matrix indexing convention
        """
        h, w = grid_size
        coords_i = np.linspace(0, 1, h, endpoint=False)
        coords_j = np.linspace(0, 1, w, endpoint=False)
        grid = torch.from_numpy(np.stack(np.meshgrid(coords_i, coords_j, indexing='ij'), axis=-1).astype('float32'))
        return grid

In [116]:
class ImageDataModule(pl.LightningDataModule):
    def __init__(self, hr_path, grid_transform, n_splits=0):
        super().__init__()
        self.hr_path = hr_path
        self.batch_size = 1
        self.num_workers = 1
        self.grid_transform = grid_transform
        self.fold = 0

    def setup(self, stage=None):
        # Image transformations
        self.image_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        
        self.train_dataset = ImageDataset(self.hr_path, self.grid_transform, self.image_transforms)

    def set_fold(self, fold):
        self.fold = fold
        self.setup()  # Re-setup the datasets for the new fold

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self):
        return
    
    def test_dataloader(self):
        return

In [1]:
# Fourier Encoded Features
class FourierEncoding(object):
    def __init__(self, num_input_channels: int, mapping_size: int = 256, scale: int = 10):
        super(FourierEncoding, self).__init__()
        self._num_input_channels = num_input_channels
        self._mapping_size = mapping_size
        self._B = np.random.randn(num_input_channels, mapping_size) * scale

    def __call__(self, x: np.ndarray) -> np.ndarray:
        assert len(x.shape) == 3, 'Expected 3D input (got {}D input)'.format(len(x.shape))
        height, width, channels = x.shape
        assert channels == self._num_input_channels, \
            "Expected input to have {} channels (got {} channels)".format(self._num_input_channels, channels)
        # Make shape compatible for matmul with _B.
        # From [H, W, C] to [(H*W), C].
        x = x.reshape(height * width, channels)
        x = x @ self._B
        # From [(H*W), C] to [H, W, C]
        x = x.reshape(height, width, self._mapping_size)
        x = 2 * np.pi * x
        return np.concatenate((np.sin(x), np.cos(x)), axis=-1).astype('float32')

NameError: name 'np' is not defined

In [2]:
# Wavelet Encoded Features
class WaveletEncoding(object):
    def __init__(self):
        super(WaveletEncoding, self).__init__()
        
    def __call__(self, x: np.ndarray) -> np.ndarray:
        pass

NameError: name 'np' is not defined

In [134]:
class NerfModel(nn.Module):
    def __init__(self, input_shape, output_dim: int = 3, num_layers: int = 4, num_channels: int = 256):      
        super(NerfModel, self).__init__()
        self.num_layers = num_layers
        self.input_channels = input_shape[0]
        self.conv_layers = nn.ModuleList()

        # Create the layers
        for i in range(num_layers - 1):
            self.conv_layers.append(nn.Conv2d(self.input_channels if i == 0 else num_channels, num_channels, kernel_size=1, padding=0))
            self.conv_layers.append(nn.BatchNorm2d(num_channels))
        
        # Output layer
        self.output_layer = nn.Conv2d(num_channels, output_dim, kernel_size=1, padding=0)

    def forward(self, x):
        for i in range(self.num_layers - 1):
            x = F.relu(self.conv_layers[2*i](x))
            x = self.conv_layers[2*i+1](x)
        x = torch.sigmoid(self.output_layer(x))
        return x

In [135]:
class NerfTrainer(pl.LightningModule):
    def __init__(self, input_shape, output_dim: int = 3, num_layers: int = 4, num_channels: int = 256, learning_rate: float = 1e-3):
        super(NerfTrainer, self).__init__()
        self.model = NerfModel(input_shape, output_dim, num_layers, num_channels)
        self.loss_fn = torch.nn.MSELoss()
        self.accuracy = Accuracy()
        self.learning_rate = learning_rate

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        acc = self.accuracy(y_hat, y)
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        acc = self.accuracy(y_hat, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        acc = self.accuracy(y_hat, y)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [136]:
# # Load the original image
# original_image = ImageOps.exif_transpose(Image.open('../data/hr_ground_truth.JPG'))

# # Calculate the new resolution as 70% of the original size
# new_resolution = (int(original_image.width * 0.7), int(original_image.height * 0.7))

# # Resize the image to the new resolution
# low_res_image = original_image.resize(new_resolution)
# low_res_image.rotate(-90)

# # Save the low-resolution image
# low_res_image.save('path_to_save_low_res_image.jpg')

# print(f"Low-resolution image saved as 'path_to_save_low_res_image.jpg'")


In [139]:
# Grid Transform
grid_transform = FourierEncoding(num_input_channels=2, mapping_size=128, scale=10)

# Inputs
image_inputs = np.array([
    '../data/low_res_train.jpg'
])

# Data Module Setup
data_module = ImageDataModule(image_inputs, grid_transform)
data_module.set_fold(0)

# Trainer Setup
early_stopping = EarlyStopping(monitor='val_loss', patience=10, mode='min')
lr_monitor = LearningRateMonitor(logging_interval='step')

In [140]:

# Get the first batch
first_batch = next(iter(data_module.train_dataloader()))

# If the DataLoader returns a tuple (inputs, labels)
inputs, labels = first_batch

print(inputs.shape)  # Example: torch.Size([32, 3, 224, 224]) for a batch of 32 RGB images of size 224x224
print(labels.shape)  # Example: torch.Size([32]) for a batch of 32 labels

In [None]:
# Model Setup
model = NerfModel(input_shape, output_dim=3, num_layers=4, num_channels=256)

# Callbacks
checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath=f'{args.save_dir}/fold{fold}', filename=f'model-fold{fold}-{{epoch:02d}}-{{val_loss:.2f}}', save_top_k=1, mode='min')

trainer = pl.Trainer(
    max_epochs=100,
    callbacks=[early_stopping, checkpoint_callback, lr_monitor],
    accelerator='gpu',
    devices=1
)

# Training and Testing

trainer.fit(model, datamodule=data_module)

In [None]:
# Metrics

In [None]:
# Output