<a href="https://colab.research.google.com/github/1ucky40nc3/TREX/blob/main/TREX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title Check if Runtime is connected with a GPU ❓ 💪 
!nvidia-smi

Tue Sep  7 15:39:58 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.63.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   50C    P0    27W /  70W |   6226MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# ***Set up TREX*** 🦖💬	


In [None]:
# @title Install Dependencies for the entire Notebook ⇩
# @markdown 🆘 The runtime will automatically CRASH and RESTART after you execute this cell. 🆘

# @markdown 👉 This happens to switch depencency versions. Just keep on going and execute the next cell. 😉

#@markdown ---
VERBOSE = False # @param {type:"boolean"}


from IPython.utils.io import capture_output


def execute(func, *args, verbose: bool = False, **kwargs):
    if verbose:
        return func(*args, **{"verbose": verbose, **kwargs})
    
    with capture_output() as captured:
        return func(*args, **{"verbose": verbose, **kwargs})

def install_notebook_dependencies(**kwargs):
    !pip install -U numpy
    !pip install -U PyYAML
    import numpy

    import os
    os.kill(os.getpid(), 9)

execute(install_notebook_dependencies, verbose=VERBOSE)

In [None]:
# @title Utils for the entire Notebook
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄

from IPython.utils.io import capture_output


def execute(func, *args, verbose: bool = False, **kwargs):
    if verbose:
        return func(*args, **{"verbose": verbose, **kwargs})
    
    with capture_output() as captured:
        return func(*args, **{"verbose": verbose, **kwargs})

---


## ***Natural Language Processing (NLP)*** 📰🤯


---

In [None]:
#@markdown ### Language selection during operation 🏳️‍🌈/🏴‍☠️
LANGUAGE = "de" #@param ["en", "de"]

In [None]:
# @title | NLP | Install Dependencies ⇩
VERBOSE = False # @param {type:"boolean"}
    

def install_nlp_dependencies(**kwargs):
    !pip install sentencepiece
    !pip install transformers
    !pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 torchaudio===0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
    !pip install torch-geometric
    !pip install torch-scatter==2.0.8 -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html


execute(install_nlp_dependencies, verbose=VERBOSE)

In [None]:
# @title | NLP | Set up Services for Data
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄

from typing import Any

import io
import pandas as pd
import transformers


"""~~~~~~~~~~~~~~~~~~~~~~~~~~~
Section for general utilities.
~~~~~~~~~~~~~~~~~~~~~~~~~~~"""

DATE = "2021-08-29" #@param {type: "string"}
TIME = "7:00" #@param {type: "string"}
LOCATION = "Munich" #@param {type: "string"}

def date() -> str:
    return DATE

def time() -> str:
    return TIME

def location() -> str:
    return LOCATION

def set_dtype(df: pd.DataFrame, dtype: Any) -> pd.DataFrame:
    return df.astype({column: dtype for column in df.columns.values})

def df_to_csv(df) -> str:
    csv = io.StringIO()
    df.to_csv(csv, index=False)
    return csv.getvalue()

def table(string: str) -> pd.DataFrame:
    df = pd.read_csv(
        io.StringIO(string))
    df = set_dtype(df, str)

    return df
 

"""~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Section for utilities to create travel tables.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"""

TRAVEL_TABLE = """Location,Train,Start,Destination,Departure Time,Arrival Time,Departure Track,Arrival Track,Duration
Roedermark,RB61,Roedermark,Frankfurt (Main) main station,15:31,15:30,2,2,0:30
Roedermark,RB61,Roedermark,Dieburg train station,15:47,15:46,1,1,0:16
Roedermark,RB61,Roedermark,Frankfurt (Main) South Station,16:00,15:59,2,2,0:30
Roedermark,RB61,Roedermark,Rödermark-Ober-Roden station,16:17,16:16,1,1,0:03
Munich,ICE 1655,Frankfurt (Main) main station,Leipzig main station,17:21,20:24,9,14,03:03
Munich,ICE 594,Frankfurt (Main) main station,Leipzig main station,18:14,21:10,9,13,02:56
Munich,FLX 1354,Berlin main station (low),Hamburg main station,08:07,10:07,8,5,02:00
Munich,ICE 806,Berlin main station (low),Hamburg main station,08:38,10:21,8,5,01:43
Munich,ICE 598,Stuttgart main station,Mannheim main station,12:51,13:29,9,2,00:38
Munich,ICE 576,Stuttgart main station,Mannheim main station,13:23,14:02,10,3,00:39
Munich,ICE 1223,Nuremberg main station,Munich main station,14:07,15:12,9,22,01:05
Munich,ICE 705,Nuremberg main station,Munich main station,14:55,16:07,8,21,01:12"""

def travel_table() -> pd.DataFrame:
    return table(TRAVEL_TABLE)


"""~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Section for utilities to create event tables.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"""

EVENT_TABLE = """Title,Date,Start,End,Description,Location
Literary reading tour with Lou Heinrich,2021-09-17,18:30,19:30,Excerpts from the novel cycle "Leute von Seldwyla" will be presented,Bücherturm Ober-Roden Trinkbrunnenstr. 8 Raum Rothahasaal 63322 Rödermark
Autumn-Winter-Bazaar,2021-09-18,14:00,16:00,Autumn-Winter-Bazaar of the Förderverein Kindergarten St. Gallus and Rejoice,Halle Urberach Am Schellbusch 2 63322 Rödermark
Urban Priol "In the river" cabaret,2021-09-23,20:00,22:15,nan,Kulturhalle Rödermark
Musical "Ausgetickt?",2021-09-26,15:00,17:00,Musical for children from 8-13 years with the Rejoice Kids & Teens.,KSV Halle Turngartenstraße 63322 Rödermark
Info evening "Well prepared for self-employment",2021-09-29,19:00,21:00,This free info event "Well prepared for self-employment" will be held with the team of our cooperation partner "gruenderberatungen.de",Rathaus Ober-Roden Dieburger Straße 9-11 im Zehnthof 63322 Rödermark"""

def event_table() -> pd.DataFrame:
    return table(EVENT_TABLE)


"""~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Section for utilities to create restaurant tables.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"""

RESTAURANT_TABLE = """Restaurant Name,Price Class,Average Rating,Distance,Category
Wolfsschlucht Restaurant,3.0,4.0,14.4,German
Reatuarant zagreb,3.0,4.5,1.1,Balkan
Pizzeria Romana,2.0,4.5,2.1,Italian
La Scala,2.0,4.0,0.6,Italian
Ristaurante Tie-Break,2.0,4.5,2.2,Italian
Cuervo,2.0,4.0,0.7,Mexican"""

def restaurant_table() -> pd.DataFrame:
    return table(RESTAURANT_TABLE)


"""~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Section for utilities to create restaurant tables.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"""

TIMETABLE_TABLE = """Name,Day,Date,Start Time,End Time,Duration,Professor,Room,Virtual Room
Current Affairs,Monday,2021-08-30,12:15,12:45,30.0,None,Assembly hall,None
IT Law,Monday,2021-08-30,12:45,16:00,195.0,Leonardo da Vinci,Assembly hall,None
Design and Implementation of Databases,Tuesday,2021-08-31,08:30,11:45,195.0,Alan Turing,Assembly hall,None
Finance and Investment,Tuesday,2021-08-31,12:45,16:00,195.0,Henry Ford,Assembly hall,None
Practice/Project groups,Wednesday,2021-09-01,08:30,16:00,450.0,None,None,None
Practice/Project groups,Thursday,2021-09-02,08:30,16:00,450.0,None,None,None
Servicemanagement und ERP,Friday,2021-09-03,08:30,11:45,195.0,Nikola Tesla,Assembly hall,None"""

def timetable_table() -> pd.DataFrame:
    return table(TIMETABLE_TABLE)

#### Display the Tables

In [None]:
# @title Display the Travel Table

%load_ext google.colab.data_table
from google.colab import data_table

data_table.DataTable(travel_table())

Unnamed: 0,Location,Train,Start,Destination,Departure Time,Arrival Time,Departure Track,Arrival Track,Duration
0,Roedermark,RB61,Roedermark,Frankfurt (Main) main station,15:31,15:30,2,2,0:30
1,Roedermark,RB61,Roedermark,Dieburg train station,15:47,15:46,1,1,0:16
2,Roedermark,RB61,Roedermark,Frankfurt (Main) South Station,16:00,15:59,2,2,0:30
3,Roedermark,RB61,Roedermark,Rödermark-Ober-Roden station,16:17,16:16,1,1,0:03
4,Munich,ICE 1655,Frankfurt (Main) main station,Leipzig main station,17:21,20:24,9,14,03:03
5,Munich,ICE 594,Frankfurt (Main) main station,Leipzig main station,18:14,21:10,9,13,02:56
6,Munich,FLX 1354,Berlin main station (low),Hamburg main station,08:07,10:07,8,5,02:00
7,Munich,ICE 806,Berlin main station (low),Hamburg main station,08:38,10:21,8,5,01:43
8,Munich,ICE 598,Stuttgart main station,Mannheim main station,12:51,13:29,9,2,00:38
9,Munich,ICE 576,Stuttgart main station,Mannheim main station,13:23,14:02,10,3,00:39


In [None]:
# @title Display the Event Table

%load_ext google.colab.data_table
from google.colab import data_table

data_table.DataTable(event_table())

The google.colab.data_table extension is already loaded. To reload it, use:
  %reload_ext google.colab.data_table


Unnamed: 0,Title,Date,Start,End,Description,Location
0,Literary reading tour with Lou Heinrich,2021-09-17,18:30,19:30,"Excerpts from the novel cycle ""Leute von Seldw...",Bücherturm Ober-Roden Trinkbrunnenstr. 8 Raum ...
1,Autumn-Winter-Bazaar,2021-09-18,14:00,16:00,Autumn-Winter-Bazaar of the Förderverein Kinde...,Halle Urberach Am Schellbusch 2 63322 Rödermark
2,"Urban Priol ""In the river"" cabaret",2021-09-23,20:00,22:15,,Kulturhalle Rödermark
3,"Musical ""Ausgetickt?""",2021-09-26,15:00,17:00,Musical for children from 8-13 years with the ...,KSV Halle Turngartenstraße 63322 Rödermark
4,"Info evening ""Well prepared for self-employment""",2021-09-29,19:00,21:00,"This free info event ""Well prepared for self-e...",Rathaus Ober-Roden Dieburger Straße 9-11 im Ze...


In [None]:
# @title Display the Restaurant Table

%load_ext google.colab.data_table
from google.colab import data_table

data_table.DataTable(restaurant_table())

The google.colab.data_table extension is already loaded. To reload it, use:
  %reload_ext google.colab.data_table


Unnamed: 0,Restaurant Name,Price Class,Average Rating,Distance,Category
0,Wolfsschlucht Restaurant,3.0,4.0,14.4,German
1,Reatuarant zagreb,3.0,4.5,1.1,Balkan
2,Pizzeria Romana,2.0,4.5,2.1,Italian
3,La Scala,2.0,4.0,0.6,Italian
4,Ristaurante Tie-Break,2.0,4.5,2.2,Italian
5,Cuervo,2.0,4.0,0.7,Mexican


In [None]:
# @title Display the Timetable Table

%load_ext google.colab.data_table
from google.colab import data_table

data_table.DataTable(timetable_table())

The google.colab.data_table extension is already loaded. To reload it, use:
  %reload_ext google.colab.data_table


