# :train and val .txt file creation

In [2]:
import os

# Set the folder path containing the images and labels
folder_path = '/workspace/Datasets/DIV2K_valid_HR'

# Set the filename for the output text file
output_file = '/workspace/val_files.txt'
image_count = 0
# Open the output file for writing
with open(output_file, 'w') as f:
    # Loop over all files in the folder
    for filename in os.listdir(folder_path):
        # Check if the file is an image (e.g. .jpg or .png)
        if filename.endswith('.jpg') or filename.endswith('.png'):
            # Extract the label from the filename (assumes label is the first character(s) before the underscore)
            label = filename.split('_')[0]
            image_count += 1
            # Write the label to the output file
            f.write(label + '\n')

print(image_count)

100


# datalabel end

In [1]:
pip install ptflops

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting ptflops
  Downloading ptflops-0.7.3-py3-none-any.whl (18 kB)
Installing collected packages: ptflops
Successfully installed ptflops-0.7.3
You should consider upgrading via the '/usr/bin/python -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install pytorch-msssim

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting pytorch-msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl (7.7 kB)
Installing collected packages: pytorch-msssim
Successfully installed pytorch-msssim-1.0.0
You should consider upgrading via the '/usr/bin/python -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [3]:
pip install torchsummaryX

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting torchsummaryX
  Downloading torchsummaryX-1.3.0-py3-none-any.whl (3.6 kB)
Installing collected packages: torchsummaryX
Successfully installed torchsummaryX-1.3.0
You should consider upgrading via the '/usr/bin/python -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [5]:
pip install scikit-image

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting scikit-image
  Downloading scikit_image-0.21.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.9 MB)
