In [2]:
#Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
from torch.nn import init
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR
#scipy-numpy
import scipy.io as sio
import numpy as np
import os
import platform
from time import time
from datetime import datetime

import matplotlib as mpl

## CS-Ratio and Epoch

In [3]:
#define EPOCHS
start_epoch = 0
end_epoch = 1
cs_ratio = 20
#Sampling matrix
data_dir='training_data'
mask_dir='sampling_matrix_istanet'
mask_type='10'    #q1 for DLMRi
##############
layer_num = 9
group_num = 1
#############
nrtrain = 800   # number of training blocks
batch_size = 4
learning_rate = 1e-4

In [4]:
#Assign GPU
gpu_list = '0'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_list
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load CS Sampling Matrix: phi
# Phi_data_Name = './%s/mask_%s_%d.mat'  %(mask_dir,mask_type,cs_ratio)
Phi_data_Name = './%s/mask_%d.mat'  %(mask_dir,cs_ratio)
Phi_data = sio.loadmat(Phi_data_Name)
mask_matrix = Phi_data['mask_matrix']
mask_matrix = torch.from_numpy(mask_matrix).type(torch.FloatTensor)
mask = torch.unsqueeze(mask_matrix, 2)
mask = torch.cat([mask, mask], 2)
mask = mask.to(device)
#Load train data
Training_data_Name = './%s/Training_BrainImages_256x256_100.mat' %(data_dir)
Training_data = sio.loadmat(Training_data_Name)
Training_labels = Training_data['labels']

In [5]:
import torch; 
print([(i, torch.cuda.get_device_properties(i)) for i in range(torch.cuda.device_count())])

[(0, _CudaDeviceProperties(name='NVIDIA GeForce RTX 3090', major=8, minor=6, total_memory=24268MB, multi_processor_count=82))]


In [6]:
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")

torch.cuda.device_count()



1

In [7]:
device

device(type='cuda')

## Define ista-Net NN

In [8]:
#Define FFT transform, CS mask and inverse FFT
class FFT_Mask_ForBack(torch.nn.Module):
    def __init__(self):
        super(FFT_Mask_ForBack, self).__init__()
    def forward(self, x, mask):
        # if x.ndim > 4:
        #     x = torch.view_as_complex(x)
        x_dim_0 = x.shape[0]
        x_dim_1 = x.shape[1]
        x_dim_2 = x.shape[2]
        x_dim_3 = x.shape[3]
        # print('x:=',x.ndim)
        x = x.view(-1, x_dim_2, x_dim_3, 1)
        # print('x-view-shape',x.shape)
        y = torch.zeros_like(x)
        #print('y:=',y.shape)
        z = torch.cat([x, y], 3)
        #print('z:=',z.shape)        
        #print('mask-Shape:=',mask.shape)
        # fftz = torch.fft.fft2(z)
        fftz = torch.fft.fftn(z)
        # print(fftz[:,:,:,0])
        #print('fftz-Shape:=',fftz.shape)
        #z_hat = torch.fft.ifftn(fftz * mask) z_hat[:, :, :, 0:1]
        
        # print(z_hat[:,:,:,0])
        # z_hat = torch.view_as_real(x)
        # print('z-hat:=',z_hat.shape)
        x = z_hat[:, :, :, 0:1]
        # print('x:::=',x.shape)
        x = x.view(x_dim_0, x_dim_1, x_dim_2, x_dim_3)
        # print('x-final',x.shape)
        
        return x