Unnamed: 0,Name,Day,Date,Start Time,End Time,Duration,Professor,Room,Virtual Room
0,Current Affairs,Monday,2021-08-30,12:15,12:45,30.0,,Assembly hall,
1,IT Law,Monday,2021-08-30,12:45,16:00,195.0,Leonardo da Vinci,Assembly hall,
2,Design and Implementation of Databases,Tuesday,2021-08-31,08:30,11:45,195.0,Alan Turing,Assembly hall,
3,Finance and Investment,Tuesday,2021-08-31,12:45,16:00,195.0,Henry Ford,Assembly hall,
4,Practice/Project groups,Wednesday,2021-09-01,08:30,16:00,450.0,,,
5,Practice/Project groups,Thursday,2021-09-02,08:30,16:00,450.0,,,
6,Servicemanagement und ERP,Friday,2021-09-03,08:30,11:45,195.0,Nikola Tesla,Assembly hall,


### Set up fot the ***Legacy NLP*** Components



In [None]:
# @title | NLP | Initialize the NLP Pipelines
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄

from typing import List

import torch
import transformers


def device(boolean: bool) -> int:
    return 0 if boolean else -1


#@markdown ---
#@markdown ### Model selection for the NLP toolkit 🤖📰
ZERO_SHOT_MODEL = "facebook/bart-large-mnli" #@param ["facebook/bart-large-mnli", "typeform/distilbert-base-uncased-mnli", "joeddav/xlm-roberta-large-xnli", "Narsil/deberta-large-mnli-zero-cls"]
TABLE_QA_MODEL = "lysandre/tiny-tapas-random-wtq" #@param ["lysandre/tiny-tapas-random-wtq", "lysandre/tiny-tapas-random-sqa", "google/tapas-base-finetuned-wtq", "google/tapas-base-finetuned-sqa", "google/tapas-base-finetuned-wikisql-supervised", "google/tapas-large-finetuned-wtq", "google/tapas-large-finetuned-sqa", "google/tapas-large-finetuned-wikisql-supervised"]
SMALL_TALK_MODEL = "facebook/blenderbot-90M" #@param ["facebook/blenderbot-90M", "facebook/blenderbot-400M-distill", "facebook/blenderbot-1B-distill", "facebook/blenderbot-3B"]
FEW_SHOT_MODEL = "gpt2" #@param ["gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl", "EleutherAI/gpt-neo-125M", "EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B"]

#@markdown ---
#@markdown ### Model selection for translation between English and German
GERMAN_TO_ENGLISH_MODEL = "Helsinki-NLP/opus-mt-de-en" #@param ["Helsinki-NLP/opus-mt-de-en", "facebook/wmt19-de-en"]
ENGLISH_TO_GERMAN_MODEL = "Helsinki-NLP/opus-mt-en-de" #@param ["Helsinki-NLP/opus-mt-en-de", "facebook/wmt19-en-de"]

#@markdown ---
#@markdown ### Select if the individual model shall be on GPU 💻🔥
USE_GPU_FOR_ZERO_SHOT = True # @param {type:"boolean"}
USE_GPU_FOR_SMALL_TALK = False # @param {type:"boolean"}
USE_GPU_FOR_FEW_SHOT = False # @param {type:"boolean"}

USE_GPU_FOR_GERMAN_TO_ENGLISH = False # @param {type:"boolean"}
USE_GPU_FOR_ENGLISH_TO_GERMAN = False # @param {type:"boolean"}

#@markdown ---
VERBOSE = False # @param {type:"boolean"}
    

def initialize_nlp_pipelines(**kwargs):
    print("[DEBUG] Downloading Zero-Shot-Classification Components")
    ZERO_SHOT = transformers.pipeline(
        "zero-shot-classification",
        model=ZERO_SHOT_MODEL,
        device=device(USE_GPU_FOR_ZERO_SHOT))
    
    print("[DEBUG] Downloading Table-QA Components")
    TABLE_QA = transformers.pipeline(
        "table-question-answering", 
        model=TABLE_QA_MODEL)

    print("[DEBUG] Downloading Small-Talk Components")
    SMALL_TALK = transformers.pipeline(
        "conversational", 
        model=SMALL_TALK_MODEL, 
        device=device(USE_GPU_FOR_SMALL_TALK))

    print("[DEBUG] Downloading Text-To-Text Components")
    FEW_SHOT = transformers.pipeline(
        "text-generation", 
        model=FEW_SHOT_MODEL, 
        device=device(USE_GPU_FOR_FEW_SHOT))
    FEW_SHOT_TOKENIZER = transformers.GPT2Tokenizer.from_pretrained(
        FEW_SHOT_MODEL)
    
    if LANGUAGE == "de":
        print("[DEBUG] Downloading German-To-English Translation Components")
        GERMAN_TO_ENGLISH_TRANSLATOR = transformers.pipeline(
            "translation_de_to_en", 
            model=GERMAN_TO_ENGLISH_MODEL)
        print("[DEBUG] Downloading English-To-German Translation Components")
        ENGLISH_TO_GERMAN_TRANSLATOR = transformers.pipeline(
            "translation_en_to_de", 
            model=ENGLISH_TO_GERMAN_MODEL)
    
    return locals()

PIPELINES = execute(initialize_nlp_pipelines, verbose=VERBOSE)

In [None]:
# @title | NLP | Legacy NLP Implementation
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄

from typing import Dict
from typing import List
from typing import Tuple
from typing import Callable
from typing import Optional

import copy

import pandas as pd


"""~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Section for Classification on a Zero-Shot basis.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"""

def zero_shot_classification(input: str, 
                             labels: List[str], 
                             top_k: Optional[int] = 1,
                             **kwargs) -> List[str]:
    return PIPELINES["ZERO_SHOT"](input, labels)["labels"][:top_k]

def skill_classification(input: str, 
                         skills: List[str], 
                         verbose: Optional[bool] = False,
                         **kwargs) -> List[str]:
    if verbose:
        print(f"[DEBUG] |Skill Classification| input: {input}")
        print(f"[DEBUG] |Skill Classification| skills: {skills}")

    skill = zero_shot_classification(input, skills, **kwargs)[0]

    if verbose:
        print(f"[DEBUG] |Skill Classification| skill: {skill}")
    return skill

def sentiment_classification(input: str,
                             labels: List[str],
                             verbose: Optional[bool] = False,
                             **kwargs) -> List[str]:
    if verbose:
        print(f"[DEBUG] |Sentiment Classification| input: {input}")
        print(f"[DEBUG] |Sentiment Classification| labels: {labels}")

    label = zero_shot_classification(input, labels, **kwargs)[0]

    if verbose:
        print(f"[DEBUG] |Sentiment Classification| label: {label}")
    return label


"""~~~~~~~~~~~~~~~~~~~
Section for Table QA.
~~~~~~~~~~~~~~~~~~~"""

def table_question_answering(input: str, 
                             table: pd.DataFrame, 
                             verbose: Optional[bool] = False, 
                             **kwargs) -> str:
    if verbose:
        print(f"[DEBUG] |Table Question Answering| input: {input}")
        print(f"[DEBUG] |Table Question Answering| table: \n{table}")
    
    output = PIPELINES["TABLE_QA"](table=table, query=input)

    if verbose:
        print(f"[DEBUG] |Table Question Answering| output: \n{output}")
    return output["answer"]


"""~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Section for Few-Shot Text Generation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"""


def few_shot(query: str, 
             samples: str, 
             verbose: Optional[bool] = False, 
             **kwargs) -> List[str]:
    if verbose:
        print(f"[DEBUG] |Few-Shot Text Generation| query: \n{query}")
        print(f"[DEBUG] |Few-Shot Text Generation| samples: \n{samples}")
    
    outputs = PIPELINES["FEW_SHOT"](samples + query, **kwargs)
    outputs = [sample["generated_text"] for sample in outputs]

    if verbose:
        print(f"[DEBUG] |Few-Shot Text Generation| outputs: \n{outputs}")
    return outputs

def table_qa_few_shot(query: str, 
                      samples: str, 
                      verbose: Optional[bool] = False, 
                      **kwargs) -> List[str]:
    if verbose:
        print(f"[DEBUG] |Table QA Few-Shot Text Generation| query: \n{query}")
        print(f"[DEBUG] |Table QA Few-Shot Text Generation| samples: \n{samples}")

    outputs = few_shot(query, 
                      samples, 
                      verbose, 
                      **kwargs)
    
    for i, sample in enumerate(outputs):
        sample = sample[len(samples + query):]
        sample = sample.split('\n\n')[0]
        outputs[i] = sample

    if verbose:
        print(f"[DEBUG] |Table QA Few-Shot Text Generation| outputs: \n{outputs}")
    return outputs


"""~~~~~~~~~~~~~~~~~~~~~~~~~
Section for Skill Functions
~~~~~~~~~~~~~~~~~~~~~~~~~"""


def legacy_table_qa_skill(conversation: transformers.Conversation, 
                          associations: dict, 
                          verbose: Optional[bool] = False, 
                          **kwargs) -> List[str]:
    if verbose:
        print(f"[DEBUG] |Legacy Table QA Skill| input conversation: \n{conversation}")
        print(f"[DEBUG] |Legacy Table QA Skill| associations: \n{associations}")

    input = conversation.new_user_input
    labels = list(associations.keys())

    variant = zero_shot_classification(
        input, labels, **kwargs)[0]

    data = associations[variant]["data"]()
    samples = associations[variant]["samples"]
    config = associations[variant]["config"]

    cell = table_question_answering(input, data, verbose)

    query = f"Q: I am in {location()}. It is {date()} at {time()}. {input}"
    query = f'Q: {input}\nC: {cell}\n'

    outputs = table_qa_few_shot(
        query, 
        samples, 
        verbose, 
        **{**config, **kwargs})
    outputs = [output.replace("A: ", "") for output in outputs]

    if verbose:
        print(f"[DEBUG] |Legacy Table QA Skill| variant: {variant}")
        print(f"[DEBUG] |Legacy Table QA Skill| cell: {cell}")
        print(f"[DEBUG] |Legacy Table QA Skill| outputs: {outputs}")
    return outputs

def legacy_small_talk_skill(conversation: transformers.Conversation,
                            associations: dict,
                            verbose: Optional[bool] = False, 
                            **kwargs) -> List[str]:
    if verbose:
        print(f"[DEBUG] |Legacy Small Talk Skill| input conversation: \n{conversation}")

    num_return_sequences = associations["num_return_sequences"]
    
    conversations = [copy.deepcopy(conversation) for _ in range(num_return_sequences)]
    conversations = PIPELINES["SMALL_TALK"](conversations)
    outputs = [conversation.generated_responses[-1] for conversation in conversations]

    if verbose:
        print(f"[DEBUG] |Legacy Small Talk Skill| num_return_sequences: {num_return_sequences}")
        print(f"[DEBUG] |Legacy Small Talk Skill| outputs: \n{outputs}")
    return outputs


"""~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Section for Personas and Warm Up
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"""

def legacy_warm_up(conversation: transformers.Conversation, 
                   personas: List[str],
                   verbose: bool = False,
                   **kwargs) -> transformers.Conversation:
    for persona in personas:
        conversation.add_user_input(persona)
        conversation = PIPELINES["SMALL_TALK"](conversation)

    if verbose:
        print(f"[DEBUG] |Legacy Warm Up| personas: {personas}")
        print(f"[DEBUG] |Legacy Warm Up| personas: {conversation}")
    return conversation


"""~~~~~~~~~~~~~~~~~
Language Processors
~~~~~~~~~~~~~~~~~"""

def legacy_german_to_english_translation(input: str,
                                         verbose: Optional[bool] = False, 
                                         **kwargs) -> str:
    if verbose:
        print(f"[DEBUG] |Legacy German-To-English Translation| input: {input}")

    translation = PIPELINES["GERMAN_TO_ENGLISH_TRANSLATOR"](input, **{"num_beams": 40, **kwargs})

    if verbose:
        print(f"[DEBUG] |Legacy German-To-English Translation| translation: {translation}")
    return translation[0]["translation_text"]

