In [None]:
import torch
from torch import nn, optim 
from torch.autograd import Variable 
import os 
from utils import * 
get_ipython().run_line_magic('matplotlib', 'inline')
import numpy as np 
import scipy.io as scio
import matplotlib.pyplot as plt 
import scipy.io
import math
import time
import pandas as pd
import datetime
import copy


In [None]:
# load npy files
data_set = 'Pavia'
case = 'MR_90' 
path2 = case + '/'
img = np.load(data_set + '_''gtruth.npy')
sampled_image = np.load(path2 + data_set + '_''sampled.npy')
interplted_img = np.load(path2 + data_set + '_''init.npy')
mask = np.load(path2 + data_set + '_''mask.npy')

print('data set: ' + data_set)
print('case: ' + case)

#set data
X_np = sampled_image 
X = torch.from_numpy(X_np).type(dtype).cuda()

X_in_np = interplted_img 
X_in = torch.from_numpy(X_in_np).type(dtype).cuda()

mask = torch.from_numpy(mask ).type(dtype)

X_GTruth_np=img 
X_GTruth = torch.from_numpy(X_GTruth_np).type(dtype).cuda()

psnr=psnr3d(X_GTruth_np,X_in_np)
print(f'psnr of the initialized data: {psnr:.3f}')

# set network parameters
band_shrink = 1
CNN_layer = 1
CNN_size = 3

# set some parameters
label= '3D-D-LRTR_' + data_set + '_' + case
path0 = ""

n_1=X.shape[0]
n_2=X.shape[1]
n_3=X.shape[2]

n_4=dict()
n_4["tm1"]=n_1*band_shrink
n_4["tm2"]=n_2*band_shrink
n_4["tm3"]=n_3*band_shrink


In [3]:
#define networks
class coupled_block(nn.Module):
    def __init__(self,n_4,n_0,CNN_layer,CNN_size):
        super(coupled_block, self).__init__()
        self.cnn_block=nn.Sequential(nn.Conv2d(n_4,n_4,CNN_size,padding=int((CNN_size-1)/2),groups=n_4,bias=False),
                               nn.LeakyReLU())   
        for l in range(CNN_layer-1):
            self.cnn_block=nn.Sequential(self.cnn_block,
                                         nn.Conv2d(n_4,n_4,CNN_size,padding=int((CNN_size-1)/2),groups=n_4,bias=False),
                                         nn.LeakyReLU())      
           
        self.fc_block=nn.Sequential(nn.Linear(n_4,n_0, bias = False),
                              nn.LeakyReLU())
    def forward(self, x):
        special_g=self.cnn_block(x.permute(2,0,1))
        g_x=self.fc_block(special_g.permute(1,2,0))
        return g_x
class g1_Xhat(nn.Module): 
    def __init__(self,n_1,n_2,n_3,n_4,CNN_layer,CNN_size):
        super(g1_Xhat, self).__init__()
        self.X_hat = nn.Parameter(torch.Tensor(n_4,n_2,n_3))
        self.g = coupled_block(n_4, n_1,CNN_layer,CNN_size)
        
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.X_hat.size(0))
        self.X_hat.data.uniform_(-stdv, stdv)
                                    
    def forward(self):
        X_hat_pmted=self.X_hat.permute(1,2,0)
        g_x=self.g(X_hat_pmted)
        g_x_uprgt=g_x.permute(2,0,1)
        return X_hat_pmted, g_x_uprgt
class g2_Xhat(nn.Module): 
    def __init__(self,n_1,n_2,n_3,n_4,CNN_layer,CNN_size):
        super(g2_Xhat, self).__init__()
        self.X_hat = nn.Parameter(torch.Tensor(n_1,n_4,n_3))
        self.g = coupled_block(n_4, n_2,CNN_layer,CNN_size)
        
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.X_hat.size(0))
        self.X_hat.data.uniform_(-stdv, stdv)
                                    
    def forward(self):
        X_hat_pmted=self.X_hat.permute(0,2,1)
        g_x=self.g(X_hat_pmted)
        g_x_uprgt=g_x.permute(0,2,1)
        return X_hat_pmted, g_x_uprgt
class g3_Xhat(nn.Module): 
    def __init__(self,n_1,n_2,n_3,n_4,CNN_layer,CNN_size):
        super(g3_Xhat, self).__init__()
        self.X_hat = nn.Parameter(torch.Tensor(n_1,n_2,n_4))
        self.g = coupled_block(n_4, n_3,CNN_layer,CNN_size)
        
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.X_hat.size(0))
        self.X_hat.data.uniform_(-stdv, stdv)
                                    
    def forward(self):
        X_hat=self.X_hat
        g_x_uprgt=self.g(self.X_hat)
        return X_hat, g_x_uprgt
