In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.io as sio
import numpy as np
import os
import glob
from time import time
import math
from torch.nn import init
import copy
import cv2
from skimage.metrics import structural_similarity as ssim

# Configuration des paramètres
config = {
    "epoch_num": 200,
    "layer_num": 9,
    "learning_rate": 1e-4,
    "group_num": 1,
    "cs_ratio": 1,
    "gpu_list": "0",
    "matrix_dir": "sampling_matrix",
    "model_dir": "model",
    "data_dir": "data",
    "log_dir": "log",
    "result_dir": "result",
    "test_name": "Set11"
}

# Configuration GPU
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = config["gpu_list"]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

ratio_dict = {1: 10, 4: 43, 10: 109, 25: 272, 30: 327, 40: 436, 50: 545}
n_input = ratio_dict[config["cs_ratio"]]
n_output = 1089

# Chargement de la matrice de sampling
Phi_data_Name = f'./{config["matrix_dir"]}/phi_0_{config["cs_ratio"]}_1089.mat'
Phi_data = sio.loadmat(Phi_data_Name)
Phi_input = Phi_data['phi']

# Définition du modèle ISTA-Net
class BasicBlock(torch.nn.Module):
    def __init__(self):
        super(BasicBlock, self).__init__()
        self.lambda_step = nn.Parameter(torch.Tensor([0.5]))
        self.soft_thr = nn.Parameter(torch.Tensor([0.01]))
        self.conv1_forward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 1, 3, 3)))
        self.conv2_forward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 32, 3, 3)))
        self.conv1_backward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 32, 3, 3)))
        self.conv2_backward = nn.Parameter(init.xavier_normal_(torch.Tensor(1, 32, 3, 3)))
    
    def forward(self, x, PhiTPhi, PhiTb):
        x = x - self.lambda_step * torch.mm(x, PhiTPhi)
        x = x + self.lambda_step * PhiTb
        x_input = x.view(-1, 1, 33, 33)
        x = F.relu(F.conv2d(x_input, self.conv1_forward, padding=1))
        x_forward = F.conv2d(x, self.conv2_forward, padding=1)
        x = torch.mul(torch.sign(x_forward), F.relu(torch.abs(x_forward) - self.soft_thr))
        x = F.relu(F.conv2d(x, self.conv1_backward, padding=1))
        x_backward = F.conv2d(x, self.conv2_backward, padding=1)
        x_pred = x_backward.view(-1, 1089)
        return x_pred

class ISTANet(torch.nn.Module):
    def __init__(self, LayerNo):
        super(ISTANet, self).__init__()
        self.LayerNo = LayerNo
        self.fcs = nn.ModuleList([BasicBlock() for _ in range(LayerNo)])
    
    def forward(self, Phix, Phi, Qinit):
        PhiTPhi = torch.mm(Phi.T, Phi)
        PhiTb = torch.mm(Phix, Phi)
        x = torch.mm(Phix, Qinit.T)
        for layer in self.fcs:
            x = layer(x, PhiTPhi, PhiTb)
        return x

# Chargement du modèle
model = ISTANet(config["layer_num"])
model = nn.DataParallel(model).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])

model_dir = f"./{config['model_dir']}/CS_ISTA_Net_layer_{config['layer_num']}_group_{config['group_num']}_ratio_{config['cs_ratio']}_lr_{config['learning_rate']:.4f}"
model.load_state_dict(torch.load(f'{model_dir}/net_params_{config["epoch_num"]}.pkl', map_location=device))


<All keys matched successfully>

In [7]:
for name, param in model.named_parameters():
    print(f"Layer: {name}, Shape: {param.shape}")


