Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
🎨 Auto-generated directory tree for repository in Architecture.md
  • Loading branch information
andrewpeng02 committed May 10, 2024
1 parent d4eecde commit 888aca8
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 129 deletions.
28 changes: 18 additions & 10 deletions .github/Architecture.md
Expand Up @@ -6,6 +6,11 @@
📦 training
| |- 📂 training:
| | |- 📂 routes:
| | | |- 📂 training:
| | | | |- 📂 results:
| | | | | |- 📜 results.py
| | | | | |- 📜 __init__.py
| | | | | |- 📜 schemas.py
| | | |- 📂 tabular:
| | | | |- 📜 tabular.py
| | | | |- 📜 __init__.py
Expand All @@ -30,18 +35,24 @@
| | | |- 📜 __init__.py
| | | |- 📜 health_check_middleware.py
| | |- 📂 core:
| | | |- 📜 trainer.py
| | | |- 📜 criterion.py
| | | |- 📜 dl_model.py : torch model based on user specifications from drag and drop
| | | |- 📜 dataset.py : read in the dataset through URL or file upload
| | | |- 📂 celery:
| | | | |- 📜 trainer.py
| | | | |- 📜 criterion.py
| | | | |- 📜 dl_model.py : torch model based on user specifications from drag and drop
| | | | |- 📜 dataset.py : read in the dataset through URL or file upload
| | | | |- 📜 __init__.py
| | | | |- 📜 worker.py
| | | | |- 📜 optimizer.py : what optimizer to use (ie: SGD or Adam for now)
| | | |- 📜 __init__.py
| | | |- 📜 authenticator.py
| | | |- 📜 optimizer.py : what optimizer to use (ie: SGD or Adam for now)
| | |- 📜 asgi.py
| | |- 📜 constants.py : list of helpful constants
| | |- 📜 celery_app.py
| | |- 📜 settings.py
| | |- 📜 __init__.py
| | |- 📜 wsgi.py
| | |- 📜 urls.py
| | |- 📜 celeryconfig.py
| |- 📜 README.md
| |- 📜 docker-compose.yml
| |- 📜 cli.py
Expand All @@ -59,15 +70,11 @@