class ThreeD_D_LRTR(nn.Module):
    def __init__(self,n_1,n_2,n_3,n_4_tm1,n_4_tm2,n_4_tm3,CNN_layer,CNN_size):
        super(ThreeD_D_LRTR, self).__init__()
   
        self.g_tm1 = g1_Xhat(n_1,n_2,n_3,n_4_tm1,CNN_layer,CNN_size)
        self.g_tm2 = g2_Xhat(n_1,n_2,n_3,n_4_tm2,CNN_layer,CNN_size)
        self.g_tm3 = g3_Xhat(n_1,n_2,n_3,n_4_tm3,CNN_layer,CNN_size)

        self.g_CoNot_inte = coupled_block(n_3*3, n_3,CNN_layer,CNN_size)
                                    
    def forward(self):
        X_hat_tm1_pmted, g_x_tm1_uprgt=self.g_tm1()
        X_hat_tm2_pmted, g_x_tm2_uprgt=self.g_tm2()
        X_hat_tm3_pmted, g_x_tm3_uprgt=self.g_tm3()

        #concatenate g_x_tm1_uprgt, g_x_tm2_uprgt, g_x_tm3_uprgt
        X_inte=torch.cat((g_x_tm1_uprgt, g_x_tm2_uprgt, g_x_tm3_uprgt),2)

        #transform into the original domain
        g_x_uprgt=self.g_CoNot_inte(X_inte)

        return X_hat_tm1_pmted, \
            X_hat_tm2_pmted, \
                X_hat_tm3_pmted, \
                    g_x_uprgt,\
        g_x_tm1_uprgt, g_x_tm2_uprgt, g_x_tm3_uprgt
F_norm = nn.MSELoss()

In [4]:
def save_training_result(model_state, X_save, df, nucs,\
            losses, psnr_curve, iter):
    now_time = datetime.datetime.now().strftime('%Y-%m-%d %H.%M')
    path=label+"_"+now_time
    try:
        os.makedirs(path)
    except:
        now_time = datetime.datetime.now().strftime('%m-%d %H.%M.%S')
        path=label+"_"+now_time
        os.makedirs(path)
        
    path_save=path0+path+"/"
    df.to_csv(path_save + label+"_"+now_time+".csv",index=False)

    start=int(iter/5)

    plt.plot(losses)
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.savefig(path_save+"Training Loss1")
    plt.show()

    plt.plot(start+np.arange(np.array(losses).shape[0]-start),losses[start:])
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.savefig(path_save+"Training Loss2")
    plt.show()

    plt.plot(psnr_curve)
    plt.xlabel('Iteration')
    plt.ylabel('PSNR')
    plt.title('Training PSNR')
    plt.savefig(path_save+"Training PSNR1")
    plt.show()

    plt.plot(start+np.arange(np.array(psnr_curve).shape[0]-start),psnr_curve[start:])
    plt.xlabel('Iteration')
    plt.ylabel('PSNR')
    plt.title('Training PSNR')
    plt.savefig(path_save+"Training PSNR2")
    plt.show()

    plt.plot(nucs)
    plt.xlabel('Iteration')
    plt.title('nuc')
    plt.savefig(path_save+"nuc_")
    plt.show()

    plt.plot(start+np.arange(np.array(nucs).shape[0]-start),nucs[start:])
    plt.xlabel('Iteration')
    plt.title('nuc')
    plt.savefig(path_save+"nuc_")
    plt.show()

    torch.save({'model':model_state},path_save + label + '.pth')
    np.save(path_save + label + '.npy',X_save)
    
