In [None]:
import numpy as np
import random
import torch
from torch import nn, optim
from torch.utils.data import DataLoader,TensorDataset
torch.cuda.set_device(1)
import scipy.io as sio
from thop import profile
import matplotlib.pyplot as plt
from network.utils import *
from network.acrnet import *

In [None]:
mat = sio.loadmat('./dataset/CDLA30.mat')
dataP = mat['dataP']
dataP = torch.from_numpy(dataP)

In [None]:
model_name = 'ACRNet'
reduction = 32
expansion = 1
epochs = 1000 
batch_size = 200 
lr=3e-4

In [None]:
train_dataloader = DataLoader(dataP[0:100000,:], batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = DataLoader(dataP[100000:120000,:], batch_size=batch_size, shuffle=True, num_workers=4)
model = eval(model_name)(reduction=reduction, expansion=expansion).cuda()
optimizer = optim.Adam(model.parameters(), lr = lr )
MSE_loss = nn.MSELoss().cuda()

In [None]:
best_epoch = -1
best_loss = 100
best_nmse = 0
best_model = model
train_epochs_loss = []
val_epochs_loss = []
NMSEs = []
print('net:{}, reduction:{}, expansion={}'.format(model_name,reduction,expansion))
for epoch in range(epochs):
############  train  ###############
    model.train()
    train_epoch_loss = []
    for idx, data in enumerate(train_dataloader):
        
        data = data.cuda()
        data_hat = model(data) 
        loss = MSE_loss(data_hat, data) 
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_epoch_loss.append(loss.item())
    train_epochs_loss.append(np.average(train_epoch_loss))
############  val  ###############
    model.eval()
    val_epoch_loss = []
    NMSE = []
    for idx, data in enumerate(val_dataloader):
        
        data = data.cuda()
        data_hat = model(data) 
        loss = MSE_loss(data_hat, data) 
        sparse_gt = data - 0.5
        sparse_pred = data_hat - 0.5
        power_gt = sparse_gt[:, 0, :, :] ** 2 + sparse_gt[:, 1, :, :] ** 2
        difference = sparse_gt - sparse_pred
        mse_gt = difference[:, 0, :, :] ** 2 + difference[:, 1, :, :] ** 2
        lossDB = (mse_gt.sum(dim=[1, 2]) / power_gt.sum(dim=[1, 2])).mean()
        
        val_epoch_loss.append(loss.item())
        NMSE.append(10*np.log10(lossDB.item()))
    val_epochs_loss.append(np.average(val_epoch_loss))
    NMSEs.append(np.average(NMSE))
############  save best  ###############
    if val_epochs_loss[epoch] < best_loss:
        best_epoch = epoch
        best_loss = val_epochs_loss[epoch]
        best_nmse = NMSEs[epoch]
        best_model = model
############  print  ###############
    if epoch%10 == 0:
        print("epoch={}/{}, lr={:.3e}, train_loss={:.3e}, val_loss={:.3e}, nmse={}".format(epoch, epochs,\
            optimizer.state_dict()['param_groups'][0]['lr'],train_epochs_loss[epoch],val_epochs_loss[epoch],NMSEs[epoch]))
    if epoch%50==0:
        print("best_epoch={},best_loss={:.3e},best_nmse={}".format(best_epoch,best_loss,best_nmse))

In [None]:
print('net:{}, reduction:{}'.format(model_name,reduction))
print('best_epoch={}, best_val_loss={}, best_nmse={}'.format(best_epoch,best_loss,best_nmse))
plt.figure(figsize=(8,4))
plt.plot(train_epochs_loss[1:], label='train')
plt.plot(val_epochs_loss[1:], label='val')
plt.legend()
plt.show()