In [1]:
import h5py
import torch
from torch.utils.data import Dataset, DataLoader

class SuperResolutionDataset(Dataset):
    def __init__(self, h5_file):
        super(SuperResolutionDataset, self).__init__()
        with h5py.File(h5_file, 'r') as f:
            self.data = f['data'][:]
            self.label = f['label'][:]

    def __len__(self):
        return self.data.shape[3]

    def __getitem__(self, idx):
        lr_patch = self.data[:, :, :, idx]
        hr_patch = self.label[:, :, :, idx]
        return (torch.tensor(lr_patch, dtype=torch.float32).permute(2, 0, 1), 
                torch.tensor(hr_patch, dtype=torch.float32).permute(2, 0, 1))

# Example usage
h5_file = 'train.h5'
dataset = SuperResolutionDataset(h5_file)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)


In [2]:
import torch.nn as nn

class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=2)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

# Example usage
model = SRCNN()


In [3]:
import torch.optim as optim

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 20
for epoch in range(num_epochs):
    for lr_patch, hr_patch in dataloader:
        # Move data to the GPU if available
        lr_patch, hr_patch = lr_patch.cuda(), hr_patch.cuda()
        
        # Forward pass
        outputs = model(lr_patch)
        loss = criterion(outputs, hr_patch)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

print('Training complete.')


IndexError: index 24 is out of bounds for axis 3 with size 21