-
Notifications
You must be signed in to change notification settings - Fork 1
/
show_result.py
91 lines (74 loc) · 2.9 KB
/
show_result.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
import torch
import torch.utils.data as Data
from torch import nn
from model import Wcnn
import argparse
"""
===============================================================
Input Params
===============================================================
"""
parser = argparse.ArgumentParser(description='manual to this script')
parser.add_argument('--gpus', type=int, default = 1)
parser.add_argument('--n', type=int, default=2)
parser.add_argument('--kmers', type=str, default='3,7,11,15')
parser.add_argument('--t', type=float, default=0.6)
parser.add_argument('--embed', type=str, default="embed.pkl")
parser.add_argument('--classifier', type=str, default="Reject_params.pkl")
args = parser.parse_args()
kmers = args.kmers
kmers = kmers.split(',')
kmers = list(map(int, kmers))
"""
===========================================================
Load Trained Model
===========================================================
"""
torch.cuda.set_device(args.gpus)
cnn = Wcnn.WCNN(num_token=100,num_class=args.n,kernel_sizes=kmers, kernel_nums=[256, 256, 256, 256])
#cnn = Wcnn.WCNN(num_token=100,num_class=20,kernel_sizes=[3, 7, 11, 15], kernel_nums=[256, 256, 256, 256], seq_len=244)
pretrained_dict=torch.load(args.classifier, map_location='cpu')
cnn.load_state_dict(pretrained_dict)
# Evaluation mode
cnn = cnn.eval()
cnn = cnn.cuda()
# Load embedding
torch_embeds = nn.Embedding(64, 100)
torch_embeds.load_state_dict(torch.load(args.embed, map_location='cpu'))
torch_embeds.weight.requires_grad=False
"""
===========================================================
Load Validation Dataset
===========================================================
"""
val = np.genfromtxt('dataset/val.csv', delimiter=',')
#val = np.genfromtxt('dataset/family_validation.csv', delimiter=',')
val_label = val[:, -1]
val_feature = val[:, :-1]
val_feature = torch.from_numpy(val_feature).long()
val_label = torch.from_numpy(val_label).float()
val_feature = torch_embeds(val_feature)
val_feature = val_feature.reshape(len(val_feature), 1, 248, 100)
"""
===========================================================
Record Confusion Matrix
===========================================================
"""
def softmax(x):
return np.exp(x)/sum(np.exp(x))
with open("prediction/early_stop.txt", 'w') as stop:
with open("prediction/result.txt", 'w') as file:
idx = 1
prediction = []
with torch.no_grad():
for (feature, label) in zip(val_feature, val_label):
pred = cnn(torch.unsqueeze(feature.cuda(), 0))
pred = pred.cpu().detach().numpy()[0]
pred = softmax(pred)
if max(pred) > args.t:
y = int(np.argmax(pred))
file.write(str(idx) + "->" + str(y) +"\n")
else:
stop.write(str(idx)+"\n")
idx+=1