```
📦 frontend
| |- 📂 layer_docs:
| | |- 📜 Linear.md : Doc for Linear layer
| | |- 📜 Softmax.md : Doc for Softmax layer
| | |- 📜 softmax_equation.png : PNG file of Softmax equation
| | |- 📜 ReLU.md : Doc for ReLU later
| |- 📂 public:
| | |- 📂 images:
| | | |- 📂 logos:
| | | | |- 📂 dlp_branding:
| | | | | |- 📜 dlp-logo.png : DLP Logo, duplicate of files in public, but essential as the frontend can't read public
| | | | | |- 📜 dlp-logo.svg : DLP Logo, duplicate of files in public, but essential as the frontend can't read public
| | | | |- 📜 dsgt-logo-white-back.png
| | | | |- 📜 python-logo.png
Expand Down Expand Up @@ -214,6 +221,7 @@
| | |- 📂 pages:
| | | |- 📂 train:
| | | | |- 📜 [train_space_id].tsx
| | | | |- 📜 metrics_to_charts.tsx
| | | | |- 📜 index.tsx
| | | |- 📜 _app.tsx
| | | |- 📜 forgot.tsx
Expand Down
18 changes: 18 additions & 0 deletions dlp-terraform/ecs/sqs.tf
Expand Up @@ -2,6 +2,24 @@ resource "aws_sqs_queue" "training_queue" {
name = "training-queue.fifo"
fifo_queue = true
message_retention_seconds = 60*24

redrive_policy = jsonencode({
deadLetterTargetArn = aws_sqs_queue.training_queue_deadletter.arn
maxReceiveCount = 4
})
}

resource "aws_sqs_queue" "training_queue_deadletter" {
name = "training-deadletter-queue"
}

resource "aws_sqs_queue_redrive_allow_policy" "training_queue_redrive_allow_policy" {
queue_url = aws_sqs_queue.training_queue_deadletter.id

redrive_allow_policy = jsonencode({
redrivePermission = "byQueue",
sourceQueueArns = [aws_sqs_queue.training_queue.arn]
})
}

output "sqs_queue_url" {
Expand Down
127 changes: 9 additions & 118 deletions frontend/src/pages/train/[train_space_id].tsx
Expand Up @@ -7,13 +7,13 @@ import { DetailedTrainResultsData } from "@/features/Train/types/trainTypes";
import Container from "@mui/material/Container";
import Grid from "@mui/material/Grid";
import Paper from "@mui/material/Paper";
import dynamic from "next/dynamic";
import { useRouter } from "next/router";
import { Data, XAxisName, YAxisName } from "plotly.js";
import React, { useEffect } from "react";
const Plot = dynamic(() => import("react-plotly.js"), { ssr: false });

const LINE_CHART_COLORS = ["red", "blue", "green"];
import {
mapMetricToLinePlot,
mapMetricToAucRocPlot,
mapMetricToConfusionMatrixPlot,
} from "./metrics_to_charts";

const mapTrainResultsDataToCharts = (
detailedTrainResultsData: DetailedTrainResultsData
Expand All @@ -27,120 +27,11 @@ const mapTrainResultsDataToCharts = (
while (i < sortedData.length) {
const metric = sortedData[i];
if (metric.chart_type === "LINE") {
const data = [];
for (let i = 0; i < metric.time_series.length; i++) {
const time_series = metric.time_series[i];
data.push({
name: time_series.y_name,
x: time_series.x_values,
y: time_series.y_values,
type: "scatter",
mode: "markers",
marker: { color: LINE_CHART_COLORS[i], size: 10 },
});
}
charts.push(
<Plot
data={data as Data[]}
layout={{
height: 350,
width: 525,
xaxis: { title: metric.time_series[0].x_name },
// yaxis: { title: "Y axis" },
title: metric.name,
showlegend: true,
paper_bgcolor: "rgba(0,0,0,0)",
plot_bgcolor: "rgba(0,0,0,0)",
}}
config={{ responsive: true }}
/>
);
charts.push(mapMetricToLinePlot(metric));
} else if (metric.chart_type === "AUC/ROC") {
charts.push(
<Plot
data={[
{
name: "baseline",
x: [0, 1],
y: [0, 1],
type: "scatter",
marker: { color: "grey" },
line: {
dash: "dash",
},
},
...(metric.values.map((x) => ({
name: `(AUC: ${x[2]})`,
x: x[0] as number[],
y: x[1] as number[],
type: "scatter",
})) as Data[]),
]}
layout={{
height: 350,
width: 525,
xaxis: { title: "False Positive Rate" },
yaxis: { title: "True Positive Rate" },
title: "AUC/ROC Curves for your Deep Learning Model",
showlegend: true,
paper_bgcolor: "rgba(0,0,0,0)",
plot_bgcolor: "rgba(0,0,0,0)",
}}
config={{ responsive: true }}
/>
);
charts.push(mapMetricToAucRocPlot(metric));
} else if (metric.chart_type === "CONFUSION_MATRIX") {
charts.push(
<Plot
data={[
{
z: metric.values,
type: "heatmap",
colorscale: [
[0, "#e6f6fe"],
[1, "#003058"],
],
},
]}
layout={{
height: 525,
width: 525,
title: "Confusion Matrix (Last Epoch)",
xaxis: {
title: "Predicted",
},
yaxis: {
title: "Actual",
autorange: "reversed",
},
showlegend: true,
annotations: metric.values
.map((row, i) =>
row.map((_, j) => ({
xref: "x1" as XAxisName,
yref: "y1" as YAxisName,
x: j,
y: (i + metric.values.length - 1) % metric.values.length,
text: metric.values[
(i + metric.values.length - 1) % metric.values.length
][j].toString(),
font: {
color:
metric.values[
(i + metric.values.length - 1) % metric.values.length
][j] > 0
? "white"
: "black",
},
showarrow: false,
}))
)
.flat(),
paper_bgcolor: "rgba(0,0,0,0)",
plot_bgcolor: "rgba(0,0,0,0)",
}}
/>
);
charts.push(mapMetricToConfusionMatrixPlot(metric));
} else {
throw Error("Undefined chart type received");
}
Expand All @@ -163,7 +54,7 @@ const TrainSpace = () => {
router.replace({ pathname: "/login" });
}
}, [user, router.isReady]);

if (error) {
setTimeout(() => refetch(), 3000);
}
Expand Down
135 changes: 135 additions & 0 deletions frontend/src/pages/train/metrics_to_charts.tsx
@@ -0,0 +1,135 @@
import {
AucRocChart,
ConfusionMatrixChart,
TimeSeriesChart,
} from "@/features/Train/types/trainTypes";
import dynamic from "next/dynamic";
import { Data, XAxisName, YAxisName } from "plotly.js";
const Plot = dynamic(() => import("react-plotly.js"), { ssr: false });

const LINE_CHART_COLORS = ["red", "blue", "green"];

const mapMetricToLinePlot = (metric: TimeSeriesChart) => {
const data = [];
for (let i = 0; i < metric.time_series.length; i++) {
const time_series = metric.time_series[i];
data.push({
name: time_series.y_name,
x: time_series.x_values,
y: time_series.y_values,
type: "scatter",
mode: "markers",
marker: { color: LINE_CHART_COLORS[i], size: 10 },
});
}
return (
<Plot
data={data as Data[]}
layout={{
height: 350,
width: 525,
xaxis: { title: metric.time_series[0].x_name },
// yaxis: { title: "Y axis" },
title: metric.name,
showlegend: true,
paper_bgcolor: "rgba(0,0,0,0)",
plot_bgcolor: "rgba(0,0,0,0)",
}}
config={{ responsive: true }}
/>
);
};

const mapMetricToAucRocPlot = (metric: AucRocChart) => {
return (
<Plot
data={[
{
name: "baseline",
x: [0, 1],
y: [0, 1],
type: "scatter",
marker: { color: "grey" },
line: {
dash: "dash",
},
},
...(metric.values.map((x) => ({
name: `(AUC: ${x[2]})`,
x: x[0] as number[],
y: x[1] as number[],
type: "scatter",
})) as Data[]),
]}
layout={{
height: 350,
width: 525,
xaxis: { title: "False Positive Rate" },
yaxis: { title: "True Positive Rate" },
title: "AUC/ROC Curves for your Deep Learning Model",
showlegend: true,
paper_bgcolor: "rgba(0,0,0,0)",
plot_bgcolor: "rgba(0,0,0,0)",
}}
config={{ responsive: true }}
/>
);
};

const mapMetricToConfusionMatrixPlot = (metric: ConfusionMatrixChart) => {
<Plot
data={[
{
z: metric.values,
type: "heatmap",
colorscale: [
[0, "#e6f6fe"],
[1, "#003058"],
],
},
]}
layout={{
height: 525,
width: 525,
title: "Confusion Matrix (Last Epoch)",
xaxis: {
title: "Predicted",
},
yaxis: {
title: "Actual",
autorange: "reversed",
},
showlegend: true,
annotations: metric.values
.map((row, i) =>
row.map((_, j) => ({
xref: "x1" as XAxisName,
yref: "y1" as YAxisName,
x: j,
y: (i + metric.values.length - 1) % metric.values.length,
text: metric.values[
(i + metric.values.length - 1) % metric.values.length
][j].toString(),
font: {
color:
metric.values[
(i + metric.values.length - 1) % metric.values.length
][j] > 0
? "white"
: "black",
},
showarrow: false,
}))
)
.flat(),
paper_bgcolor: "rgba(0,0,0,0)",
plot_bgcolor: "rgba(0,0,0,0)",
}}
/>;
};

export {
mapMetricToLinePlot,
mapMetricToAucRocPlot,
mapMetricToConfusionMatrixPlot,
};
1 change: 1 addition & 0 deletions training/training/constants.py
@@ -0,0 +1 @@
DLP_EXECUTIONS_BUCKET_NAME = "dlp-executions"
3 changes: 2 additions & 1 deletion training/training/core/celery/worker.py
Expand Up @@ -9,6 +9,7 @@
import boto3


from training.constants import DLP_EXECUTIONS_BUCKET_NAME
from training.core.celery.criterion import getCriterionHandler
from training.core.celery.dataset import SklearnDatasetCreator
from training.core.celery.dataset import ImageDefaultDatasetCreator
Expand All @@ -34,7 +35,7 @@ def saveDetailedTrainResultsDataToS3(
):
s3 = boto3.resource("s3")
s3.Object(
"dlp-executions", f"{detailedTrainResultsData.basic_info.trainspaceId}.json"
DLP_EXECUTIONS_BUCKET_NAME, f"{detailedTrainResultsData.basic_info.trainspaceId}.json"
).put(Body=detailedTrainResultsData.json())


Expand Down

0 comments on commit 888aca8

Please sign in to comment.