+
diff --git a/client/ts/api/S2SApi.ts b/client/ts/api/S2SApi.ts
index ab68a93..2d74433 100644
--- a/client/ts/api/S2SApi.ts
+++ b/client/ts/api/S2SApi.ts
@@ -15,15 +15,32 @@ export type TrainDataIndexResponse = {
export class S2SApi {
+ static project_info(project_id) {
+ const request = Networking.ajax_request('/api/project_info');
+
+
+ let payload = new Map();
+ if (project_id) {
+ payload = new Map(
+ [
+ ['project_id', project_id]
+ ]);
+ }
+
+ return request.get(payload)
+
+
+ }
+
static translate({
- input, partial =
[],force_attn = <{[key:number]:number}>{},
+ input, partial = [], force_attn = <{ [key: number]: number }>{},
neighbors: neighbors = ['decoder', 'encoder'] //, 'context'
}) {
const request = Networking.ajax_request('/api/translate');
let force_attn_array = null;
- for (const key in force_attn){
- if (!force_attn_array) force_attn_array=[];
+ for (const key in force_attn) {
+ if (!force_attn_array) force_attn_array = [];
force_attn_array.push(key);
force_attn_array.push(force_attn[key]);
}
diff --git a/client/ts/controller/PanelController.ts b/client/ts/controller/PanelController.ts
index eb60212..47ee433 100644
--- a/client/ts/controller/PanelController.ts
+++ b/client/ts/controller/PanelController.ts
@@ -67,6 +67,8 @@ export class PanelController {
selectProjection(key) {
// TODO: remove duck typing for encoder and src/tgt
+ this.pm.setVisibilityNeighborPanels(true);
+
this._current.projectedNeighbor = key;
let labels = this._current.translations.map(
translation => {
@@ -134,6 +136,11 @@ export class PanelController {
}
+ clearCompare() {
+ this._current.translations[1] = null;
+ this._current.comparison = ComparisonMode.none;
+ }
+
update(translation: Translation, main = this.pm.vis.left, extra = this.pm.vis.zero, isCompare = false) {
const cur = this._current;
@@ -320,9 +327,11 @@ export class PanelController {
this.pm.removeMediumPanel();
}
- updateAndShowWordProjector(data) {
+ updateAndShowWordProjector(data, loc) {
+ this.pm.setVisibilityNeighborPanels(false);
const wp = this.pm.getWordProjector();
console.log(data, "--- WP_data");
+ wp.options.loc = loc;
wp.update(data);
}
@@ -332,6 +341,13 @@ export class PanelController {
}
+ updateProjectInfo(info) {
+
+ if (!info.has_neighbors) {
+ this.pm.panels.loadProjectButton.style('display', 'none');
+ }
+ }
+
_bindEvents() {
@@ -437,8 +453,27 @@ export class PanelController {
// })
//
// } else {
- // this.updateAndShowWordList(word_data);
- this.updateAndShowWordProjector(word_data);
+ // this.updateAndShowWordList(word_data);
+
+ if (loc === 'src') {
+ word_data.compare = word_data.word.map(wd => {
+
+ return {
+ orig: allWords.map((aw, wi) =>
+ (wi === replaceIndex) ? wd : aw).join(' ')
+ };
+ })
+ // word_data.compare = {orig: }
+ } else {
+ word_data.compare = word_data.word.map(wd => {
+
+ return {
+ orig: allWords.map((aw, wi) =>
+ (wi === replaceIndex) ? wd : (wi < replaceIndex) ? aw : '').join(' ').trim()
+ };
+ })
+ }
+ this.updateAndShowWordProjector(word_data, loc);
// }
@@ -502,11 +537,11 @@ export class PanelController {
this.pm.vis.left.decoder_words.actionHighlightWord(0, aChg.selected, true, true, 'selected');
this.pm.vis.left.attention
.actionHighlightEdges(aChg.selected, AttentionVis.VERTEX_TYPE.tgt, true, 'highlight');
- console.log("-hen-- AAAAJ");
+
} else {
alert('Please select a decoder word first. ' +
- 'Then you can increase weights respective weights by clicking on encoder');
+ 'Then you can increase respective weights by clicking on encoder');
}
}
@@ -520,8 +555,6 @@ export class PanelController {
else if (d.caller === vis.left.beam) {
-
-
const partialDec = this._current.translations[0]
.decoderWords[0].slice(0, d.col).join(' ') + ' ' + d.word.word.text;
@@ -553,7 +586,7 @@ export class PanelController {
const minIndex = _.min(Object.keys(this._current.attnChange.changes).map(d => +d));
const partialDec = this._current.translations[0]
- .decoderWords[0].slice(0, minIndex).join(' ')
+ .decoderWords[0].slice(0, minIndex).join(' ')
S2SApi.translate(
@@ -579,17 +612,34 @@ export class PanelController {
d.caller.highlightWord(d.word, true,
true, 'selected');
+ const loc = d.caller.options.loc;
- S2SApi.translate_compare({
- input: this._current.sentence,
- compare: d.sentence,
- neighbors: []
- }).then(data => {
- // TODO: ENC / DEC difference !!!
- this._current.comparison = ComparisonMode.enc_diff;
- data = JSON.parse(data);
- updateComparisonView(data)
- })
+
+ if (loc === 'src') {
+ S2SApi.translate_compare({
+ input: this._current.sentence,
+ compare: d.sentence,
+ neighbors: []
+ }).then(data => {
+ // TODO: ENC / DEC difference !!!
+ this._current.comparison = ComparisonMode.enc_diff;
+ data = JSON.parse(data);
+ updateComparisonView(data)
+ })
+
+ } else {
+ S2SApi.translate({
+ input: this._current.sentence,
+ // compare: d.sentence,
+ partial: [d.sentence],
+ neighbors: []
+ }).then(data => {
+ // TODO: ENC / DEC difference !!!
+ this._current.comparison = ComparisonMode.none;
+ data = new Translation(JSON.parse(data));
+ this.update(data)
+ })
+ }
});
diff --git a/client/ts/controller/PanelManager.ts b/client/ts/controller/PanelManager.ts
index 40a379f..87ec28a 100644
--- a/client/ts/controller/PanelManager.ts
+++ b/client/ts/controller/PanelManager.ts
@@ -76,7 +76,9 @@ export class PanelManager {
wordBtn: d3.select('#word_vector_fix_btn'),
attnBtn: d3.select('#attn_fix_btn'),
attnApplyBtn: d3.select('#apply_attn')
- }
+ },
+ statePictoPanel: d3.select('#statePictos')
+
};
private _vis = {
@@ -530,4 +532,8 @@ export class PanelManager {
}
+ setVisibilityNeighborPanels(visibility: boolean) {
+ this.panels.projectorPanel.style('display', visibility?null:'none');
+ this.panels.statePictoPanel.style('display', visibility?null:'none');
+ }
}
diff --git a/client/ts/main.ts b/client/ts/main.ts
index 34e93a4..5040353 100644
--- a/client/ts/main.ts
+++ b/client/ts/main.ts
@@ -16,11 +16,12 @@ window.onload = () => {
S2SApi.translate({input: value, neighbors: []})
.then((data: string) => {
const raw_data = JSON.parse(data);
- console.log(raw_data, "--- raw_data");
+ panelCtrl.clearCompare();
panelCtrl.update(new Translation(raw_data));
panelCtrl.cleanPanels();
+
$('#spinner').hide();
})
.catch((error: Error) => console.log(error, "--- error"));
@@ -74,4 +75,16 @@ window.onload = () => {
windowResize();
+ S2SApi.project_info(null).then((data) =>{
+
+ data = JSON.parse(data);
+
+ panelCtrl.updateProjectInfo(data);
+
+
+ })
+
+
+
+
};
\ No newline at end of file
diff --git a/client/ts/vis/WordProjector.ts b/client/ts/vis/WordProjector.ts
index da63133..28ac215 100644
--- a/client/ts/vis/WordProjector.ts
+++ b/client/ts/vis/WordProjector.ts
@@ -45,7 +45,8 @@ export class WordProjector extends VComponent {
words: d => d.word,
compare: d => d.compare
},
- text_measurer: null
+ text_measurer: null,
+ loc: null
};
@@ -144,15 +145,15 @@ export class WordProjector extends VComponent {
.text(d => d.word)
.style('font-size', d => wordScale(d.score) + 'pt');
- if (this._current.has_compare) {
- const bd_max = _.max(renderData.map(d => d.compare.dist));
- const bd_scale = d3.scaleLinear().domain([0, bd_max])
- .range(['#ffffff', '#63676e']); //TODO: hard-coded range ??
- allWords.select('rect').style('fill', d => {
- return bd_scale(d.compare.dist)
- })
-
- }
+ // if (this._current.has_compare) {
+ // const bd_max = _.max(renderData.map(d => d.compare.dist));
+ // const bd_scale = d3.scaleLinear().domain([0, bd_max])
+ // .range(['#ffffff', '#63676e']); //TODO: hard-coded range ??
+ // allWords.select('rect').style('fill', d => {
+ // return bd_scale(d.compare.dist)
+ // })
+ //
+ // }
if (this._current.clearHighlights) {
diff --git a/s2s/project.py b/s2s/project.py
index 396df04..62d62a1 100644
--- a/s2s/project.py
+++ b/s2s/project.py
@@ -29,6 +29,8 @@ def __init__(self, config_file, directory):
self.indexType = self.config.get('indexType', 'annoy')
self.indices = {}
+ print(self.config,('indices' in self.config) )
+ self.has_neighbors = ('indices' in self.config)
self.currentIndexName = None
self.currentIndex = None
@@ -51,6 +53,13 @@ def __init__(self, config_file, directory):
self.dicts['t2i'][h][token] = iid
raw = f.readline()
+
+ def info(self):
+ return {
+ 'model': self.config['model'],
+ 'has_neighbors': self.has_neighbors
+ }
+
def cached_norm(self, loc, matrix):
if self.cached_norms[loc] is None:
self.cached_norms[loc] = np.linalg.norm(matrix, axis=1)
diff --git a/server.py b/server.py
index 18a6d62..e5dd620 100644
--- a/server.py
+++ b/server.py
@@ -34,7 +34,8 @@
parser.add_argument("--nodebug", default=True)
parser.add_argument("--port", default="8080")
parser.add_argument("--nocache", default=False)
-parser.add_argument("--dir", type=str, default=os.path.abspath('model_api/data'))
+parser.add_argument("--dir", type=str,
+ default=os.path.abspath('model_api/data'))
# parser.add_argument('-api', type=str, default='pytorch',
# choices=['pytorch', 'lua'],
# help="""The API to use.""")
@@ -526,6 +527,14 @@ def get_neighbor_details(**request):
return index.get_details(indices)
+def get_info(**request):
+ if 'project_id' not in request:
+ current_project = list(projects.values())[0] # type: S2SProject
+ return current_project.info()
+
+ return request
+
+
def get_close_vectors(**request):
current_project = list(projects.values())[0] # type: S2SProject
# os.path.join(current_project.directory, request["vector_name"] + ".ann")
diff --git a/swagger.yaml b/swagger.yaml
index be655cd..32d8729 100644
--- a/swagger.yaml
+++ b/swagger.yaml
@@ -43,7 +43,16 @@ paths:
responses:
200:
description: fun
-
+ /project_info:
+ get:
+ tags: [All]
+ operationId: server.get_info
+ summary: get general project informations
+ parameters:
+ - $ref: '#/parameters/project_id'
+ responses:
+ 200:
+ description: fun
# /compare_translation:
# get:
# tags: [Translate, All]
@@ -207,7 +216,12 @@ parameters:
items:
type: integer
required: false
-
+ project_id:
+ name: project_id
+ description: Project ID
+ in: query
+ type: string
+ required: false
# These definitions are only needed for proper documentation