In [50]:
import numpy as np
import matplotlib.pyplot as plt
# import seaborn as sns 
import tensorflow as tf
# import keras 
# from keras.models import Sequential
# from keras.layers import Conv2D,MaxPool2D, UpSampling2D,Dropout

import numpy as np
import torch.nn as nn
import torch
import torch.optim as optim
import argparse

In [51]:
from keras.datasets import mnist
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# to get the shape of the data 
print("x_train shape:",x_train.shape)
print("x_test shape", x_test.shape)



x_train shape: (60000, 28, 28)
x_test shape (10000, 28, 28)


In [52]:
val_images = x_test[:9000]
test_images = x_test[9000:]

In [53]:


val_images = val_images.astype('float32') / 255.0
val_images = np.reshape(val_images,(val_images.shape[0],28,28,1))

test_images = test_images.astype('float32') / 255.0
test_images = np.reshape(test_images,(test_images.shape[0],28,28,1))

train_images = x_train.astype("float32") / 255.0
train_images = np.reshape(train_images, (train_images.shape[0],28,28,1))



In [54]:
factor = 0.39
train_noisy_images = train_images + factor * np.random.normal(loc = 0.0,scale = 1.0,size = train_images.shape)
val_noisy_images = val_images + factor * np.random.normal(loc = 0.0,scale = 1.0,size = val_images.shape)
test_noisy_images = test_images + factor * np.random.normal(loc = 0.0,scale = 1.0,size = test_images.shape)

# here maximum pixel value for our images may exceed 1 so we have to clip the images
train_noisy_images = np.clip(train_noisy_images,0.,1.)
val_noisy_images = np.clip(val_noisy_images,0.,1.)
test_noisy_images = np.clip(test_noisy_images,0.,1.)



In [55]:

class PrintShape(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def forward(self,x):
        print(x.shape)
        return x

In [56]:
class BF_CNN(nn.Module):

    def __init__(self, args):
        super(BF_CNN, self).__init__()

        self.padding = args.padding
        self.num_kernels = args.num_kernels
        self.kernel_size = args.kernel_size
        self.num_layers = args.num_layers
        self.num_channels = args.num_channels

        self.conv_layers = nn.ModuleList([])
        self.running_sd = nn.ParameterList([])
        self.gammas = nn.ParameterList([])


        self.conv_layers.append(nn.Conv2d(self.num_channels,self.num_kernels, self.kernel_size, padding=self.padding , bias=False))

        for l in range(1,self.num_layers-1):
            self.conv_layers.append(PrintShape())
            self.conv_layers.append(nn.Conv2d(self.num_kernels ,self.num_kernels, self.kernel_size, padding=self.padding , bias=False))
            self.running_sd.append( nn.Parameter(torch.ones(1,self.num_kernels,1,1), requires_grad=False) )
            g = (torch.randn( (1,self.num_kernels,1,1) )*(2./9./64.)).clamp_(-0.025,0.025)
            self.gammas.append(nn.Parameter(g, requires_grad=True) )

        self.conv_layers.append(nn.Conv2d(self.num_kernels,self.num_channels, self.kernel_size, padding=self.padding , bias=False))
        


    def forward(self, x):
        relu = nn.ReLU(inplace=True)
        x = relu(self.conv_layers[0](x))
        for l in range(1,self.num_layers-1):
            x = self.conv_layers[l](x)
            # BF_BatchNorm
            sd_x = torch.sqrt(x.var(dim=(0,2,3) ,keepdim = True, unbiased=False)+ 1e-05)

            if self.conv_layers[l].training:
                x = x / sd_x.expand_as(x)
                self.running_sd[l-1].data = (1-.1) * self.running_sd[l-1].data + .1 * sd_x
                x = x * self.gammas[l-1].expand_as(x)

            else:
                x = x / self.running_sd[l-1].expand_as(x)
                x = x * self.gammas[l-1].expand_as(x)

            x = relu(x)

        x = self.conv_layers[-1](x)

        return x



In [57]:
def def_args(grayscale=False):
    '''
    @ grayscale: if True, number of input and output channels are set to 1. Otherwise 3
    @ training_data: models provided in here have been trained on {BSD400, mnist, BSD300}
    @ training_noise: standard deviation of noise during training the denoiser
    '''
    parser = argparse.ArgumentParser(description='BF_CNN_color')
    parser.add_argument('--dir_name', default= '../noise_range_')
    parser.add_argument('--kernel_size', default= 3)
    parser.add_argument('--padding', default= 1)
    parser.add_argument('--num_kernels', default= 28)
    parser.add_argument('--num_layers', default= 20)
    if grayscale is True: 
        parser.add_argument('--num_channels', default= 1)
    else:
        parser.add_argument('--num_channels', default= 3)
    
    args = parser.parse_args('')
    return args


In [58]:
args = def_args(grayscale=True)
# Initialize the model, loss function, and optimizer
model = BF_CNN(args)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Load your dataset (replace 'your_dataset_path' with your actual path)
# transform = transforms.Compose([transforms.ToTensor()])
# dataset = ImageFolder(root='C://Users//ANT//Documents//info//IIN_universal_inverse_problem//test_images//color//BSD100', transform=transform)
# dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# Training loop
num_epochs = 50
n_batch = 64
max_std_dev = 0.3  # Set the initial standard deviation of Gaussian noise

In [59]:
Xtrain = torch.FloatTensor(train_images)
Xtest = torch.FloatTensor(test_images)
Xval = torch.FloatTensor(val_images)

Xtrain_noisy = torch.FloatTensor(train_noisy_images)
Xtest_noisy = torch.FloatTensor(test_noisy_images)
Xval_noisy = torch.FloatTensor(val_noisy_images)

N = Xtrain.shape[0]

In [60]:
for epoch in range(num_epochs):
    running_loss = 0.0

    p = np.random.permutation(N)
    X = Xtrain[p,...]
    X_noisy = Xtrain_noisy[p,...]
    for i in range(0,N,n_batch):

        X_batch = X[i:(i+n_batch),...]
        X_noisy_batch = X_noisy[i:(i+n_batch),...]

        optimizer.zero_grad()

        # Forward pass
        outputs = model(X_noisy_batch)
        loss = criterion(outputs, X_batch)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    average_loss = running_loss / N
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {average_loss:.4f}')

# Save the trained model
torch.save(model.state_dict(), 'bf_cnn_mnist.pth')

RuntimeError: Given groups=1, weight of size [28, 1, 3, 3], expected input[64, 28, 28, 1] to have 1 channels, but got 28 channels instead