In [1]:
import torch
import torch.nn

In [2]:
# class ProximalOperatorM(torch.nn.Module):

#     def __init__(self, in_channels: int) -> None:

#         super(ProximalOperatorM, self).__init__()

#         self.in_channels = in_channels
#         self.num_features = in_channels // 2
        
#         conv_X = torch.nn.Conv2d(self.in_channels, self.num_features, (3, 3), padding='same')
#         activation_X = torch.nn.ReLU()
#         self.conv = torch.nn.Sequential([conv_X, activation_X])

        

#         for i in range(0, 5):
#             for j in range(1, 3):
#                 # Create conv_1 and conv_2
#                 name = 'iteration_'+str(i)+':'+'conv'+str(j)
#                 conv_X = torch.nn.Conv2d(self.num_features, self.num_features, (3, 3), padding='same')
#                 activation_X = torch.nn.ReLU()
#                 sequence = torch.nn.Sequential([conv_X, activation_X])
#                 self.add_module(name=name, module=sequence)

#         self.out = torch.nn.Conv2d(self.num_features, self.in_channels, (3, 3), padding='same')


#     def forward(self, image):

#         out_conv = self.conv(image)

#         for i in range(0, 5):
            
#             conv_1 = self.get_submodule(target='iteration_'+str(i)+':'+'conv_'+str(1))
#             conv_2 = self.get_submodule(target='iteration_'+str(i)+':'+'conv_'+str(2))

#             out_conv_1 = conv_1(out_conv)
#             out_conv_2 = conv_2(out_conv_1)
#             out_conv = out_conv + out_conv_2

#         out_out = self.out(out_conv)

#         return out_out      

In [3]:
# class ProximalOperatorO(torch.nn.Module):

#     def __init__(self, in_channels: int, num_features: int) -> None:

#         super(ProximalOperatorO, self).__init__()

#         self.in_channels = in_channels
#         self.num_features = num_features
        
#         conv_X = torch.nn.Conv2d(self.in_channels, self.num_features, (3, 3), padding='same')
#         activation_X = torch.nn.ReLU()
#         self.conv = torch.nn.Sequential([conv_X, activation_X])

#         self.out = torch.nn.Conv2d(self.in_channels, (3, 3), padding='same')

#         for i in range(0, 5):
#             for j in range(1, 3):
#                 # Create conv_1 and conv_2
#                 name = 'iteration_'+str(i)+':'+'conv'+str(j)
#                 conv_X = torch.nn.Conv2d(self.in_channels, self.num_features, (3, 3), padding='same')
#                 activation_X = torch.nn.ReLU()
#                 sequence = torch.nn.Sequential([conv_X, activation_X])
#                 self.add_module(name=name, module=sequence)


#     def forward(self, image):

#         out_conv = self.conv(image)

#         for i in range(0, 5):
            
#             conv_1 = self.get_submodule(target='iteration_'+str(i)+':'+'conv_'+str(1))
#             conv_2 = self.get_submodule(target='iteration_'+str(i)+':'+'conv_'+str(2))

#             out_conv_1 = conv_1(out_conv)
#             out_conv_2 = conv_2(out_conv_1)
#             out_conv = out_conv + out_conv_2

#         out_out = self.out(out_conv)
#         B = out_out[:, :, :, 0:1]
#         Z = out_out[:, :, :, 1:self.in_channels]

#         return B, Z

In [4]:
class ProximalOperator(torch.nn.Module):

    def __init__(self, in_channels: int, num_features: int) -> None:

        super(ProximalOperator, self).__init__()

        self.in_channels = in_channels
        self.num_features = num_features
        
        conv_X = torch.nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.num_features,
            kernel_size=(3, 3),
            padding='same'
        )

        activation_X = torch.nn.ReLU()
        self.conv = torch.nn.Sequential(conv_X, activation_X)

        for i in range(0, 5):
            for j in range(1, 3):
                # Create conv_1 and conv_2
                name = 'iteration_'+str(i)+':'+'conv_'+str(j)
                conv_X =  conv_X = torch.nn.Conv2d(
                    in_channels=self.num_features,
                    out_channels=self.num_features,
                    kernel_size=(3, 3),
                    padding='same'
                )
                activation_X = torch.nn.ReLU()
                sequence = torch.nn.Sequential(conv_X, activation_X)
                self.add_module(name=name, module=sequence)

        self.out = torch.nn.Conv2d(
            in_channels=self.num_features,
            out_channels=self.in_channels,
            kernel_size=(3, 3),
            padding='same'
        )


    def forward(self, image):

        out_conv = self.conv(image)

        for i in range(0, 5):
            
            conv_1 = self.get_submodule(target='iteration_'+str(i)+':'+'conv_'+str(1))
            conv_2 = self.get_submodule(target='iteration_'+str(i)+':'+'conv_'+str(2))

            out_conv_1 = conv_1(out_conv)
            out_conv_2 = conv_2(out_conv_1)
            out_conv = out_conv + out_conv_2

        out_out = self.out(out_conv)

        return out_out


