In [10]:
import torch
import torch.nn as nn
import torch.optim as optim

class SuperResolutionCNN(nn.Module):
    def __init__(self):
        super(SuperResolutionCNN, self).__init__()
        
        #Feature extraction layer
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, stride=1, padding=4)  
        self.relu1 = nn.ReLU()

    
        self.conv2 = nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0)  
        self.relu2 = nn.ReLU()

     
        self.conv3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)  
        self.relu3 = nn.ReLU()

       
        self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)  
        self.relu4 = nn.ReLU()

        #Image output layer
        self.conv5 = nn.Conv2d(32, 3, kernel_size=5, stride=1, padding=2)  


    def forward(self, x):
        x = self.relu1(self.conv1(x))  
        x = self.relu2(self.conv2(x))  
        x = self.relu3(self.conv3(x))  
        x = self.relu4(self.conv4(x)) 
        x = self.conv5(x)              
        return x

model = SuperResolutionCNN()

print(model)

SuperResolutionCNN(
  (conv1): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (relu1): ReLU()
  (conv2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
  (relu2): ReLU()
  (conv3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu3): ReLU()
  (conv4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu4): ReLU()
  (conv5): Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
)


In [11]:
import os
from torch.utils.data import Dataset
from torchvision import transforms 
from torch.utils.data import DataLoader 

class SuperResolutionDataset(Dataset):
    def __init__(self, high_res_dir, low_res_dir, transform=None):
        self.high_res_dir = high_res_dir
        self.low_res_dir = low_res_dir
        self.transform = transform
        self.high_res_images = sorted(os.listdir(high_res_dir))
        self.low_res_images = sorted(os.listdir(low_res_dir))

    def __len__(self):
        return len(self.high_res_images)

    def __getitem__(self, idx):
        high_res_image_path = os.path.join(self.high_res_dir, self.high_res_images[idx])
        low_res_image_path = os.path.join(self.low_res_dir, self.low_res_images[idx])
        
        high_res_image = Image.open(high_res_image_path).convert('RGB')
        low_res_image = Image.open(low_res_image_path).convert('RGB')
        
        if self.transform:
            high_res_image = self.transform(high_res_image)
            low_res_image = self.transform(low_res_image)
        
        return low_res_image, high_res_image

transform = transforms.Compose([
    transforms.ToTensor(),  
    transforms.Resize((256, 256))
])

# Change below to desired training datasets
high_res_dir = r'DIV2K_HR'
low_res_dir = r'DIV2K_LR'

dataset = SuperResolutionDataset(high_res_dir, low_res_dir, transform)

dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

print(f"Dataset size: {len(dataset)}")


Dataset size: 800


In [12]:
from PIL import Image

def train_with_loss_tracking(model, dataloader, epochs=50, lr=1e-4):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    loss_history = []
    
    for epoch in range(epochs):
        model.train() 
        running_loss = 0.0
        
        for i, (low_res, high_res) in enumerate(dataloader):
            low_res, high_res = low_res.to(device), high_res.to(device)
            optimizer.zero_grad()
            output = model(low_res)
            loss = criterion(output, high_res)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        avg_loss = running_loss / len(dataloader)
        loss_history.append(avg_loss)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
    
    print("Finished Training!")
    return loss_history

loss_history = train_with_loss_tracking(model, dataloader)



Epoch 1/50, Loss: 0.0989
Epoch 2/50, Loss: 0.0182
Epoch 3/50, Loss: 0.0112
Epoch 4/50, Loss: 0.0071
Epoch 5/50, Loss: 0.0058
Epoch 6/50, Loss: 0.0048
Epoch 7/50, Loss: 0.0039
Epoch 8/50, Loss: 0.0033
Epoch 9/50, Loss: 0.0030
Epoch 10/50, Loss: 0.0028
Epoch 11/50, Loss: 0.0026
Epoch 12/50, Loss: 0.0024
Epoch 13/50, Loss: 0.0023
Epoch 14/50, Loss: 0.0022
Epoch 15/50, Loss: 0.0021
Epoch 16/50, Loss: 0.0020
Epoch 17/50, Loss: 0.0019
Epoch 18/50, Loss: 0.0018
Epoch 19/50, Loss: 0.0018
Epoch 20/50, Loss: 0.0017
Epoch 21/50, Loss: 0.0017
Epoch 22/50, Loss: 0.0017
Epoch 23/50, Loss: 0.0017
Epoch 24/50, Loss: 0.0016
Epoch 25/50, Loss: 0.0016
Epoch 26/50, Loss: 0.0015
Epoch 27/50, Loss: 0.0015
Epoch 28/50, Loss: 0.0015
Epoch 29/50, Loss: 0.0015
Epoch 30/50, Loss: 0.0015
Epoch 31/50, Loss: 0.0014
Epoch 32/50, Loss: 0.0014
Epoch 33/50, Loss: 0.0014
Epoch 34/50, Loss: 0.0014
Epoch 35/50, Loss: 0.0013
Epoch 36/50, Loss: 0.0013
Epoch 37/50, Loss: 0.0015
Epoch 38/50, Loss: 0.0013
Epoch 39/50, Loss: 0.

In [13]:
from PIL import Image  # Importing Image from PIL

def test(model, image_path):
    model.eval()  
    with torch.no_grad():  
        image = Image.open(image_path).convert('RGB')
        image = transform(image).unsqueeze(0)  
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        image = image.to(device)
        
        output = model(image)
        output_image = output.squeeze(0).cpu().numpy().transpose(1, 2, 0)
        output_image = (output_image * 255).astype('uint8')
        
        result_image = Image.fromarray(output_image)
        result_image.save("upscaled_image.png")

original_image = Image.open("original_image.png").convert('RGB')
original_width, original_height = original_image.size

# Using the model defined in CELL INDEX: 0
test(model, "original_image.png")

#Resizing image if necessary
upscaled_image = Image.open("upscaled_image.png")
upscaled_image_resized = upscaled_image.resize((original_width, original_height), Image.BICUBIC)
upscaled_image_resized = upscaled_image_resized.convert("RGB")  
upscaled_image_resized.save("upscaled_image_resized.png")