In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchsr.datasets import Div2K  # Ensure the correct import path for your Div2K dataset

class SineLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        self.init_weights()
        
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features)
            else:
                limit = np.sqrt(6 / self.in_features) / self.omega_0
                self.linear.weight.uniform_(-limit, limit)
                
    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))

class SimpleSiren(nn.Module):
    def __init__(self, coord_dim, num_images, hidden_features, hidden_layers, out_features, omega_0=30):
        super().__init__()
        input_dim = coord_dim + num_images
        layers = [SineLayer(input_dim, hidden_features, is_first=True, omega_0=omega_0)]
        for _ in range(hidden_layers):
            layers.append(SineLayer(hidden_features, hidden_features, is_first=False, omega_0=omega_0))
        layers.append(nn.Linear(hidden_features, out_features))
        self.net = nn.Sequential(*layers)

    def forward(self, coords, image_ids_one_hot):
        batch_size, N, _ = coords.shape
        # Expand the one-hot encoding along the spatial/coordinate dimension
        image_ids_expanded = image_ids_one_hot.unsqueeze(1).expand(-1, N, -1)
        input_concat = torch.cat([coords, image_ids_expanded], dim=-1)
        input_flat = input_concat.view(batch_size * N, -1)
        output = self.net(input_flat)
        output = output.view(batch_size, N, -1)
        return output

class Div2KWrapper(Dataset):
    def __init__(self, div2k_dataset, num_images, img_size=64):
        self.dataset = div2k_dataset
        self.num_images = num_images
        self.img_size = img_size

        # Pre-compute a coordinate grid for the target image size.
        xs = torch.linspace(-1, 1, steps=self.img_size)
        ys = torch.linspace(-1, 1, steps=self.img_size)
        grid_x, grid_y = torch.meshgrid(xs, ys, indexing='ij')
        self.coords = torch.stack([grid_x.reshape(-1), grid_y.reshape(-1)], dim=-1)
        
        # Transform to convert PIL images to tensors.
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Get image from the underlying Div2K dataset.
        img = self.dataset[idx][0]
        # If the returned image is not a tensor (e.g. PIL Image), convert it.
        if not torch.is_tensor(img):
            img = self.to_tensor(img)
            
        # If the tensor is [C, H, W], unsqueeze to create a batch dimension for interpolation.
        if img.dim() == 3 and img.shape[0] == 3:
            img = img.unsqueeze(0)  # Now [1, 3, H, W]
        
        # Resize image using bilinear interpolation.
        img = torch.nn.functional.interpolate(img, size=(self.img_size, self.img_size),
                                              mode='bilinear', align_corners=False).squeeze(0)
        # Permute from [C, H, W] to [H, W, C] for later processing and visualization.
        img = img.permute(1, 2, 0)
        # Rescale image values from [0, 1] to [-1, 1].
        img = img * 2 - 1
        img = img.view(-1, img.shape[-1])
        
        # Create one-hot encoding for the image.
        one_hot = torch.zeros(self.num_images)
        one_hot[idx % self.num_images] = 1.0
        
        return one_hot, self.coords, img

num_images = 5
img_size = 64
coord_dim = 2
hidden_features = 128
hidden_layers = 3
out_features = 3
learning_rate = 1e-4
total_steps = 1000

div2k_dataset = Div2K(root="./data", scale=2, download=False)  # Set download=True if needed
dataset = Div2KWrapper(div2k_dataset, num_images, img_size=img_size)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

model = SimpleSiren(coord_dim, num_images, hidden_features, hidden_layers, out_features)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training Loop
for step in range(total_steps):
    for one_hot, coords, ground_truth in dataloader:
        one_hot = one_hot.float()
        coords = coords.float()
        ground_truth = ground_truth.float()
        
        pred = model(coords, one_hot)
        loss = ((pred - ground_truth) ** 2).mean()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    if step % 100 == 0:
        print(f"Step {step}, Loss: {loss.item()}")
        with torch.no_grad():
            # Reshape output to image dimensions.
            output_img = pred[0].view(img_size, img_size, 3)
            output_img = torch.clamp(output_img, -1, 1) * 0.5 + 0.5  # Map from [-1, 1] to [0, 1]
            plt.imshow(output_img.cpu().numpy())
            plt.title(f"Step: {step}")
            plt.show()


OSError: image file is truncated