class Prox_M(ProximalOperator):

    def __init__(self, in_channels: int) -> None:
        super(Prox_M, self).__init__(in_channels, in_channels // 2)

class Prox_O(ProximalOperator):

    def __init__(self, in_channels: int, num_features: int) -> None:
        super(Prox_O, self).__init__(in_channels, num_features)

    def forward(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        out_out = super().forward(image)
        B = out_out[:, 0:1, :, :]
        Z = out_out[:, 1:self.in_channels, :, :]
        return B, Z

In [5]:
class Unfolding(torch.nn.Module):

    def __init__(self, in_channels: int, num_features: int = 48, iterations: int = 10) -> None:

        """
            in_channels : img.shape[2]
                + grey level => in_channels=1
                + rgb color => in_channels=3
        """

        super(Unfolding, self).__init__()

        self.in_channels = in_channels
        self.num_features = num_features
        self.iterations = iterations

        # Initial
        self.O_0 = torch.nn.Conv2d(
            in_channels=self.in_channels, 
            out_channels=self.num_features,
            kernel_size=(3, 3),
            padding='same'
        )

        self.add_module(
            name='O_0',
            module=self.O_0
        )

        self.stepO = torch.tensor(data=0.1, dtype=torch.float, requires_grad=True)
        self.stepM = torch.tensor(data=0.1, dtype=torch.float, requires_grad=True)

        self.prox_M = Prox_M(in_channels=self.num_features*3)
        self.add_module(name='Prox_M', module=self.prox_M)

        self.prox_O = Prox_O(in_channels=self.num_features+self.in_channels, num_features=self.num_features)
        self.add_module(name='Prox_O', module=self.prox_O)

        for i in range(0, iterations):
            self.__init_iteration(i)


    def __init_iteration(self, i: int) -> None:

        self.add_module(
            name='iteration_'+str(i)+':X1',
            module=torch.nn.Conv2d(in_channels=self.in_channels, out_channels=self.num_features, kernel_size=(3, 3), dilation=(1, 1), padding='same', bias=False)
        )

        self.add_module(
            name='iteration_'+str(i)+':X2',
            module=torch.nn.Conv2d(in_channels=self.in_channels, out_channels=self.num_features, kernel_size=(3, 3), dilation=(2, 2), padding='same', bias=False)
        )

        self.add_module(
            name='iteration_'+str(i)+':X4',
            module=torch.nn.Conv2d(in_channels=self.in_channels, out_channels=self.num_features, kernel_size=(3, 3), dilation=(4, 4), padding='same', bias=False)
        )


        self.add_module(
            name='iteration_'+str(i)+':X11',
            module=torch.nn.Conv2d(in_channels=self.num_features, out_channels=self.num_features, kernel_size=(3, 3), dilation=(1, 1), padding='same', bias=False)
        )

        self.add_module(
            name='iteration_'+str(i)+':X22',
            module=torch.nn.Conv2d(in_channels=self.num_features, out_channels=self.num_features, kernel_size=(3, 3), dilation=(2, 2), padding='same', bias=False)
        )

        self.add_module(
            name='iteration_'+str(i)+':X44',
            module=torch.nn.Conv2d(in_channels=self.num_features, out_channels=self.num_features, kernel_size=(3, 3), dilation=(4, 4), padding='same', bias=False)
        )

        if 0 < i :

            self.add_module(
                name='iteration_'+str(i)+':X111',
                module=torch.nn.Conv2d(in_channels=self.num_features, out_channels=self.num_features, kernel_size=(3, 3), dilation=(1, 1), padding='same', bias=False)
            )
        
            self.add_module(
                name='iteration_'+str(i)+':X222',
                module=torch.nn.Conv2d(in_channels=self.num_features, out_channels=self.num_features, kernel_size=(3, 3), dilation=(2, 2), padding='same', bias=False)
            )
            
            self.add_module(
                name='iteration_'+str(i)+':X444',
                module=torch.nn.Conv2d(in_channels=self.num_features, out_channels=self.num_features, kernel_size=(3, 3), dilation=(4, 4), padding='same', bias=False)
            )

    def __apply_layer(self, iter: int, name: str, input: torch.Tensor) -> torch.Tensor:
        layer = self.get_submodule(target='iteration_'+str(iter)+':'+name)
        return layer(input)

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        
        # Initial
        out_O_0 = self.O_0(image)
        tmp = torch.concat([out_O_0, image], 1)
        O_previous, Z = self.prox_O(tmp)
        H = image - O_previous

        # Iteration 0

        X_1 = self.__apply_layer(iter=0, name='X1', input=H)
        X_2 = self.__apply_layer(iter=0, name='X2', input=H)
        X_4 = self.__apply_layer(iter=0, name='X4', input=H)
        
        M = self.prox_M(torch.concat([X_1, X_2, X_4], 1))
  
        X_1 = self.__apply_layer(iter=0, name='X11', input=M[:, 0:self.num_features, :, :])
        X_2 = self.__apply_layer(iter=0, name='X22', input=M[:, self.num_features:self.num_features*2, :, :])
        X_4 = self.__apply_layer(iter=0, name='X44', input=M[:, self.num_features*2:self.num_features*3, :, :])
    
        h_current = torch.concat([X_1, X_2, X_4], 1)
        # H_current = torch.sum(h_current, h_current.dim(), keepdim=True)
        H_current = h_current.sum(1).unsqueeze(1)

        O_current = image-H_current
        tmp = torch.concat([Z, self.stepO*O_current+(1.0-self.stepO)*O_previous], 1)

        O_current, Z = self.prox_O(tmp)

        # Iteration 1 to 9
        for i in range(1, self.iterations):

            O_previous = O_current
            H = image - O_previous

            X_1 = self.__apply_layer(iter=i, name='X11', input=M[:, 0:self.num_features, :, :])
            X_2 = self.__apply_layer(iter=i, name='X22', input=M[:, self.num_features:self.num_features*2, :, :])
            X_4 = self.__apply_layer(iter=i, name='X44', input=M[:, self.num_features*2:self.num_features*3, :, :])

            H_star = torch.concat([X_1, X_2, X_4], 1)
            # H_current = torch.sum(h_current, h_current.dim(), keepdim=True)
            H_star = h_current.sum(1).unsqueeze(1)

            X_1 = self.__apply_layer(iter=i, name='X1', input=H_star-H)
            X_2 = self.__apply_layer(iter=i, name='X2', input=H_star-H)
            X_4 = self.__apply_layer(iter=i, name='X4', input=H_star-H)

            # stepM = self.get_submodule(target='iteration_'+str(i)+':stepM')
            M = self.prox_M(M-self.stepM*torch.concat([X_1, X_2, X_4], 1))

            X_1 = self.__apply_layer(iter=i, name='X111', input=M[:, 0:self.num_features, :, :])
            X_2 = self.__apply_layer(iter=i, name='X222', input=M[:, self.num_features:self.num_features*2, :, :])
            X_4 = self.__apply_layer(iter=i, name='X444', input=M[:, self.num_features*2:self.num_features*3, :, :])

            h_current = torch.concat([X_1, X_2, X_4], 1)
            # H_current = torch.sum(h_current, h_current.dim(), keepdim=True)
            H_current = h_current.sum(1).unsqueeze(1)

            O_current = image-H_current
            tmp = torch.concat([Z, self.stepO*O_current+(1.0-self.stepO)*O_previous], 1)
            O_current, Z = self.prox_O(tmp)

        final_out = O_current
        
        return final_out

In [6]:
# Dataset customization : https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

import torchvision
import pathlib
import os
import torch.nn.functional
import torch.utils.data


class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, root_dir: pathlib.Path) -> None:
        """
        Args:
            root_dir (pathlib.Path): Directory with all the images.
        """
        super().__init__()
        self.root_dir: pathlib.Path = root_dir
        self.image_names: list[pathlib.Path] = [ filename.stem for filename in map(lambda e : pathlib.Path(e), os.listdir(root_dir / 'Artifacts')) ]

    def __len__(self) -> int:
        return len(self.image_names)

    def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
        
        filename_artifact: pathlib.Path = self.root_dir / 'Artifacts' / (self.image_names[index] + '.jpg')
        image_artifact_string: torch.Tensor = torchvision.io.read_file(str(filename_artifact))
        image_artifact_decoded: torch.Tensor = torchvision.io.decode_jpeg(input=image_artifact_string, mode=torchvision.io.ImageReadMode.GRAY) / 255.0

        filename_result: pathlib.Path = self.root_dir / 'Results' / (self.image_names[index] + '.png')
        image_result_string: torch.Tensor = torchvision.io.read_file(str(filename_result))
        image_result_decoded: torch.Tensor = torchvision.io.decode_png(input=image_result_string, mode=torchvision.io.ImageReadMode.GRAY) / 255.0

        image_artifact_decoded = torch.nn.functional.max_pool2d(image_artifact_decoded, (2,2))
        image_result_decoded = torch.nn.functional.max_pool2d(image_result_decoded, (2,2))

        return image_artifact_decoded, image_result_decoded
        # return image_artifact_decoded.squeeze(), image_result_decoded.squeeze()
        

        # return { 'Artifacts' : image_artifact_decoded, 'Results' : image_result_decoded }

dataset: ImageDataset = ImageDataset(root_dir=pathlib.Path('./phantom-datas'))

In [7]:
# model = Unfolding(in_channels=1, num_features=48, iterations=10)
# O_0 = torch.nn.Conv2d(1, 48, (3, 3), padding='same')
# O_0(dataset[0][0])
# model.forward(dataset[0][0])
# print(O_0)

In [8]:
import ignite.engine
import torch.optim
import torch.optim.lr_scheduler
import ignite.contrib.handlers


def split_dataset(dataset: ImageDataset, train_size: float) -> tuple[ImageDataset, ImageDataset]:
    n = len(dataset)
    train_size = int(0.8*n)
    test_size = n-train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
    return train_dataset, test_dataset

def get_dataloaders(config: dict[str, str|int]) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
    dataset_full = ImageDataset(pathlib.Path(config['dataset_path']))
    train_dataset, test_dataset = split_dataset(dataset=dataset_full, train_size=config.get('train_size', 0.8))
    train_loader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=config.get('batch_size', 32),
        shuffle=config.get('shuffle', False)
    )
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=config.get('batch_size', 32))
    return train_loader, test_loader