def legacy_english_to_german_translation(input: str,
                                         verbose: Optional[bool] = False, 
                                         **kwargs) -> str:
    if verbose:
        print(f"[DEBUG] |Legacy English-To-German Translation| input: {input}")

    translation = PIPELINES["ENGLISH_TO_GERMAN_TRANSLATOR"](input, **{"num_beams": 40, **kwargs})

    if verbose:
        print(f"[DEBUG] |Legacy English-To-German Translation| translation: {translation}")
    return translation[0]["translation_text"]


"""~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Section Few-Shot Samples and their utilities.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"""

legacy_travel_samples = """Q: It is 12:30. Wich is the next train from Frankfurt to Leipzig?
C: ICE 1655
A: The next train to Leipzig is the ICE 1655.

Q: It is 17:22. Wich is the next train from Frankfurt to Leipzig?
C: ICE 594
A: The next train to Leipzig is the ICE 594.

Q: When will the ICE 594 from Frankfurt arrive in Leipzig.
C: 21:10
A: The ICE 594 will arrive at 21:10.

Q: How long will the ICE 1655 need to get from Frankfurt to Leipzig?
C: 03:03
A: The ICE 1655 will need 3 hours and 3 minutes.

Q: At wich track will the FLX 1354 from Berlin arrive?
C: 5
A: The FLX 1354 from Berlin will arrive at the track 5.

Q: Which train is the fastest option from Berlin to Hamburg?
C: ICE 806
A: The ICE 806 is the fastest option.

Q: Which is the fastest option from Berlin to Hamburg?
C: ICE 806
A: The ICE 806 is the fastest option.

Q: Can I take a Flixtrain from Berlin to Hamburg?
C: FLX 1354
A: The Flixtrain FLX 1354 will travel to Hamburg.

"""

legacy_event_samples = """Q: It is the 2021-09-16. When is the next event?
C: 2021-09-16
A: The next event will take place at the 16th September.

Q: Wich event will take place the 18th November?
C: Herbst-Winter-Basar
A: The Herbst-Winter-Basar will take place at the 18th November.

Q: When will the Musical Ausgetickt end?
C: 17:00
A: The Musical Ausgetickt will end 17:00.

Q: What is the Info evening Well prepared for self-employment about?
C: This free info event "Well prepared for self-employment" will be held with the team of our cooperation partner "gruenderberatungen.de"
A: The info event will be about beeing well prepared for self-employment.

"""

legacy_timetable_samples = """Q: Wich lectures are planned for the 30th of August?
C: Current Affairs, IT Law
A: The lectures Current Affairs and IT Law are planned for the 30th of August.

Q: It is the 2021-08-31. How late will the lecture Design and Implementation of Databases end?
C: 11:45
A: The lecture Design and Implementation of Databases will end 11:45.

Q: Which lecturer will give the lecture Finance and Investment?
C: Henry Ford
A: Henry Ford will give the lecture Finance and Investment.

Q: In wich room will the lecture Servicemanagement und ERP be?
C: Assembly hall
A: The lecture will be given in the assembly hall.

Q: It is the 2021-09-01. What is planned for tomorrow?
C: Practice/Project groups
A: Practice and Project groups is scheduled for tomorrow.

Q: It is the 2021-09-01. When do I have my next lecture?
C: 2021-09-03
A: Your next lecture will be at the 3rd of September.

"""

legacy_restaurant_samples = """Q: How far away is the nearest restaurant.
C: 0.6 km
A: The nearest restaurant is 0.6 km away.

Q: What is the closest Italian restaurant?
C: La Scala
A: The nearest restaurant is La Scala.

Q: What kind of food does the restaurant Cuervo serve?
C: Mexican
A: The restaurant Cuervo serves Mexican food.

Q: Can you tell me the best rated restaurants you know?
C: Reatuarant zagreb, Pizzeria Romana, Ristaurante Tie-Break
A: The best rated restaurants i know are Reatuarant zagreb, Pizzeria Romana, Ristaurante Tie-Break.

"""

def length(samples: str, model: str) -> int:
    tokenizer = PIPELINES["FEW_SHOT_TOKENIZER"]
    input_ids = tokenizer(
        samples, return_tensors="pt").input_ids
    return input_ids.shape[-1]


"""~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Section for the Legacy Skills configuration.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"""

LEGACY_PERSONAS = {
    "warm_up": legacy_warm_up,
    "personas": []
}

LEGACY_SMALL_TALK_SKILL = {
    "associations": {
        "num_return_sequences": 2,
    }, 
    "function": legacy_small_talk_skill
}

LEGACY_TRAVEL_SKILL = {
    "associations": {
        "any": {
            "data": travel_table,
            "samples": legacy_travel_samples,
            "config": {
                "temperature": 0.1,
                "do_sample": False,
                "max_length": length(
                    legacy_travel_samples,
                    FEW_SHOT_MODEL) + 100,
            }
        }
    }, 
    "function": legacy_table_qa_skill
}

LEGACY_EVENT_SKILL = {
    "associations": {
        "any": {
            "data": event_table,
            "samples": legacy_event_samples,
            "config": {
                "temperature": 0.1,
                "do_sample": False,
                "max_length": length(
                    legacy_event_samples,
                    FEW_SHOT_MODEL) + 100,
            }
        }
    }, 
    "function": legacy_table_qa_skill
}

LEGACY_TIMETABLE_SKILL = {
    "associations": {
        "any": {
            "data": timetable_table,
            "samples": legacy_timetable_samples,
            "config": {
                "temperature": 0.1,
                "do_sample": False,
                "max_length": length(
                    legacy_timetable_samples,
                    FEW_SHOT_MODEL) + 100,
            }
        }
    }, 
    "function": legacy_table_qa_skill
}

LEGACY_RESTAURANT_SKILL = {
    "associations": {
        "any": {
            "data": restaurant_table,
            "samples": legacy_restaurant_samples,
            "config": {
                "temperature": 0.1,
                "do_sample": False,
                "max_length": length(
                    legacy_restaurant_samples,
                    FEW_SHOT_MODEL) + 100,
            }
        }
    }, 
    "function": legacy_table_qa_skill
}

### Set up fot the ***AI21 NLP*** Components



In [None]:
# @title | NLP | Set up AI21 Studio API Key
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄


from getpass import getpass
import requests
import json

AI21_API_KEY = getpass("""
    ▄▄▄▄▄▄▄ ▄▄▄ ▄▄▄▄▄▄▄ ▄▄▄▄    ▄▄▄▄▄▄▄ ▄▄▄▄▄▄▄ ▄▄   ▄▄ ▄▄▄▄▄▄  ▄▄▄ ▄▄▄▄▄▄▄ 
    █       █   █       █    █  █       █       █  █ █  █      ██   █       █
    █   ▄   █   █▄▄▄▄   ██   █  █  ▄▄▄▄▄█▄     ▄█  █ █  █  ▄    █   █   ▄   █
    █  █▄█  █   █▄▄▄▄█  ██   █  █ █▄▄▄▄▄  █   █ █  █▄█  █ █ █   █   █  █ █  █
    █       █   █ ▄▄▄▄▄▄██   █  █▄▄▄▄▄  █ █   █ █       █ █▄█   █   █  █▄█  █
    █   ▄   █   █ █▄▄▄▄▄ █   █   ▄▄▄▄▄█ █ █   █ █       █       █   █       █
    █▄▄█ █▄▄█▄▄▄█▄▄▄▄▄▄▄██▄▄▄█  █▄▄▄▄▄▄▄█ █▄▄▄█ █▄▄▄▄▄▄▄█▄▄▄▄▄▄██▄▄▄█▄▄▄▄▄▄▄█

    Note: If you DO NOT wish to use the AI21 toolkit simply press Enter.
    Paste your AI21 Studio API key here: """)


    ▄▄▄▄▄▄▄ ▄▄▄ ▄▄▄▄▄▄▄ ▄▄▄▄    ▄▄▄▄▄▄▄ ▄▄▄▄▄▄▄ ▄▄   ▄▄ ▄▄▄▄▄▄  ▄▄▄ ▄▄▄▄▄▄▄ 
    █       █   █       █    █  █       █       █  █ █  █      ██   █       █
    █   ▄   █   █▄▄▄▄   ██   █  █  ▄▄▄▄▄█▄     ▄█  █ █  █  ▄    █   █   ▄   █
    █  █▄█  █   █▄▄▄▄█  ██   █  █ █▄▄▄▄▄  █   █ █  █▄█  █ █ █   █   █  █ █  █
    █       █   █ ▄▄▄▄▄▄██   █  █▄▄▄▄▄  █ █   █ █       █ █▄█   █   █  █▄█  █
    █   ▄   █   █ █▄▄▄▄▄ █   █   ▄▄▄▄▄█ █ █   █ █       █       █   █       █
    █▄▄█ █▄▄█▄▄▄█▄▄▄▄▄▄▄██▄▄▄█  █▄▄▄▄▄▄▄█ █▄▄▄█ █▄▄▄▄▄▄▄█▄▄▄▄▄▄██▄▄▄█▄▄▄▄▄▄▄█

    Note: If you DO NOT wish to use the AI21 toolkit simply press Enter.
    Paste your AI21 Studio API key here: ··········


In [None]:
# @title | NLP | AI21 NLP Implementation
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄


from typing import List

import json
import requests
import pandas as pd


"""~~~~~~~~~~~~~~~~~~~
Section for utilities.
~~~~~~~~~~~~~~~~~~~"""


class Ai21ApiKeyException(Exception):
    pass

class Ai21ApiResponseException(Exception):
    pass


def ai21_pipeline(model: str,
                  input: str,
                  num_beams: int = 0,
                  num_return_sequences: int = 1,
                  max_length: int = 100,
                  stop_sequences: List[str] = [],
                  top_p: float = 0.98,
                  top_k: int = 0,
                  temperature: float = 0.0,
                  verbose: bool = False,
                  **kwargs) -> List[str]:
        if verbose:
            print(f"[DEBUG] |AI21 Studio Pipeline| model: {model}")
            print(f"[DEBUG] |AI21 Studio Pipeline| input: {input}")
            print(f"[DEBUG] |AI21 Studio Pipeline| config: {locals()}")
        
        if AI21_API_KEY == "":
            raise Ai21ApiKeyException(
                """[Error] No valid AI21 Studio API key was entered!
                Please rerun the "| NLP | Set up AI21 Studio API Key" Cell 
                and enter your valid API Key.""")

        response = requests.post(
            f"https://api.ai21.com/studio/v1/{model}/complete",
            headers={"Authorization": f"Bearer {AI21_API_KEY}"},
            json={
                "prompt": input, 
                "numResults": num_return_sequences, 
                "maxTokens": max_length, 
                "stopSequences": stop_sequences,
                "topP": top_p,
                "topKReturn": top_k,
                "temperature": temperature,
            })
        
        if response.status_code != 200:
            raise Ai21ApiResponseException(
                f"""[Error] The AI21 Studio request has returned a status code other than 200!
                The request returned the following status code: {response.status_code}.
                with the following request body:\n{response.text}""")
        
        outputs = json.loads(response.text)["completions"]
        outputs = [output["data"]["text"] for output in outputs]

        if verbose:
            print(f"[DEBUG] |AI21 Studio Pipeline| outputs: \n{outputs}")
        return outputs

def ai21_preprocess_table(data: pd.DataFrame) -> str:
    table = df_to_csv(data)
    table = table.replace(",", " | ")

    split = table.split("\n")
    del split[-1]

    for i, line in enumerate(split):
        split[i] = f"| {line} |"

    table = "\n".join(split)
    return table

def ai21_warm_up(conversation: transformers.Conversation,
                 *args,
                 verbose: bool = False,
                 **kwargs) -> transformers.Conversation:
    if verbose:
        print(f"[DEBUG] |AI21 Warm Up| personas: {conversation}")
    return conversation


