In [1]:
import os 
import torch.nn as nn
from torchsummary import summary
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, random_split
import random


from conv_layer import conv_layer
from RLFB import RLFB
from SUBP import SubPixelConvBlock
from Trainning_Loop import train_model, CharbonnierLoss

In [8]:
class MESR(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, num_blocks=12, esa_channels=32, upscale_factor=32):
        super(MESR, self).__init__()


        self.conv_in = conv_layer(in_channels, mid_channels, 3)

        self.RLFB_blocks = nn.Sequential(*[RLFB(mid_channels, esa_channels=esa_channels) for _ in range(num_blocks)])

        self.conv_out1 = conv_layer(mid_channels, out_channels, 3)

        self.conv_out2 = conv_layer(mid_channels, out_channels, 3)

        self.sub_pixel_conv = SubPixelConvBlock(out_channels, out_channels, upscale_factor=upscale_factor)



    def forward(self, x):

        out_conv_in = self.conv_in(x)  

        out_RLFB = self.RLFB_blocks(out_conv_in)

        out2_conv_in = self.conv_out1(out_RLFB)  

        out_skip = out2_conv_in + out_conv_in  

        out = self.conv_out2(out_skip)  

        out = self.sub_pixel_conv(out)  

        return out





def model_summary(model, device):

    model.to(device)

    summary(model, input_size=(3, 128, 128)) # Change order & num of channels to match grayscale channel

In [5]:
class MESR(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, num_blocks=12, esa_channels=16, upscale_factor=32):
        super(MESR, self).__init__()

        self.conv_in = conv_layer(in_channels, mid_channels, 3)
        self.RLFB_blocks = nn.Sequential(*[RLFB(mid_channels, esa_channels=esa_channels) for _ in range(num_blocks)])
        self.conv_out = conv_layer(mid_channels, out_channels, 3)
        self.sub_pixel_conv = SubPixelConvBlock(out_channels, out_channels, upscale_factor=upscale_factor)

    def forward(self, x):
        out_conv_in = self.conv_in(x)  
        out_RLFB = self.RLFB_blocks(out_conv_in)  
        out_skip = out_RLFB + out_conv_in  
        out = self.conv_out(out_skip)  
        out = self.sub_pixel_conv(out)  
        return out


def model_summary(model, device):
    model.to(device)
    summary(model, input_size=(3, 128, 128)) # Change order & num of channels to match grayscale channel 


In [9]:
class SuperResolutionDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, lr_transform=None, hr_transform=None):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.lr_images = os.listdir(lr_dir)
        self.lr_transform = lr_transform
        self.hr_transform = hr_transform

    def __len__(self):
        return len(self.lr_images)

    def __getitem__(self, idx):
        lr_image_path = os.path.join(self.lr_dir, self.lr_images[idx])
        hr_image_path = os.path.join(self.hr_dir, self.lr_images[idx])

        lr_image = Image.open(lr_image_path).convert("RGB")
        hr_image = Image.open(hr_image_path).convert("RGB")

        if self.lr_transform:
            lr_image = self.lr_transform(lr_image)
        if self.hr_transform:
            hr_image = self.hr_transform(hr_image)

        return {'image': lr_image, 'label': hr_image}


In [10]:
def dataloaders(train_dataset, val_dataset, batch_size=1): # setting the batch size to 2
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader


def setup_training(model, device, train_loader_func, val_loader, stages, epochs=50, patience=50):
    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    loss_function = CharbonnierLoss(epsilon=1e-6)

    train_model(
        model=model,
        train_loader_func=train_loader_func,
        val_loader=val_loader,
        optimizer=optimizer,
        loss_function=loss_function,
        device=device,
        epochs=epochs,
        patience=patience,
        val_interval=10,
        output_dir="./model_output",  # Specify the output directory
        stages =stages
    )

In [11]:
import torch

# Check if GPU is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is being used. Device:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("Using CPU instead of GPU.")

GPU is being used. Device: NVIDIA GeForce RTX 4060 Laptop GPU


In [12]:
def main():
    if torch.cuda.is_available():
        print("CUDA is available!")
        print(f"Number of available GPUs: {torch.cuda.device_count()}")
        print(f"Current device: {torch.cuda.current_device()}")
        print(f"Device name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
    else:
        print("CUDA is not available.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = MESR(in_channels=3, mid_channels=64, out_channels=3, num_blocks=12)

    transform_lr = transforms.Compose([
 
        transforms.ToTensor(),

    ])

    transform_hr = transforms.Compose([

        transforms.ToTensor(),

    ])

    # load the dataset
    lr_dir = "low_res"
    hr_dir = "high_res"

    # dataset instance with separate transformations for LR and HR images
    full_dataset = SuperResolutionDataset(
        lr_dir=lr_dir,
        hr_dir=hr_dir,
        lr_transform=transform_lr,
        hr_transform=transform_hr
    )

    #splitting the dataset
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    train_loader, val_loader = dataloaders(train_dataset, val_dataset, batch_size=2)



    model_summary(model, device)
    setup_training(model, device, train_loader, val_loader)
main()


CUDA is available!
Number of available GPUs: 1
Current device: 0
Device name: NVIDIA GeForce RTX 4060 Laptop GPU


RuntimeError: The size of tensor a (3) must match the size of tensor b (64) at non-singleton dimension 1