In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from transformers import ViTMAEModel, ViTMAEConfig

In [None]:
class GalaxyDataset(Dataset):
    def __init__(self, input_dir, target_dir = None, transform = None):
        self.input_files = sorted([os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.npy')])
        self.target_files = sorted([os.path.join(target_dir, f) for f in os.listdir(target_dir) if f.endswith('.npy')]) if target_dir else None
        self.transform = transform

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        input_data = np.load(self.input_files[idx])
        input_tensor = torch.Tensor(input_data)

        # If target_dir is provided, load the target, otherwise return only the input
        if self.target_files:
            target_data = np.load(self.target_files[idx])
            target_tensor = torch.Tensor(target_data)
            return input_tensor, target_tensor  # Return input and target
        else:
            return input_tensor  # Only return input when no target_dir is provided

In [None]:
def preprocess_target_patches_batch(batch_tensor, patch_size = 16):
    """
    Converts a batch of 2D target tensors into patches suitable for ViT.
    Args:
        batch_tensor (Tensor): Input tensor of shape (batch_size, H, W).
        patch_size (int): Size of each patch.
    Returns:
        patches (Tensor): Tensor of shape (batch_size, num_patches, patch_size, patch_size).
    """
    batch_size, h, w = batch_tensor.shape
    assert h == w == 128, "Target image must be 128x128"

    # Unfold (slice) the height and width dimensions to create patches
    patches = batch_tensor.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)  # (batch_size, num_patches_h, num_patches_w, patch_size, patch_size)

    # Rearrange the patches into the required shape
    patches = patches.permute(0, 1, 2, 3, 4).reshape(batch_size, -1, patch_size, patch_size)  # Shape: (batch_size, num_patches, patch_size, patch_size)

    return patches

In [None]:
class WeakLensingViT(nn.Module):
    def __init__(self, patch_size = 16):
        super(WeakLensingViT, self).__init__()

        # Load a pretrained ViT model
        config = ViTMAEConfig(image_size = 128, patch_size = patch_size, num_channels = 3, hidden_size = 768)
        self.vit = ViTMAEModel(config = config)

        # Modify the output layer to output 256 values (16x16) for each patch
        self.reconstruction_head = nn.Linear(in_features = 768, out_features = patch_size*patch_size)

    def forward(self, patches):
        # Pass the patches through the ViT model
        vit_output = self.vit(patches).last_hidden_state  # Shape: (num_patches + 1, 768)

        # Reconstruct the 2D map (16x16 = 256 values per patch)
        map_output = self.reconstruction_head(vit_output)  # Output shape: (num_patches, 256)
        map_output = map_output.view(-1, 64, 16, 16)  # Reshape to (num_patches, 16, 16) for each patch

        return map_output

In [None]:
# Instantiate the dataset and data loader

input_dir = '1/EPSILON'
target_dir = '1/KAPPA'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Use the dataset class
train_dataset = GalaxyDataset(input_dir = input_dir, target_dir = target_dir)

# Use more workers for faster parallel data loading
train_loader = DataLoader(train_dataset, batch_size = 128, shuffle = True, num_workers = 8, pin_memory = True)

# Initialize the model, optimizer
model = WeakLensingViT(patch_size = 16)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr = 1e-4)
loss = nn.HubberLoss()

# Training loop
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for inputs, targets in train_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)

        targets_patches = preprocess_target_patches_batch(targets.squeeze())  # Shape: (num_patches, 16, 16)

        # Zero gradients
        optimizer.zero_grad()

        outputs = model(inputs)

        loss = loss(outputs, targets_patches)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss/len(train_loader)}")

In [None]:
torch.save(model.state_dict(), 'model.pth')

In [None]:
def assemble_patches(patches, original_size = (128, 128), patch_size = 16):
    """
    Reassemble patches into the original full-sized image.
    Args:
        patches (Tensor): Patch tensor of shape (num_patches, patch_size, patch_size).
        original_size (tuple): The original size of the image (H, W).
        patch_size (int): The size of each patch.
    Returns:
        full_map (Tensor): The reassembled image tensor.
    """
    num_patches = patches.shape[1]
    h, w = original_size
    assert num_patches == (h//patch_size)*(w//patch_size), "Mismatch between patches and original size"

    full_map = patches.view(h//patch_size, w//patch_size, patch_size, patch_size)
    full_map = full_map.permute(0, 2, 1, 3).reshape(h, w)  # Reassemble into (H, W)

    return full_map

In [None]:
# Load the test dataset
test_input_dir = 'EPSILON'
output_dir = 'tierraplana'

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

test_dataset = GalaxyDataset(input_dir = test_input_dir)
test_loader = DataLoader(test_dataset, batch_size = 1, shuffle = False)

# Evaluate the model
model.eval()
with torch.no_grad():
    for i, inputs in enumerate(test_loader):  # Only retrieve inputs since no target
        inputs = inputs.to(device)

        # Forward pass
        outputs = model(inputs)

        # Reassemble the patches back into a full image
        full_output = assemble_patches(outputs.cpu(), original_size = (128, 128), patch_size = 16)

        # Convert the output to float16 before saving
        full_output_float16 = full_output.numpy().astype(np.float16)

        # Get the original filename from the test folder
        input_filename = test_dataset.input_files[i]
        base_filename = os.path.basename(input_filename)  # Extract the filename from the full path

        # Save the output as .npy with the same filename in the output directory in float16 format
        output_filepath = os.path.join(output_dir, base_filename)
        np.save(output_filepath, full_output_float16)  # Save the full output as .npy in float16