Layer: module.fcs.0.lambda_step, Shape: torch.Size([1])
Layer: module.fcs.0.soft_thr, Shape: torch.Size([1])
Layer: module.fcs.0.conv1_forward, Shape: torch.Size([32, 1, 3, 3])
Layer: module.fcs.0.conv2_forward, Shape: torch.Size([32, 32, 3, 3])
Layer: module.fcs.0.conv1_backward, Shape: torch.Size([32, 32, 3, 3])
Layer: module.fcs.0.conv2_backward, Shape: torch.Size([1, 32, 3, 3])
Layer: module.fcs.1.lambda_step, Shape: torch.Size([1])
Layer: module.fcs.1.soft_thr, Shape: torch.Size([1])
Layer: module.fcs.1.conv1_forward, Shape: torch.Size([32, 1, 3, 3])
Layer: module.fcs.1.conv2_forward, Shape: torch.Size([32, 32, 3, 3])
Layer: module.fcs.1.conv1_backward, Shape: torch.Size([32, 32, 3, 3])
Layer: module.fcs.1.conv2_backward, Shape: torch.Size([1, 32, 3, 3])
Layer: module.fcs.2.lambda_step, Shape: torch.Size([1])
Layer: module.fcs.2.soft_thr, Shape: torch.Size([1])
Layer: module.fcs.2.conv1_forward, Shape: torch.Size([32, 1, 3, 3])
Layer: module.fcs.2.conv2_forward, Shape: torch.Size(

In [10]:
weights = model.state_dict()["module.fcs.8.conv2_backward"]
print(weights.shape)


torch.Size([1, 32, 3, 3])


In [19]:
list(model.state_dict()['module.fcs.8.lambda_step'])

[tensor(2.4472)]

In [11]:
last_layer_name = list(model.state_dict().keys())[-1]
last_layer_weights = model.state_dict()[last_layer_name]
print(f'Poids de la dernière couche: {last_layer_weights}')

Poids de la dernière couche: tensor([[[[ 9.3944e-02,  7.8764e-03,  4.3230e-02],
          [ 1.5584e-01,  4.0211e-02,  5.5612e-02],
          [-2.0164e-02,  4.6392e-02, -3.4072e-02]],

         [[ 1.0583e-01, -1.8420e-02,  5.8330e-02],
          [ 4.7325e-02,  1.6555e-02,  5.8020e-02],
          [-3.3384e-02,  1.6411e-01, -2.1265e-03]],

         [[-3.6796e-02,  6.3130e-02,  8.4416e-02],
          [-9.0057e-03,  6.1645e-02,  7.4977e-02],
          [-3.6470e-02,  4.1247e-02,  8.5854e-02]],

         [[-4.0988e-02, -4.5996e-02, -3.7543e-02],
          [-3.3075e-02, -1.4151e-02, -1.7834e-02],
          [-6.7438e-02, -5.0019e-02, -5.9243e-02]],

         [[ 1.8537e-01,  4.1767e-02, -2.4961e-02],
          [ 1.4437e-01, -1.0833e-01,  1.2532e-01],
          [ 1.5701e-01,  2.0168e-01,  2.3425e-01]],

         [[ 6.9351e-01,  2.6046e-01,  4.4288e-02],
          [ 1.8627e-01,  3.5107e-01,  3.7830e-02],
          [ 1.2682e-01,  6.9190e-02,  3.4096e-02]],

         [[ 3.0768e-02,  5.9904e-02,  2.1

In [24]:
list(model.state_dict()[ 'module.fcs.8.lambda_step'])

[tensor(2.4472)]

In [44]:
lambda_fista = model.state_dict()['module.fcs.8.lambda_step'].item()
soft_thr_fista = model.state_dict()['module.fcs.8.soft_thr'].item()

conv1_f = model.state_dict()['module.fcs.8.conv1_forward'].detach().cpu().numpy()
conv2_f = model.state_dict()['module.fcs.8.conv2_forward'].detach().cpu().numpy()
conv1_b = model.state_dict()['module.fcs.8.conv1_backward'].detach().cpu().numpy()
conv2_b = model.state_dict()['module.fcs.8.conv2_backward'].detach().cpu().numpy()

In [45]:
lambda_fista, soft_thr_fista

(2.4471731185913086, -0.018895795568823814)

In [50]:
import numpy as np
import scipy.signal

def soft_thresholding(x, threshold):
    return np.sign(x) * np.maximum(np.abs(x) - threshold, 0)

def convolve2d(x, kernel):
    """ Convolution 2D avec padding symétrique """
    return scipy.signal.convolve2d(x, kernel, mode='same', boundary='symm')

def fista_conv(y, Phi, PhiT, conv1_f, conv2_f, conv1_b, conv2_b, lambda_fista, soft_thr_fista, num_iters=50):
    """
    Implémentation de FISTA en utilisant les convolutions apprises de ISTA-Net.
    - y : observation (M, 1)
    - Phi : matrice de mesure (M, N)
    - PhiT : transposée de Phi (N, M)
    """

    N = PhiT.shape[0]  # Taille de la reconstruction (doit être 1089 pour ISTA-Net)
    H, W = 33, 33  # Dimensions d'image attendues (carré parfait pour convolution)

    x = np.dot(PhiT, y)  # Estimation initiale (N, 1)
    x = x.reshape(H, W)  # Reshape en image

    t = 1
    z = x.copy()

    for k in range(num_iters):
        gradient = np.dot(PhiT, np.dot(Phi, z.flatten())) - np.dot(PhiT, y).flatten()
        x_new = soft_thresholding(z.flatten() - lambda_fista * gradient, lambda_fista)
        x_new = x_new.reshape(H, W)  # Reshape en image 2D

        # Appliquer les convolutions apprises
        x_conv1 = convolve2d(x_new, conv1_f[0, 0])
        x_relu1 = np.maximum(x_conv1, 0)
        x_conv2 = convolve2d(x_relu1, conv2_f[0, 0])
        
        x_thresh = soft_thresholding(x_conv2, soft_thr_fista)

        x_conv3 = convolve2d(x_thresh, conv1_b[0, 0])
        x_relu2 = np.maximum(x_conv3, 0)
        x_conv4 = convolve2d(x_relu2, conv2_b[0, 0])

        x_final = x_conv4  # Image reconstruite

        # Mise à jour de FISTA
        t_new = (1 + np.sqrt(1 + 4 * t**2)) / 2
        z = x_final + ((t - 1) / t_new) * (x_final - x)
        x, t = x_final, t_new

    return x_final  # Image (33x33)


In [51]:
# Générer une matrice Phi et un signal y avec la bonne taille
N = 1089  # Taille de l'image aplatie (33x33)
M = N // 2  # Moitié des mesures
A = np.random.randn(M, N)  # Matrice de mesure (M x N)
y = np.random.randn(M, 1)  # Signal observé (M x 1)
A_T = A.T  # Transposée de Phi

# Reconstruction avec FISTA utilisant les poids de ISTA-Net
x_reconstructed = fista_conv(y, A, A_T, conv1_f, conv2_f, conv1_b, conv2_b, lambda_fista, soft_thr_fista)

print("Reconstruction terminée !")
print("Taille de l'image reconstruite :", x_reconstructed.shape)  # Doit être (33, 33)


Reconstruction terminée !
Taille de l'image reconstruite : (33, 33)