def initialize(config: dict[str, str|int|dict]):

    config_model: dict = config.get('model')

    model = Unfolding(
        in_channels=config_model.get('input_channels', 1),
        num_features=config_model.get('num_features', 48),
        iterations=config_model.get('iterations', 10)
    )
   
    optimizer = torch.optim.Adam(
        params=model.parameters(),
        lr = config.get('learning_rate', 0.001)
    )

    # MAE
    criterion = torch.nn.L1Loss()

    # # https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html
    # le = config["num_iters_per_epoch"]
    # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=le, gamma=0.9)
    lr_scheduler = None

    return model, optimizer, criterion, lr_scheduler





In [11]:
config = {
    'model' : {
        'input_channels' : 1,
        'iterations' : 10,
        'num_features' : 48
    },
    'dataset_path' : './phantom-datas',
    'train_size' : 0.8,
    'batch_size' : 1,
    'output_path' : 'output',
    'shuffle' : True,
    'learning_rate' : 0.001,
    'max_epochs' : 1
}

In [12]:

def create_train_step(
    model: torch.nn.Module, 
    optimizer: torch.optim.Optimizer, 
    criterion,
    lr_scheduler: torch.optim.lr_scheduler.StepLR = None
):

    # model, optimizer, criterion, lr_scheduler = initialize(config)
    # Define any training logic for iteration update
    def train_step(engine, batch):
        
        # x, y = batch[0].to(idist.device()), batch[1].to(idist.device())
        # artifact, result = batch[0], batch[1]
        artifacts, results = batch

        model.train()
        predictions = model(artifacts)
        loss: torch.Tensor = criterion(predictions, results)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if not(lr_scheduler is None):
            lr_scheduler.step()

        output = {
            'prediction' : predictions,
            'result' : results,
            'loss' : loss.item()
        }

        return output

    return train_step

