diff --git a/client/index.html b/client/index.html index ec79a80..42e2634 100644 --- a/client/index.html +++ b/client/index.html @@ -111,7 +111,7 @@
-
+
diff --git a/client/ts/api/S2SApi.ts b/client/ts/api/S2SApi.ts index ab68a93..2d74433 100644 --- a/client/ts/api/S2SApi.ts +++ b/client/ts/api/S2SApi.ts @@ -15,15 +15,32 @@ export type TrainDataIndexResponse = { export class S2SApi { + static project_info(project_id) { + const request = Networking.ajax_request('/api/project_info'); + + + let payload = new Map(); + if (project_id) { + payload = new Map( + [ + ['project_id', project_id] + ]); + } + + return request.get(payload) + + + } + static translate({ - input, partial = [],force_attn = <{[key:number]:number}>{}, + input, partial = [], force_attn = <{ [key: number]: number }>{}, neighbors: neighbors = ['decoder', 'encoder'] //, 'context' }) { const request = Networking.ajax_request('/api/translate'); let force_attn_array = null; - for (const key in force_attn){ - if (!force_attn_array) force_attn_array=[]; + for (const key in force_attn) { + if (!force_attn_array) force_attn_array = []; force_attn_array.push(key); force_attn_array.push(force_attn[key]); } diff --git a/client/ts/controller/PanelController.ts b/client/ts/controller/PanelController.ts index eb60212..47ee433 100644 --- a/client/ts/controller/PanelController.ts +++ b/client/ts/controller/PanelController.ts @@ -67,6 +67,8 @@ export class PanelController { selectProjection(key) { // TODO: remove duck typing for encoder and src/tgt + this.pm.setVisibilityNeighborPanels(true); + this._current.projectedNeighbor = key; let labels = this._current.translations.map( translation => { @@ -134,6 +136,11 @@ export class PanelController { } + clearCompare() { + this._current.translations[1] = null; + this._current.comparison = ComparisonMode.none; + } + update(translation: Translation, main = this.pm.vis.left, extra = this.pm.vis.zero, isCompare = false) { const cur = this._current; @@ -320,9 +327,11 @@ export class PanelController { this.pm.removeMediumPanel(); } - updateAndShowWordProjector(data) { + updateAndShowWordProjector(data, loc) { + this.pm.setVisibilityNeighborPanels(false); const wp = this.pm.getWordProjector(); console.log(data, "--- WP_data"); + wp.options.loc = loc; wp.update(data); } @@ -332,6 +341,13 @@ export class PanelController { } + updateProjectInfo(info) { + + if (!info.has_neighbors) { + this.pm.panels.loadProjectButton.style('display', 'none'); + } + } + _bindEvents() { @@ -437,8 +453,27 @@ export class PanelController { // }) // // } else { - // this.updateAndShowWordList(word_data); - this.updateAndShowWordProjector(word_data); + // this.updateAndShowWordList(word_data); + + if (loc === 'src') { + word_data.compare = word_data.word.map(wd => { + + return { + orig: allWords.map((aw, wi) => + (wi === replaceIndex) ? wd : aw).join(' ') + }; + }) + // word_data.compare = {orig: } + } else { + word_data.compare = word_data.word.map(wd => { + + return { + orig: allWords.map((aw, wi) => + (wi === replaceIndex) ? wd : (wi < replaceIndex) ? aw : '').join(' ').trim() + }; + }) + } + this.updateAndShowWordProjector(word_data, loc); // } @@ -502,11 +537,11 @@ export class PanelController { this.pm.vis.left.decoder_words.actionHighlightWord(0, aChg.selected, true, true, 'selected'); this.pm.vis.left.attention .actionHighlightEdges(aChg.selected, AttentionVis.VERTEX_TYPE.tgt, true, 'highlight'); - console.log("-hen-- AAAAJ"); + } else { alert('Please select a decoder word first. ' + - 'Then you can increase weights respective weights by clicking on encoder'); + 'Then you can increase respective weights by clicking on encoder'); } } @@ -520,8 +555,6 @@ export class PanelController { else if (d.caller === vis.left.beam) { - - const partialDec = this._current.translations[0] .decoderWords[0].slice(0, d.col).join(' ') + ' ' + d.word.word.text; @@ -553,7 +586,7 @@ export class PanelController { const minIndex = _.min(Object.keys(this._current.attnChange.changes).map(d => +d)); const partialDec = this._current.translations[0] - .decoderWords[0].slice(0, minIndex).join(' ') + .decoderWords[0].slice(0, minIndex).join(' ') S2SApi.translate( @@ -579,17 +612,34 @@ export class PanelController { d.caller.highlightWord(d.word, true, true, 'selected'); + const loc = d.caller.options.loc; - S2SApi.translate_compare({ - input: this._current.sentence, - compare: d.sentence, - neighbors: [] - }).then(data => { - // TODO: ENC / DEC difference !!! - this._current.comparison = ComparisonMode.enc_diff; - data = JSON.parse(data); - updateComparisonView(data) - }) + + if (loc === 'src') { + S2SApi.translate_compare({ + input: this._current.sentence, + compare: d.sentence, + neighbors: [] + }).then(data => { + // TODO: ENC / DEC difference !!! + this._current.comparison = ComparisonMode.enc_diff; + data = JSON.parse(data); + updateComparisonView(data) + }) + + } else { + S2SApi.translate({ + input: this._current.sentence, + // compare: d.sentence, + partial: [d.sentence], + neighbors: [] + }).then(data => { + // TODO: ENC / DEC difference !!! + this._current.comparison = ComparisonMode.none; + data = new Translation(JSON.parse(data)); + this.update(data) + }) + } }); diff --git a/client/ts/controller/PanelManager.ts b/client/ts/controller/PanelManager.ts index 40a379f..87ec28a 100644 --- a/client/ts/controller/PanelManager.ts +++ b/client/ts/controller/PanelManager.ts @@ -76,7 +76,9 @@ export class PanelManager { wordBtn: d3.select('#word_vector_fix_btn'), attnBtn: d3.select('#attn_fix_btn'), attnApplyBtn: d3.select('#apply_attn') - } + }, + statePictoPanel: d3.select('#statePictos') + }; private _vis = { @@ -530,4 +532,8 @@ export class PanelManager { } + setVisibilityNeighborPanels(visibility: boolean) { + this.panels.projectorPanel.style('display', visibility?null:'none'); + this.panels.statePictoPanel.style('display', visibility?null:'none'); + } } diff --git a/client/ts/main.ts b/client/ts/main.ts index 34e93a4..5040353 100644 --- a/client/ts/main.ts +++ b/client/ts/main.ts @@ -16,11 +16,12 @@ window.onload = () => { S2SApi.translate({input: value, neighbors: []}) .then((data: string) => { const raw_data = JSON.parse(data); - console.log(raw_data, "--- raw_data"); + panelCtrl.clearCompare(); panelCtrl.update(new Translation(raw_data)); panelCtrl.cleanPanels(); + $('#spinner').hide(); }) .catch((error: Error) => console.log(error, "--- error")); @@ -74,4 +75,16 @@ window.onload = () => { windowResize(); + S2SApi.project_info(null).then((data) =>{ + + data = JSON.parse(data); + + panelCtrl.updateProjectInfo(data); + + + }) + + + + }; \ No newline at end of file diff --git a/client/ts/vis/WordProjector.ts b/client/ts/vis/WordProjector.ts index da63133..28ac215 100644 --- a/client/ts/vis/WordProjector.ts +++ b/client/ts/vis/WordProjector.ts @@ -45,7 +45,8 @@ export class WordProjector extends VComponent { words: d => d.word, compare: d => d.compare }, - text_measurer: null + text_measurer: null, + loc: null }; @@ -144,15 +145,15 @@ export class WordProjector extends VComponent { .text(d => d.word) .style('font-size', d => wordScale(d.score) + 'pt'); - if (this._current.has_compare) { - const bd_max = _.max(renderData.map(d => d.compare.dist)); - const bd_scale = d3.scaleLinear().domain([0, bd_max]) - .range(['#ffffff', '#63676e']); //TODO: hard-coded range ?? - allWords.select('rect').style('fill', d => { - return bd_scale(d.compare.dist) - }) - - } + // if (this._current.has_compare) { + // const bd_max = _.max(renderData.map(d => d.compare.dist)); + // const bd_scale = d3.scaleLinear().domain([0, bd_max]) + // .range(['#ffffff', '#63676e']); //TODO: hard-coded range ?? + // allWords.select('rect').style('fill', d => { + // return bd_scale(d.compare.dist) + // }) + // + // } if (this._current.clearHighlights) { diff --git a/s2s/project.py b/s2s/project.py index 396df04..62d62a1 100644 --- a/s2s/project.py +++ b/s2s/project.py @@ -29,6 +29,8 @@ def __init__(self, config_file, directory): self.indexType = self.config.get('indexType', 'annoy') self.indices = {} + print(self.config,('indices' in self.config) ) + self.has_neighbors = ('indices' in self.config) self.currentIndexName = None self.currentIndex = None @@ -51,6 +53,13 @@ def __init__(self, config_file, directory): self.dicts['t2i'][h][token] = iid raw = f.readline() + + def info(self): + return { + 'model': self.config['model'], + 'has_neighbors': self.has_neighbors + } + def cached_norm(self, loc, matrix): if self.cached_norms[loc] is None: self.cached_norms[loc] = np.linalg.norm(matrix, axis=1) diff --git a/server.py b/server.py index 18a6d62..e5dd620 100644 --- a/server.py +++ b/server.py @@ -34,7 +34,8 @@ parser.add_argument("--nodebug", default=True) parser.add_argument("--port", default="8080") parser.add_argument("--nocache", default=False) -parser.add_argument("--dir", type=str, default=os.path.abspath('model_api/data')) +parser.add_argument("--dir", type=str, + default=os.path.abspath('model_api/data')) # parser.add_argument('-api', type=str, default='pytorch', # choices=['pytorch', 'lua'], # help="""The API to use.""") @@ -526,6 +527,14 @@ def get_neighbor_details(**request): return index.get_details(indices) +def get_info(**request): + if 'project_id' not in request: + current_project = list(projects.values())[0] # type: S2SProject + return current_project.info() + + return request + + def get_close_vectors(**request): current_project = list(projects.values())[0] # type: S2SProject # os.path.join(current_project.directory, request["vector_name"] + ".ann") diff --git a/swagger.yaml b/swagger.yaml index be655cd..32d8729 100644 --- a/swagger.yaml +++ b/swagger.yaml @@ -43,7 +43,16 @@ paths: responses: 200: description: fun - + /project_info: + get: + tags: [All] + operationId: server.get_info + summary: get general project informations + parameters: + - $ref: '#/parameters/project_id' + responses: + 200: + description: fun # /compare_translation: # get: # tags: [Translate, All] @@ -207,7 +216,12 @@ parameters: items: type: integer required: false - + project_id: + name: project_id + description: Project ID + in: query + type: string + required: false # These definitions are only needed for proper documentation