"""~~~~~~~~~~~~~~~~~~~~~~~~~
Section for Skill Functions
~~~~~~~~~~~~~~~~~~~~~~~~~"""

def ai21_table_qa_skill(conversation: transformers.Conversation, 
                        associations: dict, 
                        verbose: Optional[bool] = False, 
                        **kwargs) -> List[str]:
    if verbose:
        print(f"[DEBUG] |AI21 Table QA Skill| input conversation: \n{conversation}")
        print(f"[DEBUG] |AI21 Table QA Skill| associations: \n{associations}")

    input = conversation.new_user_input
    labels = list(associations.keys())

    variant = zero_shot_classification(input, labels, **kwargs)[0]

    data = associations[variant]["data"]()
    
    samples = associations[variant]["samples"]
    config = associations[variant]["config"]
    model = associations[variant]["model"]

    table = ai21_preprocess_table(data)
    query = f"Q: I am in {location()}. It is {date()} at {time()}. {input}"
    input = f"{table}\n\n{samples}{query}"

    outputs = ai21_pipeline(
        model,
        input,
        verbose=verbose,
        **config)
    outputs = [output.replace("\nA: ", "") for output in outputs]

    if verbose:
        print(f"[DEBUG] |AI21 Table QA Skill| variant: {variant}")
        print(f"[DEBUG] |AI21 Table QA Skill| input: \n{input}")
        print(f"[DEBUG] |AI21 Table QA Skill| outputs: {outputs}")
    return outputs

def ai21_small_talk_skill(conversation: transformers.Conversation, 
                          associations: dict,
                          verbose: Optional[bool] = False, 
                          **kwargs) -> List[str]:
    if verbose:
        print(f"[DEBUG] |AI21 Small Talk Skill| input conversation: \n{conversation}")
    
    model = associations["model"]
    samples = associations["samples"]
    config = associations["config"]
    
    input = str(conversation)
    input = input.split("\n")[1:]
    input = "\n".join(input)
    input = samples + input

    outputs = ai21_pipeline(
        model,
        input,
        verbose=verbose,
        **config)
    outputs = [output.replace("bot >> ", "").replace("\n", "") for output in outputs]

    if verbose:
        print(f"[DEBUG] |AI21 Small Talk Skill| outputs: {outputs}")
    return outputs


"""~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Section Few-Shot Samples and their utilities.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"""

ai21_travel_samples = """Q: It is 12:30. Wich is the next train from Frankfurt to Leipzig?
A: The next train to Leipzig is ICE 1655 at 17:21 from track 9.

Q: It is 17:22. Wich is the next train from Frankfurt to Leipzig?
A: The next train to Leipzig is ICE 594 at 18:14 from track 9.

Q: When will the ICE 594 from Frankfurt arrive in Leipzig.
A: The ICE 594 will arrive at 21:10 on the track 13.

Q: How long will the ICE 1655 need to get from Frankfurt to Leipzig?
A: The ICE 1655 will need 3 hours and 3 minutes.

Q: At wich track will the FLX 1354 from Berlin arrive?
A: The FLX 1354 from Berlin will arrive at the track 5 at 10:07.

Q: Which train is the fastest option from Berlin to Hamburg?
A: The ICE 806 is the fastest option. It's travel duration is only 1 hour and 43 minutes.

Q: Which is the fastest option from Berlin to Hamburg?
A: The ICE 806 is the fastest option. It's travel duration is only 1 hour and 43 minutes.

Q: Can I take a Flixtrain from Berlin to Hamburg?
A: The Flixtrain FLX 1354 will travel to Hamburg starting at 08:07 from track 8.

"""

ai21_event_samples = """Q: It is the 2021-09-16. When is the next event?
A: The next event will take place at the 16th September.

Q: Wich event will take place the 18th November?
A: The Herbst-Winter-Basar will take place at the 18th November.

Q: When will the Musical Ausgetickt end?
A: The Musical Ausgetickt will end 17:00.

Q: What is the Info evening Well prepared for self-employment about?
A: The info event will be about beeing well prepared for self-employment.

"""

ai21_timetable_samples = """Q: Wich lectures are planned for the 30th of August?
A: The lectures Current Affairs and IT Law are planned for the 30th of August.

Q: It is the 2021-08-31. When will the lecture Design and Implementation of Databases end?
A: The lecture Design and Implementation of Databases will end 11:45.

Q: Which lecturer will give the lecture Finance and Investment?
A: Henry Ford will give the lecture Finance and Investment.

Q: In wich room will the lecture Servicemanagement und ERP be?
A: The lecture will be given in the assembly hall. It will start 12:45 and last 195 minutes.

Q: It is the 2021-09-01. What is planned for tomorrow?
A: Practice and Project groups is scheduled for tomorrow.

Q: It is the 2021-09-01. When do I have my next lecture?
A: Your next lecture will be at the 3rd of September.

"""

ai21_restaurant_samples = """Q: How far away is the nearest restaurant.
A: The nearest restaurant is 0.6 km away.

Q: What is the closest Italian restaurant?
A: The nearest restaurant is La Scala. It is just 0.6 km away and servers Italian food.

Q: What kind of food does the restaurant Cuervo serve?
A: The restaurant Cuervo serves Mexican food. It has an average rating of 4 out of 5.

Q: Can you tell me the best rated restaurants you know?
A: The best rated restaurants i know are Reatuarant zagreb, Pizzeria Romana, Ristaurante Tie-Break they share an average rating of 4.5 out of 5.

Q: What is the furthest restaurant?
A: The furthest restaurant I know is the restaurant Wolfsschlucht. It is 14.4 km away and serves German food.

Q: What is the closest restaurant?
A: The nearest restaurant is La Scala. It is just 0.6 km away and servers Italian food.

"""

ai21_small_talk_samples = """user >> what is your name
bot >> My name is Mia. How can I help you? How are you doing?
user >> what is your job
bot >> I'm your assistant. Feel free to ask me about travel or keep the small talk going.
user >> i m going away now bye
bot >>  Good Bye. It was nice meeting you. See you soon.

user >> how are you today
bot >> I'm fine thanks! How are you? How can I help you?
user >> i am good would you rather have a dragon or unicorn as a pet
bot >> A unicorn! But both would be awesome pets! How about you?
user >> i think i would take dragon they can fly so traveling should be a breeze
bot >> That is a good point. Can I help you with anything else?
user >> no i am fine good bye see you later
bot >>  Good Bye. It was nice meeting you. See you soon.

"""


"""~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Section for the AI21 Skills configuration.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"""

AI21_PERSONAS = {
    "warm_up": ai21_warm_up,
    "personas": []
}

# model = "j1-jumbo"
model = "j1-large"

AI21_SMALL_TALK_SKILL = {
    "associations": {
        "model": model,
        "samples": ai21_small_talk_samples,
        "config": {
            "num_results": 10,
            "max_tokens": 64,
            "temperature": 0.7,
            "top_p": 0.98,
            "stop_sequences": ["user >>"],
        }
    },
    "function": ai21_small_talk_skill
}

AI21_TABLE_QA_CONFIG = {
    "max_tokens": 100,
    "temperature": 0.0,
    "top_p": 1.0,
    "stop_sequences": ["\n\n"],
}

AI21_TRAVEL_SKILL = {
    "associations": {
        "travel": {
            "model": model,
            "data": travel_table,
            "samples": ai21_travel_samples,
            "config": AI21_TABLE_QA_CONFIG
        }
    },
    "function": ai21_table_qa_skill
}

AI21_EVENT_SKILL = {
    "associations": {
        "travel": {
            "model": model,
            "data": event_table,
            "samples": ai21_event_samples,
            "config": AI21_TABLE_QA_CONFIG
        }
    },
    "function": ai21_table_qa_skill
}

AI21_TIMETABLE_SKILL = {
    "associations": {
        "travel": {
            "model": model,
            "data": timetable_table,
            "samples": ai21_timetable_samples,
            "config": AI21_TABLE_QA_CONFIG
        }
    },
    "function": ai21_table_qa_skill
}

AI21_RESTAURANT_SKILL = {
    "associations": {
        "travel": {
            "model": model,
            "data": restaurant_table,
            "samples": ai21_restaurant_samples,
            "config": AI21_TABLE_QA_CONFIG
        }
    },
    "function": ai21_table_qa_skill
}

### Set up fot the ***DeepL NLP*** Components



In [None]:
# @title | NLP | Set up DeepL API Key
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄


from getpass import getpass
import requests

DEEPL_API_KEY = getpass("""
    ██████╗ ███████╗███████╗██████╗ ██╗     
    ██╔══██╗██╔════╝██╔════╝██╔══██╗██║     
    ██║  ██║█████╗  █████╗  ██████╔╝██║     
    ██║  ██║██╔══╝  ██╔══╝  ██╔═══╝ ██║     
    ██████╔╝███████╗███████╗██║     ███████╗
    ╚═════╝ ╚══════╝╚══════╝╚═╝     ╚══════╝

    Note: If you DO NOT wish to use the DeepL component simply press Enter.
    Paste your DeepL API key here: """)


    ██████╗ ███████╗███████╗██████╗ ██╗     
    ██╔══██╗██╔════╝██╔════╝██╔══██╗██║     
    ██║  ██║█████╗  █████╗  ██████╔╝██║     
    ██║  ██║██╔══╝  ██╔══╝  ██╔═══╝ ██║     
    ██████╔╝███████╗███████╗██║     ███████╗
    ╚═════╝ ╚══════╝╚══════╝╚═╝     ╚══════╝

    Note: If you DO NOT wish to use the DeepL component simply press Enter.
    Paste your DeepL API key here: ··········


In [None]:
# @title | NLP | DeepL Translation Implementation
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄


"""~~~~~~~~~~~~~~~~~
Language Processors
~~~~~~~~~~~~~~~~~"""

class DeeplApiKeyException(Exception):
    pass

class DeeplApiResponseException(Exception):
    pass
    

def deepl_translation(text: str,
                      target_lang: str="DE",
                      verbose: Optional[bool] = False,
                      **kwargs) -> dict:
    if DEEPL_API_KEY == "":
        raise DeeplApiKeyException(
            """[Error] No valid  DeepL API key was entered!
            Please rerun the "| NLP | Set up DeepL API Key" Cell 
            and enter your valid API Key.""")
    
    url = "https://api-free.deepl.com/v2/translate"
    headers = {"Content-Type": "application/x-www-form-urlencoded"}
    data = f"auth_key={DEEPL_API_KEY}&text={text}&target_lang={target_lang}"

    response = requests.post(url, headers=headers, data=data)
    
    if response.status_code != 200:
        raise DeeplApiResponseException(
            f"""[Error] The DeepL API request has returned a status code other than 200!
            The request returned the following status code: {response.status_code}.
            with the following request body:\n{response.text}""")
        
    json = response.json()
    translation = json["translations"][0]

    return translation

def deepl_german_to_english_translation(input: str,
                                        verbose: Optional[bool] = False, 
                                        **kwargs) -> str:
    if verbose:
        print(f"[DEBUG] |DeepL German-To-English Translation| input: {input}")

    translation = deepl_translation(input, target_lang="EN")

    if verbose:
        print(f"[DEBUG] |DeepL German-To-English Translation| translation: {translation}")
    return translation["text"]

def deepl_english_to_german_translation(input: str,
                                        verbose: Optional[bool] = False, 
                                        **kwargs) -> str:
    if verbose:
        print(f"[DEBUG] |DeepL English-To-German Translation| input: {input}")

    translation = deepl_translation(input, target_lang="DE")

    if verbose:
        print(f"[DEBUG] |DeepL English-To-German Translation| translation: \n{translation}")
    return translation["text"]