In [13]:
def create_evaluate_function(
    model: torch.nn.Module
):

    # model, optimizer, criterion, lr_scheduler = initialize(config)
    # Define any evaluation
    def eval_step(engine, batch):
        
        # x, y = batch[0].to(idist.device()), batch[1].to(idist.device())
        # artifact, result = batch[0], batch[1]
        artifacts, results = batch

        model.eval() # model.train(False)
        predictions = model(artifacts)
        
        output = {
            'prediction' : predictions, 
            'result' : results
        }
        
        return output

    return eval_step

In [14]:
train_loader, validation_loader = get_dataloaders(config)
model, optimizer, criterion, lr_scheduler = initialize(config)

In [15]:
import ignite.metrics

# Define trainer engine and Setup model trainer
train_step = create_train_step(model, optimizer, criterion, lr_scheduler)
trainer = ignite.engine.Engine(train_step)

loss_history: list = []
def update_loss_history(engine: ignite.engine.Engine, loss_history: list):
    loss_history.append(engine.state.output['loss'])

def print_logs(engine: ignite.engine.Engine):
    strp = 'Epoch [{}/{}] : Loss {:.2f}'
    print(
        strp.format(
            engine.state.epoch,
            engine.state.epoch_length,
            engine.state.iteration,
            engine.state.output['loss']
        )
    )

trainer.add_event_handler(
    ignite.engine.Events.EPOCH_COMPLETED,
    # Callback
    update_loss_history,
    # Parameters of callback
    loss_history
)

# trainer.add_event_handler(
#     ignite.engine.Events.GET_BATCH_COMPLETED,
#     # Callback
#     print_logs,
# )

# Add progress bar showing batch loss value adn some metrics
pbar = ignite.contrib.handlers.ProgressBar(
    persist=True
)
pbar.attach(
    engine=trainer, 
    output_transform=lambda output: {'loss': output['loss']}
)



In [16]:
evaluate_function = create_evaluate_function(model)
evaluator = ignite.engine.Engine(evaluate_function)


# METRICS CONFIG
## https://pytorch.org/ignite/metrics.html
## https://pytorch.org/ignite/generated/ignite.metrics.RunningAverage.html

### MAE METRICS

output_transform = lambda output: (output['prediction'], output['result'])


### MAE METRICS

mae = ignite.metrics.MeanAbsoluteError(
    output_transform=output_transform
)

avg_mae = ignite.metrics.RunningAverage(src=mae)