def training(max_iter, lr, beta_, theta_, gamma, beta1):  
    #training
    
    gamma_1 = gamma
    gamma_2 = gamma
    gamma_3 = gamma
    

    model, params = init_model()
    optimizier = optim.Adam(params, lr=lr, betas=(beta1,0.999)) 

    psnr_curve = []
    losses = []
    fdltys = []
    nucs = []
    nuc1s = []
    nuc2s = []
    nuc3s = []
    psnr_best = 0

    start_time=time.time()
    start_time_1 = time.time()
    for iter in range(max_iter):
        
        X_LR1, X_LR2, X_LR3, X_Out, g1_X, g2_X, g3_X = model()
        psnr = psnr3d(X_GTruth.cpu().detach().numpy(),X_Out.cpu().detach().numpy())
        
        if psnr > psnr_best:
            psnr_best = psnr
            X_save = X_Out.clone().cpu().detach().numpy()
            model_to_save = copy.deepcopy(model).cpu()
            model_state = model_to_save.state_dict() 
            
        try:    
            nuc1 =  torch.norm(X_LR1[:,:,int(iter%n_4["tm1"])].cuda(),'nuc')
        except:
            nuc1 =  torch.norm(X_LR1[:,:,int((iter+1)%n_4["tm1"])].cuda(),'nuc')
        
        try:    
            nuc2 = torch.norm(X_LR2[:,:,int(iter%n_4["tm2"])].cuda(),'nuc')
        except:
            nuc2 = torch.norm(X_LR2[:,:,int((iter+1)%n_4["tm2"])].cuda(),'nuc')

        try:    
            nuc3 = torch.norm(X_LR3[:,:,int(iter%n_4["tm3"])].cuda(),'nuc')
        except:
            nuc3 = torch.norm(X_LR3[:,:,int((iter+1)%n_4["tm3"])].cuda(),'nuc')
        #optimize the nuclear norm of only one frontal slice at one time

        nuc = beta_*(nuc1 + nuc2 + theta_*nuc3)/(1+1+theta_)
        
        f1 = gamma_1*F_norm(g1_X*mask,X*mask)
        f2 = gamma_2*F_norm(g2_X*mask,X*mask)
        f3 = gamma_3*F_norm(g3_X*mask,X*mask)
        
        f = F_norm(X_Out*mask,X*mask)

        loss = nuc + f + f1 + f2 + f3

        optimizier.zero_grad()
        loss.backward(retain_graph=True)
        optimizier.step()

        psnr_curve.append(psnr.item())
        losses.append(loss.item())
        fdltys.append(f.item())
        nucs.append(nuc.item())
        nuc1s.append(nuc1.item())
        nuc2s.append(nuc2.item())
        nuc3s.append(nuc3.item())
        

        if iter % 40 == 0:
            end_time=time.time()
            iteration_time=end_time-start_time
            start_time=time.time()
            print('training: ' + label)
            print('iter: {:}/{:}; time used: {:.3f}; PSNR: {:.3f}'.format(iter, max_iter, iteration_time, psnr))
            print()
            
            
    end_time_1 = time.time()
    training_time = end_time_1-start_time_1
    training_time = round(training_time,2)
    
    X_LR1, X_LR2, X_LR3, X_Out, g1_X, g2_X, g3_X = model()
    psnr = psnr3d(X_GTruth.cpu().detach().numpy(),X_Out.cpu().detach().numpy())
    psnr_max = max(psnr_curve)

    df = pd.DataFrame(list(zip([lr],[beta1], [beta_], [theta_],[ gamma], \
    [band_shrink],[CNN_layer],[CNN_size],\
    [iter], [training_time],[psnr],[psnr_max])),\
    columns =['learning rate','beta1','beta','theta','gamma',\
    'band_shrink','CNN_layer',\
    'CNN_size', 'iter','time_used', 'PSNR_stop','PSNR_max'])

    save_training_result(model_state, X_save, df, nucs,\
            losses, psnr_curve, iter)
            
    return model
def init_model():

    n_4_tm1=n_4["tm1"]
    n_4_tm2=n_4["tm2"]
    n_4_tm3=n_4["tm3"]

    model = ThreeD_D_LRTR(n_1,n_2,n_3,n_4_tm1,n_4_tm2,n_4_tm3,CNN_layer,CNN_size).type(dtype)

    #initialize X_hat
    model.g_tm1.X_hat.data = X_in.cpu().detach().cuda()
    model.g_tm2.X_hat.data = X_in.cpu().detach().cuda()
    model.g_tm3.X_hat.data = X_in.cpu().detach().cuda()

    params = []
    params += [x for x in model.parameters()] 

    s = sum([np.prod(list(p.size())) for p in params]); 
    print('Number of params: %d' % s)

    return model, params


In [None]:
#training
max_iter = 3000 #120000
lr = math.pow(10,-2)
beta_ = math.pow(10, -6)
theta_ = math.pow(10,-1)
gamma = math.pow(10, -7)
beta1 = 0.99

model = training(max_iter, \
    lr, beta_, theta_, gamma,beta1)