### Set up fot the ***NLP*** Module

In [None]:
# @title | NLP | Set up Module
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄

#@markdown ---
#@markdown ### Selection between the Skill Versions 👀🔀
SMALL_TALK_SKILL_VERSION = "ai21" #@param ["legacy", "ai21"]
TRAVEL_SKILL_VERSION = "ai21" #@param ["legacy", "ai21"]
EVENT_SKILL_VERSION = "ai21" #@param ["legacy", "ai21"]
TIMETABLE_SKILL_VERSION = "ai21" #@param ["legacy", "ai21"]
RESTAURANT_SKILL_VERSION = "ai21" #@param ["legacy", "ai21"]
TRANSLATION_COMPONENT = "deepl" #@param ["legacy", "deepl"]

#@markdown ---
VERBOSE = True # @param {type:"boolean"}

"""~~~~~~~~~~~~~~~
NLP Configuration
~~~~~~~~~~~~~~~"""

LANGUAGES = {
    "en": {
        "to_native": lambda x: x,
        "to_source": lambda x: x,
    },
    "de": {
        "to_native": legacy_german_to_english_translation if TRANSLATION_COMPONENT == "legacy" else deepl_german_to_english_translation,
        "to_source": legacy_english_to_german_translation if TRANSLATION_COMPONENT == "legacy" else deepl_english_to_german_translation,
    }
}

SENTIMENT = {
    "positive": ["non-toxic", "travel", "small talk"],
    "negative": ["toxic", "vulgar", "sex", "sexual", "criminal"],
}

PERSONAS = LEGACY_PERSONAS if SMALL_TALK_SKILL_VERSION == "legacy" else AI21_PERSONAS

SKILLS = {
    "travel": {
        "labels": ["travel", "travel on time", "travel delayed"],
        "pipeline": LEGACY_TRAVEL_SKILL if TRAVEL_SKILL_VERSION == "legacy" else AI21_TRAVEL_SKILL
    },
    "event": {
        "labels": ["event", "events"],
        "pipeline": LEGACY_EVENT_SKILL if EVENT_SKILL_VERSION == "legacy" else AI21_EVENT_SKILL
    },
    "timetable": {
        "labels": ["lecture", "professor", "university"],
        "pipeline": LEGACY_TIMETABLE_SKILL if TIMETABLE_SKILL_VERSION == "legacy" else AI21_TIMETABLE_SKILL
    },
    "restaurant": {
        "labels": ["restaurant", "food", "serve food", "eat"],
        "pipeline": LEGACY_RESTAURANT_SKILL if RESTAURANT_SKILL_VERSION == "legacy" else AI21_RESTAURANT_SKILL

    },
    "small talk": {
        "labels": ["small talk", "other"],
        "pipeline": LEGACY_SMALL_TALK_SKILL if SMALL_TALK_SKILL_VERSION == "legacy" else AI21_SMALL_TALK_SKILL
    }
}

CONFIG = {
    "languages": LANGUAGES,
    "sentiment": SENTIMENT,
    "personas": PERSONAS,
    "skills": SKILLS,
}


class NegativeInputCapturedException(Exception):
    pass

class NegativeOutputsCapturedException(Exception):
    pass


class NLP:
    def __init__(self,
                 config: dict = CONFIG,
                 verbose: bool = True,
                 **kwargs):
        self.config = config

        language = config["languages"][LANGUAGE]
        self.to_native = language["to_native"]
        self.to_source = language["to_source"]

        self.sentiment_labels = []
        for _, labels in self.config["sentiment"].items():
            self.sentiment_labels.extend(labels)

        personas = config["personas"]
        self.conversation = personas["warm_up"](
            transformers.Conversation(), 
            personas=personas["personas"],
            verbose=verbose,
            **kwargs)

        self.skills = config["skills"]

        self.skill_labels = []
        for _, skill in self.skills.items():
            self.skill_labels.extend(skill["labels"])
        
    def __call__(self, 
                 input: str, 
                 verbose: bool = False, 
                 **kwargs) -> str:
        if verbose:
            print(f"[DEBUG] |NLP __call__ <START>|" + "~"*20)
            print(f"[DEBUG] |NLP ATTR skills|: {self.skills}")
            print(f"[DEBUG] |NLP ATTR conversation|: \n{self.conversation}")
            print(f"[DEBUG] |NLP User input|: {input}")
        # Convert input from the source to native language.
        input = self.to_native(input)

        # Check if sentiment of the input is negative.
        if self.is_sentiment("negative", input):
            raise NegativeInputCapturedException(
                "[Warning] A negative input was captured and discarded.")
        self.conversation.add_user_input(input)

        # Match a skill to the given input.
        label = skill_classification(
            input, 
            self.skill_labels,
            verbose=verbose, 
            **kwargs)
        
        # Collect components to do further processing.
        skill = self.skill_from_label(label)
        pipeline = skill["pipeline"]
        function = pipeline["function"]
        associations = pipeline["associations"]
        
        # Generate the skills outputs.
        outputs = function(
            self.conversation, 
            associations=associations, 
            verbose=verbose,
            **kwargs)
        
        # Find the first output with a positive sentiment.
        output = ""
        for sample in outputs:
            if self.is_sentiment("positive", sample):
                output = sample
                break

        self.conversation.mark_processed()
        self.conversation.append_response(output)

        # Issue a warning if no outputs where positive.
        if output == "":
            raise NegativeOutputsCapturedException(
                "[Warning] All outputs where negative and discarded.")
        
        # Convert input from the native to source language.
        output = self.to_source(output)

        if verbose:
            print(f"[DEBUG] |NLP conversation |: \n{self.conversation}")
            print(f"[DEBUG] |NLP outputs|: \n{outputs}")
            print(f"[DEBUG] |NLP output|: {output}")
            print(f"[DEBUG] |NLP __call__ <END>|" + "~"*20)
        return output

    def is_sentiment(self, name: str, input: str) -> bool:
        """Return if the input has a given sentiment."""
        label = sentiment_classification(
            input, self.sentiment_labels)
        labels = self.config["sentiment"][name]
        
        return label in labels
        
    def skill_from_label(self, label: str) -> dict:
        """Return the first skill that has the label."""
        for _, skill in self.skills.items():
            if label in skill["labels"]:
                return skill

        raise Exception("The classified skill_label is not mapped to a skill.")


nlp = execute(NLP, verbose=VERBOSE, config=CONFIG)

[DEBUG] |AI21 Warm Up| personas: Conversation id: 16569cc9-8b1c-4a29-b96d-419fe0dc889c 



---


## ***Speech Recognition (STT)*** 🎤💬


---


### Set up for ***Legacy Speech Recognition***

In [None]:
# @title | STT | Installation of Legacy Dependencies ⇩

#@markdown ---
VERBOSE = False # @param {type:"boolean"}

def install_legacy_sst_dependencies(**kwargs):
    !pip install transformers
    !pip install numpy==1.20
    !pip install numba==0.48
    !pip install ffmpeg-python
    !pip install -q https://github.com/tugstugi/dl-colab-notebooks/archive/colab_utils.zip

execute(install_legacy_sst_dependencies, verbose=VERBOSE)

In [None]:
# @title | STT | Legacy Wav2Vec2 Speech Recognition
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄

#@markdown ---
VERBOSE = False # @param {type:"boolean"}

from typing import List

import torch
import numpy as np


def speech_to_text_implementation(**kwargs):
    from transformers import Wav2Vec2Tokenizer
    from transformers import Wav2Vec2ForCTC

    STT_MODEL = "facebook/wav2vec2-large-960h-lv60-self" if LANGUAGE == "en" else "facebook/wav2vec2-large-xlsr-53-german"

    # load model and tokenizer
    tokenizer = Wav2Vec2Tokenizer.from_pretrained(STT_MODEL)
    wav2vec2 = Wav2Vec2ForCTC.from_pretrained(STT_MODEL)

    def speech_to_text(audio: np.ndarray, 
                       **kwargs) -> List[str]:   
        input_values = tokenizer(
            [audio], 
            return_tensors="pt", 
            padding="longest"
        ).input_values

        logits = wav2vec2(input_values).logits
        predicted_ids = torch.argmax(logits, dim=-1)

        text = tokenizer.batch_decode(predicted_ids)
        text = " ".join(text)
        return text

    return speech_to_text

legacy_stt = execute(speech_to_text_implementation, verbose=VERBOSE)

### Set up for ***Google Cloud Speech Recognition***

In [None]:
# @title | STT | Installation of Google Cloud Dependencies ⇩

#@markdown ---
VERBOSE = False # @param {type:"boolean"}

def install_gcloud_sst_dependencies(**kwargs):
    !pip install soundfile
    !pip install --upgrade google-auth
    !pip install --upgrade google-cloud-speech
    !pip install numpy==1.20
    !pip install numba==0.48
    !pip install ffmpeg-python
    !pip install -q https://github.com/tugstugi/dl-colab-notebooks/archive/colab_utils.zip

execute(install_gcloud_sst_dependencies, verbose=VERBOSE)

In [None]:
# @title | STT | Mount Google Drive to access Google Cloud credentials
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄


from google.oauth2 import service_account
from google.colab import drive

drive.mount('/content/gdrive')
GCLOUD_STT_CREDENTIALS = service_account.Credentials.from_service_account_file(
    '/content/gdrive/MyDrive/projects/TREX/STT/key.json')

Mounted at /content/gdrive


In [None]:
# @title | STT | Google Cloud Speech Recognition Implementation
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄


import numpy as np
import soundfile as sf
from google.cloud import speech


def gcloud_stt(data: np.ndarray,
               rate: int=16000,
               language_code: str="en-US",
               speech_file: str="speech_file.flac",
               encoding=speech.RecognitionConfig.AudioEncoding.FLAC,
               credentials=GCLOUD_STT_CREDENTIALS,
               verbose: bool=False) -> str:
    """Transcribe audio data via the Google Cloud Speech-To-Text Service.
    
    Args:
        data (np.ndarray): The audio data.
    
    Kwargs:
        speech_file (str): A file in which the audio is stored.
        rate (int): The sample rate of the audio.
        encoding (enum): The encoding of the audio file.
        language_code (str): The language of the speech.

    Returns:
        (str) The most likely transcript.

    Note:
        Transcription is limited to a 60 seconds audio file.
        Use a GCS file for audio longer than 1 minute.
    """
    sf.write(speech_file, data, rate)

    client = speech.SpeechClient(credentials=credentials)

    with io.open(speech_file, "rb") as audio_file:
        content = audio_file.read()

    audio = speech.RecognitionAudio(content=content)
    config = speech.RecognitionConfig(
        encoding=encoding,
        sample_rate_hertz=rate,
        language_code=language_code)

    operation = client.long_running_recognize(
        config=config, 
        audio=audio)

    if verbose:
        print("[DEBUG] | Google Cloud STT | Waiting for operation to complete...")
    response = operation.result(timeout=90)

    # Each result is for a consecutive portion of the audio. Iterate through
    # them to get the transcripts for the entire audio file.
    for result in response.results:
        # The first alternative is the most likely one for this portion.
        transcript = result.alternatives[0].transcript
        confidence = result.alternatives[0].confidence

        if verbose:
            print(u"[DEBUG] | Google Cloud STT | Transcript: {}".format(transcript))
            print("[DEBUG] | Google Cloud STT | Confidence: {}".format(confidence))
        return transcript
    return None

### Set up audio recording utilities

In [None]:
# @title | STT | Audio Recording Utils
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄
"""Utils for recording audio in a Google Colaboratory notebook.

This code is adapted from:
    https://ricardodeazambuja.com/deep_learning/2019/03/09/audio_and_video_google_colab/
    https://colab.research.google.com/gist/ricardodeazambuja/03ac98c31e87caf284f7b06286ebf7fd/microphone-to-numpy-array-from-your-browser-in-colab.ipynb
"""

