In [1]:
import torch
from torch import nn
from d2l import torch as d2l

from function.generate_label import SS_data_label
from model.PS_Polarity_Net import get_model, get_loss

from tqdm.notebook import tqdm

from utils import Metric_polarity

import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

import numpy as np


In [2]:
## define the test function
# test function
def test(model, test_dataset, criterion, num_gpus, ratio=1):

    devices = [d2l.try_gpu(i) for i in range(num_gpus)]

    model.eval()
    metrics = {'polarity':Metric_polarity()}
    
    N = int(len(test_dataset) * ratio)
    
    polarity_predict = np.zeros(N)
    polarity_actual = np.zeros(N)
    sac_name = []

    with torch.no_grad():
        all_loss = []
        with tqdm(range(0,N,1), total=N, smoothing=0.9) as tqdmUpdate:
            for i in tqdmUpdate:
                x, (SS,dection,polarity_label,sacname) = test_dataset.__getitem__(i)
                x = torch.Tensor(x).to(devices[0]).unsqueeze(0)\
                
                if polarity_label == 2:
                    polarity_actual[i] = -1
                elif polarity_label == 1:
                    polarity_actual[i] = 1
                elif polarity_label == 0:
                    polarity_actual[i] = 0

                polarity_label = torch.Tensor([polarity_label]).to(devices[0])
                targets = (polarity_label)

                preds = model(x)
                loss = criterion(preds[0], targets.long())
                all_loss.append(loss.item())

                polarity = [pred[0].cpu().numpy() for pred in preds]
                if np.argmax(polarity) == 0:
                    polarity_predict[i] = 0
                elif np.argmax(polarity) == 1:
                    polarity_predict[i] = 1
                elif np.argmax(polarity) == 2:
                    polarity_predict[i] = -1

                sac_name.append(test_dataset.sacname_labels[i].decode().strip())
                
                tqdmUpdate.set_postfix(ordered_dict={
                    'updata':'{:0>4d}'.format(i+1),
                    'test_loss':'{:.6f}'.format(loss.item()),
                    })
    
    loss_ave = np.array(all_loss).mean()
          
    return metrics, polarity_predict, polarity_actual, sac_name, loss_ave

In [3]:
model = get_model().cuda()
model.load_state_dict(torch.load('./checkpoints/net_params.pth'))

data_root = '../data'
model.eval()
testset = SS_data_label(data_root, 'test', 8000, 1, 'gaussian')
devices = d2l.try_all_gpus()
num_gpus = len(devices)
criterion = get_loss()

metrics, polarity_predict, polarity_actual, sac_name, loss_ave = test(model, testset, criterion, num_gpus, 1)

  0%|          | 0/10 [00:00<?, ?it/s]

In [4]:
print('Predict: ', polarity_predict)
print('Lable: ', polarity_actual)
print(polarity_predict==polarity_actual)

Predict:  [-1. -1.  1.  1.  1.  0.  0.  0.  0.  0.]
Lable:  [-1. -1.  1.  1.  1.  0.  0.  0.  0.  0.]
[ True  True  True  True  True  True  True  True  True  True]