mae.attach(engine=evaluator, name='mae')
avg_mae.attach(engine=evaluator, name='avg_mae')

### MSE METRICS

mse = ignite.metrics.MeanSquaredError(
    output_transform=output_transform
)
avg_mse = ignite.metrics.RunningAverage(src=mse)

mse.attach(engine=evaluator, name='mse')
avg_mse.attach(engine=evaluator, name='avg_mse')

       

# HANDLERS CONFIG
## https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine.add_event_handler


def update_history_metrics(
    engine: ignite.engine.Engine, 
    evaluator: ignite.engine.Engine,
    dataloader: torch.utils.data.DataLoader,
    history: dict[str, list],
    mode: str
) -> None:

    evaluator.run(dataloader, max_epochs=1)

    no_epoch = engine.state.epoch
    
    metrics = evaluator.state.metrics
    mae = metrics['mae']
    avg_mae = metrics['avg_mae']
    mse = metrics['mse']
    avg_mse = metrics['avg_mse']
    str_print = mode + ' Results - Epoch {} - mae: {:.2f} Avg mae: {:.2f} mse: {:.2f} Avg mse: {:.2f}'

    # Print logs
    print(str_print.format(no_epoch, mae, avg_mae, mse, avg_mse))

    # Update history
    for key in evaluator.state.metrics.keys():
        history[key].append(evaluator.state.metrics[key])


validation_history = {
    'mae' : [],
    'avg_mae' : [],
    'mse' : [],
    'avg_mse' : []
}

training_history = {
    'mae' : [],
    'avg_mae' : [],
    'mse' : [],
    'avg_mse' : [],
}

# For each epoch completed:
# - we keep metrics
# - we print metrics

## Evaluation on datas using for training
trainer.add_event_handler(
    ignite.engine.Events.ITERATION_COMPLETED,
    # Callback
    update_history_metrics,
    # Parameters of callback
    evaluator, 
    train_loader, 
    training_history,
    'Training Datas'
)

## Evaluation on datas using for validation
trainer.add_event_handler(
    ignite.engine.Events.ITERATION_COMPLETED,
    # Callback
    update_history_metrics,
    # Parameters of callback
    evaluator, 
    validation_loader, 
    validation_history,
    'Validation Datas'
)

# Add progress bar showing batch loss value adn some metrics
pbar = ignite.contrib.handlers.ProgressBar(
    persist=True
)
pbar.attach(
    engine=evaluator
    # metric_names=['mae', 'avg_mae', 'mse', 'avg_mse']
)





In [17]:
# training(config)
trainer.run(train_loader, epoch_length=config.get('max_epochs', 3))

Iteration: [7/7] 100%|██████████ [01:23<00:00]00:00<?]


Training Datas Results - Epoch 1 - mae: 7640.21 Avg mae: 7640.32 mse: 1894.67 Avg mse: 1895.04


Iteration: [2/2] 100%|██████████ [00:12<00:00]
Iteration: [1/1] 100%|██████████, loss=0.241 [02:07<?]

Validation Datas Results - Epoch 1 - mae: 7650.71 Avg mae: 7647.26 mse: 1899.01 Avg mse: 1897.53





State:
	iteration: 1
	epoch: 1
	epoch_length: 1
	max_epochs: 1
	output: <class 'dict'>
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

In [None]:
# https://github.com/pytorch/ignite/blob/master/examples/notebooks/VAE.ipynb
# plt.plot(range(20), training_history['bce'], 'dodgerblue', label='training')
# plt.plot(range(20), validation_history['bce'], 'orange', label='validation')
# plt.xlim(0, 20);
# plt.xlabel('Epoch')
# plt.ylabel('BCE')
# plt.title('Binary Cross Entropy on Training/Validation Set')
# plt.legend();

In [293]:
# class Unfolding(torch.nn.Module):

#     def __init__(self, in_channels: int, num_features: int = 48, iterations: int = 10) -> None:

#         """
#             in_channels : img.shape[2]
#                 + grey level => in_channels=1
#                 + rgb color => in_channels=3
#         """

#         super(Unfolding, self).__init__()

#         self.in_channels = in_channels
#         self.num_features = num_features
#         self.iterations = iterations

#         # Initial
#         # self.O_0 = torch.nn.Conv2d(self.in_channels, self.num_features, (3, 3), 'same')

#         self.add_module(
#             name='O_0',
#             module=torch.nn.Conv2d(self.in_channels, self.num_features, (3, 3), 'same')
#         )

#         self.stepO = torch.tensor(data=0.1, dtype=torch.float, requires_grad=True)
#         self.stepM = torch.tensor(data=0.1, dtype=torch.float, requires_grad=True)


#         self.add_module(
#             name='Prox_M', 
#             module=Prox_M(in_channels=self.num_features)
#         )

#         self.add_module(
#             name='Prox_O', 
#             module=Prox_O(in_channels=self.num_features*3, num_features=self.num_features)
#         )



        

#         # for i in range(0, iterations):
            
            
        
