In [1]:
import torch
from torch import nn, optim
from DIP import DIP
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from PIL import Image
from customDataset import CustomDataset
import os

In [2]:
input_shape = (3, 64, 64)
model = DIP(input_shape)
input_image = torch.randn(1, 3, 64, 64)
print(input_image.shape)
output_image = model(input_image)
print(output_image.shape)

torch.Size([1, 3, 64, 64])
torch.Size([1, 3, 512, 512])


In [3]:
# show the input image
input_image = input_image.squeeze(0)
input_image = transforms.ToPILImage()(input_image)
input_image.show()

In [4]:
#show the output image
output_image = output_image.squeeze(0)
output_image = transforms.ToPILImage()(output_image)
output_image.show()

In [5]:
# Pretraining
# hyperparameters
batch_size = 4
LR_image_path = "LR_images_mini"
HR_image_path = "HR_images_mini"

loss_fn = nn.MSELoss()
# optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Define a transformation operation to normalize the images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # normalize to [-1, 1]
])


dataset = CustomDataset(LR_image_path, HR_image_path, transform=transforms.ToTensor())
print(len(dataset))
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

500


In [6]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU")
else:
    device = torch.device("cpu")
    print("Using CPU")
import torch

# Check for GPU availability and set device
if torch.cuda.is_available():
    device_count = torch.cuda.device_count()
    device = torch.device("cuda")
    current_device_idx = torch.cuda.current_device()
else:
    device_count = 0
    device = torch.device("cpu")
    current_device_idx = None

# Print device information
print(f"Number of available GPUs: {device_count}")
print(f"Current device index: {current_device_idx}")
print(f"Current device name: {torch.cuda.get_device_name(current_device_idx)}")

Using GPU
Number of available GPUs: 1
Current device index: 0
Current device name: NVIDIA GeForce MX450


In [7]:
############################################# TRAINING #############################################

# Move your model and tensors to the device
model.to(device)
# Move the optimizer to the same device
optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 50
for epoch in range(epochs):
    print("Epoch: ", epoch)
    model.train()
    epoch_loss = 0.0  # Variable to accumulate loss within each epoch
    num_batches = 1
    for batch in dataset:
        # get the inputs
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # forward + backward + optimize
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        epoch_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        num_batches += 1

    # Print epoch information
    print(f"Epoch {epoch+1}/{epochs}: Loss: {epoch_loss / num_batches}")


Epoch:  0
Epoch 1/50: Loss: 0.016194780484053813
Epoch:  1
Epoch 2/50: Loss: 0.0067130597419813726
Epoch:  2
Epoch 3/50: Loss: 0.006312004876547974
Epoch:  3
Epoch 4/50: Loss: 0.006161907119880953
Epoch:  4
Epoch 5/50: Loss: 0.005965906772074734
Epoch:  5
Epoch 6/50: Loss: 0.005885573986084935
Epoch:  6
Epoch 7/50: Loss: 0.0056574147644484355
Epoch:  7
Epoch 8/50: Loss: 0.005656150594900201
Epoch:  8
Epoch 9/50: Loss: 0.005656649520064884
Epoch:  9
Epoch 10/50: Loss: 0.005526103487934456
Epoch:  10
Epoch 11/50: Loss: 0.005766941148377843
Epoch:  11
Epoch 12/50: Loss: 0.005553724064501914
Epoch:  12
Epoch 13/50: Loss: 0.0054126765209291045
Epoch:  13
Epoch 14/50: Loss: 0.005517432463593305
Epoch:  14
Epoch 15/50: Loss: 0.005415005641691981
Epoch:  15
Epoch 16/50: Loss: 0.005444692246901232
Epoch:  16
Epoch 17/50: Loss: 0.005531341573613846
Epoch:  17
Epoch 18/50: Loss: 0.005455289158143825
Epoch:  18
Epoch 19/50: Loss: 0.00537303610025866
Epoch:  19
Epoch 20/50: Loss: 0.0053581275917708

In [33]:
# Preprocess the input image
image = Image.open("LR_images/nn/downscaled_nn_flickr_dog_000742.jpg")
image.show()
input_image = transform(image)

# # Load the trained model
# model = DeepImagePrior(input_shape)
# model.load_state_dict(torch.load(PATH_TO_SAVED_WEIGHTS))

# Set the model to evaluation mode
model.eval()

# Perform the inference
with torch.no_grad():
    input_image = input_image.to(device)  # Move the input image to the device
    output = model(input_image)

# Convert the output to a PIL image
output = output.squeeze(0)
output = transforms.ToPILImage()(output)
output.show()


In [34]:
# Postprocess the output as needed
# Define the inverse transformation
# Convert PIL image to tensor
transform_to_PIL = transforms.ToTensor()
output = transform_to_PIL(output)

inverse_transform = transforms.Compose([
    transforms.Normalize(mean=[-0.5, -0.5, -0.5], std=[1/0.5, 1/0.5, 1/0.5])  # denormalize from [-1, 1] to [0, 1]
])
output = inverse_transform(output)
print(output.shape)

# Convert the tensor to a PIL Image
output = Image.fromarray((output * 255).byte().numpy(), 'RGB')
output.show()

torch.Size([3, 512, 512])