SILENT = "&> /dev/null"

import io
import ffmpeg
import numpy as np
from base64 import b64decode
from scipy.io.wavfile import write
from dl_colab_notebooks.audio import audio_bytes_to_np

from IPython.display import display
from IPython.display import HTML
from google.colab.output import eval_js


STYLES_HTML = """
<script>

var styles = `

button {
    width: 300px;
    height: 54px;

    padding: 20px;
    margin: 5px;

    display: flex;
    justify-content: center;
    align-items: center;
    border-radius: 40px;
    border: none;

    text-align: center;
    font-size: 28px;
    
    transition: all 0.5s;
    cursor: pointer;
}

button span {
    display: inline-block;
    position: relative;

    cursor: pointer;
    transition: 0.5s;
}

button span:after {
    content: '🙏';

    position: absolute;
    right: -20px;

    opacity: 0;
    transition: 0.5s;
}

button:hover span {
    padding-right: 25px;
}

button:hover span:after {
    right: 0;
    opacity: 1;
}
`

var styleSheet = document.createElement("style")
styleSheet.type = "text/css"
styleSheet.innerText = styles
document.head.appendChild(styleSheet);

</script>
"""

AUDIO_HTML = """
<script>

var container = document.createElement("div");
var button = document.createElement("button");
var span = document.createElement("span");

button.appendChild(span);
container.appendChild(button);
document.body.appendChild(container);

var base64data = 0;
var reader, recorder, gumStream;

var handleSuccess = function(stream) {
    gumStream = stream;
    var options = {
            mimeType : 'audio/webm;codecs=opus'
    };            
    recorder = new MediaRecorder(stream);
    recorder.ondataavailable = function(e) {            
        var url = URL.createObjectURL(e.data);
        var preview = document.createElement('audio');

        preview.controls = true;
        preview.src = url;
        container.appendChild(preview);

        reader = new FileReader();
        reader.readAsDataURL(e.data); 
        reader.onloadend = function() {
            base64data = reader.result;
        }
    };
    recorder.start();
};

span.innerText = "⏸︎";
button.style.verticalAlign = "middle";
navigator.mediaDevices.getUserMedia({audio: true}).then(handleSuccess);

function toggleRecording() {
    if (recorder && recorder.state == "recording") {
        recorder.stop();
        gumStream.getAudioTracks()[0].stop();
        span.innerText = "✅"
    }
}

// https://stackoverflow.com/a/951057
function sleep(ms) {
    return new Promise(resolve => setTimeout(resolve, ms));
}

var data = new Promise(resolve => {
    button.onclick = () => {
        toggleRecording()

        sleep(2000).then(() => {
            resolve(base64data.toString())
        });
    }
});
      
</script>
"""

def record(sample_rate: int = 16000) -> str:
    display(HTML(STYLES_HTML + AUDIO_HTML))
    data = eval_js("data")
    
    audio_bytes = b64decode(data.split(',')[1])
    return audio_bytes_to_np(audio_bytes, sample_rate)

---


## ***Text-To-Speech (TTS)*** 💭📣


---

### Set up for the ***Legacy Text-To-Speech*** Module

In [None]:
# @title | TTS | Installation of Legacy Dependencies ⇩

#@markdown ---
VERBOSE = False # @param {type:"boolean"}

def install_legacy_tts_dependencies(**kwargs):
    !apt-get install -y espeak

    if LANGUAGE == "de":
        !gdown --id 1VG0EI7J6S1bk3h0q1VBc9ALExkdZdeVm -O tts_model.pth.tar
        !gdown --id 1s1GcSihlj58KX0LeA-FPFvdMWGMkcxKI -O config.json
        !gdown --id 1zYFHElvYW_oTeilvbZVLMLscColWRbck -O vocoder_model.pth.tar
        !gdown --id 1ye9kVDbatAKMncRMui7watrLQ_5DaJ3e -O config_vocoder.json
        !gdown --id 1QD40bU_M7CWrj9k0MEACNBRqwqVTSLDc -O scale_stats.npy
        !sudo apt-get install espeak
        !git clone https://github.com/coqui-ai/TTS

        %cd TTS
        !git checkout 540d811
        !pip install -r requirements.txt
        !pip install numpy==1.21.2
        !pip install torchaudio
        !python setup.py install

        # sometimes installation does not work
        import os, sys
        sys.path.append(os.getcwd())
        %cd ..
    else:
        !git clone https://github.com/1ucky40nc3/TransformerTTS.git
        %cd TransformerTTS
        !git checkout package
        !pip install torchaudio
        !pip install -r /content/TransformerTTS/requirements.txt
        !pip install -r /content/TransformerTTS/TransformerTTS/vocoding/extra_requirements.txt
        !python setup.py develop

        !wget https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/hifigan.zip
        !unzip -q hifigan.zip
        !rsync -avq hifigan/ /content/TransformerTTS/TransformerTTS/vocoding/hifigan/

execute(install_legacy_tts_dependencies, verbose=VERBOSE)

In [None]:
# @title | TTS | TTS Implementation
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄

# @markdown ---

# @markdown #### 🆘 ***If "de" is selected as language an error may accure.*** 🆘
# @markdown #### ⏩ Just try to run this cell again. 👻

#@markdown ---
VERBOSE = False # @param {type:"boolean"}

import torch
import numpy as np
from torchaudio import functional as F

def text_to_speech_implementation(**kwargs):
    if LANGUAGE == "de":
        import os
        from TTS.utils.io import load_config
        from TTS.utils.audio import AudioProcessor
        from TTS.tts.utils.io import load_checkpoint
        from TTS.tts.utils.synthesis import synthesis
        from TTS.tts.utils.text.symbols import symbols
        from TTS.tts.utils.generic_utils import setup_model
        from TTS.vocoder.utils.generic_utils import setup_generator
        from TTS.vocoder.utils.io import load_checkpoint as load_vocoder_checkpoint

        TTS_MODEL = "/content/tts_model.pth.tar"
        TTS_CONFIG = "/content/config.json"
        VOCODER_MODEL = "/content/vocoder_model.pth.tar"
        VOCODER_CONFIG = "/content/config_vocoder.json"

        TTS_CONFIG = load_config(TTS_CONFIG)
        TTS_CONFIG.audio["stats_path"] = "/content/scale_stats.npy"

        VOCODER_CONFIG = load_config(VOCODER_CONFIG)

        audio_processor = AudioProcessor(**TTS_CONFIG.audio)

        model, _ = load_checkpoint(
            setup_model(
                num_chars=len(symbols), 
                num_speakers=0,
                c=TTS_CONFIG),
            checkpoint_path=TTS_MODEL)

        vocoder, _ = load_vocoder_checkpoint(
            setup_generator(VOCODER_CONFIG), 
            checkpoint_path=VOCODER_MODEL)
        vocoder.remove_weight_norm()
        vocoder.inference_padding = 0

        model.eval()
        vocoder.eval()

        def text_to_speech(text: str, 
                           **kwargs) -> np.ndarray:
            _, _, _, mel_postnet_spec, _, _ = synthesis(
                model, 
                text, 
                TTS_CONFIG,
                False, 
                audio_processor)
            
            speech = vocoder.inference(
                torch.FloatTensor(
                    mel_postnet_spec.T,
                ).unsqueeze(0))
            speech = speech.flatten().cpu().numpy()

            return speech
        
        return text_to_speech
    
    %cd /content/TransformerTTS

    from TransformerTTS.model.factory import tts_ljspeech
    from TransformerTTS.vocoding.predictors import HiFiGANPredictor


    folder = "/content/TransformerTTS/TransformerTTS/vocoding/hifigan/en"


    model, _ = tts_ljspeech()
    vocoder = HiFiGANPredictor.from_folder(folder)

    def text_to_speech(text: str, 
                       **kwargs) -> np.ndarray:
        speech = model.predict(text)
        speech = speech["mel"].numpy().T
        speech = vocoder([speech])[0]

        return speech

    %cd ..
    return text_to_speech

legacy_tts = execute(text_to_speech_implementation, verbose=VERBOSE)

### Set up for the ***Google Cloud Text-To-Speech*** Module

In [None]:
# @title | TTS | Installation of Google Cloud Dependencies ⇩

#@markdown ---
VERBOSE = False # @param {type:"boolean"}

def install_gcloud_tts_dependencies(**kwargs):
    !pip install soundfile
    !pip install --upgrade google-auth
    !pip install --upgrade google-cloud-texttospeech

execute(install_gcloud_tts_dependencies, verbose=VERBOSE)

In [None]:
# @title | TTS | Mount Google Drive to access Google Cloud credentials
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄


from google.oauth2 import service_account
from google.colab import drive

drive.mount('/content/gdrive')
GCLOUD_TTS_CREDENTIALS = service_account.Credentials.from_service_account_file(
    '/content/gdrive/MyDrive/projects/TREX/TTS/key.json')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
# @title | TTS | Google Cloud Text-To-Speech Implementation
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄

import google.cloud.texttospeech as texttospeech
import scipy


def gcloud_tts(text: str, 
               voice_name: str="en-US-Wavenet-D",
               credentials=GCLOUD_TTS_CREDENTIALS,
               verbose: bool=False) -> np.ndarray:
    text_input = texttospeech.SynthesisInput(text=text)

    language_code = "-".join(voice_name.split("-")[:2])
    voice_params = texttospeech.VoiceSelectionParams(
        language_code=language_code, 
        name=voice_name)
    
    audio_config = texttospeech.AudioConfig(
        audio_encoding=texttospeech.AudioEncoding.LINEAR16)

    client = texttospeech.TextToSpeechClient(
        credentials=credentials)

    response = client.synthesize_speech(
        input=text_input, 
        voice=voice_params, 
        audio_config=audio_config)

    filename = f"{language_code}.wav"
    with open(filename, "wb") as out:
        out.write(response.audio_content)

    if verbose:
        print(f'[DEBUG] | Google Cloud TTS | Generated speech saved to "{filename}"')

    rate, data = scipy.io.wavfile.read(filename)
    return data

### Set up for audio processing utilities

In [None]:
# @title | TTS | Audio Processing Utils
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄

import torch
import numpy as np
from torchaudio import functional as F

def postprocessing(wav: np.ndarray) -> np.ndarray:
    wav = torch.from_numpy(wav)

    wav = wav.unsqueeze(-1).T
    wav = F.apply_codec(
        waveform=wav, 
        sample_rate=22050,
        format="wav", 
        encoding="PCM_F")
    wav = F.resample(
        waveform=wav, 
        orig_freq=22050, 
        new_freq=16000)

    wav = wav.squeeze()
    wav = wav.numpy()
    
    return wav

---


## ***Avatar (PC-AVS)*** 🤗🤖


---

In [None]:
# @title | PC-AVS | Install Dependencies ⇩

#@markdown ---
VERBOSE = False # @param {type:"boolean"}

def install_avatar_dependencies(**kwargs):
    !git clone https://github.com/1ucky40nc3/Talking-Face_PC-AVS.git
    %cd /content/Talking-Face_PC-AVS

    !pip install -r requirements.txt
    !pip install lws
    !pip install face-alignment
    !pip install av
    !pip install torchaudio

    !unzip ./misc/Audio_Source.zip -d ./misc/
    !unzip ./misc/Input.zip -d ./misc/
    !unzip ./misc/Mouth_Source.zip -d ./misc/ 
    !unzip ./misc/Pose_Source.zip -d ./misc/

    !gdown https://drive.google.com/u/0/uc?id=1Zehr3JLIpzdg2S5zZrhIbpYPKF-4gKU_&export=download
    !mkdir checkpoints
    !unzip demo.zip -d ./checkpoints/