# Define ISTA-Net-plus Block- Dictionaries, biases, mu
class BasicBlock(torch.nn.Module):
    def __init__(self):
        super(BasicBlock, self).__init__()
        self.lambda_step = nn.Parameter(torch.Tensor([0.5]))
        self.soft_thr = nn.Parameter(torch.Tensor([0.01]))
        self.W1 = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 1, 3, 3)))
        self.conv1_forward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 32, 3, 3)))
        self.conv2_forward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 32, 3, 3)))
        self.conv1_backward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 32, 3, 3)))
        self.conv2_backward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 32, 3, 3)))
        self.conv_G = nn.Parameter(init.xavier_normal_(torch.Tensor(1, 32, 3, 3)))
        #define mu
        self.c1 = nn.Parameter(torch.ones(1,1,1,1), requires_grad=True)
        self.c2 = nn.Parameter(torch.ones(1,1,1,1), requires_grad=True)
        self.c3 = nn.Parameter(torch.ones(1,1,1,1), requires_grad=True)
        #bias        
        self.b1 = nn.Parameter(torch.zeros(1,32,1,1), requires_grad=True)
        self.b2 = nn.Parameter(torch.zeros(1,32,1,1), requires_grad=True)
        self.b3 = nn.Parameter(torch.zeros(1,32,1,1), requires_grad=True)
    def forward(self, x, fft_forback, PhiTb, mask):
        if x.ndim > 4:
            x = torch.view_as_complex(x)
        if PhiTb.ndim > 4:
            PhiTb =  torch.view_as_complex(PhiTb)
            # x = x[:, :, :, 0]
        # print("a=", x.shape)
        # print("b=",mask.shape)
        x = x - self.lambda_step * fft_forback(x, mask)
        x = x + self.lambda_step * PhiTb
        # print(self.lambda_step.shape)
        x_input = torch.view_as_real(x)[:,:,:,:,0]
        # print(x_input.shape)
        gamma1 = F.conv2d(x_input, self.W1, padding=1)
        gamma2 = self.c1*F.conv2d(gamma1, self.conv1_forward, padding=1)+self.b1
        gamma2 = F.relu(gamma2)
        gamma3 = F.conv2d(gamma2, self.conv2_forward, padding=1)
        
        for _ in  range(1):            
            # backward computation
            gamma2 = F.conv_transpose2d(gamma3,self.conv2_forward,padding = 1)
            gamma1 = F.conv_transpose2d(gamma2,self.conv1_forward,padding = 1)            
            # forward computation
            gamma1 = F.relu( (gamma1 - self.c1 * F.conv2d( F.conv_transpose2d(gamma1,self.W1,padding=1) - x_input ,self.W1,padding=1)) + self.b1)
            gamma2 = F.relu( (gamma2 - self.c2 * F.conv2d( F.conv_transpose2d(gamma2,self.conv1_forward,padding=1) - gamma1, self.conv1_forward,padding=1)) + self.b2) 
            gamma3 = F.relu( (gamma3 - self.c3 * F.conv2d( F.conv_transpose2d(gamma3,self.conv2_forward,padding=1) - gamma2, self.conv2_forward,padding=1)) + self.b3) 
        gammaE = torch.mul(torch.sign(gamma3), F.relu(torch.abs(gamma3) - self.soft_thr))
        gamma4 = F.conv2d(gammaE, self.conv1_backward, padding=1)
        gamma4 = F.relu(gamma4)
        gamma5 = F.conv2d(gamma4, self.conv2_backward, padding=1)
        gamma6 = F.conv2d(gamma5, self.conv_G, padding=1)
        x_pred = x_input + gamma6
        x = F.conv2d(gamma3, self.conv1_backward, padding=1)
        x = F.relu(x)
        x_D_est = F.conv2d(x, self.conv2_backward, padding=1)
        symloss = x_D_est - gamma1
        return [x_pred, symloss]

    
# Define ISTA-Net-plus
class ISTANetplus(torch.nn.Module):
    def __init__(self, LayerNo):
        super(ISTANetplus, self).__init__()
        onelayer = []
        self.LayerNo = LayerNo
        self.fft_forback = FFT_Mask_ForBack()
        for i in range(LayerNo):
            onelayer.append(BasicBlock())
        self.fcs = nn.ModuleList(onelayer)
        
    def forward(self, PhiTb, mask):
        x = PhiTb
        # print(x.shape)
        layers_sym = []   # for computing symmetric loss
        for i in range(self.LayerNo):
            [x, layer_sym] = self.fcs[i](x, self.fft_forback, PhiTb, mask)
            layers_sym.append(layer_sym)
        x_final = x
        
        return [x_final, layers_sym]


model = ISTANetplus(layer_num)
# device_ids=[0,1]
# torch.cuda.device_count()
# model = torch.nn.DataParallel(model)
model = nn.DataParallel(model)
# model = nn.DataParallel(model, gpu_ids=[0,1])
# model = model.cuda()
# model = torch.nn.DataParallel(model, device_ids=[0,1])
model = model.to(device)
print_flag = 1   # print parameter number
if print_flag:
    num_count = 0
    for para in model.parameters():
        num_count += 1
        #print('Layer %d' % num_count)
        #print(para.size())
class RandomDataset(Dataset):
    def __init__(self, data, length):
        self.data = data
        self.len = length
    def __getitem__(self, index):
        return torch.Tensor(self.data[index, :]).float()
    def __len__(self):
        return self.len
##
model_dir='model_dir_brain'
log_dir='log_dir_brain'
##
rand_loader = DataLoader(dataset=RandomDataset(Training_labels, nrtrain), batch_size=batch_size, num_workers=0,shuffle=True)
#optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
model_dir = "./%s/MRI_CS_ISTA_Net_plus_layer_%d_group_%d_ratio_%d" % (model_dir, layer_num, group_num, cs_ratio)
log_file_name = "./%s/Log_MRI_CS_ISTA_Net_plus_layer_%d_group_%d_ratio_%d.txt" % (log_dir, layer_num, group_num, cs_ratio)
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

SyntaxError: invalid syntax (3237312960.py, line 24)