#         # for i in range(2, iterations-1):
#         #     pass
        
    

#     # def __init_iteration_0(self) -> None:

#     #     self.add_module(
#     #         name='iteration_0:X1',
#     #         module=torch.nn.Conv2d(in_channels=self.num_features, num_features=self.num_features, kernel_size=(3, 3), dilation=(1, 1), padding='same', bias=False)
#     #     )
#     #     self.add_module(
#     #         name='iteration_0:X11',
#     #         module=torch.nn.Conv2d(in_channels=self.num_features, num_features=self.num_features, kernel_size=(3, 3), dilation=(1, 1), padding='same', bias=False)
#     #     )

#     #     self.add_module(
#     #         name='iteration_0:X2',
#     #         module=torch.nn.Conv2d(in_channels=self.num_features, num_features=self.num_features, kernel_size=(3, 3), dilation=(2, 2), padding='same', bias=False)
#     #     )
#     #     self.add_module(
#     #         name='iteration_0:X22',
#     #         module=torch.nn.Conv2d(in_channels=self.num_features, num_features=self.num_features, kernel_size=(3, 3), dilation=(2, 2), padding='same', bias=False)
#     #     )

#     #     self.add_module(
#     #         name='iteration_0:X4',
#     #         module=torch.nn.Conv2d(in_channels=self.num_features, num_features=self.num_features, kernel_size=(3, 3), dilation=(2, 2), padding='same', bias=False)
#     #     )
#     #     self.add_module(
#     #         name='iteration_0:X44',
#     #         module=torch.nn.Conv2d(in_channels=self.num_features, num_features=self.num_features, kernel_size=(3, 3), dilation=(2, 2), padding='same', bias=False)
#     #     )

#     def __init_iteration_i(self, i: int) -> None:

#         self.add_module(
#             name='iteration_'+str(i)+':X1',
#             module=torch.nn.Conv2d(in_channels=self.num_features, num_features=self.num_features, kernel_size=(3, 3), dilation=(1, 1), padding='same', bias=False)
#         )

#         self.add_module(
#             name='iteration_'+str(i)+':X2',
#             module=torch.nn.Conv2d(in_channels=self.num_features, num_features=self.num_features, kernel_size=(3, 3), dilation=(2, 2), padding='same', bias=False)
#         )

#         self.add_module(
#             name='iteration_'+str(i)+':X4',
#             module=torch.nn.Conv2d(in_channels=self.num_features, num_features=self.num_features, kernel_size=(3, 3), dilation=(4, 4), padding='same', bias=False)
#         )


#         self.add_module(
#             name='iteration_'+str(i)+':X11',
#             module=torch.nn.Conv2d(in_channels=self.num_features, num_features=self.num_features, kernel_size=(3, 3), dilation=(1, 1), padding='same', bias=False)
#         )

#         self.add_module(
#             name='iteration_'+str(i)+':X22',
#             module=torch.nn.Conv2d(in_channels=self.num_features, num_features=self.num_features, kernel_size=(3, 3), dilation=(2, 2), padding='same', bias=False)
#         )

#         self.add_module(
#             name='iteration_'+str(i)+':X44',
#             module=torch.nn.Conv2d(in_channels=self.num_features, num_features=self.num_features, kernel_size=(3, 3), dilation=(4, 4), padding='same', bias=False)
#         )

#         if 0 < i :

#             self.add_module(
#                 name='iteration_'+str(i)+':X111',
#                 module=torch.nn.Conv2d(in_channels=self.num_features, num_features=self.num_features, kernel_size=(3, 3), dilation=(1, 1), padding='same', bias=False)
#             )
        
#             self.add_module(
#                 name='iteration_'+str(i)+':X222',
#                 module=torch.nn.Conv2d(in_channels=self.num_features, num_features=self.num_features, kernel_size=(3, 3), dilation=(2, 2), padding='same', bias=False)
#             )
            
#             self.add_module(
#                 name='iteration_'+str(i)+':X444',
#                 module=torch.nn.Conv2d(in_channels=self.num_features, num_features=self.num_features, kernel_size=(3, 3), dilation=(4, 4), padding='same', bias=False)
#             )

#     def __apply_layer(self, iter: int, name: str, input: torch.Tensor) -> torch.Tensor:
#         layer = self.get_submodule(target='iteration_'+str(iter)+':'+name)
#         return layer(input)


#     def forward(self, image):
        
#         # Initial
#         out_O_0 = self.O_0(image)
#         tmp = torch.concat([out_O_0, image], -1)
#         O_previous, Z = self.prox_O(tmp)
#         H = image - O_previous

#         # Iteration 0

#         X_1 = self.__apply_layer(iter=0, name='X1', input=H)
#         X_2 = self.__apply_layer(iter=0, name='X2', input=H)
#         X_4 = self.__apply_layer(iter=0, name='X4', input=H)
        
#         M = self.__apply_layer(iter=0, name='Prox_M', input=torch.concat([X_1, X_2, X_4], -1))
  
