In [28]:
import random
from enum import Enum
from dataclasses import dataclass, fields

In [26]:
import torch
import numpy as np
from transformers import BertModel, BertTokenizer
from scipy.spatial.distance import cosine

In [3]:
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)



### Defining the State class of the AI agent

In [9]:
class Mood(Enum):
    HAPPY = "happy"
    SAD = "sad"
    ANGRY = "angry"
    EXCITED = "excited"
    CALM = "calm"


class Background(Enum):
    OFFICE = "office"
    BEACH = "beach"
    SPACE = "space"
    CITYSCAPE = "cityscape"
    STUDIO = "studio"


class Light(Enum):
    BRIGHT = "bright"
    DIM = "dim"
    NATURAL = "natural"
    COLORFUL = "colorful"
    SOFT = "soft"


class Scene(Enum):
    MAIN = "main"
    BE_RIGHT_BACK = "be_right_back"
    ENDING = "ending"
    INTERVIEW = "interview"
    GAMEPLAY = "gameplay"


class Poll(Enum):
    YES_NO = "yes_no"
    MULTIPLE_CHOICE = "multiple_choice"
    RATING_SCALE = "rating_scale"
    FEEDBACK = "feedback"
    CUSTOM = "custom"


class Text(Enum):
    ALERT = "alert"
    NEWS_UPDATE = "news_update"
    STREAM_TITLE = "stream_title"
    VIEWER_COUNT = "viewer_count"
    SPONSOR_MESSAGE = "sponsor_message"


@dataclass
class AgentState:
    uptime: float
    mood: Mood
    background: Background
    light: Light
    scene: Scene
    poll: Poll
    text: Text


@dataclass
class AudienceState:
    viewers: int
    subscribers: int
    comments: int
    likes: int
    sensitivity: float


@dataclass
class State:
    agent: AgentState
    audience: AudienceState

In [30]:
def generate_random_state():
    return State(
        agent=AgentState(
            uptime=random.uniform(0, 24),  # Uptime in hours
            mood=random.choice(list(Mood)),
            background=random.choice(list(Background)),
            light=random.choice(list(Light)),
            scene=random.choice(list(Scene)),
            poll=random.choice(list(Poll)),
            text=random.choice(list(Text))
        ),
        audience=AudienceState(
            viewers=random.randint(0, 10000),
            subscribers=random.randint(0, 5000),
            comments=random.randint(0, 1000),
            likes=random.randint(0, 1000),
            sensitivity=random.uniform(0.0, 1.0)
        )
    )

In [31]:
import random
from dataclasses import dataclass, fields
from enum import Enum


def mutate_field(value):
    if isinstance(value, Enum):
        possible_values = list(type(value))
        possible_values.remove(value)  # Remove current value to ensure mutation
        return random.choice(possible_values)
    elif isinstance(value, int):
        return random.randint(0, 10000)  # or some appropriate range for integers
    elif isinstance(value, float):
        return random.uniform(0.0, 24.0)  # or some appropriate range for floats
    return value  # Fallback if type is unrecognized


def mutate_state(state: State) -> State:
    # Choose either the agent or the audience to mutate
    component_to_mutate = random.choice(['agent', 'audience'])
    component = getattr(state, component_to_mutate)

    # Get all fields of the component
    comp_fields = fields(component)
    field_to_mutate = random.choice(comp_fields)

    # Get the current value of the field
    current_value = getattr(component, field_to_mutate.name)

    # Mutate the field
    new_value = mutate_field(current_value)

    # Create a copy of the component with the mutated value
    new_component = {field.name: getattr(component, field.name) for field in comp_fields}
    new_component[field_to_mutate.name] = new_value  # Set new value to the selected field

    # Create a new state with the mutated component
    if component_to_mutate == 'agent':
        return State(agent=AgentState(**new_component), audience=state.audience)
    else:
        return State(agent=state.agent, audience=AudienceState(**new_component))

