In [1]:
import torch
import torch.nn as nn
import torch.nn.init as init
from torchsummary import summary

filters = 96
blocks = 16

class ResBlock(nn.Module):
    def __init__(self, filters=filters, kernel_size=3):
        super(ResBlock, self).__init__()
        self.act = nn.SiLU()
        self.conv0 = nn.Conv2d(filters, filters, kernel_size=kernel_size, padding='same')
        self.conv1 = nn.Conv2d(filters, filters, kernel_size=kernel_size, padding='same')
        self.conv2 = nn.Conv2d(filters, filters, kernel_size=kernel_size, padding='same')

    def forward(self, input):
        x = self.act(self.conv0(input))
        x = self.act(self.conv1(x))
        x = self.conv2(x)
        return x + input

class Model(nn.Module):
    def __init__(self, filters=filters, kernel_size=3, upscale_factor=2):
        super(Model, self).__init__()
        self.conv0 = nn.Conv2d(1, filters, kernel_size=kernel_size, padding='same')
        self.res_blocks = nn.ModuleList([ResBlock() for _ in range(blocks)])
        self.conv1 = nn.Conv2d(filters, filters, kernel_size=kernel_size, padding='same')
        self.feats_conv = nn.Conv2d(filters, 4, kernel_size=kernel_size, padding='same')
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

    def forward(self, input):
        conv0 = self.conv0(input)
        x = conv0
        for block in self.res_blocks:
            x = block(x)
        conv1 = self.conv1(x)
        x = self.feats_conv(conv1 + conv0)
        x = self.pixel_shuffle(x)
        x = torch.clip(x, 0.0, 1.0)
        return x

model = Model()
summary(model.cuda(), (1, 256, 256))
model.load_state_dict(torch.load("/content/r16f96_torch.pth"))

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!cp /content/drive/MyDrive/Datasets/Anime_Train_HR.zip /content/HR1.zip
!cp /content/drive/MyDrive/Datasets/Digital_Art_Train_HR.zip /content/HR2.zip
!unzip /content/HR1.zip
!unzip /content/HR2.zip

In [None]:
import cv2
import numpy as np
import glob
import os
from pathlib import Path
from tqdm import tqdm

rotations = [0, 90, 180, 270]

filelist = sorted(glob.glob('/content/HR/*.png'))

for myFile in tqdm(filelist):
    img = cv2.imread(myFile, cv2.IMREAD_UNCHANGED)  # Preserve grayscale

    if img is None:  # Check if image was read correctly
        print(f"Error reading image: {myFile}")
        continue

    for rotation in rotations:
        rotated_img = np.rot90(img, rotation // 90) # More efficient rotation
        cv2.imwrite("/content/HR/" + str(Path(myFile).stem) + str(rotation) + ".png", rotated_img)

    flipped_img = cv2.flip(img, 1)  # Horizontal flip (flop)

    for rotation in rotations:
        rotated_flipped_img = np.rot90(flipped_img, rotation // 90)
        cv2.imwrite("/content/HR/" + str(Path(myFile).stem) + str(rotation) + "f.png", rotated_flipped_img)

    os.remove(myFile)

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
import glob
import numpy as np
import random

class ImagePairDataset(Dataset):
    def __init__(self, filelist):
        self.filelist = filelist

    def __len__(self):
        return len(self.filelist)

    def __getitem__(self, idx):
        file = self.filelist[idx]
        image = cv2.imread(file, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

        ref_image = image.astype(np.float32) / 255.0
        ref_image = np.clip(ref_image, 0.0, 1.0)
        ref_image = np.expand_dims(ref_image, axis=0)  # shape: (1, H, W)

        in_image = cv2.resize(image, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_LINEAR_EXACT)
        in_image = in_image.astype(np.float32) / 255.0
        in_image = np.clip(in_image, 0.0, 1.0)
        in_image = np.expand_dims(in_image, axis=0)  # shape: (1, H/2, W/2)

        return torch.tensor(in_image, dtype=torch.float32), torch.tensor(ref_image, dtype=torch.float32)

# Create DataLoader for training data
filelist = sorted(glob.glob('/content/HR/*.png'))
train_dataset = ImagePairDataset(filelist)
train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=8,
    pin_memory=True
)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

# Define loss function and optimizer
loss_function = nn.L1Loss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001)

# Move model to device if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    progress_bar = tqdm(train_loader, desc=f"[Epoch {epoch + 1}/{num_epochs}]", unit="batch")

    for inputs, targets in progress_bar:
        inputs, targets = inputs.to(device), targets.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)

        # Calculate loss
        loss = loss_function(outputs, targets)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Report progress
        progress_bar.set_postfix({'loss': loss.item()})

        # Accumulate loss
        running_loss += loss.item() * inputs.size(0)

    # Calculate average loss for the epoch
    epoch_loss = running_loss / len(train_loader.dataset)

    # Print epoch statistics
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.8f}')

print('Finished Training')
torch.save(model.state_dict(), "/content/r16f96_torch.pth")

In [None]:
import torch
import cv2
import numpy as np

image = cv2.imread('/content/downscaled.png', cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY, 0)
image = np.array(image).astype(np.float32) / 255.0
image = np.expand_dims(image, axis=-1)
image = np.expand_dims(image, axis=0)
image = np.transpose(image, (0, 3, 1, 2))
image = torch.tensor(image)

model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image = image.to(device)

with torch.no_grad():
    output = model(image)

output = output.cpu().numpy()
output = np.squeeze(output)
output = np.clip(output, 0.0, 1.0)
output = np.around(output * 255.0)
output = output.astype(np.uint8)

cv2.imwrite('/content/prediction.png', output)

In [None]:
!pip install onnx
import onnx

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

x = torch.ones((1, 1, 256, 256))  # N x C x W x H
x = x.to(device)
torch.onnx.export(
    model, x, '/content/r16f64_relu_torch.onnx',
    opset_version=20,
    input_names = ['input'],
    output_names = ['output'],
    dynamic_axes={
        'input' : {0 : 'batch', 2: 'width', 3: 'height'},
        'output' : {0 : 'batch', 2: 'width', 3: 'height'},
    }
)
torch.save(model.state_dict(), "/content/r16f64_relu_torch.pth")

!cp /content/r8f64_relu_torch.pth /content/drive/MyDrive/tmp/r8f64_relu_torch.pth
!cp /content/r8f64_relu_torch.onnx /content/drive/MyDrive/tmp/r8f64_relu_torch.onnx