In [None]:
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM, TrainingArguments, \
                         Trainer, DataCollatorForSeq2Seq
from transformers.keras_callbacks import KerasMetricCallback
import tensorflow as tf
import pandas as pd
import numpy as np
import datasets
import torch

In [None]:
# Avoid out of memory errors
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
gpus

# Process data

In [None]:
checkpoint = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
query = []
response = []

with open("D:/Users/Natha/Datasets/MyJarvisConversation/conversation.txt", "r") as f:
    for line in f.readlines():
        if line[0] == "U":
            query.append(line[6:].split("\n")[0])
        elif line[0] == "J":
            response.append(line[8:].split("\n")[0])
        else:
            pass

In [None]:
data =  {"query": query,
         "response": response}

In [None]:
dataset = datasets.Dataset.from_pandas(pd.DataFrame(data=data))
dataset = datasets.Dataset.from_pandas(pd.DataFrame(data={"conversation": dataset}))
dataset = dataset.train_test_split(test_size=0.2)

In [None]:
source_lang = "query"
target_lang = "response"

In [None]:
def preprocess_function(examples):
    inputs = [example[source_lang] for example in examples["conversation"]]
    targets = [example[target_lang] for example in examples["conversation"]]
    model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True)
    return model_inputs

In [None]:
tokenized_ds = dataset.map(preprocess_function, batched=True)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint, return_tensors="tf")

In [None]:
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [None]:
optimizer = AdamWeightDecay(learning_rate=2e-5, weight_decay_rate=0.01)

In [None]:
model = TFAutoModelForSeq2SeqLM.from_pretrained(checkpoint, from_pt=True)

In [None]:
tf_train_set = model.prepare_tf_dataset(
    tokenized_ds["train"],
    shuffle=True,
    batch_size=16,
    collate_fn=data_collator,
)

tf_val_set = model.prepare_tf_dataset(
    tokenized_ds["test"],
    shuffle=False,
    batch_size=16,
    collate_fn=data_collator,
)

In [None]:
metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_val_set)

In [None]:
model.compile(optimizer=optimizer)

In [None]:
model.fit(x=tf_train_set, epochs=30, validation_data=tf_val_set)

In [None]:
while True:
    input_text = input("Enter text: ")
    tokenized  = tokenizer(input_text, return_tensors="tf").input_ids
    prediction = model.generate(tokenized, max_new_tokens=40, do_sample=True, top_k=30, top_p=0.95)
    output_text = tokenizer.decode(prediction[0], skip_special_tokens=True)
    
    response = output_text.lower()
        # Get the date
    if "/udate" in response:
        response = response.replace("/udate", datetime.date.today().strftime("%B %d, %Y"))
    # Get the time
    if "/utime" in response:
        response = response.replace("/utime", datetime.datetime.now().strftime("%I:%M:%S"))
    # Get the temperature
    if "/utemp" in response:
        response = response.replace("/utemp", weather("boston", "tm"))
    # Get the humidity
    if "/uhumidity" in response:
        response = response.replace("/uhumidity", weather("boston", "hm"))
    # Get the wind speed
    if "/uwind" in response:
        response = response.replace("/uwind", weather("boston", "ws"))
    # Get the amount of precipitation
    if "/uprecipitation" in response:
        response = response.replace("/uprecipitation", weather("boston", "pp"))
    if "/uvolume" in response:
        after = response.split("/uvolume")[-1]
        vol = after.split("'")[1]
        #response = response.replace("/uvolume"+"'"+vol+"'", "")
        volume(vol)
    if "/usleep" in response:
        response = response.replace("/usleep", "")
        print(response)
        break
    if "/unewtab" in response:
        response = response.replace("/unewtab", "")
        pyautogui.hotkey('ctrl', 't')
    if "/uclosetab" in response:
        response = response.replace("/uclosetab", "")
        pyautogui.hotkey('ctrl', 'w')
    if "/uswitchtab" in response:
        after = response.split("/uswitchtab")[-1]
        new = after.split("'")[1]
        response = response.replace("/uswitchtab"+"'"+new+"'", "")
        pyautogui.hotkey('ctrl', str(word2num(new)))
    
    print(output_text)