[K     |████████████████████████████████| 13.9 MB 3.4 MB/s eta 0:00:01
[?25hCollecting networkx>=2.8
  Downloading networkx-3.1-py3-none-any.whl (2.1 MB)
[K     |████████████████████████████████| 2.1 MB 25.7 MB/s eta 0:00:01
[?25hCollecting lazy_loader>=0.2
  Downloading lazy_loader-0.4-py3-none-any.whl (12 kB)
Collecting tifffile>=2022.8.12
  Downloading tifffile-2023.7.10-py3-none-any.whl (220 kB)
[K     |████████████████████████████████| 220 kB 31.4 MB/s eta 0:00:01
[?25hCollecting imageio>=2.27
  Downloading imageio-2.35.1-py3-none-any.whl (315 kB)
[K     |████████████████████████████████| 315 kB 43.1 MB/s eta 0:00:01
Collecting PyWavelets>=1.1.1
  Downloading PyWavelets-1.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.9 MB)
[K     |██████████████████████████████

Data management

In [6]:
# data_management.py
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, Sampler
from skimage.util import view_as_windows
# from utils import load_image
# from transforms import ToTensor


def data_augmentation(image):
    augmented_images_arrays, augmented_images_list = [], []
    to_transform = [image, np.rot90(image, axes=(1, 2))]

    for t in to_transform:
        t_ud = t[:, ::-1, ...]
        t_lr = t[:, :, ::-1, ...]
        t_udlr = t_ud[:, :, ::-1, ...]

        flips = [t_ud, t_lr, t_udlr]
        augmented_images_arrays.extend(flips)

    augmented_images_arrays.extend(to_transform)

    for img in augmented_images_arrays:
        img_unbatch = list(img)
        augmented_images_list.extend(img_unbatch)

    return augmented_images_list


def create_patches(image, patch_size, step):
    image = view_as_windows(image, patch_size, step)
    h, w = image.shape[:2]
    image = np.reshape(image, (h * w, patch_size[0], patch_size[1], patch_size[2]))

    return image


class DataSampler(Sampler):

    def __init__(self, data_source, num_samples=None):
        super().__init__(data_source)
        self.data_source = data_source
        self._num_samples = num_samples
        self.rand = np.random.RandomState(0)
        self.perm = []

    @property
    def num_samples(self):
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self):
        n = len(self.data_source)
        if self._num_samples is not None:
            while len(self.perm) < self._num_samples:
                perm = self.rand.permutation(n).astype('int32').tolist()
                self.perm.extend(perm)
            idx = self.perm[:self._num_samples]
            self.perm = self.perm[self._num_samples:]
        else:
            idx = self.rand.permutation(n).astype('int32').tolist()

        return iter(idx)

    def __len__(self):
        return self.num_samples


class NoisyImagesDataset(Dataset):
    def __init__(self, files, channels, patch_size, transform=None, noise_transform=None):
        self.channels = channels
        self.patch_size = patch_size
        self.transform = transform
        self.noise_transforms = noise_transform
        self.to_tensor = ToTensor()
        self.dataset = {'image': [], 'noisy': []}
        self.load_dataset(files)

    def __len__(self):
        return len(self.dataset['image'])

    def __getitem__(self, idx):
        image, noisy = self.dataset.get('image')[idx], self.dataset.get('noisy')[idx]
        sample = {'image': image, 'noisy': noisy}
        if self.transform is not None:
            sample = self.transform(sample)
        sample = self.to_tensor(sample)

        return sample.get('noisy'), sample.get('image')

    def load_dataset(self, files):
        patch_size = (self.patch_size, self.patch_size, self.channels)
        for file in tqdm(files):
            image = load_image(file, self.channels)
            if image is None:
                continue

            image = create_patches(image, patch_size, step=self.patch_size)
            sample = {'image': image, 'noisy': None}

            for noise_transform in self.noise_transforms:
                _sample = noise_transform(sample)
                image, noisy = _sample['image'], _sample['noisy']
                image, noisy = list(image), list(noisy)

                self.dataset['image'].extend(image)
                self.dataset['noisy'].extend(noisy)

# Util

In [7]:
# utils.py
import random
import torch
import numpy as np
from skimage import io, color, img_as_ubyte


def load_image(image_path, channels):
    """
    Load image and change it color space from RGB to Grayscale if necessary.
    :param image_path: str
        Path of the image.
    :param channels: int
        Number of channels (3 for RGB, 1 for Grayscale)
    :return: numpy array
        Image loaded.
    """
    image = io.imread(image_path)

    if image.ndim == 3 and channels == 1:       # Convert from RGB to Grayscale and expand dims.
        image = img_as_ubyte(color.rgb2gray(image))
        return np.expand_dims(image, axis=-1)
    elif image.ndim == 2 and channels == 1:     # Handling grayscale images if needed.
        if image.dtype != 'uint8':
            image = img_as_ubyte(image)
        return np.expand_dims(image, axis=-1)

    return image


def mod_crop(image, mod):
    """
    Crops image according to mod to restore spatial dimensions
    adequately in the decoding sections of the model.
    :param image: numpy array
        Image to crop.
    :param mod: int
        Module for padding allowed by the number of
        encoding/decoding sections in the model.
    :return: numpy array
        Copped image
    """
    size = image.shape[:2]
    size = size - np.mod(size, mod)
    image = image[:size[0], :size[1], ...]

    return image


def mod_pad(image, mod):
    """
    Pads image according to mod to restore spatial dimensions
    adequately in the decoding sections of the model.
    :param image: numpy array
        Image to pad.
    :param mod: int
        Module for padding allowed by the number of
        encoding/decoding sections in the model.
    :return: numpy  array, tuple
        Padded image, original image size.
    """
    size = image.shape[:2]
    h, w = np.mod(size, mod)
    h, w = mod - h, mod - w
    if h != mod or w != mod:
        if image.ndim == 3:
            image = np.pad(image, ((0, h), (0, w), (0, 0)), mode='reflect')
        else:
            image = np.pad(image, ((0, h), (0, w)), mode='reflect')

    return image, size


def set_seed(seed=1):
    """
    Sets all random seeds.
    :param seed: int
        Seed value.
    :return: None
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def build_ensemble(image, normalize=True):
    """
    Create image ensemble to estimate denoised image.
    :param image: numpy array
        Noisy image.
    :param normalize: bool
        Normalize image to range [0., 1.].
    :return: list
        Ensemble of noisy image transformed.
    """
    img_rot = np.rot90(image)
    ensemble_list = [
        image, np.fliplr(image), np.flipud(image), np.flipud(np.fliplr(image)),
        img_rot, np.fliplr(img_rot), np.flipud(img_rot), np.flipud(np.fliplr(img_rot))
    ]

    ensemble_transformed = []
    for img in ensemble_list:
        if img.ndim == 2:                                           # Expand dims for channel dimension in gray scale.
            img = np.expand_dims(img.copy(), 0)                     # Use copy to avoid problems with reverse indexing.
        else:
            img = np.transpose(img.copy(), (2, 0, 1))               # Channels-first transposition.
        if normalize:
            img = img / 255.

        img_t = torch.from_numpy(np.expand_dims(img, 0)).float()    # Expand dims again to create batch dimension.
        ensemble_transformed.append(img_t)

    return ensemble_transformed


def separate_ensemble(ensemble, return_single=False):
    """
    Apply inverse transforms to predicted image ensemble and average them.
    :param ensemble: list
        Predicted images, ensemble[0] is the original image,
        and ensemble[i] is a transformed version of ensemble[i].
    :param return_single: bool
        Return also ensemble[0] to evaluate single prediction
    :return: numpy array or tuple of numpy arrays
        Average of the predicted images, original image denoised.
    """
    ensemble_np = []

    for img in ensemble:
        img = img.squeeze()                     # Remove additional dimensions.
        if img.ndim == 3:                       # Transpose if necessary.
            img = np.transpose(img, (1, 2, 0))

        ensemble_np.append(img)

    # Apply inverse transforms to vertical and horizontal flips.
    img = ensemble_np[0] + np.fliplr(ensemble_np[1]) + np.flipud(ensemble_np[2]) + np.fliplr(np.flipud(ensemble_np[3]))

    # Apply inverse transforms to 90º rotation, vertical and horizontal flips
    img = img + np.rot90(ensemble_np[4], k=3) + np.rot90(np.fliplr(ensemble_np[5]), k=3)
    img = img + np.rot90(np.flipud(ensemble_np[6]), k=3) + np.rot90(np.fliplr(np.flipud(ensemble_np[7])), k=3)

    # Average and clip final predicted image.
    img = img / 8.
    img = np.clip(img, 0., 1.)

    if return_single:
        return img, ensemble_np[0]
    else:
        return img


def predict_ensemble(model, ensemble, device):
    """
    Predict batch of images from an ensemble.
    :param model: torch Module
        Trained model to estimate denoised images.
    :param ensemble: list
        Images to estimate.
    :param device: torch device
        Device of the trained model.
    :return: list
        Estimated images of type numpy ndarray.
    """
    y_hat_ensemble = []

    for x in ensemble:
        x = x.to(device)

        with torch.no_grad():
            y_hat = model(x)
            y_hat_ensemble.append(y_hat.cpu().detach().numpy().astype('float32'))

    return y_hat_ensemble

Transform

In [8]:
# transforms.py
import random
import torch
import numpy as np


class AdditiveWhiteGaussianNoise(object):
    """Additive white gaussian noise generator."""
    def __init__(self, noise_level, fix_sigma=False, clip=False):
        self.noise_level = noise_level
        self.fix_sigma = fix_sigma
        self.rand = np.random.RandomState(1)
        self.clip = clip
        if not fix_sigma:
            self.predefined_noise = [i for i in range(5, noise_level + 1, 5)]

    def __call__(self, sample):
        """
        Generates additive white gaussian noise, and it is applied to the clean image.
        :param sample:
        :return:
        """
        image = sample.get('image')

        if image.ndim == 4:                 # if 'image' is a batch of images, we set a different noise level per image
            samples = image.shape[0]        # (Samples, Height, Width, Channels) or (Samples, Channels, Height, Width)
            if self.fix_sigma:
                sigma = self.noise_level * np.ones((samples, 1, 1, 1))
            else:
                sigma = np.random.choice(self.predefined_noise, size=(samples, 1, 1, 1))
            noise = self.rand.normal(0., 1., size=image.shape)
            noise = noise * sigma
        else:                               # else, 'image' is a simple image
            if self.fix_sigma:              # (Height, Width, Channels) or (Channels , Height, Width)
                sigma = self.noise_level
            else:
                sigma = self.rand.randint(5, self.noise_level)
            noise = self.rand.normal(0., sigma, size=image.shape)

        noisy = image + noise

        if self.clip:
            noisy = np.clip(noisy, 0., 255.)

        return {'image': image, 'noisy': noisy.astype('float32')}


class ToTensor(object):
    """Convert data sample to pytorch tensor"""
    def __call__(self, sample):
        image, noisy = sample.get('image'), sample.get('noisy')
        image = torch.from_numpy(image.transpose((2, 0, 1)).astype('float32') / 255.)

        if noisy is not None:
            noisy = torch.from_numpy(noisy.transpose((2, 0, 1)).astype('float32') / 255.)

        return {'image': image, 'noisy': noisy}


class RandomVerticalFlip(object):

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, sample):
        if random.uniform(0., 1.) < self.p:
            image, noisy = sample.get('image'), sample.get('noisy')
            image = np.flipud(image)

            if noisy is not None:
                noisy = np.flipud(noisy)

            return {'image': image, 'noisy': noisy}

        return sample


class RandomHorizontalFlip(object):

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, sample):
        if random.uniform(0., 1.) < self.p:
            image, noisy = sample.get('image'), sample.get('noisy')
            image = np.fliplr(image)

            if noisy is not None:
                noisy = np.fliplr(noisy)

            return {'image': image, 'noisy': noisy}

        return sample


class RandomRot90(object):

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, sample):
        if random.uniform(0., 1.) < self.p:
            image, noisy = sample.get('image'), sample.get('noisy')
            image = np.rot90(image)

            if noisy is not None:
                noisy = np.rot90(noisy)

            return {'image': image, 'noisy': noisy}

        return sample

Metrics

In [9]:
# metrics.py
import torch
from pytorch_msssim import SSIM as _SSIM


class PSNR(object):
    r"""
    Evaluates the PSNR metric in a tensor.
    It can return a result with different reduction methods.

    Args:
        data_range (int, float): Range of the input images.
        reduction (string): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed.
        eps (float): Epsilon value to avoid division by zero.
    """
    def __init__(self, data_range, reduction='none', eps=1e-8):
        self.data_range = data_range
        self.reduction = reduction
        self.eps = eps

    def __call__(self, outputs, targets):
        with torch.set_grad_enabled(False):
            mse = torch.mean((outputs - targets) ** 2., dim=(1, 2, 3))
            psnr = 10. * torch.log10((self.data_range ** 2.) / (mse + self.eps))

            if self.reduction == 'mean':
                return psnr.mean()
            if self.reduction == 'sum':
                return psnr.sum()

            return psnr


class SSIM(object):
    r"""
    Evaluates the SSIM metric in a tensor.
    It can return a result with different reduction methods.

    Args:
        channels (int): Number of channels of the images.
        data_range (int, float): Range of the input images.
        reduction (string): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed.
    """
    def __init__(self, channels, data_range, reduction='none'):
        self.data_range = data_range
        self.reduction = reduction
        self.ssim_module = _SSIM(data_range=data_range, size_average=False, channel=channels)

    def __call__(self, outputs, targets):
        with torch.set_grad_enabled(False):
            ssim = self.ssim_module(outputs, targets)

            if self.reduction == 'mean':
                return ssim.mean()
            if self.reduction == 'sum':
                return ssim.sum()

            return ssim

# modified uNet + attn Model

In [10]:
import torch
import torch.nn as nn
from torchvision import models

from typing import List, Tuple
import torch.nn.functional as F



@torch.no_grad()
def init_weights(init_type='xavier'):
    if init_type == 'xavier':
        init = nn.init.xavier_normal_
    elif init_type == 'he':
        init = nn.init.kaiming_normal_
    else:
        init = nn.init.orthogonal_

    def initializer(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            init(m.weight)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight, 1.0, 0.01)
            nn.init.zeros_(m.bias)

    return initializer



class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.PReLU(out_channels)
        self.stride = stride

        # If the input and output channels differ, use a 1x1 convolution to match them
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.shortcut = nn.Sequential()

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += self.shortcut(x)
        out = self.relu(out)

        return out

class SkipPath(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SkipPath, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        # print('skippath',x.shape)
        x1 = self.conv1(x)
        # print('skippath_con1',x1.shape)
        x2 = self.conv2(x)
        # print('skippath_con2',x2.shape)
        x_0 = x1 + x2
        # print('skippath_add',x_0.shape)

        x3 = self.conv1(x_0)
        x4 = self.conv2(x_0)
        x_1 = x3 + x4

        x5 = self.conv1(x_1)
        x6 = self.conv2(x_1)
        x_2 = x5 + x6

        x7 = self.conv1(x_2)
        x8 = self.conv2(x_2)
        x = x7 + x8


        return x



class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.res_block = ResBlock(out_channels, out_channels)
        self.relu = nn.PReLU(out_channels)



    def forward(self, x):
        # print('encoder',x.shape)
        x = self.conv1(x)
        # print('encoder_conv1',x.shape)
        x = self.conv2(x)
        # print('encoder_conv2',x.shape)
        x = self.res_block(x)
        # print('encoder_res',x.shape)
        x = self.relu(x)


        return x

class Down_DC(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down_DC, self).__init__()
        self.layer = nn.Sequential(
            nn.MaxPool2d(2),
            EncoderBlock(in_channels, out_channels)
        )
    def forward(self, x):
        x = self.layer(x)
        # print('maxpool', x.shape)
        return x


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.res_block = ResBlock(out_channels, out_channels)
        self.relu = nn.PReLU(out_channels)


    def forward(self, x):
        # print('decoder_in',x.shape)
        x = self.conv1(x)
        # print('dec_con1', x.shape)
        x = self.conv2(x)
        # print('dec_con2', x.shape)
        x = self.res_block(x)
        # print('dec_res', x.shape)
        x = self.relu(x)
        return x


    # def forward(self, x, skip_connection):
    #     print('Decoder_input',x.shape)
    #     x = torch.cat([x, skip_connection], dim=1)
    #     print('dec_cat',x.shape)
    #     print('skip',skip_connection.shape)
    #     x = self.conv1(x)
    #     print('dec_con1', x.shape)
    #     x = self.conv2(x)
    #     print('dec_con2', x.shape)
    #     x = self.res_block(x)
    #     print('dec_res', x.shape)
    #     x = self.relu(x)
    #     return x


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpBlock, self).__init__()
        self.Tconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)


    def forward(self, x):
        # print('up_in',x.shape)
        x = self.Tconv(x)
        # print('up_out',x.shape)

        return x




class EnDecoderModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EnDecoderModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.res_block = ResBlock(out_channels, out_channels)
        # self.Tconv = nn.ConvTranspose2d(out_channels, out_channels, kernel_size=2, stride=2)
        self.relu = nn.PReLU(out_channels)

    def forward(self, x):
        # print('bottle',x.shape)
        x = self.conv1(x)
        # print('bottle',x.shape)
        x = self.conv2(x)
        x = self.res_block(x)
        # print('bottle_res',x.shape)
        x = self.relu(x)
        # x = self.Tconv(x)
        # print('bottle_tconv',x.shape)
        return x




class InputBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(InputBlock, self).__init__()
        self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)

        self.actv_1 = nn.PReLU(out_channels)
        self.actv_2 = nn.PReLU(out_channels)

    def forward(self, x):
        x = self.actv_1(self.conv_1(x))
        return self.actv_2(self.conv_2(x))


class OutputBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutputBlock, self).__init__()
        self.conv_1 = nn.Conv2d(in_channels, in_channels, 3, padding=1)
        self.conv_2 = nn.Conv2d(in_channels, out_channels, 3, padding=1)

        self.actv_1 = nn.PReLU(in_channels)
        self.actv_2 = nn.PReLU(out_channels)

    def forward(self, x):
        x = self.actv_1(self.conv_1(x))
        return self.actv_2(self.conv_2(x))

class SequentialPolarizedSelfAttention(nn.Module):

    # Multi-head Attention

    def __init__(self, channel=512, num_heads=4):
        super().__init__()
        self.ch_wv=nn.ModuleList([nn.Conv2d(channel, channel//2, kernel_size=(1,1)) for _ in range(num_heads)])
        self.ch_wq=nn.ModuleList([nn.Conv2d(channel, 1, kernel_size=(1,1)) for _ in range(num_heads)])
        self.softmax_channel=nn.Softmax(1)
        self.softmax_spatial=nn.Softmax(-1)
        self.ch_wz=nn.ModuleList([nn.Conv2d(channel//2, channel, kernel_size=(1,1)) for _ in range(num_heads)])
        self.ln=nn.LayerNorm(channel)
        self.sigmoid=nn.Sigmoid()
        self.sp_wv=nn.ModuleList([nn.Conv2d(channel, channel//2, kernel_size=(1,1)) for _ in range(num_heads)])
        self.sp_wq=nn.ModuleList([nn.Conv2d(channel, channel//2, kernel_size=(1,1)) for _ in range(num_heads)])
        self.agp=nn.AdaptiveAvgPool2d((1,1))
        self.num_heads = num_heads
        self.conv = nn.Conv2d(channel*num_heads, channel, kernel_size=(1, 1))

    def forward(self, x):
          b, c, h, w = x.size()
          channel_out = 0
          spatial_heads = []
          for i in range(self.num_heads):
              #Channel-only Self-Attention
              channel_wv=self.ch_wv[i](x) #bs,c//2,h,w
              channel_wq=self.ch_wq[i](x) #bs,1,h,w
              channel_wv=channel_wv.reshape(b,c//2,-1) #bs,c//2,h*w
              channel_wq=channel_wq.reshape(b,-1,1) #bs,h*w,1
              channel_wq=self.softmax_channel(channel_wq)
              channel_wz=torch.matmul(channel_wv,channel_wq).unsqueeze(-1) #bs,c//2,1,1
              channel_weight=self.sigmoid(self.ch_wz[i](channel_wz).reshape(b,c,1).permute(0,2,1)).permute(0,2,1).reshape(b,c,1,1) #bs,c,1,1
              channel_out=channel_weight*x

              #Spatial-only Self-Attention
              spatial_wv=self.sp_wv[i](channel_out) #bs,c//2,h,w
              spatial_wq=self.sp_wq[i](channel_out) #bs,c//2,h,w
              spatial_wq=self.agp(spatial_wq) #bs,c//2,1,1
              spatial_wv=spatial_wv.reshape(b,c//2,-1) #bs,c//2,h*w
              spatial_wq=spatial_wq.permute(0,2,3,1).reshape(b,1,c//2) #bs,1,c//2
              spatial_wq=self.softmax_spatial(spatial_wq)
              spatial_wz=torch.matmul(spatial_wq,spatial_wv) #bs,1,h*w
              spatial_weight=self.sigmoid(spatial_wz.reshape(b,1,h,w)) #bs,1,h,w
              spatial_out=spatial_weight*channel_out
              spatial_heads.append(spatial_out)


          # Combine results of Spatial-only Self-Attention across all heads
          spatial_heads = torch.cat(spatial_heads, dim=1) #bs,num_heads,h,w
          spatial_out = self.conv(spatial_heads) #bs,c,h,w
          return spatial_out


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        self.relu = nn.PReLU(out_channels)


    def forward(self, x):

        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)

        return x

class DeepRDU(nn.Module):
    r"""
    Residual-Dense U-net for image denoising.
    """
    def __init__(self, **kwargs):
        super().__init__()

        channels = kwargs['channels']
        filters_0 = kwargs['base filters']
        # filters_1 = 2 * filters_0
        # filters_2 = 4 * filters_0
        # filters_3 = 8 * filters_0
        # filters_4 = 16 * filters_0

        # Encoder:
        # Level 0:
        self.in_0 = InputBlock(channels, filters_0)

        self.enc_0 = EncoderBlock(64, 64)
#         self.enc_0 = EncoderBlock(1, 64)
        self.DC_0 = Down_DC(64, 128)





        # Level 1:

        self.enc_1 = EncoderBlock(128, 128)
        self.DC_1 = Down_DC(128, 256)

        # Level 2:
        self.enc_2 = EncoderBlock(256, 256)
        self.DC_2 = Down_DC(256, 512)

        # Level 3:
        self.enc_3 = EncoderBlock(512, 512)
        self.DC_3 = Down_DC(512, 1024)


        # Level 4 (Bottleneck)
        self.EnDe = EnDecoderModule(1024, 1024)
        self.EnDe_3 = UpBlock(1024, 1024)
        # self.EnDe = EncoderBlock(1024, 1024)
        # self.up_EnDe = UpBlock(1024, 1024)



        # Decoder
        # Level 3:

        self.dec_3 = DecoderBlock(1024, 512)
        # self.up_1_3 = UpBlock(1024, 1024)
        self.up_3 = UpBlock(1024, 512)
        self.sc_3 = SkipPath(512, 512)

        # Level 2:

        self.dec_2 = DecoderBlock(512, 256)
        # self.up_1_3 = UpBlock(1024, 1024)
        self.up_2 = UpBlock(512, 256)
        self.sc_2 = SkipPath(256, 256)

        # Level 1:


        self.dec_1 = DecoderBlock(256, 128)
        self.up_1 = UpBlock(256, 128)
        self.sc_1 = SkipPath(128, 128)

        # Level 0:


        self.dec_0 = DecoderBlock(128, 64)
        self.conv = nn.Conv2d(128,64, 3, padding=1)
#         self.conv = nn.Conv2d(128,64, 1, padding=1)
        self.sc_0 = SkipPath(64, 64)

        self.at_con = ConvBlock(192, 128)
        self.ca = SequentialPolarizedSelfAttention(128)



        self.output_block = OutputBlock(128, 3)
#         self.output_block = OutputBlock(128, 1)


    def forward(self, inputs):


        out = self.in_0(inputs)
#         print('input',inputs.shape)

        out_0_1 = self.enc_0(out)
#         out_0_1 = self.enc_0(inputs)
        out_0 = self.DC_0(out_0_1)

        out_1_1 = self.enc_1(out_0)
        out_1 = self.DC_1(out_1_1)

        out_2_1 = self.enc_2(out_1)
        out_2 = self.DC_2(out_2_1)

        out_3_1 = self.enc_3(out_2)
        out_3 = self.DC_3(out_3_1)
        #bottleneck
        out_4 = self.EnDe(out_3)
        out_4 = self.EnDe_3(out_4)

        #Decoder
        out_5 = self.dec_3(out_4)
#         out_5_1 = self.sc_3(out_3_1)  #skip3
#         out_5 = torch.cat([out_5, out_5_1], 1)
        out_5 = torch.cat([out_5, out_3_1], 1)
        out_6 = self.up_3(out_5)

        out_7 = self.dec_2(out_6)

        out_7_1 = self.sc_2(out_2_1)  #skip2
        out_7 = torch.cat([out_7, out_7_1], 1)
        # out_7 = torch.cat([out_7, out_2_1], 1)
        out_8 = self.up_2(out_7)

        out_8_1 = self.sc_1(out_1_1) #skip1
        out_9 = self.dec_1(out_8)
        out_9 = torch.cat([out_9, out_8_1], 1)

        out_9 = self.up_1(out_9)

        out_10_1 = self.sc_0(out_0_1) #skip0
        out_10 = self.dec_0(out_9)
        out_10 = torch.cat([out_10_1, out_10], 1)

        #attention
        out = torch.cat([out_10, out_0_1], 1)
        out = self.at_con(out)
        out = self.ca(out)

        out = self.output_block(out)


        out = inputs - out


        return out




# Train

In [11]:
# train.py
import csv
import os
import torch
import time
import numpy as np
from tqdm import tqdm
from torch import optim

# from metrics import PSNR, SSIM



class EpochLogger:
    r"""
    Keeps a log of metrics in the current epoch.
    """
    def __init__(self):
        self.log = {
            'train loss': 0., 'train psnr': 0., 'train ssim': 0., 'val loss': 0., 'val psnr': 0., 'val ssim': 0.
        }

    def update_log(self, metrics, phase):
        """
        Update the metrics in the current epoch, this method is called at every step of the epoch.
        :param metrics: dict
            Metrics to update: loss, PSNR and SSIM.
        :param phase: str
            Phase of the current epoch: training (train) or validation (val).
        :return: None
        """
        for key, value in metrics.items():
            self.log[' '.join([phase, key])] += value

    def get_log(self, n_samples, phase):
        """
        Returns the average of the monitored metrics in the current moment,
        given the number of evaluated samples.
        :param n_samples: int
            Number of evaluated samples.
        :param phase: str
            Phase of the current epoch: training (train) or validation (val).
        :return: dic
            Log of the current phase in the training.
        """
        log = {
            phase + ' loss': self.log[phase + ' loss'] / n_samples,
            phase + ' psnr': self.log[phase + ' psnr'] / n_samples,
            phase + ' ssim': self.log[phase + ' ssim'] / n_samples
        }
        return log


class FileLogger(object):
    """
    Keeps a log of the whole training and validation process.
    The results are recorded in a CSV files.

    Args:
        file_path (string): path of the csv file.
    """
    def __init__(self, file_path):
        """
        Creates the csv record file.
        :param f
        """
        self.file_path = file_path
        header = ['epoch', 'lr', 'train loss', 'train psnr', 'train ssim', 'val loss', 'val psnr', 'val ssim']

        with open(self.file_path, 'w') as csv_file:
            file_writer = csv.writer(csv_file, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL)
            file_writer.writerow(header)

    def __call__(self, epoch_log):
        """
        Updates the CSV record file.
        :param epoch_log: dict
            Log of the current epoch.
        :return: None
        """

        # Format log file:
        # Epoch and learning rate:
        log = ['{:03d}'.format(epoch_log['epoch']), '{:.5e}'.format(epoch_log['learning rate'])]

        # Training loss, PSNR, SSIM:
        log.extend([
            '{:.5e}'.format(epoch_log['train loss']),
            '{:.5f}'.format(epoch_log['train psnr']),
            '{:.5f}'.format(epoch_log['train ssim'])
        ])

        # Validation loss, PSNR, SSIM
        # Validation might not be done at all epochs, in that case the default calue is zero.
        log.extend([
            '{:.5e}'.format(epoch_log.get('val loss', 0.)),
            '{:.5f}'.format(epoch_log.get('val psnr', 0.)),
            '{:.5f}'.format(epoch_log.get('val ssim', 0.))
        ])

        with open(self.file_path, 'a') as csv_file:
            file_writer = csv.writer(csv_file, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL)
            file_writer.writerow(log)


def fit_model(model, data_loaders, channels, criterion, optimizer, scheduler, device, n_epochs, val_freq, checkpoint_dir, model_name):
    """
    Training of the denoiser model.
    :param model: torch Module
        Neural network to fit.
    :param data_loaders: dict
        Dictionary with torch DataLoaders with training and validation datasets.
    :param channels: int
        Number of image channels
    :param criterion: torch Module
        Loss function.
    :param optimizer: torch Optimizer
        Gradient descent optimization algorithm.
    :param scheduler: torch lr_scheduler
        Learning rate scheduler.
    :param device: torch device
        Device used during training (CPU/GPU).
    :param n_epochs: int
        Number of epochs to fit the model.
    :param val_freq: int
        How many training epochs to run between validations.
    :param checkpoint_dir: str
        Path to the directory where the model checkpoints and CSV log files will be stored.
    :param model_name: str
        Prefix name of the trained model saved in checkpoint_dir.
    :return: None
    """
    psnr = PSNR(data_range=1., reduction='sum')
    ssim = SSIM(channels, data_range=1., reduction='sum')
    os.makedirs(checkpoint_dir, exist_ok=True)
    logfile_path = os.path.join(checkpoint_dir,  ''.join([model_name, '_logfile.csv']))
    model_path = os.path.join(checkpoint_dir, ''.join([model_name, '-{:03d}-{:.4e}-{:.4f}-{:.4f}.pth']))
    file_logger = FileLogger(logfile_path)
    best_model_path, best_psnr = '', -np.inf
    since = time.time()

    flag = True #False
    CP = True


    if flag == True :
      checkpoint = torch.load('/workspace/Checkpoints/checkpoint.pt')
      model.load_state_dict(checkpoint['model_state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
      epoch_start = checkpoint['epoch']
      loss = checkpoint['loss']
      step = checkpoint['step']
      for param_group in optimizer.param_groups:
            param_group["lr"] = checkpoint['lr']
    else :
      epoch_start = 0
      step = 0
      for param_group in optimizer.param_groups:
            param_group["lr"] = optimizer.param_groups[0]['lr']





    for epoch in range(epoch_start, n_epochs + 1):
        lr = optimizer.param_groups[0]['lr']
        epoch_logger = EpochLogger()
        epoch_log = dict()

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                print('\nEpoch: {}/{} - Learning rate: {:.4e}'.format(epoch, n_epochs, lr))
                description = 'Training - Loss:{:.5e} - PSNR:{:.5f} - SSIM:{:.5f}'
            elif phase == 'val' and epoch % val_freq == 0:
                model.eval()
                description = 'Validation - Loss:{:.5e} - PSNR:{:.5f} - SSIM:{:.5f}'
            else:
                break

            iterator = tqdm(enumerate(data_loaders[phase], 1), total=len(data_loaders[phase]), ncols=110)
            iterator.set_description(description.format(0, 0, 0))
            n_samples = 0

            for step, (inputs, targets) in iterator:
                inputs, targets = inputs.to(device), targets.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                n_samples += inputs.size()[0]
                metrics = {
                    'loss': loss.item() * inputs.size()[0],
                    'psnr': psnr(outputs, targets).item(),
                    'ssim': ssim(outputs, targets).item()
                }
                epoch_logger.update_log(metrics, phase)
                log = epoch_logger.get_log(n_samples, phase)
                iterator.set_description(description.format(log[phase + ' loss'], log[phase + ' psnr'], log[phase + ' ssim']))

            if phase == 'val':
                # Apply Reduce LR On Plateau if it is the case and save the model if the validation PSNR is improved.
                if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                    scheduler.step(log['val psnr'])
                if log['val psnr'] > best_psnr:
                    best_psnr = log['val psnr']

                    # best_model_path = model_path.format(epoch, log['val loss'], log['val psnr'], log['val ssim'])
                    best_model_path = model_path
                    torch.save(model.state_dict(), best_model_path)

            elif scheduler is not None:         # Apply another scheduler at epoch level.
                scheduler.step()

            epoch_log = {**epoch_log, **log}



        # Save the current epoch metrics in a CVS file.
        epoch_data = {'epoch': epoch, 'learning rate': lr, **epoch_log}
        file_logger(epoch_data)

        if CP == True :
          torch.save({
              'epoch': epoch + 1,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'loss': loss,
              'step' : step,
              'lr' : param_group["lr"],
              }, '/workspace/Checkpoints/checkpoint.pt')


    # Save the last model and report training time.
    # best_model_path = model_path.format(epoch, log['val loss'], log['val psnr'], log['val ssim'])
    best_model_path = model_path
    torch.save(model.state_dict(), best_model_path)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best PSNR: {:4f}'.format(best_psnr))

# Main train

In [12]:
# import yaml
# import torch
# import torch.optim as optim
# from os.path import join
# from torch import nn
# from torch.utils.data import DataLoader
# from torchvision.transforms import transforms
# from ptflops import get_model_complexity_info

# from torchsummaryX import summary

# # from model import RDUNet
# # from data_management import NoisyImagesDataset, DataSampler
# # from train import fit_model
# # from transforms import AdditiveWhiteGaussianNoise, RandomHorizontalFlip, RandomVerticalFlip, RandomRot90
# # from utils import set_seed


# def main():
#     with open('/workspace/config.yaml', 'r') as stream:                # Load YAML configuration file.
#         config = yaml.safe_load(stream)

#     model_params = config['model']
#     train_params = config['train']
#     val_params = config['val']

#     # Defining model:
#     set_seed(0)
#     drop_prob = 0.1
#     model = DeepRDU(**model_params)
#     # model = DN(**model_params)
#     # model = RDUNet(**model_params)
#     # model = UNet(n_classes = 1, depth = 4, padding = True)




#     print('Model summary:')
#     test_shape = (model_params['channels'], train_params['patch size'], train_params['patch size'])
#     with torch.no_grad():
#         macs, params = get_model_complexity_info(model, test_shape, as_strings=True,
#                                                  print_per_layer_stat=False, verbose=False)
#         print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
#         print('{:<30}  {:<8}'.format('Number of parameters: ', params))

#     # Define the model name and use multi-GPU if it is allowed.
#     model_name = 'model_color' if model_params['channels'] == 3 else 'model_gray'
#     device = torch.device(train_params['device'])
#     print("Using device: {}".format(device))
#     if torch.cuda.device_count() > 1 and 'cuda' in device.type and train_params['multi gpu']:
#         model = nn.DataParallel(model)
#         print('Using multiple GPUs')

#     model = model.to(device)
#     param_group = []
#     for name, param in model.named_parameters():
#         if 'conv' in name and 'weight' in name:
#             p = {'params': param, 'weight_decay': train_params['weight decay']}
#         else:
#             p = {'params': param, 'weight_decay': 0.}
#         param_group.append(p)

#     # Load training and validation file names.
#     # Modify .txt files if datasets do not fit in memory.
#     with open('/workspace/train_files.txt', 'r') as f_train, open('/workspace/val_files.txt', 'r') as f_val:
#         raw_train_files = f_train.read().splitlines()
#         raw_val_files = f_val.read().splitlines()
#         train_files = list(map(lambda file: join(train_params['dataset path'], file), raw_train_files))
#         val_files = list(map(lambda file: join(val_params['dataset path'], file), raw_val_files))

#     training_transforms = transforms.Compose([
#         RandomHorizontalFlip(),
#         RandomVerticalFlip(),
#         RandomRot90()
#     ])

#     # Predefined noise level
#     train_noise_transform = [AdditiveWhiteGaussianNoise(train_params['noise level'], clip=True)]
#     val_noise_transforms = [AdditiveWhiteGaussianNoise(s, fix_sigma=True, clip=True) for s in val_params['noise levels']]

#     print('\nLoading training dataset:')
#     training_dataset = NoisyImagesDataset(train_files,
#                                           model_params['channels'],
#                                           train_params['patch size'],
#                                           training_transforms,
#                                           train_noise_transform)

#     print('\nLoading validation dataset:')
#     validation_dataset = NoisyImagesDataset(val_files,
#                                             model_params['channels'],
#                                             val_params['patch size'],
#                                             None,
#                                             val_noise_transforms)
#     # Training in sub-epochs:
#     print('Training patches:', len(training_dataset))
#     print('Validation patches:', len(validation_dataset))
#     n_samples = len(training_dataset) // train_params['dataset splits']
#     n_epochs = train_params['epochs'] * train_params['dataset splits']
#     sampler = DataSampler(training_dataset, num_samples=n_samples)

#     data_loaders = {
#         'train': DataLoader(training_dataset, train_params['batch size'], num_workers=train_params['workers'], sampler=sampler),
#         'val': DataLoader(validation_dataset, val_params['batch size'], num_workers=val_params['workers']),
#     }

#     # Optimization:
#     learning_rate = train_params['learning rate']
#     step_size = train_params['scheduler step'] * train_params['dataset splits']

#     criterion = nn.L1Loss()
#     optimizer = optim.AdamW(param_group, lr=learning_rate)
#     lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=train_params['scheduler gamma'])

#     # Train the model
#     fit_model(model, data_loaders, model_params['channels'], criterion, optimizer, lr_scheduler, device,
#               n_epochs, val_params['frequency'], train_params['checkpoint path'], model_name)


# if __name__ == '__main__':
#     main()
#     # ""## Model Compilation"""
#     # base_filters = 128
#     # channels = 3

#     # model = RDUNet(base_filters, channels)
#     # model.summary()


In [None]:
import os
import csv
import time
import yaml
import torch
import torch.optim as optim
from os.path import join
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from ptflops import get_model_complexity_info

# from model import DeepRDU
# from data_management import NoisyImagesDataset, DataSampler
# from train import fit_model
# from transforms import AdditiveWhiteGaussianNoise, RandomHorizontalFlip, RandomVerticalFlip, RandomRot90
# from utils import set_seed


def write_model_summary_to_csv(file_path, model_name, macs, params, model_size, latency_ms):
    file_exists = os.path.isfile(file_path)
    with open(file_path, mode='a', newline='') as csv_file:
        fieldnames = ['Model Name', 'MACs', 'Parameters', 'Size (MB)', 'Latency (ms)']
        writer = csv.DictWriter(csv_file, fieldnames=fieldnames)

        if not file_exists:
            writer.writeheader()
        writer.writerow({
            'Model Name': model_name,
            'MACs': macs,
            'Parameters': params,
            'Size (MB)': f"{model_size:.2f}",
            'Latency (ms)': f"{latency_ms:.2f}"
        })


def main():
    with open('/workspace/config.yaml', 'r') as stream:
        config = yaml.safe_load(stream)

    model_params = config['model']
    train_params = config['train']
    val_params = config['val']

    set_seed(0)
    model = DeepRDU(**model_params)

    print('Model summary:')
    test_shape = (model_params['channels'], train_params['patch size'], train_params['patch size'])

    with torch.no_grad():
        macs, params = get_model_complexity_info(
            model, test_shape, as_strings=True,
            print_per_layer_stat=False, verbose=False
        )

        # Model size
        torch.save(model.state_dict(), "temp.pth")
        model_size_mb = os.path.getsize("temp.pth") / (1024 * 1024)
        os.remove("temp.pth")

        # Inference latency
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model.to(device)
        dummy_input = torch.randn(1, *test_shape).to(device)
        start_time = time.time()
        for _ in range(50):
            _ = model(dummy_input)
        end_time = time.time()
        latency_ms = ((end_time - start_time) / 50) * 1000

        print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
        print('{:<30}  {:<8}'.format('Number of parameters: ', params))
        print('{:<30}  {:.2f} MB'.format('Model size: ', model_size_mb))
        print('{:<30}  {:.2f} ms'.format('Inference latency: ', latency_ms))

        # Write to CSV
        summary_file_path = '/workspace/model_summary.csv'
        write_model_summary_to_csv(summary_file_path, 'DeepRDU', macs, params, model_size_mb, latency_ms)

    model_name = 'model_color' if model_params['channels'] == 3 else 'model_gray'
    if torch.cuda.device_count() > 1 and 'cuda' in device.type and train_params['multi gpu']:
        model = nn.DataParallel(model)
        print('Using multiple GPUs')

    model = model.to(device)
    param_group = []
    for name, param in model.named_parameters():
        if 'conv' in name and 'weight' in name:
            p = {'params': param, 'weight_decay': train_params['weight decay']}
        else:
            p = {'params': param, 'weight_decay': 0.}
        param_group.append(p)

    with open('/workspace/train_files.txt', 'r') as f_train, open('/workspace/val_files.txt', 'r') as f_val:
        raw_train_files = f_train.read().splitlines()
        raw_val_files = f_val.read().splitlines()
        train_files = list(map(lambda file: join(train_params['dataset path'], file), raw_train_files))
        val_files = list(map(lambda file: join(val_params['dataset path'], file), raw_val_files))

    training_transforms = transforms.Compose([
        RandomHorizontalFlip(),
        RandomVerticalFlip(),
        RandomRot90()
    ])

    train_noise_transform = [AdditiveWhiteGaussianNoise(train_params['noise level'], clip=True)]
    val_noise_transforms = [AdditiveWhiteGaussianNoise(s, fix_sigma=True, clip=True) for s in val_params['noise levels']]

    print('\nLoading training dataset:')
    training_dataset = NoisyImagesDataset(train_files, model_params['channels'],
                                          train_params['patch size'], training_transforms, train_noise_transform)

    print('\nLoading validation dataset:')
    validation_dataset = NoisyImagesDataset(val_files, model_params['channels'],
                                            val_params['patch size'], None, val_noise_transforms)

    print('Training patches:', len(training_dataset))
    print('Validation patches:', len(validation_dataset))
    n_samples = len(training_dataset) // train_params['dataset splits']
    n_epochs = train_params['epochs'] * train_params['dataset splits']
    sampler = DataSampler(training_dataset, num_samples=n_samples)

    data_loaders = {
        'train': DataLoader(training_dataset, train_params['batch size'], num_workers=train_params['workers'], sampler=sampler),
        'val': DataLoader(validation_dataset, val_params['batch size'], num_workers=val_params['workers']),
    }

    # Optimization:
    learning_rate = train_params['learning rate']
    step_size = train_params['scheduler step'] * train_params['dataset splits']

    criterion = nn.L1Loss()
    optimizer = optim.AdamW(param_group, lr=learning_rate)
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=train_params['scheduler gamma'])

    # Train the model
    fit_model(model, data_loaders, model_params['channels'], criterion, optimizer, lr_scheduler, device,
              n_epochs, val_params['frequency'], train_params['checkpoint path'], model_name)


if __name__ == '__main__':
    main()


Model summary:
Computational complexity:       16.86 GMac
Number of parameters:           121.27 M
Model size:                     462.76 MB
Inference latency:              9.01 ms
Using multiple GPUs

Loading training dataset:


100%|█████████████████████████████████████████| 800/800 [04:27<00:00,  2.99it/s]



Loading validation dataset:


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:12<00:00,  1.38it/s]


Training patches: 522939
Validation patches: 10794

Epoch: 105/210 - Learning rate: 2.5000e-05


Training - Loss:1.64323e-02 - PSNR:35.11794 - SSIM:0.90500: 100%|█████████| 3269/3269 [11:20<00:00,  4.80it/s]
Validation - Loss:1.92763e-02 - PSNR:33.25536 - SSIM:0.86532: 100%|███████| 1350/1350 [02:54<00:00,  7.72it/s]



Epoch: 106/210 - Learning rate: 2.5000e-05


Training - Loss:1.63106e-02 - PSNR:35.18881 - SSIM:0.90612: 100%|█████████| 3269/3269 [11:05<00:00,  4.91it/s]
Validation - Loss:1.84198e-02 - PSNR:33.49335 - SSIM:0.88001: 100%|███████| 1350/1350 [03:12<00:00,  7.01it/s]



Epoch: 107/210 - Learning rate: 2.5000e-05


Training - Loss:1.64443e-02 - PSNR:35.11454 - SSIM:0.90441: 100%|█████████| 3269/3269 [11:06<00:00,  4.91it/s]
Validation - Loss:1.88700e-02 - PSNR:33.35095 - SSIM:0.87136: 100%|███████| 1350/1350 [03:11<00:00,  7.03it/s]



Epoch: 108/210 - Learning rate: 2.5000e-05


Training - Loss:1.63591e-02 - PSNR:35.17065 - SSIM:0.90595: 100%|█████████| 3269/3269 [11:21<00:00,  4.80it/s]
Validation - Loss:1.91358e-02 - PSNR:33.17848 - SSIM:0.86720: 100%|███████| 1350/1350 [03:12<00:00,  7.01it/s]



Epoch: 109/210 - Learning rate: 2.5000e-05


Training - Loss:1.63290e-02 - PSNR:35.19394 - SSIM:0.90590: 100%|█████████| 3269/3269 [11:11<00:00,  4.87it/s]
Validation - Loss:1.95145e-02 - PSNR:33.15788 - SSIM:0.86058: 100%|███████| 1350/1350 [03:10<00:00,  7.08it/s]



Epoch: 110/210 - Learning rate: 2.5000e-05


Training - Loss:1.64073e-02 - PSNR:35.12811 - SSIM:0.90526: 100%|█████████| 3269/3269 [11:19<00:00,  4.81it/s]
Validation - Loss:1.88699e-02 - PSNR:33.29340 - SSIM:0.87436: 100%|███████| 1350/1350 [02:52<00:00,  7.82it/s]



Epoch: 111/210 - Learning rate: 2.5000e-05


Training - Loss:1.63642e-02 - PSNR:35.15699 - SSIM:0.90546: 100%|█████████| 3269/3269 [11:19<00:00,  4.81it/s]
Validation - Loss:2.06122e-02 - PSNR:32.83115 - SSIM:0.84199: 100%|███████| 1350/1350 [02:51<00:00,  7.88it/s]



Epoch: 112/210 - Learning rate: 2.5000e-05


Training - Loss:1.63093e-02 - PSNR:35.20084 - SSIM:0.90614: 100%|█████████| 3269/3269 [11:09<00:00,  4.89it/s]
Validation - Loss:1.85489e-02 - PSNR:33.44911 - SSIM:0.87825: 100%|███████| 1350/1350 [02:48<00:00,  8.03it/s]



Epoch: 113/210 - Learning rate: 2.5000e-05


Training - Loss:1.63361e-02 - PSNR:35.18783 - SSIM:0.90581: 100%|█████████| 3269/3269 [10:45<00:00,  5.07it/s]
Validation - Loss:1.85527e-02 - PSNR:33.39878 - SSIM:0.87688: 100%|███████| 1350/1350 [02:49<00:00,  7.98it/s]



Epoch: 114/210 - Learning rate: 2.5000e-05


Training - Loss:1.63186e-02 - PSNR:35.19197 - SSIM:0.90621: 100%|█████████| 3269/3269 [10:39<00:00,  5.11it/s]
Validation - Loss:1.85150e-02 - PSNR:33.40105 - SSIM:0.87999: 100%|███████| 1350/1350 [02:48<00:00,  8.03it/s]



Epoch: 115/210 - Learning rate: 2.5000e-05


Training - Loss:1.63834e-02 - PSNR:35.16469 - SSIM:0.90563: 100%|█████████| 3269/3269 [11:13<00:00,  4.85it/s]
Validation - Loss:1.75269e-02 - PSNR:33.83686 - SSIM:0.89152: 100%|███████| 1350/1350 [02:50<00:00,  7.90it/s]



Epoch: 116/210 - Learning rate: 2.5000e-05


Training - Loss:1.64168e-02 - PSNR:35.13228 - SSIM:0.90486: 100%|█████████| 3269/3269 [10:46<00:00,  5.06it/s]
Validation - Loss:1.79646e-02 - PSNR:33.62786 - SSIM:0.88620: 100%|███████| 1350/1350 [02:51<00:00,  7.89it/s]



Epoch: 117/210 - Learning rate: 2.5000e-05


Training - Loss:1.63121e-02 - PSNR:35.19237 - SSIM:0.90578: 100%|█████████| 3269/3269 [11:04<00:00,  4.92it/s]
Validation - Loss:1.76802e-02 - PSNR:33.84379 - SSIM:0.89099: 100%|███████| 1350/1350 [03:07<00:00,  7.20it/s]



Epoch: 118/210 - Learning rate: 2.5000e-05


Training - Loss:1.62869e-02 - PSNR:35.19721 - SSIM:0.90601: 100%|█████████| 3269/3269 [11:15<00:00,  4.84it/s]
Validation - Loss:1.88649e-02 - PSNR:33.34127 - SSIM:0.87036: 100%|███████| 1350/1350 [02:48<00:00,  7.99it/s]



Epoch: 119/210 - Learning rate: 2.5000e-05


Training - Loss:1.63953e-02 - PSNR:35.14587 - SSIM:0.90565: 100%|█████████| 3269/3269 [10:49<00:00,  5.03it/s]
Validation - Loss:1.89603e-02 - PSNR:33.22877 - SSIM:0.87027: 100%|███████| 1350/1350 [02:53<00:00,  7.78it/s]



Epoch: 120/210 - Learning rate: 2.5000e-05


Training - Loss:1.63979e-02 - PSNR:35.16142 - SSIM:0.90581: 100%|█████████| 3269/3269 [11:17<00:00,  4.82it/s]
Validation - Loss:2.03742e-02 - PSNR:32.89293 - SSIM:0.84772: 100%|███████| 1350/1350 [02:49<00:00,  7.96it/s]



Epoch: 121/210 - Learning rate: 2.5000e-05


Training - Loss:1.63171e-02 - PSNR:35.17976 - SSIM:0.90629: 100%|█████████| 3269/3269 [11:00<00:00,  4.95it/s]
Validation - Loss:1.94794e-02 - PSNR:33.14035 - SSIM:0.86376: 100%|███████| 1350/1350 [02:48<00:00,  7.99it/s]



Epoch: 122/210 - Learning rate: 2.5000e-05


Training - Loss:1.63356e-02 - PSNR:35.18038 - SSIM:0.90578: 100%|█████████| 3269/3269 [11:10<00:00,  4.88it/s]
Validation - Loss:1.91038e-02 - PSNR:33.23269 - SSIM:0.86744: 100%|███████| 1350/1350 [02:50<00:00,  7.93it/s]



Epoch: 123/210 - Learning rate: 2.5000e-05


Training - Loss:1.63133e-02 - PSNR:35.18588 - SSIM:0.90559: 100%|█████████| 3269/3269 [10:49<00:00,  5.03it/s]
Validation - Loss:1.77369e-02 - PSNR:33.80292 - SSIM:0.88700: 100%|███████| 1350/1350 [02:53<00:00,  7.77it/s]



Epoch: 124/210 - Learning rate: 2.5000e-05


Training - Loss:1.63151e-02 - PSNR:35.19642 - SSIM:0.90589: 100%|█████████| 3269/3269 [11:07<00:00,  4.90it/s]
Validation - Loss:1.96216e-02 - PSNR:33.05294 - SSIM:0.85801: 100%|███████| 1350/1350 [02:51<00:00,  7.86it/s]



Epoch: 125/210 - Learning rate: 2.5000e-05


Training - Loss:1.63115e-02 - PSNR:35.20604 - SSIM:0.90576: 100%|█████████| 3269/3269 [10:43<00:00,  5.08it/s]
Validation - Loss:2.06355e-02 - PSNR:32.82948 - SSIM:0.84413: 100%|███████| 1350/1350 [02:47<00:00,  8.05it/s]



Epoch: 126/210 - Learning rate: 2.5000e-05


Training - Loss:1.62955e-02 - PSNR:35.20983 - SSIM:0.90612: 100%|█████████| 3269/3269 [10:48<00:00,  5.04it/s]
Validation - Loss:1.79805e-02 - PSNR:33.57939 - SSIM:0.88582: 100%|███████| 1350/1350 [02:51<00:00,  7.87it/s]



Epoch: 127/210 - Learning rate: 2.5000e-05


Training - Loss:1.63691e-02 - PSNR:35.15621 - SSIM:0.90539: 100%|█████████| 3269/3269 [11:33<00:00,  4.71it/s]
Validation - Loss:1.79529e-02 - PSNR:33.68562 - SSIM:0.88637: 100%|███████| 1350/1350 [03:12<00:00,  7.00it/s]



Epoch: 128/210 - Learning rate: 2.5000e-05


Training - Loss:1.63116e-02 - PSNR:35.19708 - SSIM:0.90565: 100%|█████████| 3269/3269 [12:01<00:00,  4.53it/s]
Validation - Loss:1.90426e-02 - PSNR:33.23120 - SSIM:0.86904: 100%|███████| 1350/1350 [03:09<00:00,  7.11it/s]



Epoch: 129/210 - Learning rate: 2.5000e-05


Training - Loss:1.62443e-02 - PSNR:35.23886 - SSIM:0.90652: 100%|█████████| 3269/3269 [11:25<00:00,  4.77it/s]
Validation - Loss:2.19804e-02 - PSNR:32.41868 - SSIM:0.82258: 100%|███████| 1350/1350 [02:48<00:00,  8.02it/s]



Epoch: 130/210 - Learning rate: 2.5000e-05


Training - Loss:1.64276e-02 - PSNR:35.12796 - SSIM:0.90559: 100%|█████████| 3269/3269 [10:32<00:00,  5.16it/s]
Validation - Loss:1.92873e-02 - PSNR:33.19252 - SSIM:0.86520: 100%|███████| 1350/1350 [02:47<00:00,  8.04it/s]



Epoch: 131/210 - Learning rate: 2.5000e-05


Training - Loss:1.63378e-02 - PSNR:35.15799 - SSIM:0.90567: 100%|█████████| 3269/3269 [11:05<00:00,  4.91it/s]
Validation - Loss:1.93118e-02 - PSNR:33.07373 - SSIM:0.86456: 100%|███████| 1350/1350 [02:52<00:00,  7.81it/s]



Epoch: 132/210 - Learning rate: 2.5000e-05


Training - Loss:1.63570e-02 - PSNR:35.16311 - SSIM:0.90569: 100%|█████████| 3269/3269 [10:59<00:00,  4.96it/s]
Validation - Loss:1.99440e-02 - PSNR:33.00771 - SSIM:0.85584: 100%|███████| 1350/1350 [02:56<00:00,  7.66it/s]



Epoch: 133/210 - Learning rate: 2.5000e-05


Training - Loss:1.62951e-02 - PSNR:35.19949 - SSIM:0.90645:  31%|██▊      | 1006/3269 [03:51<08:15,  4.56it/s]

# test

In [None]:
#main_test.py
import os
import yaml
import torch
import numpy as np
import scipy.io as sio
from os.path import join


#from model import RDUNet
# from model import RatUNet, BasicBlock
from torchvision.transforms import transforms
from skimage import io
from skimage.metrics import structural_similarity, peak_signal_noise_ratio

# from utils import build_ensemble, separate_ensemble, predict_ensemble, mod_pad, mod_crop

f1=open('/content/drive/MyDrive/RDUNet-main/Results/test_gray_bsd.txt','w')

def predict(model, noisy_dataset, gt_dataset, device, padding, n_channels, results_path):
    # Load test datasets in format .mat
    X = sio.loadmat(noisy_dataset)['data'].flatten()
    Y = sio.loadmat(gt_dataset)['label'].flatten()


    y_pred, y_pred_ens = [], []
    psnr_list, ssim_list = [], []
    ens_psnr_list, ens_ssim_list = [], []

    n_images = len(X)
    print(n_images)
    multi_channel = True if n_channels == 3 else False

    for i in range(n_images):
        x, y = X[i], Y[i]
        # print("image_x",x.shape)
        # print('image_y',y.shape)

        if padding:
            x, size = mod_pad(x, 8)
            # print(x.shape)
            # print(size)
        else:
            x, y = mod_crop(x, 8), mod_crop(y, 8)


        x = build_ensemble(x, normalize=False)

        with torch.no_grad():
            y_hat_ens = predict_ensemble(model, x, device)
            # print(y_hat_ens.shape)
            y_hat_ens, y_hat = separate_ensemble(y_hat_ens, return_single=True)

            if padding:
                y_hat = y_hat[:size[0], :size[1], ...]
                y_hat_ens = y_hat_ens[:size[0], :size[1], ...]

            y_pred.append(y_hat)
            y_pred_ens.append(y_hat_ens)
            psnr = peak_signal_noise_ratio(y, y_hat, data_range=1.)
            ssim = structural_similarity(y, y_hat, data_range=1., multichannel=multi_channel, gaussian_weights=True,
                                         sigma=1.5, use_sample_covariance=False)

            psnr_ens = peak_signal_noise_ratio(y, y_hat_ens, data_range=1.)
            ssim_ens = structural_similarity(y, y_hat_ens, data_range=1., multichannel=multi_channel,
                                             gaussian_weights=True, sigma=1.5, use_sample_covariance=False)

            psnr_list.append(psnr)
            ssim_list.append(ssim)

            ens_psnr_list.append(psnr_ens)
            ens_ssim_list.append(ssim_ens)
            print('Image: {} - PSNR: {:.4f} - SSIM: {:.4f} - ens PSNR: {:.4f}'
                  ' - ens SSIM: {:.4f}'.format(i + 1, psnr, ssim, psnr_ens, ssim_ens))
            f1.write('Image: {} - PSNR: {:.4f} - SSIM: {:.4f} - ens PSNR: {:.4f}'
                  ' - ens SSIM: {:.4f} \n'.format(i + 1, psnr, ssim, psnr_ens, ssim_ens))


    if results_path is not None:
        for i in range(n_images):
            y_hat = (255 * y_pred[i]).astype('uint8')
            y_hat_ens = (255 * y_pred_ens[i]).astype('uint8')

            y_hat = np.squeeze(y_hat)
            y_hat_ens = np.squeeze(y_hat_ens)

            os.makedirs(results_path, exist_ok=True)

            name = os.path.join(results_path, '{}_{:.4f}_{:.4f}.png'.format(i, psnr_list[i], ssim_list[i]))
            io.imsave(name, y_hat)

            name = os.path.join(results_path, '{}_{:.4f}_{:.4f}_ens.png'.format(i, ens_psnr_list[i], ens_ssim_list[i]))
            io.imsave(name, y_hat_ens)

    return np.mean(psnr_list), np.mean(ssim_list), np.mean(ens_psnr_list), np.mean(ens_ssim_list)


if __name__ == '__main__':
    with open('/content/drive/MyDrive/RDUNet-main/config.yaml', 'r') as stream:
        config = yaml.safe_load(stream)

    model_params = config['model']
    test_params = config['test']
    n_channels = model_params['channels']

    if n_channels == 3:
        model_path = join(test_params['pretrained models path'], 'model_color.pth')
        noisy_datasets = ['noisy_McMaster_']  # Also tested in Kodak24 and Urban100 datasets.
        gt_datasets = ['McMaster_label']

    else:
        model_path = join(test_params['pretrained models path'], 'model_gray.pth')
        noisy_datasets = ['noisy_bsd68_']   # Also tested in BSD68 and Kodak24 datasets.
        gt_datasets = ['bsd68_label']

    model_params = config['model']
    model = DeepRDU(**model_params)
    # model = DN(**model_params)
    # model = RDUNet(**model_params)
    # model = RatUNet(BasicBlock, 64)

    device = torch.device(test_params['device'])
    print("Using device: {}".format(device))

    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict, strict=True)
    model = model.to(device)
    model.eval()

    base_directory = "/content/drive/MyDrive/RDUNet-main/Datasets/Test_matfile"

    for noisy_dataset, gt_dataset in zip(noisy_datasets, gt_datasets):
        print('Noisy Dataset: ', noisy_dataset)
        print("Gt Dataset:", gt_dataset)

        for noise_level in test_params['noise levels']:
            extension = '_color' if model_params['channels'] == 3 else '_gray'
            print('extension', extension)
            print('noise level:', noise_level)
            print("base directory", base_directory)

            # noisy_path = os.path.join(test_params['/content/drive/MyDrive/RDUNet-main/Datasets/Test_matfile/'], ''.join([noisy_dataset, str(noise_level), extension, '.mat']))
            # label_path = join(test_params['/content/drive/MyDrive/RDUNet-main/Datasets/Test_matfile/'], ''.join([gt_dataset, extension, '.mat']))
            noisy_path = os.path.join(base_directory, ''.join([noisy_dataset, str(noise_level), extension, '.mat']))
            print("Generated Noisy Path:", noisy_path)
            label_path = os.path.join(base_directory, ''.join([gt_dataset, extension, '.mat']))
            print("Generated GT Path:", label_path)

            if test_params['save images']:
                save_path = join(
                    test_params['results path'], ''.join([noisy_dataset.replace('noisy_', ''), 'sigma_', str(noise_level)])
                )
            else:
                save_path = None

            psnr, ssim, psnr_ens, ssim_ens = predict(model, noisy_path, label_path, device,
                                                     test_params['padding'],  n_channels, save_path)

            message = 'sigma = {} - PSNR: {:.4f} - SSIM: {:.4f} - ens PSNR: {:.4f} - ens SSIM: {:.4f} \n'
            print(message.format(noise_level, np.around(psnr, decimals=4), np.around(ssim, decimals=4),
                                 np.around(psnr_ens, decimals=4), np.around(ssim_ens, decimals=4)))
            f1.write(message.format(noise_level, np.around(psnr, decimals=4), np.around(ssim, decimals=4),
                                 np.around(psnr_ens, decimals=4), np.around(ssim_ens, decimals=4)))

Using device: cuda:0
Noisy Dataset:  noisy_bsd68_
Gt Dataset: bsd68_label
extension _gray
noise level: 10
base directory /content/drive/MyDrive/RDUNet-main/Datasets/Test_matfile
Generated Noisy Path: /content/drive/MyDrive/RDUNet-main/Datasets/Test_matfile/noisy_bsd68_10_gray.mat
Generated GT Path: /content/drive/MyDrive/RDUNet-main/Datasets/Test_matfile/bsd68_label_gray.mat
68


  ssim = structural_similarity(y, y_hat, data_range=1., multichannel=multi_channel, gaussian_weights=True,
  ssim_ens = structural_similarity(y, y_hat_ens, data_range=1., multichannel=multi_channel,


Image: 1 - PSNR: 29.8978 - SSIM: 0.8552 - ens PSNR: 30.0117 - ens SSIM: 0.8587
Image: 2 - PSNR: 34.2142 - SSIM: 0.9372 - ens PSNR: 34.3984 - ens SSIM: 0.9394
Image: 3 - PSNR: 33.8786 - SSIM: 0.9141 - ens PSNR: 33.9831 - ens SSIM: 0.9153
Image: 4 - PSNR: 35.0214 - SSIM: 0.9319 - ens PSNR: 35.1360 - ens SSIM: 0.9334
Image: 5 - PSNR: 32.1384 - SSIM: 0.9263 - ens PSNR: 32.2327 - ens SSIM: 0.9281
Image: 6 - PSNR: 36.6900 - SSIM: 0.9481 - ens PSNR: 36.8303 - ens SSIM: 0.9496
Image: 7 - PSNR: 33.0203 - SSIM: 0.9203 - ens PSNR: 33.1457 - ens SSIM: 0.9220
Image: 8 - PSNR: 31.0953 - SSIM: 0.9064 - ens PSNR: 31.1612 - ens SSIM: 0.9077
Image: 9 - PSNR: 33.2749 - SSIM: 0.9252 - ens PSNR: 33.3835 - ens SSIM: 0.9270
Image: 10 - PSNR: 33.8911 - SSIM: 0.9124 - ens PSNR: 34.0049 - ens SSIM: 0.9149
Image: 11 - PSNR: 32.7461 - SSIM: 0.9339 - ens PSNR: 32.9808 - ens SSIM: 0.9368
Image: 12 - PSNR: 32.6704 - SSIM: 0.8973 - ens PSNR: 32.7844 - ens SSIM: 0.8993
Image: 13 - PSNR: 33.7737 - SSIM: 0.9178 - ens PS

In [None]:
import torch
import yaml
from torchsummary import summary
# from your_model_module import DeepRDU  # Replace 'your_model_module' with the actual module where you defined the DeepRDU class

# Load the model configuration from the YAML file
with open('/content/drive/MyDrive/RDUNet-main/config.yaml', 'r') as stream:
    config = yaml.safe_load(stream)
    model_params = config['model']

# Create an instance of the DeepRDU model
model = DeepRDU(**model_params)

# Move the model to the device (cuda if available, else cpu)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Specify the input size (channels, height, width)
input_size = (3, 64, 64)  # Assuming your input size is (3, 256, 256)

# Print the model summary
summary(model, input_size=input_size)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 64, 64]           1,792
            Conv2d-2           [-1, 64, 64, 64]          36,928
            Conv2d-3           [-1, 64, 64, 64]          36,864
       BatchNorm2d-4           [-1, 64, 64, 64]             128
             PReLU-5           [-1, 64, 64, 64]              64
            Conv2d-6           [-1, 64, 64, 64]          36,864
       BatchNorm2d-7           [-1, 64, 64, 64]             128
             PReLU-8           [-1, 64, 64, 64]              64
          ResBlock-9           [-1, 64, 64, 64]               0
            PReLU-10           [-1, 64, 64, 64]              64
     EncoderBlock-11           [-1, 64, 64, 64]               0
        MaxPool2d-12           [-1, 64, 32, 32]               0
           Conv2d-13          [-1, 128, 32, 32]          73,856
           Conv2d-14          [-1, 128,

In [None]:
! pip install -q torchview
! pip install -q -U graphviz

In [None]:
from torchview import draw_graph
# from torchvision.models import resnet18, GoogLeNet, densenet, vit_b_16
import graphviz

# when running on VSCode run the below command
# svg format on vscode does not give desired result
graphviz.set_jupyter_format('png')

# model_graph1 = draw_graph(model, input_size=(1,3,64,64), expand_nested=True)
model_graph1 = draw_graph(model, input_size=(1,3,64,64), roll=True, expand_nested=True)

model_graph1.visual_graph