#         X_1 = self.__apply_layer(iter=0, name='X11', input=M[:, :, :, 0:self.num_features])
#         X_2 = self.__apply_layer(iter=0, name='X22', input=M[:, :, :, self.num_features:self.num_features*2])
#         X_4 = self.__apply_layer(iter=0, name='X44', input=M[:, :, :, self.num_features*2:self.num_features*3])
    
#         h_current = torch.concat([X_1, X_2, X_4], -1)
#         # H_current = torch.sum(h_current, h_current.dim(), keepdim=True)
#         H_current = h_current.sum(1).unsqueeze(1)

#         O_current = image-H_current
#         stepO = self.get_submodule(target='iteration_0:stepO')
#         tmp = torch.concat([Z, stepO*O_current+(1.0-stepO)*O_previous], -1)

#         O_current, Z = self.__apply_layer(iter=0, name='Prox_O', input=tmp)

#         # Iteration 1 to 9
#         for i in range(1, self.iterations):

#             O_previous = O_current
#             H = image - O_previous

#             X_1 = self.__apply_layer(iter=i, name='X11', input=M[:, :, :, 0:self.num_features])
#             X_2 = self.__apply_layer(iter=i, name='X22', input=M[:, :, :, self.num_features:self.num_features*2])
#             X_4 = self.__apply_layer(iter=i, name='X44', input=M[:, :, :, self.num_features*2:self.num_features*3])

#             H_star = torch.concat([X_1, X_2, X_4], -1)
#             # H_current = torch.sum(h_current, h_current.dim(), keepdim=True)
#             H_star = h_current.sum(1).unsqueeze(1)

#             X_1 = self.__apply_layer(iter=i, name='X1', input=H_star-H)
#             X_2 = self.__apply_layer(iter=i, name='X2', input=H_star-H)
#             X_4 = self.__apply_layer(iter=i, name='X4', input=H_star-H)

#             stepM = self.get_submodule(target='iteration_'+str(i)+':stepM')
#             M = self.__apply_layer(
#                 iter=i,
#                 name='Prox_M',
#                 input=M-stepM*torch.concat([out_X1, out_X2, out_X4], -1)
#             )

#             X_1 = self.__apply_layer(iter=i, name='X111', input=M[:, :, :, 0:self.num_features])
#             X_2 = self.__apply_layer(iter=i, name='X222', input=M[:, :, :, self.num_features:self.num_features*2])
#             X_4 = self.__apply_layer(iter=i, name='X444', input=M[:, :, :, self.num_features*2:self.num_features*3])

#             h_current = torch.concat([X_1, X_2, X_4], -1)
#             # H_current = torch.sum(h_current, h_current.dim(), keepdim=True)
#             H_current = h_current.sum(1).unsqueeze(1)

#             O_current = image-H_current
#             stepO = self.get_submodule(target='iteration_0:stepO')
#             tmp = torch.concat([Z, stepO*O_current+(1.0-stepO)*O_previous], -1)
#             # O_current, Z = self.__apply_layer(iter=i, name='Prox_O', input=tmp)
#             O_current, Z = self.prox_O(tmp)

#         final_out = O_current
        
#         return O_current

In [294]:


# class Unfolding(torch.nn.Module):


#     def __init__(self, in_channels: int, num_features: int = 48, iterations: int = 10) -> None:
#         super().__init__()

#         self.prox_O: Prox_O = Prox_O(in_channels=in_channels+num_features, num_features=num_features)
#         self.prox_M: Prox_M = Prox_M(in_channels=num_features*3)
#         self.stepO = torch.tensor(0.1, dtype=torch.double, requires_grad=True)
#         self.stepM = torch.tensor(0.1, dtype=torch.double, requires_grad=True)
#         self.num_features = num_features
#         self.iterations = iterations

#         #initial scope
#         self.O_0 = torch.nn.Conv2d(in_channels=1, out_channels=num_features, kernel_size=3, padding="same")

