/
predict.deepnet.R
127 lines (75 loc) · 2.5 KB
/
predict.deepnet.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#' Predict Function for Deepnet
#'
#' @param object deepnet model object
#' @param newData pass dataframe for prediction
#' @param ... further arguments passed to or from other methods.
#'
#' @return returns predictions vector or dataframe
#' @export predict.deepnet
#' @export
#' @importFrom graphics barplot
#' @importFrom stats formula predict runif
predict.deepnet=function(object,
newData,...){
singlerow=F
if(nrow(newData)==1){
#create duplicate row if data has only one row
newData= rbind(newData,newData)
singlerow=T
}
if (!inherits(object, "deepnet")) stop("Not a legitimate \"deepnet\" object")
newData<-data.frame(newData)
weightMatrix=object[["weightMatrix"]]
activation=object[["activation"]]
modelType=object[["modelType"]]
baisUnits=object[["baisUnits"]]
reluLeak=object[["reluLeak"]]
inColMin=object[['inColMin']]
inColMax=object[['inColMax']]
xcolnames=object[['xcolnames']]
if(length(xcolnames)>1){
newData<-newData[,xcolnames]
}
xType<-sapply(newData,class)
xfctrChrColsIdx<-which(xType%in%c("character","factor"))
if(length(xfctrChrColsIdx)>0L){
newData<-fastDummies::dummy_cols(newData)
newData<-newData[,-xfctrChrColsIdx]
}
newData<-data.frame(newData)
for(i in 1:ncol(newData)){
newData[,i]<- (newData[,i]-inColMin[i])/(inColMax[i]-inColMin[i])
}
newData<-as.matrix(cbind(const = rep(1, nrow(newData)), newData))
if(ncol(newData)>2){
reqmat<-matrix(1,ncol = nrow(weightMatrix[[1]]))
colnames(reqmat)<-row.names(weightMatrix[[1]])
newData<-plyr::rbind.fill(as.data.frame(reqmat),
as.data.frame(newData))
newData<-as.matrix(newData[-1,])
newData<- newData[,colnames(reqmat)]
}
newData[is.na(newData)]<-0
feedList<- feedForward( newData, weightMatrix, activation,reluLeak, modelType,baisUnits)
feedOut <- feedList$a_output
zin <- feedList$z_in
ypred = feedOut[[length(feedOut)]]
outColMax<-object[['outColMax']]
outColMin<-object[['outColMin']]
for(i in 1:ncol(ypred)){
ypred[,i]<-ypred[,i]*(outColMax[i]-outColMin[i])+outColMin[i]
ypred[,i]= ypred[,i]
}
ypred<-data.frame(ypred)
if(singlerow==T){
namepred=names(ypred)
ypred=data.frame(ypred[1,])
if(modelType=='regress'){
names(ypred)='ypred'}
}
if(modelType=="multiClass"){
ypred$ypred<-stringr::str_remove_all( names(ypred),"y_")[max.col(ypred)]
names(ypred)=stringr::str_remove_all(names(ypred),"y_")
}
return(ypred)
}