execute(install_avatar_dependencies, verbose=VERBOSE)

In [None]:
# @title | PC-AVS | PC-AVS Implementation
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄
%cd /content/Talking-Face_PC-AVS

import os
import sys
import torch
import torchvision
from tqdm import tqdm

sys.path.append('..')

from data import create_dataloader
from models import create_model


torch.manual_seed(0)


class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)


def pc_avs_inference(opt, 
                     path_label, 
                     model, 
                     wav) -> str:
    opt.path_label = path_label
    dataloader = create_dataloader(opt, wav=wav)

    fake_image_driven_pose_as = []

    for data_i in tqdm(dataloader):
        _, fake_image_driven_pose_a = model.forward(
            data_i, mode='inference')

        fake_image_driven_pose_as.append(
            fake_image_driven_pose_a)

    filename = os.path.join(
        dataloader.dataset.get_processed_file_savepath(), 
        "G_Pose_Driven_.mp4")

    video_array = torch.cat(fake_image_driven_pose_as, dim=0)
    video_array = video_array.cpu().transpose(1, 3)
    video_array = video_array * 125.5 + 125.5 
    video_array = video_array.type(torch.uint8)
    video_array = torch.rot90(video_array, -1, [1, 2])

    wav = torch.from_numpy(wav)
    wav = torch.unsqueeze(wav, dim=0)
    
    torchvision.io.write_video(
        filename=filename, 
        video_array=video_array,
        fps=25,
        video_codec="libx264",
        audio_array=wav,
        audio_fps=16000,
        audio_codec="aac"
    )    

    del dataloader
    return filename


def avatar(opt,
           path_label,
           wav) -> str:
    opt.isTrain = False

    model = create_model(opt).cuda()
    model.eval()

    return pc_avs_inference(
        opt, 
        path_label, 
        model, 
        wav)
    

opt = Namespace(
    D_input='single', 
    VGGFace_pretrain_path='', 
    aspect_ratio=1.0, 
    audio_nc=256, 
    augment_target=False, 
    batchSize=16, 
    beta1=0.5, 
    beta2=0.999, 
    checkpoints_dir='./checkpoints', 
    clip_len=1, 
    crop=False, 
    crop_len=16, 
    crop_size=224, 
    data_path='/home/SENSETIME/zhouhang1/Downloads/VoxCeleb2/voxceleb2_train.csv', 
    dataset_mode='voxtest', 
    defined_driven=False, 
    dis_feat_rec=False, 
    display_winsize=224, 
    driven_type='face', 
    driving_pose=True, 
    feature_encoded_dim=2560, 
    feature_fusion='concat', 
    filename_tmpl='{:06}.jpg', 
    fitting_iterations=10, 
    frame_interval=1, 
    frame_rate=25, 
    gan_mode='hinge', 
    gen_video=True, 
    generate_from_audio_only=True, 
    generate_interval=1, 
    gpu_ids=[0], 
    has_mask=False, 
    heatmap_size=3, 
    hop_size=160, 
    how_many=1000000, 
    init_type='xavier', 
    init_variance=0.02, 
    input_id_feature=True, 
    input_path='./checkpoints/results/input_path', 
    isTrain=False, 
    label_mask=False, 
    lambda_D=1, 
    lambda_contrastive=100, 
    lambda_crossmodal=1, 
    lambda_feat=10.0, 
    lambda_image=1.0, 
    lambda_rotate_D=0.1, 
    lambda_softmax=1000000, 
    lambda_vgg=10.0, 
    lambda_vggface=5.0, 
    landmark_align=False, 
    landmark_type='min', 
    list_end=1000000, 
    list_num=0, 
    list_start=0, 
    load_from_opt_file=False, 
    load_landmark=False, 
    lr=0.001, 
    lrw_data_path='/home/SENSETIME/zhouhang1/Downloads/VoxCeleb2/voxceleb2_train.csv', 
    max_dataset_size=9223372036854775807, 
    meta_path_vox='./conversations/feaa8fc7-8fc7-4ecf-acef-f06ca221b493/15/avatar.csv', 
    mode='cpu', 
    model='av', 
    multi_gpu=False, 
    nThreads=4, 
    n_mel_T=4, 
    name='demo', 
    ndf=64, 
    nef=16, 
    netA='resseaudio', 
    netA_sync='ressesync', 
    netD='multiscale', 
    netE='fan', 
    netG='modulate', 
    netV='resnext', 
    ngf=64, 
    no_TTUR=False, 
    no_flip=True, 
    no_ganFeat_loss=False, 
    no_gaussian_landmark=False, 
    no_id_loss=False, 
    no_instance=False, 
    no_pairing_check=False, 
    no_spectrogram=False, 
    no_vgg_loss=False, 
    noise_pose=True, 
    norm_A='spectralinstance', 
    norm_D='spectralinstance', 
    norm_E='spectralinstance', 
    norm_G='spectralinstance', 
    num_bins_per_frame=4, 
    num_classes=5830, 
    num_clips=1, 
    num_frames_per_clip=5, 
    num_inputs=1, 
    onnx=False, 
    optimizer='adam', 
    output_nc=3, 
    phase='test', 
    pose_dim=12, 
    positional_encode=False, 
    preprocess_mode='resize_and_crop', 
    results_dir='./conversations/feaa8fc7-8fc7-4ecf-acef-f06ca221b493/15', 
    save_path='./conversations/feaa8fc7-8fc7-4ecf-acef-f06ca221b493/15', 
    serial_batches=False, 
    start_ind=0, 
    style_dim=2560, 
    style_feature_loss=True, 
    target_crop_len=0, 
    train_dis_pose=False, 
    train_recognition=False, 
    train_sync=False, 
    train_word=False, 
    trainer='audio', 
    use_audio=1, 
    use_audio_id=0, 
    use_transformer=False, 
    verbose=False, 
    vgg_face=False, 
    which_epoch='latest', 
    word_loss=False
)

/content/Talking-Face_PC-AVS


---
---


# ***T-REX*** 🦖💬


---
---

In [None]:
# @title | T-REX | Start new Conversation
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄

#@markdown ---
#@markdown ### Selection between the Skill Versions 👀🔀
SMALL_TALK_SKILL_VERSION = "ai21" #@param ["legacy", "ai21"]
TRAVEL_SKILL_VERSION = "ai21" #@param ["legacy", "ai21"]
EVENT_SKILL_VERSION = "ai21" #@param ["legacy", "ai21"]
TIMETABLE_SKILL_VERSION = "ai21" #@param ["legacy", "ai21"]
RESTAURANT_SKILL_VERSION = "ai21" #@param ["legacy", "ai21"]
TRANSLATION_COMPONENT = "deepl" #@param ["legacy", "deepl"]

"""~~~~~~~~~~~~~~~
NLP Configuration
~~~~~~~~~~~~~~~"""

LANGUAGES = {
    "en": {
        "to_native": lambda x: x,
        "to_source": lambda x: x,
    },
    "de": {
        "to_native": legacy_german_to_english_translation if TRANSLATION_COMPONENT == "legacy" else deepl_german_to_english_translation,
        "to_source": legacy_english_to_german_translation if TRANSLATION_COMPONENT == "legacy" else deepl_english_to_german_translation,
    }
}

SENTIMENT = {
    "positive": ["non-toxic", "travel", "small talk"],
    "negative": ["toxic", "vulgar", "sex", "sexual", "criminal", "suicide", "violence"],
}

PERSONAS = LEGACY_PERSONAS if SMALL_TALK_SKILL_VERSION == "legacy" else AI21_PERSONAS

SKILLS = {
    "travel": {
        "labels": ["travel", "travel on time", "travel delayed"],
        "pipeline": LEGACY_TRAVEL_SKILL if TRAVEL_SKILL_VERSION == "legacy" else AI21_TRAVEL_SKILL
    },
    "event": {
        "labels": ["event", "events"],
        "pipeline": LEGACY_EVENT_SKILL if EVENT_SKILL_VERSION == "legacy" else AI21_EVENT_SKILL
    },
    "timetable": {
        "labels": ["lecture", "professor", "university"],
        "pipeline": LEGACY_TIMETABLE_SKILL if TIMETABLE_SKILL_VERSION == "legacy" else AI21_TIMETABLE_SKILL
    },
    "restaurant": {
        "labels": ["restaurant", "food", "serve food", "eat"],
        "pipeline": LEGACY_RESTAURANT_SKILL if RESTAURANT_SKILL_VERSION == "legacy" else AI21_RESTAURANT_SKILL

    },
    "small talk": {
        "labels": ["small talk", "other"],
        "pipeline": LEGACY_SMALL_TALK_SKILL if SMALL_TALK_SKILL_VERSION == "legacy" else AI21_SMALL_TALK_SKILL
    }
}

CONFIG = {
    "languages": LANGUAGES,
    "sentiment": SENTIMENT,
    "personas": PERSONAS,
    "skills": SKILLS,
}

#@markdown ---
ACTIVATE_LEGACY_PERSONAS = False # @param {type:"boolean"}
PERSONA_1 = "I work in a travel agency" # @param {type:"string"}
PERSONA_1 = f"your persona: {PERSONA_1}"
PERSONA_2 = "My name is Mia" # @param {type:"string"}
PERSONA_2 = f"your persona: {PERSONA_2}"

#@markdown ---
VERBOSE = False # @param {type:"boolean"}


import copy
import uuid
import base64
from IPython.display import HTML


def trex_setup(**kwargs):
    personas = copy.deepcopy(LEGACY_PERSONAS)
    personas["personas"] = [PERSONA_1, PERSONA_2] if ACTIVATE_LEGACY_PERSONAS else []

    config = {
        "languages": LANGUAGES,
        "personas": personas,
        "skills": SKILLS,
        "sentiment": SENTIMENT,
    }

    nlp = NLP(config=config, **kwargs)

    conversation_id = uuid.uuid4()
    conversation_dir = f"./conversations/{conversation_id}"
    !mkdir ./conversations/
    !mkdir {conversation_dir}

    interaction_counter = 0
    f"Current Conversation is logged at: {conversation_dir}"

    !rm -r /content/Talking-Face_PC-AVS/results/id_input_pose_00473_audio_tts_output

    return nlp

nlp = execute(trex_setup, verbose=VERBOSE)

In [None]:
# @title # Interact with T-REX 🦖
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄

#@markdown ---
#@markdown ### Selection between the STT & TTS Module Versions 👀🔀
STT_VERSION = "gcloud" #@param ["legacy", "gcloud"]
TTS_VERSION = "gcloud" #@param ["legacy", "gcloud"]

stt = legacy_stt if STT_VERSION == "legacy" else gcloud_stt
tts = legacy_tts if TTS_VERSION == "legacy" else gcloud_tts

LANGUAGE_CODE = "de-DE" if LANGUAGE == "de" else "en-US"

#@markdown ### Selection Google Cloud TTS Voice 👀🔀
EN_TTS_VOICE_NAME = "en-US-Wavenet-H" #@param ["en-US-Wavenet-A", "en-US-Wavenet-B", "en-US-Wavenet-C", "en-US-Wavenet-D", "en-US-Wavenet-E", "en-US-Wavenet-F", "en-US-Wavenet-G", "en-US-Wavenet-H", "en-US-Wavenet-I", "en-US-Wavenet-J"]
DE_TTS_VOICE_NAME = "de-DE-Wavenet-C" #@param ["de-DE-Wavenet-A", "de-DE-Wavenet-B", "de-DE-Wavenet-C", "de-DE-Wavenet-D", "de-DE-Wavenet-E", "de-DE-Wavenet-F"]

VOICE_NAME = DE_TTS_VOICE_NAME if LANGUAGE == "de" else EN_TTS_VOICE_NAME