## Define optimizer

In [19]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=10, gamma=0.2)

## Train Model

In [20]:
if start_epoch > 0:
    pre_model_dir = model_dir
    model.load_state_dict(torch.load('./%s/net_params_%d.pkl' % (pre_model_dir, start_epoch)))
# for graphs
L=len(rand_loader)
total_loss = np.zeros((end_epoch+1,))
loss_disc = np.zeros((end_epoch+1,))
loss_const = np.zeros((end_epoch+1,))
i=1
# Training loop
start_time = datetime.now()
for epoch_i in range(start_epoch+1, end_epoch+1):
    for data in rand_loader:
        batch_x = data
        batch_x = batch_x.to(device)
        batch_x = batch_x.view(batch_x.shape[0], 1, batch_x.shape[1], batch_x.shape[2])
        start = time()
        PhiTb = FFT_Mask_ForBack()(batch_x, mask)
        [x_output, loss_layers_sym] = model(PhiTb, mask)
        end = time()
        # Compute and print loss
        loss_discrepancy = torch.mean(torch.pow(x_output - batch_x, 2))
        loss_constraint = torch.mean(torch.pow(loss_layers_sym[0], 2))
        for k in range(layer_num-1):
            loss_constraint += torch.mean(torch.pow(loss_layers_sym[k+1], 2))
        gamma = torch.Tensor([0.01]).to(device)
        # loss_all = loss_discrepancy
        loss_all = loss_discrepancy + torch.mul(gamma, loss_constraint)
        # Zero gradients, perform a backward pass, and update the weights.
        optimizer.zero_grad()
        loss_all.backward()
        optimizer.step()       
        output_data = "[%02d/%02d] Total Loss: %.6f, Discrepancy Loss: %.6f,  Constraint Loss: %.6f\n" % (epoch_i, end_epoch, loss_all.item(), loss_discrepancy.item(), loss_constraint)
    print(output_data)
    scheduler.step()    
    #
    output_data1 = "%.6f, \n" % (loss_all.item())
    output_data2 = "%.6f, \n" % (loss_discrepancy.item() )
    output_data3 = "%.6f, \n" % (loss_constraint)
    output_file_name1 = "./%s/total_loss__%d.txt" % (log_dir, cs_ratio)
    output_file_name2 = "./%s/disc_loss_%d.txt" % (log_dir, cs_ratio)
    output_file_name3 = "./%s/const_loss_%d.txt" % (log_dir, cs_ratio)
    output_file1 = open(output_file_name1, 'a')
    output_file1.write(output_data1)
    output_file1.close() 
    output_file2 = open(output_file_name2, 'a')
    output_file2.write(output_data2)
    output_file2.close()
    output_file3 = open(output_file_name3, 'a')  
    output_file3.write(output_data3)   
    output_file3.close()
    #    
    total_loss[epoch_i] = loss_all.item()
    loss_disc[epoch_i]  = loss_discrepancy.item() 
    loss_const[epoch_i] = loss_constraint     
end_time = datetime.now()
print('Duration: {}'.format(end_time - start_time))
print("MRI CS training End")

mask-Shape:= torch.Size([256, 256, 2])
fftz-Shape:= torch.Size([4, 256, 256, 2])
mask-Shape:= torch.Size([256, 256, 2])
fftz-Shape:= torch.Size([4, 256, 256, 2])
mask-Shape:= torch.Size([256, 256, 2])
fftz-Shape:= torch.Size([4, 256, 256, 2])
mask-Shape:= torch.Size([256, 256, 2])
fftz-Shape:= torch.Size([4, 256, 256, 2])
mask-Shape:= torch.Size([256, 256, 2])
fftz-Shape:= torch.Size([4, 256, 256, 2])
mask-Shape:= torch.Size([256, 256, 2])
fftz-Shape:= torch.Size([4, 256, 256, 2])
mask-Shape:= torch.Size([256, 256, 2])
fftz-Shape:= torch.Size([4, 256, 256, 2])
mask-Shape:= torch.Size([256, 256, 2])
fftz-Shape:= torch.Size([4, 256, 256, 2])
mask-Shape:= torch.Size([256, 256, 2])
fftz-Shape:= torch.Size([4, 256, 256, 2])
mask-Shape:= torch.Size([256, 256, 2])
fftz-Shape:= torch.Size([4, 256, 256, 2])
mask-Shape:= torch.Size([256, 256, 2])
fftz-Shape:= torch.Size([4, 256, 256, 2])
mask-Shape:= torch.Size([256, 256, 2])
fftz-Shape:= torch.Size([4, 256, 256, 2])
mask-Shape:= torch.Size([256

## Save Model paramters for use in test code

In [11]:
torch.save(model.state_dict(), "./%s/net_params_%d.pkl" %(model_dir,end_epoch))  # save only the parameters