In [2]:
library(tidyverse)
library(autognet)
library(arrow)
library(igraph)
library(Matrix)

“package ‘ggplot2’ was built under R version 4.3.1”
── [1mAttaching core tidyverse packages[22m ───────────────────────────────────────────── tidyverse 2.0.0 ──
[32m✔[39m [34mdplyr    [39m 1.1.2     [32m✔[39m [34mreadr    [39m 2.1.4
[32m✔[39m [34mforcats  [39m 1.0.0     [32m✔[39m [34mstringr  [39m 1.5.0
[32m✔[39m [34mggplot2  [39m 3.5.1     [32m✔[39m [34mtibble   [39m 3.2.1
[32m✔[39m [34mlubridate[39m 1.9.2     [32m✔[39m [34mtidyr    [39m 1.3.0
[32m✔[39m [34mpurrr    [39m 1.0.1     
── [1mConflicts[22m ─────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
[31m✖[39m [34mdplyr[39m::[32mfilter()[39m masks [34mstats[39m::filter()
[31m✖[39m [34mdplyr[39m::[32mlag()[39m    masks [34mstats[39m::lag()
[36mℹ[39m Use the conflicted package ([3m[34m<http://conflicted.r-lib.org/>[39m[23m) to force all conflicts to become errors

Attaching package: ‘arrow’


The following object is masked from ‘package:

In [3]:
source('autognet_functions.R')

In [4]:
datO.feature <- read_feather(
    paste0('data/feature_', formatC(0, flag=0, width=3),'.feather')
)
datO.network <- read_feather('data/network.feather')

In [5]:
datO.feature$T.new1 = 1
datO.feature$T.new2 = 0

In [6]:
n_node = dim(datO.feature)[1]

In [7]:
datO.feature

X1,X2,X3,W1,W2,W3,p,T,m,Y,T.new1,T.new2
<int>,<int>,<int>,<dbl>,<dbl>,<dbl>,<dbl>,<int>,<dbl>,<int>,<dbl>,<dbl>
0,1,1,-0.25,0.25,-0.25,0.5000000,1,0.62245933,0,1,0
0,0,1,0.25,-0.25,-0.25,0.2227001,1,0.90465054,1,1,0
1,0,0,-0.25,0.25,-0.25,0.5445246,0,0.32082130,1,1,0
0,1,1,-0.25,0.25,-0.25,0.6513549,1,0.40733340,1,1,0
1,1,1,0.25,0.25,0.25,0.5000000,1,0.62245933,1,1,0
0,1,0,-0.25,-0.25,0.25,0.7772999,1,0.22270014,1,1,0
0,1,0,-0.25,-0.25,0.25,0.5000000,0,0.37754067,1,1,0
1,1,1,0.25,0.25,0.25,0.6026853,1,0.47917871,0,1,0
0,0,1,0.25,-0.25,-0.25,0.7772999,0,0.09534946,0,1,0
0,0,1,0.25,-0.25,-0.25,0.2227001,0,0.77729986,1,1,0


In [8]:
G = graph.data.frame(datO.network[c(2,1)], vertices=1:n_node)

In [9]:
adjmat <- as.matrix(Matrix(as.matrix(G), sparse = FALSE))
weights <- pmax(apply(adjmat,1,sum), 1)
# weights <- rep(1,1000)

In [10]:
d_alpha = 15
d_beta = 10

In [14]:
iter_time = Sys.time()

#STEP 2A. SETUP DATASET S
data.i = as.matrix(datO.feature)
cov1.i <- data.i[,1] ; cov2.i <- data.i[,2] ; cov3.i <- data.i[,3]
trt.i <- data.i[,8]
outcome.i <- data.i[,10]

cov1.n <-  (adjmat%*%cov1.i)/weights ; cov2.n <-  (adjmat%*%cov2.i)/weights ; cov3.n <-  (adjmat%*%cov3.i)/weights
trt.n <- (adjmat%*%trt.i)/weights
outcome.n <- (adjmat%*%outcome.i)/weights

#STEP 2B. COVARIATE MODEL
## fit
fit.cov <- optim(par=runif(d_alpha,-1,1),cov.pl,gr=NULL,covariate=cbind(cov1.i,cov2.i,cov3.i),covariate.n=cbind(cov1.n,cov2.n,cov3.n),hessian=TRUE,method='BFGS')

##estimates
alpha <-  fit.cov$par

#STEP 2C. OUTCOME MODEL
##fit
fit.outcome <- glm(outcome.i ~ trt.i + trt.n + cov1.i + cov1.n + cov2.i + cov2.n + cov3.i + cov3.n + outcome.n,family=binomial(link='logit'))

##estimates
beta <- fit.outcome$coefficients

message(
    paste0("elapsed time: ", 
           difftime(Sys.time(), iter_time, units="secs")[[1]], " secs.")
)

elapsed time: 3.05115914344788 secs.



In [15]:
#Input values 
group_lengths <- c(1,1,1)
group_functions <- c(1,1,1)

pr_trt <- 1

## Make object for AGC pkg 
beta.new <- beta[c(1,2,4,6,8,10,3,5,7,9)] #reorder because of coding function output
alpha.new <- alpha[c(1:9,11,10,12,14,15,13)] #reorder because of coding function output

outlist.point <- list(t(as.matrix(alpha.new)),t(as.matrix(beta.new)),
                NA,NA,NA,
                group_lengths,group_functions,adjmat)
names(outlist.point) <- c("alpha", "beta", "NA", "NA", "NA", "group_lengths", "group_functions", "adjmat")
class(outlist.point) <- append(class(outlist.point),"agcParamClass")

## Run AGC package 
R <- 100
burnin_R <- 20

iter_time = Sys.time()

point.estimate <- agcEffect(outlist.point, burnin = 0, thin = 1, treatment_allocation = pr_trt, subset = 0,
                            R = R, burnin_R = burnin_R, burnin_cov = 0, average = TRUE, index_override = 0,
                            return_effects = 0)

message(
    paste0("elapsed time: ", 
           difftime(Sys.time(), iter_time, units="secs")[[1]], " secs.")
)

elapsed time: 2995.63994002342 secs.



In [16]:
c(point.estimate[1], point.estimate[2])