Skip to content

Commit

Permalink
Merge pull request #182 from AI-SDC/hist_kaplan_R
Browse files Browse the repository at this point in the history
adding histogram and kaplan meier to R
  • Loading branch information
mahaalbashir committed Oct 31, 2023
2 parents aaca88c + c6ccb53 commit ac19e10
Show file tree
Hide file tree
Showing 6 changed files with 377 additions and 358 deletions.
43 changes: 31 additions & 12 deletions acro.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
library(reticulate) # import Python modules
library(admiraldev)
library(stringr)
library(png)
library(grid)

acro <- import("acro")
ac <- acro$ACRO()
Expand All @@ -17,32 +18,32 @@ acro_table <- function(index, columns, dnn=NULL, deparse.level=0, ...)
"ACRO crosstab without aggregation function"
if (is.null(dnn)) {
if (deparse.level == 0) {
rownames <- list("")
colnames <- list("")
row_names <- list("")
col_names <- list("")
} else if (deparse.level == 1) {
tryCatch({
index_symbol <- assert_symbol(substitute(index))
rownames <- list(deparse(index_symbol))},
row_names <- list(deparse(index_symbol))},
error = function(e) {
rownames <<- list("")
row_names <- list("")
})
tryCatch({
column_symbol <- assert_symbol(substitute(columns))
colnames <- list(deparse(column_symbol))},
col_names <- list(deparse(column_symbol))},
error = function(e) {
colnames <<- list("")
col_names <- list("")
})
} else if (deparse.level == 2) {
rownames <- list(deparse((substitute(index))))
colnames <- list(deparse(substitute(columns)))
row_names <- list(deparse((substitute(index))))
col_names <- list(deparse(substitute(columns)))
}
}
else {
rownames <- list(dnn[1])
colnames <- list(dnn[2])
row_names <- list(dnn[1])
col_names <- list(dnn[2])
}

table <- ac$crosstab(index, columns, rownames=rownames, colnames=colnames)
table <- ac$crosstab(index, columns, rownames=row_names, colnames=col_names)
# Check for any unused arguments
if (length(list(...)) > 0) {
warning("Unused arguments were provided: ", paste0(names(list(...)), collapse = ", "), "\n", "To find more help about the function use: acro_help(\"acro_table\")\n")
Expand Down Expand Up @@ -77,6 +78,24 @@ acro_glm <- function(formula, data, family)
model$summary()
}

acro_hist <- function(data, column, breaks=10, freq=TRUE, col=NULL, filename="histogram.png"){
"ACRO histogram"
histogram = ac$hist(data=data, column=column, bins=breaks, density=freq, color=col, filename=filename)
# Load the saved histogram
image <- readPNG(histogram)
grid.raster(image)
}

acro_surv_func <- function(time, status, output, filename="kaplan-meier.png"){
"Estimates the survival function. Produce either a plot of table"
results = ac$surv_func(time=time, status=status, output=output, filename=filename)
if (output=="plot"){
# Load the saved survival plot
image <- readPNG(results[[2]])
grid.raster(image)
}
}

acro_rename_output <- function(old, new)
{
"Rename an output"
Expand Down
35 changes: 27 additions & 8 deletions acro/acro_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,10 +474,10 @@ def surv_func( # pylint: disable=too-many-arguments,too-many-locals
)
return table
if output == "plot":
plot = self.survival_plot(
plot, filename = self.survival_plot(
survival_table, survival_func, filename, status, sdc, command, summary
)
return plot
return (plot, filename)
return None

def survival_table( # pylint: disable=too-many-arguments,too-many-locals
Expand Down Expand Up @@ -513,7 +513,22 @@ def survival_plot( # pylint: disable=too-many-arguments,too-many-locals
logger.debug("Directory acro_artifacts created successfully")
except FileExistsError: # pragma: no cover
logger.debug("Directory acro_artifacts already exists")
plt.savefig(f"acro_artifacts/{filename}")

# create a unique filename with number to avoid overwrite
filename, extension = os.path.splitext(filename)
if not extension: # pragma: no cover
logger.info("Please provide a valid file extension")
return None
increment_number = 0
while os.path.exists(
f"acro_artifacts/{filename}_{increment_number}{extension}"
): # pragma: no cover
increment_number += 1
unique_filename = f"acro_artifacts/{filename}_{increment_number}{extension}"

# save the plot to the acro artifacts directory
plt.savefig(unique_filename)

# record output
self.results.add(
status=status,
Expand All @@ -523,9 +538,9 @@ def survival_plot( # pylint: disable=too-many-arguments,too-many-locals
command=command,
summary=summary,
outcome=pd.DataFrame(),
output=[os.path.normpath(filename)],
output=[os.path.normpath(unique_filename)],
)
return plot
return (plot, unique_filename)

def hist( # pylint: disable=too-many-arguments,too-many-locals
self,
Expand Down Expand Up @@ -606,6 +621,9 @@ def hist( # pylint: disable=too-many-arguments,too-many-locals
Returns
-------
matplotlib.Axes
The histogram.
str
The name of the file where the histogram is saved.
"""
logger.debug("hist()")
command: str = utils.get_command("hist()", stack())
Expand All @@ -615,7 +633,7 @@ def hist( # pylint: disable=too-many-arguments,too-many-locals
"Calculating histogram for more than one columns is "
"not currently supported. Please do each column separately."
)
return
return None

freq, _ = np.histogram( # pylint: disable=too-many-function-args
data[column], bins, range=(data[column].min(), data[column].max())
Expand Down Expand Up @@ -693,11 +711,11 @@ def hist( # pylint: disable=too-many-arguments,too-many-locals
filename, extension = os.path.splitext(filename)
if not extension: # pragma: no cover
logger.info("Please provide a valid file extension")
return
return None
increment_number = 0
while os.path.exists(
f"acro_artifacts/{filename}_{increment_number}{extension}"
):
): # pragma: no cover
increment_number += 1
unique_filename = f"acro_artifacts/{filename}_{increment_number}{extension}"

Expand All @@ -715,6 +733,7 @@ def hist( # pylint: disable=too-many-arguments,too-many-locals
outcome=pd.DataFrame(),
output=[os.path.normpath(unique_filename)],
)
return unique_filename


def create_crosstab_masks( # pylint: disable=too-many-arguments,too-many-locals
Expand Down
22 changes: 21 additions & 1 deletion notebooks/test-nursery.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ output: html_notebook
# install.packages("haven")
# install.packages("reticulate")
# install.packages("farff")
# install.packages("survival")
```

## Import Libraries
Expand Down Expand Up @@ -67,7 +68,7 @@ table
index = data[, c("recommend")]
columns = data[, c("parents")]
table = acro_table(data[, c("recommend")], columns, dnn=c("rows", "columns"), deparse.level = 1, useNa = "no")
table = acro_table(index, columns, dnn= c("recommend", "parents"), deparse.level=0)
```

```{r}
Expand All @@ -94,6 +95,25 @@ table = acro_pivot_table(data, values=values, index=index, aggfunc=aggfunc)
table
```

### ACRO histogram

```{r}
acro_hist(data, "children")
```

### ACRO survival analysis

```{r}
data(package = "survival")
# Load the lung dataset
data(lung)
#head(lung)
acro_surv_func(time=lung$time, status=lung$status, output ="plot")
```
```
# Regression examples using ACRO
Again there is an industry-standard package in python, this time called **statsmodels**.
Expand Down

0 comments on commit ac19e10

Please sign in to comment.