In [None]:
import numpy as np
import random
import logging
import datetime
import torch
from torch import nn, optim
from torch.utils.data import DataLoader,TensorDataset
torch.cuda.set_device(0)
import scipy.io as sio
from thop import profile
import matplotlib.pyplot as plt
from network.utils import *
from network.direnet4 import *
# def setup_seed(seed):
#      torch.manual_seed(seed)
#      torch.cuda.manual_seed_all(seed)
#      np.random.seed(seed)
#      random.seed(seed)
#      torch.backends.cudnn.deterministic = True
# setup_seed(999)

In [None]:
w_ex = 5
d_ex = 4
model_name = 'DiReNetD5_d{}'.format(d_ex)
reduction = 16
R = 3
lamda = 0
flag = -1
epochs = 1000
batch_size = 200 
lr=3e-4
lr_mi=1e-4
if flag == -1:
    nums=0
else:
    nums=1
z_dim = 2048//reduction//R
w_dim = 2048//reduction - 2*z_dim

In [None]:
mat = sio.loadmat('./dataset/CDLC300.mat')
dataP = mat['dataP']
dataP = torch.from_numpy(dataP)
dataP = (dataP-dataP.min())/(dataP.max()-dataP.min())
train_dataloader = DataLoader(dataP[0:100000,:], batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = DataLoader(dataP[100000:120000,:], batch_size=batch_size, shuffle=True, num_workers=4)
class net(nn.Module):
    def __init__(self,reduction,R):
        super().__init__()
        self.en = DiReNetE()
        self.cs = DiReNetC32(reduction,R)
        self.de = eval(model_name)()
    def forward(self, data):
        xy,x,y = self.en(data)
        zw,zx,zy,x_c,y_c = self.cs(xy,x,y)
        x_hat,y_hat = self.de(x_c,y_c)
        return xy,zw,zx,zy,x_hat,y_hat
model = net(reduction,R).cuda()
if flag == -1:
    mi_estimator = CLUBSample(1, 1, 1).cuda() # nothing
if flag == 0:
    mi_estimator = CLUBSample(z_dim, z_dim, z_dim).cuda() # zx and zy
if flag == 1:
    mi_estimator = CLUBSample(2048, w_dim, 2048).cuda() # data and zw
if flag == 2:
    mi_estimator = CLUBSample(2048, 512, 2048).cuda()# data and w
mi_xy_estimator = CLUBSample(512, 512, 512).cuda()# datax and datay
optimizer = optim.Adam(model.parameters(), lr = lr )
mi_optimizer = optim.Adam(mi_estimator.parameters(), lr = lr_mi) 
mi_xy_optimizer = optim.Adam(mi_xy_estimator.parameters(), lr = lr_mi) 
MSE_loss = nn.MSELoss().cuda()

In [None]:
best_epoch = -1
best_loss = 100
best_nmse = 0
best_model = model
train_epochs_loss = []
val_epochs_loss = []
NMSEs = []
mi_values = []
mi_xy_values = []
mi_all = []
mi_xy_all = []
loss_xy_mi = 0

print('net:{}, reduction:{}, R:{}, ID:{}, lamda={}, flag={}'.format(model_name,reduction,R,1-2/R,lamda,flag))
for epoch in range(epochs):
############  train  ###############
    mi_value = []
    mi_xy_value = []
    train_epoch_loss = []
    for idx, data in enumerate(train_dataloader):
        model.train()
        mi_estimator.eval()
        data = data.cuda()
        xy,zw,zx,zy,x_hat,y_hat = model(data) 
        loss = 0.5*(MSE_loss(x_hat, data[:,:, 0:16,:]) + MSE_loss(y_hat, data[:,:, 16:32,:]))

        if flag == -1:
            loss_mi = 0 # nothing
        if flag == 0:
            loss_mi = mi_estimator(zx, zy) # zx and zy
        if flag == 1:
            loss_mi = mi_estimator(data.view(batch_size, -1), zw) # data and zw
        if flag == 2:
            loss_mi = mi_estimator(data.view(batch_size, -1), xy.view(batch_size, -1))# data and w     
        if flag!=-1 and loss_mi.item()<0:
            loss_mi = loss_mi-loss_mi         
        LOSS = loss + lamda * (loss_mi-loss_xy_mi)**2  
        optimizer.zero_grad()
        LOSS.backward()
        optimizer.step()
        
        for j in range(nums):
            mi_xy_optimizer.train()
            mi_loss = mi_xy_optimizer.learning_loss(data[:,:, 0:16,:].view(batch_size, -1),  data[:,:, 16:32,:].view(batch_size, -1))
            mi_all.append(mi_xy_optimizer(data[:,:, 0:16,:].view(batch_size, -1),  data[:,:, 16:32,:].view(batch_size, -1)).item())
            mi_xy_optimizer.zero_grad()
            mi_loss.backward()
            mi_xy_optimizer.step()
        loss_xy_mi = mi_all.mean()
        for k in range(nums):
            model.eval()
            mi_estimator.train()
            xy,zw,zx,zy,x_hat,y_hat = model(data)
            if flag == -1:
                mi_loss = mi_estimator.learning_loss(zx, zy)
                mi_all.append(mi_estimator(zx, zy).item())# nothing
            if flag == 0:
                mi_loss = mi_estimator.learning_loss(zx, zy)
                mi_all.append(mi_estimator(zx, zy).item())# zx and zy
            if flag == 1:
                mi_loss = mi_estimator.learning_loss(data.view(batch_size, -1), zw)
                mi_all.append(mi_estimator(data.view(batch_size, -1), zw).item()) # data and zw
            if flag == 2:
                mi_loss = mi_estimator.learning_loss(data.view(batch_size, -1), xy.view(batch_size, -1))
                mi_all.append(mi_estimator(data.view(batch_size, -1), xy.view(batch_size, -1)).item())# data and w
            mi_optimizer.zero_grad()
            mi_loss.backward()
            mi_optimizer.step()
        if flag !=-1:    
            mi_value.append(loss_mi.item())
        mi_xy_value.append(loss_xy_mi)
        train_epoch_loss.append(loss.item())
    train_epochs_loss.append(np.average(train_epoch_loss))
    mi_values.append(np.average(mi_value))
    mi_xy_values.append(np.average(mi_xy_value))
############  val  ###############
    model.eval()
    val_epoch_loss = []
    NMSE = []
    for idx, data in enumerate(val_dataloader): 
        data = data.cuda()
        xy,zw,zx,zy,x_hat,y_hat = model(data) 
        loss = 0.5*(MSE_loss(x_hat, data[:,:, 0:16,:]) + MSE_loss(y_hat, data[:,:, 16:32,:]))
        sparse_gt = data - 0.5
        sparse_pred = torch.cat((x_hat,y_hat),dim=2) - 0.5
        power_gt = sparse_gt[:, 0, :, :] ** 2 + sparse_gt[:, 1, :, :] ** 2
        difference = sparse_gt - sparse_pred
        mse_gt = difference[:, 0, :, :] ** 2 + difference[:, 1, :, :] ** 2
        lossDB = (mse_gt.sum(dim=[1, 2]) / power_gt.sum(dim=[1, 2])).mean()
        val_epoch_loss.append(loss.item())
        NMSE.append(10*np.log10(lossDB.item()))
    val_epochs_loss.append(np.average(val_epoch_loss))
    NMSEs.append(np.average(NMSE))
############  save best  ###############
    if val_epochs_loss[epoch] < best_loss:
        best_epoch = epoch
        best_loss = val_epochs_loss[epoch]
        best_nmse = NMSEs[epoch]
        np.savetxt("./seg_txt/EX_D{}W{}r{}.txt".format(d_ex,w_ex,reduction),[best_epoch,best_nmse],fmt='%.4f')
        best_model = model
        torch.save(best_model.state_dict(), "./seg_model/EX_D{}W{}r{}.pth".format(d_ex,w_ex,reduction))
############  print  ###############
    if epoch%10 == 0:
        print("epoch={}/{}, lr={:.3e}, train_loss={:.3e}, val_loss={:.3e}, nmse={}, mi_xy_z={:.3e}, mi_x_y={:.3e}".format(epoch, epochs,\
            optimizer.state_dict()['param_groups'][0]['lr'],train_epochs_loss[epoch],val_epochs_loss[epoch],NMSEs[epoch],mi_values[epoch],mi_xy_values[epoch]))
    if epoch%50==0:
        print("best_epoch={},best_loss={:.3e},best_nmse={}".format(best_epoch,best_loss,best_nmse))
        logging.info("best_epoch={},best_loss={:.3e},best_nmse={}".format(best_epoch,best_loss,best_nmse))

In [None]:
print('net:{}, reduction:{}, R:{}, ID:{}, lamda={}, flag={}'.format(model_name,reduction,R,1-2/R,lamda,flag))
print('best_epoch={}, best_val_loss={}, best_nmse={}'.format(best_epoch,best_loss,best_nmse))
plt.figure(figsize=(25,5))
plt.subplot(1,3,1)
plt.plot(train_epochs_loss[1:], label='train')
plt.plot(val_epochs_loss[1:], label='val')
plt.legend()
plt.subplot(1,3,2)
plt.plot(mi_all,label='MI_xy_z')
plt.plot(mi_xy_all,label='MI_x_y')
plt.legend()
plt.subplot(1,3,3)
plt.plot(NMSEs, label='NMSE')
plt.legend()
plt.show()
final = np.array(NMSEs)
np.save("./seg_nmse/EX_D{}W{}r{}.npy".format(d_ex,w_ex,reduction),final) 