In [32]:
initial_state = generate_random_state()
print(initial_state)

State(agent=AgentState(uptime=12.9255110114144, mood=<Mood.EXCITED: 'excited'>, background=<Background.OFFICE: 'office'>, light=<Light.SOFT: 'soft'>, scene=<Scene.BE_RIGHT_BACK: 'be_right_back'>, poll=<Poll.YES_NO: 'yes_no'>, text=<Text.NEWS_UPDATE: 'news_update'>), audience=AudienceState(viewers=7481, subscribers=3253, comments=436, likes=397, sensitivity=0.5024523800118395))


In [34]:
next_state = mutate_state(initial_state)
print(next_state)

State(agent=AgentState(uptime=12.9255110114144, mood=<Mood.EXCITED: 'excited'>, background=<Background.OFFICE: 'office'>, light=<Light.SOFT: 'soft'>, scene=<Scene.BE_RIGHT_BACK: 'be_right_back'>, poll=<Poll.YES_NO: 'yes_no'>, text=<Text.NEWS_UPDATE: 'news_update'>), audience=AudienceState(viewers=7481, subscribers=9513, comments=436, likes=397, sensitivity=0.5024523800118395))


### Extracting features from State objects so that they can be used during the training process

In [17]:
def enum_to_embedding(value: str) -> torch.Tensor:
    normalized_value = value.lower().replace("_", " ")
    tokens = tokenizer(normalized_value, padding=True, return_tensors="pt")
    return model(**tokens).last_hidden_state[:, 1, :].squeeze(0)

def extract_features(state: State) -> dict:
    return {
        "agent.uptime": state.agent.uptime,
        "agent.mood": enum_to_embedding(state.agent.mood.value),
        "agent.background": enum_to_embedding(state.agent.background.value),
        "agent.light": enum_to_embedding(state.agent.light.value),
        "agent.scene": enum_to_embedding(state.agent.scene.value),
        "agent.poll": enum_to_embedding(state.agent.poll.value),
        "audience.viewers": state.audience.viewers,
        "audience.subscribers": state.audience.subscribers,
        "audience.comments": state.audience.comments,
        "audience.likes": state.audience.likes,
        "audience.sensitivity": state.audience.sensitivity,
    }

In [18]:
features = extract_features(initial_state)
print(features)

{'agent.uptime': 23.5098624720669, 'agent.mood': tensor([-9.0296e-01,  2.5980e-01,  3.4059e-01, -4.3179e-01, -3.2223e-01,
         2.6134e-01, -2.7005e-01,  2.2414e-01, -4.6964e-01, -4.9043e-01,
        -4.1005e-01,  2.0386e-02,  4.4319e-01,  5.0670e-01, -2.0872e-01,
        -2.0963e-01,  3.4370e-01,  5.8429e-02,  3.5448e-01,  4.8145e-01,
        -1.9271e-01, -3.8362e-01, -2.3499e-01, -7.8184e-01, -2.2857e-04,
         3.5514e-01, -1.1904e-01,  7.0581e-02, -1.0334e-01,  1.3219e-01,
        -1.0401e-04, -9.1336e-02,  3.5596e-01,  1.3966e-01, -8.6357e-01,
        -6.7688e-01, -1.8255e-01,  7.1482e-02, -9.9470e-01,  1.9295e-01,
        -1.9316e-02, -7.2796e-01,  2.2535e-01, -4.4662e-01,  3.8422e-02,
         3.2155e-01,  9.6588e-01,  1.6033e-01,  1.8392e-01, -6.4159e-01,
        -9.2196e-02,  6.6242e-01,  6.8340e-01,  8.4368e-02, -3.3759e-01,
        -1.8833e-01,  3.8825e-01,  2.1595e-01, -2.9018e-01, -3.8885e-01,
         3.1222e-01, -7.2961e-01,  4.3923e-01, -2.1921e-02,  4.4889e-02,
  