In [138]:
import torch.nn as nn
import torch

In [139]:
class CNN_ReLU(nn.Module):
	def __init__(self, in_channels, out_channels, filter_size):
		super(CNN_ReLU, self).__init__()
		padding = int((filter_size - 1) / 2)
		self.layer = nn.Sequential(nn.Conv2d(in_channels, out_channels, filter_size, padding=padding),
		                           nn.ReLU())
		
	def forward(self, x):
		return self.layer(x)


class CNN_BN_ReLU(nn.Module):
	def __init__(self, in_channels, out_channels, filter_size):
		super(CNN_BN_ReLU, self).__init__()
		padding = int((filter_size - 1) / 2)
		self.layer = nn.Sequential(nn.Conv2d(in_channels, out_channels, filter_size, padding=padding),
		                           nn.BatchNorm2d(in_channels),
		                           nn.ReLU())
		
	def forward(self, x):
		return self.layer(x)
	

class CNN(nn.Module):
	def __init__(self, in_channels, out_channels, filter_size):
		super(CNN, self).__init__()
		padding = int((filter_size - 1) / 2)
		self.layer = nn.Conv2d(in_channels, out_channels, filter_size, padding=padding)
		
	def forward(self, x):
		return self.layer(x)
	

In [140]:
class DnCNN(nn.Module):
	def __init__(self, num_layers, input_channels, output_channels, filter_size):
		super(DnCNN, self).__init__()
		self.layers = nn.Sequential(
			CNN_ReLU(input_channels, output_channels, filter_size),
			nn.Sequential(*[CNN_BN_ReLU(output_channels, output_channels, filter_size) for x in range(num_layers)]),
			CNN(output_channels, input_channels, filter_size))
		
	def forward(self, x):
		return self.layers(x)

In [141]:
test = DnCNN(1, 1, 64, 3)
image = torch.rand((128, 1, 40, 40))  
output = test(image)



Initializing DnCNN
Forward DnCNN