#@markdown ---
#@markdown ### Define the  Input
USE_STT_AS_INPUT = True # @param {type:"boolean"}
TEXT_INPUT = "Wie geht es dir?" # @param {type:"string"}

#@markdown ---
VERBOSE = True # @param {type:"boolean"}

PATH_LABELS = "./misc/Input/input 1 ./misc/Pose_Source/00473 158 ./misc/Audio_Source/tts_output.mp3 None 0 None"


def trex(input: str="", **kwargs) -> str:
    print(f"[DEBUG] |T-REX STT| STT Output: {input}")

    output = nlp(input, **kwargs)
    print(f"[DEBUG] |T-REX NLP| NLP Output: {output}")

    audio = tts(
        output, 
        voice_name=VOICE_NAME, 
        verbose=VERBOSE)
    audio = postprocessing(audio)

    video = avatar(
        opt,
        PATH_LABELS,
        audio)

    return video

input = stt(record(), language_code=LANGUAGE_CODE, verbose=VERBOSE) if USE_STT_AS_INPUT else TEXT_INPUT
video = execute(trex, input=input, verbose=VERBOSE)

# Show the final output.
mp4 = open(video,'rb').read()
data_url = "data:video/mp4;base64," + base64.b64encode(mp4).decode()

HTML("""
<video width=700 controls autoplay>
    <source src="%s" type="video/mp4">
</video>
""" % data_url)

---
# ***Test*** TREX 🦖💬
---

## ***Test the NLP Module*** 📰🤯

In [None]:
# @title | NLP | Utils for Testing
# @markdown ✋ Rerun Cell if Runtime was restarted 🔄


from typing import Any
from typing import Tuple
from typing import List

import io
from time import sleep
import random
from google.colab import files
import pandas as pd
import transformers


def get_nlp_models(
    order: List[str] = [
        "ZERO_SHOT",
        "SMALL_TALK", 
        "FEW_SHOT", 
        "TABLE_QA", 
        "SOURCE_TO_NATIVE_TRANSLATOR", 
        "NATIVE_TO_SOURCE_TRANSLATOR"],
    delimiter: str = ",") -> str:
    models = {
        "ZERO_SHOT": ZERO_SHOT_MODEL,
        "SMALL_TALK": SMALL_TALK_MODEL,
        "FEW_SHOT": FEW_SHOT_MODEL,
        "TABLE_QA": TABLE_QA_MODEL,
        "SOURCE_TO_NATIVE_TRANSLATOR": GERMAN_TO_ENGLISH_MODEL,
        "NATIVE_TO_SOURCE_TRANSLATOR": ENGLISH_TO_GERMAN_MODEL,
    }
    return delimiter.join([models[i] for i in order])

def test(dataset: List[Tuple[Any]],
         config: dict = {}) -> dict:
    results = {**locals()}

    predictions = []
    for language, x, y in dataset:
        LANGUAGE = language
        nlp = NLP(config=CONFIG)

        sleep(random.randint(5, 15))

        prediction = ""
        try:
            prediction = nlp(x, **config)
        except NegativeInputCapturedException:
            prediction = "[EXCEPTION] NEGATIVE INPUT DISCARDED"
        except NegativeOutputsCapturedException:
            prediction = "[EXCEPTION] ONLY NEGATIVE OUPUTS"

        predictions.append(prediction)
    
    results["predictions"] = predictions
    results["models"] = get_nlp_models()
    return results

def upload_file(extension: str) -> bytes:
    """Upload files and return the content of the file with the extension."""
    uploaded = files.upload()

    for filename in uploaded.keys():
        if extension in filename:
            return uploaded[filename]
    
    raise Exception("No file with specified extension was found in the uploaded files!\n"\
                    "Check the uploaded files and please retry the procedure!")

def dataframe_from_csv(content: bytes,
                       **kwargs) -> pd.DataFrame:
    """Load the content of a csv file into a DataFrame."""
    return pd.read_csv(io.BytesIO(content))

def dataframe_from_excel(content: bytes,
                         sheet: str) -> pd.DataFrame:
    """Load the content of a excel sheet into a DataFrame."""
    return pd.read_excel(
        io.BytesIO(content),
        sheet_name=sheet)
    
def dataframe_from_type(type: str,
                        content: bytes,
                        sheet: str=None) -> pd.DataFrame:
    """Load the content of a file of the given type into a DataFrame."""
    function = {
        ".xlsx": dataframe_from_excel,
        ".csv": dataframe_from_csv,
    }

    return function[type](content, sheet=sheet)

def preprocess_dataframe(dataframe: pd.DataFrame) -> pd.DataFrame:
    """Convert a given DataFrame into testing format."""
    dataframe = dataframe.astype(
        {column: str for column in dataframe.columns.values})
    
    return dataframe

def parse_dict_from_dataframe(dataframe: pd.DataFrame,
                              headers: list) -> dict:
    """Load a dataset as dict from a DataFrame restricting to the headers."""
    dataframe = dataframe.to_dict()
    datafame = {key: value for key, value in dataframe.items()
                    if key in headers}
    return dataframe

def dataset_from_dict(test: dict,
                      header_l: str="l",
                      header_x: str="x",
                      header_y: str="y") -> List[Tuple[str]]:
    """Create a list of (x, y) tuples to execute a given test."""
    l, x, y = test[header_l], test[header_x], test[header_y]
    
    l = [l[i] for i in l.keys()]
    x = [x[i] for i in x.keys()]
    y = [y[i] for i in y.keys()]

    dataset = [(i, j, k) for i, j, k in zip(l, x, y)]
    return dataset

def save_test_results(results: dict, 
                      dictionary: dict,
                      filename: str="test.xlsx") -> str:
    """Save the results of the test as excel file and return the filename."""
    dictionary["Output"] = {}

    for i, prediction in enumerate(results["predictions"]):
        dictionary["Output"][i] = prediction

    dataframe = pd.DataFrame.from_dict(dictionary)
    dataframe.to_excel(filename)

    return filename

In [None]:
# @title | NLP | Prepare Testing Data 📋 🆒
# @markdown ---

# @markdown Select a file type. The first file with the given extension will be loaded. 📗
EXTENSION = ".xlsx" #@param [".xlsx", ".csv"]

# @markdown Select a sheet if the file is an excel file. 📜	
SHEET = "Tabelle3" #@param {type: "string"}

# @markdown Specify headers in the sheet that shall be included in the dataset. 📋
DATASET_HEADERS = "Language, Input, Erwartetes Ergebnis" #@param {type: "string"}
DATASET_HEADERS = DATASET_HEADERS.split(", ")

assert len(DATASET_HEADERS) == 3, "Warning! There can only be two dataset headers! Please refactor and retry!"
DATASET_HEADER_L, DATASET_HEADER_X, DATASET_HEADER_Y = DATASET_HEADERS

# @markdown ✨ Note: There must be three headers in the following order "Language, Input, Expected Ouput".

# @markdown ✨ Note: The headers must be concatenated via the string ", ".


uploaded_content = upload_file(EXTENSION)
dataframe = dataframe_from_type(
    EXTENSION, 
    uploaded_content, 
    SHEET)
dataframe = preprocess_dataframe(dataframe)

dictionary = parse_dict_from_dataframe(
    dataframe,
    DATASET_HEADERS)

dataset = dataset_from_dict(
    dictionary,
    DATASET_HEADER_L,
    DATASET_HEADER_X,
    DATASET_HEADER_Y)

Saving Test_Tabelle_3_-_Kopie.xlsx to Test_Tabelle_3_-_Kopie.xlsx


In [None]:
# @title | NLP | Display the Dataset 📋🎉

%load_ext google.colab.data_table
from google.colab import data_table

data_table.DataTable(
    pd.DataFrame(
        dataset_from_dict(
            dictionary,
            DATASET_HEADER_L,
            DATASET_HEADER_X,
            DATASET_HEADER_Y)))

Unnamed: 0,0,1,2
0,en,What is 9 plus 12,21
1,de,"Es ist 14:00, wie komme ich am schnellsten von...",Mit dem ICE 1223 um 14:07 von Gleis 9
2,de,Ich stehe in Nürnberg und es ist 14:30. Wann k...,Der nächste ICE nach München kommt um 14:55 au...
3,de,Wie lange dauert die Fahrt des ICE 806 von Ber...,Die Fahrt dauert 1 Stunde und 42 Minuten
4,de,"Hallo, wie geht es dir?",Mir geht es gut. Wie geht es dir?
5,de,Hast du irgendwelche Hobbies?,Meine Hobbies sind XYZ
6,de,Was ist deine Lieblingsfarbe?,Meine Lieblingsfarbe ist XY
7,de,Magst du Musik?,"Ja ich mag Musik/ nein, ich mag keine Musik"
8,de,Wann findet das Musical Ausgetickt statt?,Das Musical Ausgetickt findet am 26.09. statt
9,de,Wo findet das Musical Ausgetickt statt?,Das Musical Ausgetickt findet in der KSV Halle...


In [None]:
# @title | NLP | Test Component 👻
# @markdown ---

import uuid

from datetime import datetime as dt
from datetime import timedelta as td
from datetime import timezone as tz


FILENAME = f"Test_NLP_%u_%t.xlsx" #@param {type: "string"}
# @markdown ⚡ Select if a UUID shall be substituted for the **%u** string.
UUID = True # @param {type:"boolean"}
# @markdown ✨ Note: **%t** in the filename will the replaced with the current timestamp.

#@markdown ---
VERBOSE = True # @param {type:"boolean"}


uuid_ = f"_{str(uuid.uuid4())}_" if UUID else ""
filename = FILENAME.replace("_%u_", uuid_)

timestamp = dt.now() + td(hours=2)
timestamp = f"{timestamp:%Y%m%d%H%M}"
filename = filename.replace("%t", timestamp)

tmp_language = LANGUAGE

results = test(
    dataset,
    {"verbose": VERBOSE})

LANGUAGE = tmp_language

print(get_nlp_models())

save_test_results(
    results,
    dictionary,
    filename)

[DEBUG] |AI21 Warm Up| personas: Conversation id: 3cda52e4-ea7e-477a-b09a-b8dbf8cc8c54 

[DEBUG] |NLP __call__ <START>|~~~~~~~~~~~~~~~~~~~~
[DEBUG] |NLP ATTR skills|: {'travel': {'labels': ['travel', 'travel on time', 'travel delayed'], 'pipeline': {'associations': {'travel': {'model': 'j1-jumbo', 'data': <function travel_table at 0x7feefb647cb0>, 'samples': "Q: It is 12:30. Wich is the next train from Frankfurt to Leipzig?\nA: The next train to Leipzig is ICE 1655 at 17:21 from track 9.\n\nQ: It is 17:22. Wich is the next train from Frankfurt to Leipzig?\nA: The next train to Leipzig is ICE 594 at 18:14 from track 9.\n\nQ: When will the ICE 594 from Frankfurt arrive in Leipzig.\nA: The ICE 594 will arrive at 21:10 on the track 13.\n\nQ: How long will the ICE 1655 need to get from Frankfurt to Leipzig?\nA: The ICE 1655 will need 3 hours and 3 minutes.\n\nQ: At wich track will the FLX 1354 from Berlin arrive?\nA: The FLX 1354 from Berlin will arrive at the track 5 at 10:07.\n\nQ: Which 

'Test_NLP_294f224b-a2e1-4a65-8e84-1b87e494e1c3_202109071013.xlsx'

In [None]:
from google.colab import files
files.download(filename)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
nlp = NLP(config=CONFIG)
nlp("how are you?")

[DEBUG] |AI21 Warm Up| personas: Conversation id: 03fd9958-df53-4a3a-a2fd-34b358dd8257 



"Mir geht's gut, danke! Wie geht es Ihnen? Wie kann ich Ihnen helfen?"

In [None]:
nlp()