In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
from torchvision.utils import make_grid

In [2]:
# Hyperparameters
num_epochs = 10
batch_size = 16
learning_rate = 0.001
tolerance = 0.05

# Load STL-10 Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])

trainset = torchvision.datasets.STL10(root='./data', split='train', download=True, transform=transform)
testset = torchvision.datasets.STL10(root='./data', split='test', download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./data/stl10_binary.tar.gz


100%|██████████| 2.64G/2.64G [08:55<00:00, 4.93MB/s]


Extracting ./data/stl10_binary.tar.gz to ./data
Files already downloaded and verified


In [28]:
# UNet Model (Standard Version)
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        # Encoder
        self.enc1 = self.conv_block(3, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)

        # Bottleneck
        self.bottleneck = self.conv_block(512, 1024)

        # Decoder
        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = self.conv_block(1024, 512)
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = self.conv_block(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = self.conv_block(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = self.conv_block(128, 64)

        # Output layer
        self.conv_final = nn.Conv2d(64, 3, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder path
        enc1 = self.enc1(x)
        enc2 = self.enc2(nn.functional.max_pool2d(enc1, kernel_size=2))
        enc3 = self.enc3(nn.functional.max_pool2d(enc2, kernel_size=2))
        enc4 = self.enc4(nn.functional.max_pool2d(enc3, kernel_size=2))

        # Bottleneck
        bottleneck = self.bottleneck(nn.functional.max_pool2d(enc4, kernel_size=2))

        # Decoder path
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((enc4, dec4), dim=1)
        dec4 = self.dec4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((enc3, dec3), dim=1)
        dec3 = self.dec3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((enc2, dec2), dim=1)
        dec2 = self.dec2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((enc1, dec1), dim=1)
        dec1 = self.dec1(dec1)

        # Final output layer
        return torch.sigmoid(self.conv_final(dec1))

In [29]:
# Gaussian Blur Functions
def apply_gaussian_blur(img, sigma):
    img = img.cpu().numpy().transpose(1, 2, 0)
    blurred = gaussian_filter(img, sigma=(sigma, sigma, 0))
    return torch.from_numpy(blurred.transpose(2, 0, 1))


In [30]:
# Training and Testing Blur Matrices
h1_sigma = 1
h2_sigma = 0.5

In [31]:
# Initialize Model, Loss, and Optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [32]:
# PSNR Calculation
def calculate_psnr(original, restored):
    mse = torch.mean((original - restored) ** 2)
    if mse.item() == 0:
        return float('inf')
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

In [33]:








# Training Loop
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for i, (inputs, _) in enumerate(trainloader):
        mini_batch_psnr = 0.0
        inputs = inputs.to(device)
        blurred_inputs = torch.stack([apply_gaussian_blur(img, h1_sigma) for img in inputs]).to(device)

        # Forward + Backward + Optimize
        outputs = model(blurred_inputs)
        loss = criterion(outputs, inputs)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # PSNR Calculation
        psnr = calculate_psnr(inputs, outputs).item()
        mini_batch_psnr += psnr
        print(f'Mini-batch {i+1}, PSNR: {psnr:.2f} dB')


        train_loss += loss.item()
        total += inputs.size(0)
        correct += torch.sum(torch.abs(outputs - inputs) < tolerance).item()

    train_accuracy = 100 * correct / (total * inputs[0].numel())
    train_psnr = mini_batch_psnr / len(trainloader)
    train_loss /= len(trainloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Train PSNR: {train_psnr:.2f} dB')

    # Testing Loop
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, _ in testloader:
            inputs = inputs.to(device)
            blurred_inputs_h2 = torch.stack([apply_gaussian_blur(img, h2_sigma) for img in inputs]).to(device)
            outputs = model(blurred_inputs_h2)
            loss = criterion(outputs, inputs)

            # PSNR Calculation
            psnr = calculate_psnr(inputs, outputs).item()


            test_loss += loss.item()
            total += inputs.size(0)
            correct += torch.sum(torch.abs(outputs - inputs) < tolerance).item()

    test_loss1 = 0.0
    correct1 = 0
    total1 = 0
    with torch.no_grad():
        for inputs, _ in testloader:
            inputs = inputs.to(device)
            blurred_inputs_h1 = torch.stack([apply_gaussian_blur(img, h1_sigma) for img in inputs]).to(device)
            outputs = model(blurred_inputs_h1)
            loss = criterion(outputs, inputs)

            # PSNR Calculation
            psnr1 = calculate_psnr(inputs, outputs).item()


            test_loss1 += loss.item()
            total1 += inputs.size(0)
            correct1 += torch.sum(torch.abs(outputs - inputs) < tolerance).item()


    test_accuracy = 100 * correct / (total * inputs[0].numel())
    test_psnr = psnr / len(testloader)
    test_loss /= len(testloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], for h2:{h2_sigma}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%, Test PSNR: {test_psnr:.2f} dB')

    test_accuracy1 = 100 * correct1 / (total1 * inputs[0].numel())
    test_psnr1 = psnr1 / len(testloader)
    test_loss1 /= len(testloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], for h1:{h1_sigma}, Test Loss: {test_loss1:.4f}, Test Accuracy: {test_accuracy1:.2f}%, Test PSNR: {test_psnr1:.2f} dB')

# Save the Model
torch.save(model.state_dict(), 'unet_deblur_stl10.pth')

Mini-batch 1, PSNR: 11.83 dB
Mini-batch 2, PSNR: 10.96 dB
Mini-batch 3, PSNR: 11.08 dB
Mini-batch 4, PSNR: 11.92 dB
Mini-batch 5, PSNR: 11.18 dB
Mini-batch 6, PSNR: 12.57 dB
Mini-batch 7, PSNR: 12.67 dB
Mini-batch 8, PSNR: 12.40 dB
Mini-batch 9, PSNR: 14.05 dB
Mini-batch 10, PSNR: 14.26 dB
Mini-batch 11, PSNR: 17.03 dB
Mini-batch 12, PSNR: 15.04 dB
Mini-batch 13, PSNR: 15.94 dB
Mini-batch 14, PSNR: 17.65 dB
Mini-batch 15, PSNR: 16.42 dB
Mini-batch 16, PSNR: 17.75 dB
Mini-batch 17, PSNR: 17.09 dB
Mini-batch 18, PSNR: 18.01 dB
Mini-batch 19, PSNR: 18.69 dB
Mini-batch 20, PSNR: 19.69 dB
Mini-batch 21, PSNR: 18.82 dB
Mini-batch 22, PSNR: 17.63 dB
Mini-batch 23, PSNR: 20.12 dB
Mini-batch 24, PSNR: 19.99 dB
Mini-batch 25, PSNR: 17.27 dB
Mini-batch 26, PSNR: 18.92 dB
Mini-batch 27, PSNR: 18.76 dB
Mini-batch 28, PSNR: 19.71 dB
Mini-batch 29, PSNR: 18.61 dB
Mini-batch 30, PSNR: 18.36 dB
Mini-batch 31, PSNR: 18.86 dB
Mini-batch 32, PSNR: 19.71 dB
Mini-batch 33, PSNR: 18.29 dB
Mini-batch 34, PSNR

In [34]:
# Visualize Results
model.eval()
with torch.no_grad():
    dataiter = iter(testloader)
    images, _ = dataiter.next()
    original = images[0]
    blurred_h1 = apply_gaussian_blur(original, h1_sigma)
    blurred_h2 = apply_gaussian_blur(original, h2_sigma)

    restored_h1 = model(blurred_h1.unsqueeze(0).to(device)).cpu().squeeze()
    restored_h2 = model(blurred_h2.unsqueeze(0).to(device)).cpu().squeeze()

    fig, axes = plt.subplots(2, 3, figsize=(12, 8))
    axes[0, 0].imshow(np.transpose(original.numpy(), (1, 2, 0)))
    axes[0, 0].set_title('Original')
    axes[0, 1].imshow(np.transpose(blurred_h1.numpy(), (1, 2, 0)))
    axes[0, 1].set_title('Blurred with h1')
    axes[0, 2].imshow(np.transpose(blurred_h2.numpy(), (1, 2, 0)))
    axes[0, 2].set_title('Blurred with h2')

    axes[1, 1].imshow(np.transpose(restored_h1.numpy(), (1, 2, 0)))
    axes[1, 1].set_title('Restored from h1')
    axes[1, 2].imshow(np.transpose(restored_h2.numpy(), (1, 2, 0)))
    axes[1, 2].set_title('Restored from h2')

    plt.show()




AttributeError: '_SingleProcessDataLoaderIter' object has no attribute 'next'