#         self.conv_it = {}
#         for i in range(iterations):
#             self.conv_it["conv_it"+str(i)+"X1"] = torch.nn.Conv2d(in_channels=in_channels, out_channels=num_features, kernel_size=3, dilation=1, padding="same", bias=False)
#             self.add_module(name="conv_it"+str(i)+"X1", module=self.conv_it["conv_it"+str(i)+"X1"])
#             self.conv_it["conv_it"+str(i)+"X2"] = torch.nn.Conv2d(in_channels=in_channels, out_channels=num_features, kernel_size=3, dilation=2, padding="same", bias=False)
#             self.add_module(name="conv_it"+str(i)+"X2", module=self.conv_it["conv_it"+str(i)+"X2"])
#             self.conv_it["conv_it"+str(i)+"X4"] = torch.nn.Conv2d(in_channels=in_channels, out_channels=num_features, kernel_size=3, dilation=4, padding="same", bias=False)
#             self.add_module(name="conv_it"+str(i)+"X4", module=self.conv_it["conv_it"+str(i)+"X4"])
#             self.conv_it["conv_it"+str(i)+"X11"] = torch.nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=3, dilation=1, padding="same", bias=False)
#             self.add_module(name="conv_it"+str(i)+"X11", module=self.conv_it["conv_it"+str(i)+"X11"])
#             self.conv_it["conv_it"+str(i)+"X22"] = torch.nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=3, dilation=2, padding="same", bias=False)
#             self.add_module(name="conv_it"+str(i)+"X22", module=self.conv_it["conv_it"+str(i)+"X22"])
#             self.conv_it["conv_it"+str(i)+"X44"] = torch.nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=3, dilation=4, padding="same", bias=False)
#             self.add_module(name="conv_it"+str(i)+"X44", module=self.conv_it["conv_it"+str(i)+"X44"])
#             self.conv_it["conv_it"+str(i)+"X111"] = torch.nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=3, dilation=1, padding="same", bias=False)
#             self.add_module(name="conv_it"+str(i)+"X111", module=self.conv_it["conv_it"+str(i)+"X111"])
#             self.conv_it["conv_it"+str(i)+"X222"] = torch.nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=3, dilation=2, padding="same", bias=False)
#             self.add_module(name="conv_it"+str(i)+"X222", module=self.conv_it["conv_it"+str(i)+"X222"])
#             self.conv_it["conv_it"+str(i)+"X444"] = torch.nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=3, dilation=4, padding="same", bias=False)
#             self.add_module(name="conv_it"+str(i)+"X444", module=self.conv_it["conv_it"+str(i)+"X444"])

        



#     def forward(self, J: torch.Tensor) -> torch.Tensor:
#         #initial
#         O_0 = self.O_0(J)
#         a = [O_0,J]
#         tmp = torch.concat([O_0,J],1)
#         O_previous, Z = self.prox_O(tmp)
#         H = J - O_previous

#         #iteration 1
#         X_1 = self.conv_it["conv_it"+str(0)+"X1"](H)
#         X_2 = self.conv_it["conv_it"+str(0)+"X2"](H)
#         X_4 = self.conv_it["conv_it"+str(0)+"X4"](H)

#         M = self.prox_M(torch.concat([X_1,X_2,X_4],1))

#         X_1 = self.conv_it["conv_it"+str(0)+"X11"](M[:,0:self.num_features,:,:])
#         X_2 = self.conv_it["conv_it"+str(0)+"X22"](M[:,self.num_features:self.num_features*2,:,:])
#         X_4 = self.conv_it["conv_it"+str(0)+"X44"](M[:,self.num_features*2:self.num_features*3,:,:])

#         h_current = torch.concat([X_1,X_2,X_4],1)
#         H_current = h_current.sum(1).unsqueeze(1)

#         O_current = J - H_current

#         tmp = torch.concat([Z, self.stepO * O_current + (1.-self.stepO) * O_previous],1)
#         O_current, Z = self.prox_O(tmp)

#         for i in range(1, self.iterations):
#             O_previous = O_current
#             H = J - O_previous

#             X_1 = self.conv_it["conv_it"+str(i)+"X11"](M[:,0:self.num_features,:,:])
#             X_2 = self.conv_it["conv_it"+str(i)+"X22"](M[:,self.num_features:self.num_features*2,:,:])
#             X_4 = self.conv_it["conv_it"+str(i)+"X44"](M[:,self.num_features*2:self.num_features*3,:,:])

#             H_star = torch.concat([X_1,X_2,X_4],1)
#             H_star = H_star.sum(1).unsqueeze(1)

#             X_1 = self.conv_it["conv_it"+str(i)+"X1"](H_star-H)
#             X_2 = self.conv_it["conv_it"+str(i)+"X2"](H_star-H)
#             X_4 = self.conv_it["conv_it"+str(i)+"X4"](H_star-H)

#             M = self.prox_M(M - self.stepM * torch.concat([X_1,X_2,X_4],1))

#             X_1 = self.conv_it["conv_it"+str(i)+"X111"](M[:,0:self.num_features,:,:])
#             X_2 = self.conv_it["conv_it"+str(i)+"X222"](M[:,self.num_features:self.num_features*2,:,:])
#             X_4 = self.conv_it["conv_it"+str(i)+"X444"](M[:,self.num_features*2:self.num_features*3,:,:])

#             h_current = torch.concat([X_1,X_2,X_4],1)
#             H_current = h_current.sum(1).unsqueeze(1)


#             O_current = J - H_current
#             tmp = torch.concat([Z, self.stepO * O_current + (1.-self.stepO) * O_previous],1)
#             O_current, Z = self.prox_O(tmp)


#         return O_current


#     def to(self, *args, **kwargs):
#         self = super().to(*args, **kwargs)
#         for i in self.conv_it:
#           self.conv_it[i].to(*args, **kwargs)
#         self.prox_O.to(*args, **kwargs)
#         self.prox_M.to(*args, **kwargs)
#         return self