## Initialize

Loadin model and tokenizer

In [2]:
#!pip install torch==2.2.2+cu121 torchvision==0.17.2+cu121 torchaudio==2.2.2+cu121 -f https://pypi.tuna.tsinghua.edu.cn/simple -f https://download.pytorch.org/whl/torch_stable.html

In [1]:
from transformers import AutoTokenizer, T5ForConditionalGeneration

# Setup model path.
model_path = "models--nicolay-r--flan-t5-emotion-cause-thor-base"
# Setup device.
device = "cuda:0"

model = T5ForConditionalGeneration.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model.to(device)

  from .autonotebook import tqdm as notebook_tqdm


T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=768, out_features=2048, bias=False)
              (wi_1): Linear(in_features=768, out_features=2048, bias=False)
              (wo):

Setup `ask` method for generating LLM responses.

In [2]:
def ask(prompt):
  inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
  inputs.to(device)
  output = model.generate(**inputs, max_length=320, temperature=1)
  return tokenizer.batch_decode(output, skip_special_tokens=True)[0]

Setup Chain-of-Thought

In [3]:
def emotion_extraction_chain(context, target):
  # Setup labels.
  labels_list = ["anger", "disgust", "fear", "joy", "sadness", "surprise", "neutral"]
  # Setup Chain-of-Thought
  step1 = f"Given the conversation {context}, which text spans are possibly causes emotion on {target}?"
  span = ask(step1)
  step2 = f"{step1}. The mentioned text spans are about {span}. Based on the common sense, what " + f"is the implicit opinion towards the mentioned text spans that causes emotion on {target}, and why?"
  opinion = ask(step2)
  step3 = f"{step2}. The opinion towards the text spans that causes emotion on {target} is {opinion}. " + f"Based on such opinion, what is the emotion state of {target}?"
  emotion_state = ask(step3)
  step4 = f"{step3}. The emotion state is {emotion_state}. Based on these contexts, summarize and return the emotion cause only." + "Choose from: {}.".format(", ".join(labels_list))
  # Return the final response.
  return ask(step4)

## Usage

In [4]:
# setup history context (conv_turn_1)
conv_turn_1 = "John: ohh you made up!"
# setup utterance.
conv_turn_2 = "Jake: yaeh, I could not be mad at him for too long!"
context = conv_turn_1 + conv_turn_2
source = conv_turn_1
target = conv_turn_2
# infer emotion states
source_emotion_state = emotion_extraction_chain(context, source)
target_emotion_state = emotion_extraction_chain(context, target)
# output response
print(f"Emotion state of the speaker of `{source}` is: {source_emotion_state}")
print(f"Emotion state of the speaker of `{target}` is: {target_emotion_state}")

Emotion state of the speaker of `John: ohh you made up!` is: surprise
Emotion state of the speaker of `Jake: yaeh, I could not be mad at him for too long!` is: anger
