In [None]:
import networkx as nx
from sentence_transformers import util
from nltk import word_tokenize
from transformers import BlenderbotForConditionalGeneration, BlenderbotTokenizer
from transformers import LlamaForCausalLM, LlamaTokenizer
import numpy as np
from ConceptNet import ConceptNet
from torch_geometric.nn import GATConv
import torch
from torch import nn
import torch.nn.functional as F
import random
from nltk import stem
from nltk.corpus import wordnet
import re
from transformers import AutoModelForSequenceClassification, TFAutoModelForSequenceClassification
from transformers import AutoTokenizer
import numpy as np
from scipy.special import expit
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
conceptnet = ConceptNet('Data/', numberbatch=False)
glove_embedding = nn.Embedding.from_pretrained(conceptnet.concept_embedding.vectors).cuda()

In [None]:
class GAT(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.conv1 = GATConv(input_dim, hidden_dim)
        self.relu1 = nn.LeakyReLU()
        self.conv2 = GATConv(hidden_dim, output_dim)
        self.relu2 = nn.LeakyReLU()

    def forward(self, x, edge_index):
        y = self.conv1(x, edge_index)
        y = self.relu1(y)
        y = self.conv2(y, edge_index)
        y = self.relu2(y)
        return y

In [None]:
class KeywordPredictor(nn.Module):
    def __init__(self,
                 linear_input_dim=1112,
                 GCN_hidden_dim=512,
                 GRU_hidden_dim=256,
                 embedding_dim=300,
                 cross_entropy_weight=[1],
                 device='cuda:0',
                 balanced_loss=False,
                 GCN_layer='GAT',
                 num_relations=-1):
        super().__init__()

        self.device = device
        self.input_dim = linear_input_dim
        self.fc = nn.Sequential()
        self.fc.add_module('linear1', nn.Linear(self.input_dim, 512))
        self.fc.add_module('leaky_relu1', nn.ReLU())
        self.fc.add_module('linear2', nn.Linear(512, 128))
        self.fc.add_module('leaky_relu2', nn.ReLU())
        self.fc.add_module('linear3', nn.Linear(128, 1))
        self.fc = self.fc.to(self.device)

        if not balanced_loss:
            self.loss_func = nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor(cross_entropy_weight).to(device))
        else:
            self.loss_func = FocalLoss(ignore_index=-1)

        self.GCN_input = embedding_dim
        self.GCN_output = embedding_dim
        self.GCN_hidden = GCN_hidden_dim
        if GCN_layer == 'GCN':
            self.GCN = GCN(self.GCN_input, self.GCN_hidden, self.GCN_output).to(device)
        elif GCN_layer == 'GraphSAGE':
            self.GCN = GraphSAGE(self.GCN_input, self.GCN_hidden, self.GCN_output).to(device)
        elif GCN_layer == 'GAT':
            self.GCN = GAT(self.GCN_input, self.GCN_hidden, self.GCN_output).to(device)
        elif GCN_layer == 'TransEGCN':
            self.relation_embedding = nn.Embedding(num_relations ,1113).cuda().requires_grad_(False)
            self.GCN = KETransEGCN(self.GCN_input, self.GCN_hidden, self.GCN_output).to(device)
        else:
            self.GCN = None

        self.GRU_input_dim = embedding_dim
        self.GRU_hidden_dim = GRU_hidden_dim
        self.GRU_layer_size = 1
        self.GRU = nn.GRU(self.GRU_input_dim,
                          self.GRU_hidden_dim,
                          self.GRU_layer_size,
                          bidirectional=True).to(self.device)

        self.glove_embedding = nn.Embedding.from_pretrained(conceptnet.concept_embedding.vectors).to(device).requires_grad_(False)

    def forward(self, state):
        x = self.fc(state)
        return x

    def predict(self, state):
        with torch.no_grad():
            y = self.forward(state)
        return y

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform(m.weight)

In [None]:
import copy


def candidate_nodes(graph, start_concepts, hops=2):
    q = copy.deepcopy(start_concepts)
    result = []
    for i in range(hops):
        temp = []
        while len(q) != 0:
            head = q.pop(0)
            if not graph.has_node(head):
                continue
            if head not in result:
                result.append(head)
            adj = list(graph.neighbors(head))
            temp.extend(adj)
        q = list(set(temp))
    return result

In [None]:
from torch import optim
from torch_geometric.utils import from_networkx
from nltk import word_tokenize

In [None]:
class QNet(nn.Module):
    def __init__(self):
        super(QNet, self).__init__()

        self.conv1 = GATConv(1112, 1).cuda()

    def forward(self, x, egde_index):
        x = self.conv1(x, egde_index)
        x = x.mean(dim=0)
        return x

In [None]:
class AdvantageNet(nn.Module):
    def __init__(self):
        super(AdvantageNet, self).__init__()

        self.fc = nn.Sequential()
        self.fc.add_module('linear1', nn.Linear(1112, 512))
        self.fc.add_module('relu1', nn.LeakyReLU())
        self.fc.add_module('linear2', nn.Linear(512, 128))
        self.fc.add_module('relu2', nn.LeakyReLU())
        self.fc.add_module('linear3', nn.Linear(128, 1))
        self.fc = self.fc.cuda()

    def forward(self, x):
        x = self.fc(x)
        return x

In [None]:
class ContextEncoder(nn.Module):
    def __init__(self,
                 GRU_input_dim=300,
                 GRU_hidden_dim=256):
        super(ContextEncoder, self).__init__()

        self.GRU_input_dim = GRU_input_dim
        self.GRU_hidden_dim = GRU_hidden_dim
        self.GRU_layer_size = 1
        self.GRU = nn.GRU(
            self.GRU_input_dim,
            self.GRU_hidden_dim,
            self.GRU_layer_size,
            bidirectional=True
        ).cuda()

    def forward(self, x):
        _, _hn = self.GRU(x)
        _hn = _hn.reshape(-1)
        return _hn

In [None]:
class GraphEncoder(nn.Module):
    def __init__(self,
                GCN_input_dim,
                GCN_hidden_dim,
                GCN_output_dim
                ):
        super(GraphEncoder, self).__init__()

        # self.GCN_input = embedding_dim
        self.GCN_input = GCN_input_dim
        self.GCN_hidden = GCN_hidden_dim
        self.GCN_output = GCN_output_dim
        # self.GCN = GCN(self.GCN_input, self.GCN_hidden, 300).requires_grad_(True).to(device)
        self.GCN = GAT(
            self.GCN_input,
            self.GCN_hidden,
            self.GCN_output
        ).cuda()

    def forward(self, x, edge_index):
        x = self.GCN(x, edge_index)
        return x

In [None]:
class SharedLayer(nn.Module):
    def __init__(self,
                 GCN_input_dim=300,
                 GCN_hidden_dim=512,
                 GCN_output_dim=300
                 ):
        super(SharedLayer, self).__init__()

        self.graph_encoder = GraphEncoder(GCN_input_dim, GCN_hidden_dim, GCN_output_dim)
        self.context_encoder = ContextEncoder()

    def forward(self, global_graph, target_graph, context_ids):
        _G_global = self.graph_encoder(glove_embedding(global_graph.x), global_graph.edge_index)
        _G_target = self.graph_encoder(glove_embedding(target_graph.x), target_graph.edge_index)[0].repeat(global_graph.num_nodes, 1)
        
        _context_encoded = self.context_encoder(glove_embedding(context_ids))
        _context_encoded = _context_encoded.reshape(-1).repeat(global_graph.num_nodes, 1)
        # Graph Convolution
        
        x = torch.cat((_G_global, _G_target, _context_encoded), 1)
        return x

In [None]:
class D3QN(nn.Module):
    def __init__(self):
        super(D3QN, self).__init__()

        self.shared_layer = SharedLayer()
        self.q = QNet()
        self.adv = AdvantageNet()

    def forward(self, global_graph, target_graph, context_ids):
        x = self.shared_layer(global_graph, target_graph, context_ids)
        q, advantage = self.q(x, global_graph.edge_index), self.adv(x)
        q_vals = q + (advantage - advantage.mean(dim=0, keepdim=True))
        return q_vals
    
    def predict(self, global_graph, target_graph, context_ids):
        with torch.no_grad():
            q_vals = self.forward(global_graph, target_graph, context_ids)
        return q_vals

In [None]:
class ReplayBuffer:
    def __init__(self, buffer_size):
        self.buffer_size = buffer_size
        self.buffer = []

    def add(self, experience):
        if len(self.buffer) >= self.buffer_size:
            self.buffer.pop(0)
        self.buffer.append(experience)

    def sample(self, batch_size):
        sampled_batch = random.sample(self.buffer, batch_size)
        return sampled_batch

In [None]:
class Simulator:
    def __init__(self, d3qn, glove_embedding, conceptnet, gamma=0.99, clip_ratio=0.2, device='cuda:0'):
        self.labels = None
        self.context_ids = None

        self.G_global = None

        self.episode_blacklist = []
        self.blacklist_mask = []
        self.context = []
        self.user_bot = []
        
        self.chatbot_model_path = '../BlenderbotTraining/Chatbot/Model/'
        self.chatbot_model_name = 'FullDatasetSuffixMultiK'
        self.chatbot_model = BlenderbotForConditionalGeneration.from_pretrained(self.chatbot_model_path+'model/'+self.chatbot_model_name).cuda()
        self.chatbot_tokenizer = BlenderbotTokenizer.from_pretrained(self.chatbot_model_path+'tokenizer/'+self.chatbot_model_name)
        
        model_name = 'facebook/blenderbot-400M-distill'
        self.user_model = BlenderbotForConditionalGeneration.from_pretrained(model_name).cuda()
        self.user_tokenizer = BlenderbotTokenizer.from_pretrained(model_name)
    
        topic_model = f"cardiffnlp/tweet-topic-21-multi"
        self.topic_tokenizer = AutoTokenizer.from_pretrained(topic_model)
        self.topic_model = AutoModelForSequenceClassification.from_pretrained(topic_model).cuda()
        self.topic_words = ['arts', 'business', 'celebrity', 'diaries', 'family', 'fashion', 'film', 'fitness', 'food', 'gaming', 'education', 'music', 
               'news', 'hobbies', 'relationships', 'science', 'sports', 'travel', 'student']
        self.target_topic = None
        self.topic_dist = None
        
        self.target_encoded = None

        self.conceptnet = conceptnet
        self.previous_dist = 10
        self.target = None

        self.global_graph = None
        self.node_list = None
        self.done = False

        self.device = device

        self.glove_embedding = glove_embedding
        self.d3qn = d3qn
        
    def reset(self, initial_input, topic_idx):
        global_target = self.topic_words[topic_idx]
        self.target_topic = topic_idx
        self.context = [initial_input]
        self.previous_dist = 10
        self.target = global_target
        self.target_encoded = self.conceptnet.concept_embedding[global_target]

        self.done = False
        initial_concepts = concept_extractor(initial_input)

        if len(initial_concepts) == 0:
            initial_concepts = word_tokenize(initial_input)
            tokenized_initial_concepts = []
            for concept in initial_concepts:
                if concept in self.conceptnet.concept_embedding.stoi and concept.isalpha():
                    tokenized_initial_concepts.append(concept)
            initial_concepts = tokenized_initial_concepts
        
        global_graph = self.conceptnet.bidirectional_reasoning(initial_concepts, global_target)
        for start_concept in initial_concepts:
            global_graph = nx.compose(global_graph, self.conceptnet.bidirectional_reasoning([global_target], start_concept, K=10, hops=3))

        
        self.global_graph = global_graph
        self.G_global = from_networkx(self.global_graph).cuda()
        self.node_list = list(global_graph.nodes)
        self.conceptnet_subgraph = self.conceptnet.conceptnet.subgraph(self.node_list)
            
        target_neighbours = list(global_graph.neighbors(self.target))
        target_graph = nx.Graph()
        target_graph.add_node(self.target, x=conceptnet.concept_embedding.stoi[self.target])
        for target_neighbour in target_neighbours:
            target_graph.add_node(target_neighbour, x=conceptnet.concept_embedding.stoi[target_neighbour])
            target_graph.add_edge(self.target, target_neighbour)
        self.target_graph = target_graph   
        self.G_target = from_networkx(self.target_graph).cuda()
        
        self.blacklist_mask = []
        self.episode_blacklist = []
        for node in self.node_list:
            if node in blacklist:
                self.blacklist_mask.append(0)
                self.episode_blacklist.append(node)
            else:
                self.blacklist_mask.append(1)
        self.blacklist_mask = torch.tensor(self.blacklist_mask)

        g, t, c, m = self.get_state()

        return (g, t, c, m)

    def generate_response(self, action):
        keywords = [self.node_list[a] for a in action]
        print("[Keywords] ", keywords)
        
        concat_tokens =  '<s>' + '</s><s>'.join(self.context[-3:]) + '<keyword> ' + '</keyword><keyword>'.join(keywords)+'</keyword>'
        input_tokenized = self.chatbot_tokenizer(concat_tokens, padding='max_length', truncation=True,
                                      max_length=128, return_tensors='pt')
        # Generate a response
        response = self.chatbot_model.generate(input_ids=input_tokenized.input_ids.cuda(),
                                              attention_mask=input_tokenized.attention_mask.cuda())
        # Decode and print the response
        response_text = self.chatbot_tokenizer.decode(response[0], skip_special_tokens=True)
        return response_text
    
    
    def greedy_action(self, state):
        action_masked = (state[3] == 0).reshape(-1).float()
        action_black_list_mixed = action_masked * self.blacklist_mask.cuda()
        if action_black_list_mixed[action_black_list_mixed == 1].size(0) != 0:
            action_masked = action_black_list_mixed
        valid_elems = [(idx, elem) for idx, elem in enumerate(self.node_list) if action_masked[idx] == 1]
        
        l = 10
        greedy_action = None
        for idx, elem in valid_elems:
            _l = nx.shortest_path_length(self.global_graph, elem, self.target)
            if _l < l:
                l = _l
                greedy_action = idx
        return greedy_action
                
    
    def step(self, action):
        # generate chatbot response and add response to context list
        sentence = self.generate_response(action)
        self.context.append(sentence)
        concat_context = ' <s> ' + ' </s> <s> '.join(self.context[-3:]) + ' </s> '

        input_tokenized = self.user_tokenizer(concat_context, padding='max_length', truncation=True,
                                      max_length=128, return_tensors='pt')
        
        user_response_tokens = self.user_model.generate(input_ids=input_tokenized.input_ids.cuda(),
                                                          attention_mask=input_tokenized.attention_mask.cuda())
        user_response_text = self.user_tokenizer.decode(user_response_tokens[0], skip_special_tokens=True)

        user_concepts = word_tokenize(user_response_text)

        self.context.append(user_response_text)
        reward = self.reward(user_response_text)
        
        done = self.is_complete(user_concepts)
        self.done = done
        

        g, t, c, m = self.get_state()

        return (g, t, c, m), reward, done

    def get_state(self):
        bridge_concepts = list(set(word_tokenize(self.context[-1])) & set(list(self.global_graph.nodes)))

        if len(bridge_concepts) == 0:
            bridge_concepts = word_tokenize(self.context[-2]) if len(self.context) >= 2 else word_tokenize(self.context[0])
            tokenized_concepts = []
            for concept in bridge_concepts:
                if concept in self.conceptnet.concept_embedding.stoi and concept.isalpha():
                    tokenized_concepts.append(concept)
            bridge_concepts = list(set(tokenized_concepts) & set(list(self.global_graph.nodes)))

        candidates = set(candidate_nodes(self.global_graph, bridge_concepts, 2))
        global_graph_nodes = self.global_graph.nodes
        global_node_mapping = dict(zip(list(global_graph_nodes), range(len(global_graph_nodes))))
        candidate_indices = np.array([global_node_mapping[node] for node in candidates])
        
        candidate_mask = torch.ones(len(global_graph_nodes), dtype=int) * 0
        candidate_mask[candidate_indices.tolist()] = 1
        
        edge_index = self.G_global.edge_index
        G_global_x = self.G_global.x

        context_tokens = []
        for context in self.context:
            context_tokens.extend(word_tokenize(context))
        context_ids = []
        for concept in context_tokens:
            if concept in self.conceptnet.concept_embedding.stoi:
                # context_ids.append(self.conceptnet.concept_embedding[concept])
                context_ids.append(self.conceptnet.concept_embedding.stoi[concept])
        context_x = torch.tensor(context_ids).cuda()
        
        return self.G_global, self.G_target, context_x, candidate_mask

    def get_rl_state(self, G_global, G_target, context_x):
        # get global graph GLoVe
        G_graph_embd = self.glove_embedding(G_global_x).cuda()
        # get target GLoVe
        G_target_embd = self.glove_embedding(G_target)[0].repeat(G_global_x.size(0), 1).cuda()
        # get context embedding
        context_ids = self.glove_embedding(context_x).cuda()
        return G_graph_embd, G_target_embd, context_ids

    def get_keyword(self, text):
        concepts = [concept for concept in word_tokenize(text)
                    if (concept in list(self.global_graph.nodes))]
        concepts = list(set(concepts))
        return concepts

    def reward(self, user_text):
        w = 0
        r = 0
                        
        topic_tokens = self.topic_tokenizer(user_text, return_tensors='pt')
        topic_output = self.topic_model(input_ids=topic_tokens.input_ids.cuda(), 
                                        attention_mask=topic_tokens.attention_mask.cuda())

        topic_scores = topic_output[0][0].cpu().detach().numpy()
        topic_scores = expit(topic_scores)
        
        self.topic_dist = topic_scores[self.target_topic]
        r += self.topic_dist
        
        if self.is_complete([]):
            r += 10

        return r

    def is_complete(self, user_concepts):
        if self.topic_dist > 0.5:
            return True
        return False

In [None]:
lr = 1e-6
max_timesteps = 30
gamma=0.99
epsilon_clip=0.2
value_coef = 0.5
entropy_coef = 0.01

In [None]:
edges_to_remove = [(u, v) for u, v in conceptnet.conceptnet.edges() if u == v]  # Find self-loops

for edge in edges_to_remove:
    conceptnet.conceptnet.remove_edge(*edge) 
    
blacklist_nodes = list(set(conceptnet.conceptnet.nodes) & set(blacklist))
conceptnet.conceptnet.remove_nodes_from(blacklist_nodes)

## Actor-Critic

## D3QN

In [None]:
import subprocess

def get_gpu_usage():
    try:
        result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader'], stdout=subprocess.PIPE)
        gpu_memory_used = int(result.stdout.decode('utf-8').strip())
        return gpu_memory_used
    except Exception as e:
        print("Error fetching GPU usage:", str(e))
        return None

gpu_usage = get_gpu_usage()
if gpu_usage is not None:
    print(f"GPU Memory Used: {gpu_usage} MiB")

In [None]:
num_epochs = 500
batch_size = 64
gamma = 0.99
epsilon = 0.9
target_update_freq = 5
max_episode = 20
buffer_size = 10000
hidden_size = 64
learning_rate = 1e-4

In [None]:
load_model_path = 'model/cross_entropy_batch_32_lr_1e-05_lr_decay_8e-01_bdr_class_balanced_1_GAT_original_2_hop/model_epoch666.pth'
state_dict = torch.load(load_model_path)
keyword_predictor = KeywordPredictor()
keyword_predictor.load_state_dict(state_dict)

In [None]:
policy_net = D3QN()

In [None]:
target_net = D3QN()
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

In [None]:
optimizer = optim.Adam(policy_net.parameters(), lr=1e-4)

In [None]:
replay_buffer = ReplayBuffer(buffer_size)

In [None]:
policy_net.shared_layer.graph_encoder.GCN.load_state_dict(keyword_predictor.GCN.state_dict())
policy_net.shared_layer.context_encoder.GRU.load_state_dict(keyword_predictor.GRU.state_dict())
policy_net.adv.fc.load_state_dict(keyword_predictor.fc.state_dict())

In [None]:
keyword_predictor = None
torch.cuda.empty_cache()

In [None]:
env = Simulator(policy_net, glove_embedding, conceptnet)

In [None]:
torch.cuda.empty_cache()

In [None]:
import nltk

In [None]:
gpu_usage = get_gpu_usage()
if gpu_usage is not None:
    print(f"GPU Memory Used: {gpu_usage} MiB")

In [None]:
path = 'starting.txt'

f = open(path)
lines = f.readlines()

targets = []
startings = []

for line in lines:
    startings.append(line.strip())

In [None]:
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import random
from torch.cuda.amp import autocast, GradScaler
import time


writer = SummaryWriter('./logs/{}'.format('OriginalModelTopicRemoveBlacklist'))
pbar = tqdm(range(num_epochs))
updated_epoch = 0
global_step = 0
mse_loss = nn.MSELoss()
grad_scaler = GradScaler()
topic_idx = 0

for epoch in pbar:
    topic_idx = (topic_idx + 1) % 19
    staring_idx = random.randint(0, len(startings)-1)
    
    print("[Target] ", env.topic_words[topic_idx])
    
    state = env.reset(startings[staring_idx], topic_idx)
    done = False
    total_reward = 0
    episode = 0
    
    while not done and episode < max_episode:
        print(episode)
        global_step += 1
        
        epsilon = max(epsilon * 0.99, 0.1)  # Decay exploration epsilon
        if np.random.random() < epsilon:
            valid_actions = [idx for idx, elem in enumerate(env.node_list) if state[3][idx] == 1]
            action = random.sample(valid_actions, k=min(1, len(valid_actions)))
        else:
            with torch.no_grad():
                q_values = policy_net.predict(state[0], state[1], state[2]).reshape(-1)
                q_values[state[3] == 0] = -50
                _, action = q_values.topk(min(1, state[3][state[3] == 1].size(0)))
                
                writer.add_scalar('Q_value', q_values[state[3] == 1].max().item(), global_step)
        print(action)

        next_state, reward, done = env.step(action)
        replay_buffer.add((state, action, reward, next_state, done))
        total_reward += reward

        if len(replay_buffer.buffer) >= batch_size:
            
            policy_net.train()
            optimizer.zero_grad()
            
            batch = replay_buffer.sample(batch_size)
            
            loss = 0
            with autocast():
                q_value_batch = []
                next_q_value_batch = []
                for (_state, _action, _reward, _next_state, _done) in batch:
                    q_values = policy_net(_state[0], _state[1], _state[2]).reshape(-1)
                    q_values = q_values[_action]
                    q_value_batch.append(q_values)

                    next_q_values = target_net(_next_state[0], _next_state[1], _next_state[2]).reshape(-1).detach().clone().requires_grad_(False)
                    next_q_policy_values = policy_net(_next_state[0], _next_state[1], _next_state[2]).reshape(-1)
                    
                    next_q_policy_values[(_next_state[3] == 0).reshape(-1)] = -50
                    next_q_values_indices = next_q_policy_values.topk(len(_action)).indices

                    expected_q_values = torch.tensor([_reward]).float().clone().detach().cuda() + gamma * (1 - _done) * next_q_values[next_q_values_indices]
                    next_q_value_batch.append(expected_q_values.squeeze())

                loss = mse_loss(torch.stack(q_value_batch), 
                                torch.stack(next_q_value_batch))
                print(loss)

                writer.add_scalar('policy/loss', loss.item(), updated_epoch)
                writer.add_scalar('buffer', len(replay_buffer.buffer), updated_epoch)
                gpu_usage = get_gpu_usage()
                writer.add_scalar('GPU', gpu_usage, updated_epoch)
                updated_epoch += 1

                grad_scaler.scale(loss).backward()
                grad_scaler.step(optimizer)
                grad_scaler.update()
            
            gpu_usage = get_gpu_usage()

            # Update target network
            if epoch % target_update_freq == 0:
                target_net.load_state_dict(policy_net.state_dict())

        state = next_state
        episode += 1

    pbar.set_description(f"Epoch {epoch+1}, Total Reward: {total_reward}, Done: {done}, Buffer: {len(replay_buffer.buffer)}")
    writer.add_scalar('reward', total_reward, epoch)

In [None]:
torch.save(policy_net.state_dict(), './model/D3GQN/D3GQN_topic_glove.pth')

In [None]:
import pickle

pickle.dump(replay_buffer, open('replayBuffer.txt', 'wb'))

## Evaluation

In [None]:
policy_net_state_dict = torch.load('./model/D3GQN/D3GQN_topic_cycle.pth')

In [None]:
policy_net.load_state_dict(policy_net_state_dict)

In [None]:
from tqdm import tqdm
import nltk

valid_steps_list = []
achieve_rate_list = []

f = open('eval_log.txt', 'w')

policy_net.eval()
for t in range(14, 19):
    test_epochs = 50
    pbar = tqdm(range(test_epochs))
    topic_idx = t
    print("[Target] ", env.topic_words[topic_idx])
    f.write("[Target] " + env.topic_words[topic_idx]+'\n')

    task_completion = []
    steps = []
    starting_sampled = random.sample(startings, 50)

    for epoch in pbar:
        staring_idx = random.randint(0, len(startings)-1)
        print('[Starting]', starting_sampled[epoch])

        state = env.reset(starting_sampled[epoch], topic_idx)
        done = False
        total_reward = 0
        episode = 0

        while not done and episode < max_episode:
            with torch.no_grad():
                q_values = policy_net.predict(state[0], state[1], state[2]).reshape(-1)
                q_values[state[3] == 0] = -50
                _, action = q_values.topk(min(3, state[3][state[3] == 1].size(0)))

            next_state, reward, done = env.step(action)

            total_reward += reward

            state = next_state
            episode += 1

        task_completion.append(done)
        steps.append(episode)
        
        print(starting_sampled[epoch])
        print(f"Epoch {epoch+1}, Total Reward: {total_reward}, Done: {done}, Steps: {episode}")
        f.write("[Starting]"+starting_sampled[epoch]+'\n')
        f.write(f"Epoch {epoch+1}, Total Reward: {total_reward}, Done: {done}, Steps: {episode}"+'\n')
    valid_steps = [step for idx, step in enumerate(steps) if task_completion[idx]]
    mean_achieve_steps = sum(valid_steps) / len(valid_steps)
    achieve_rate = sum(task_completion) / len(task_completion)
    print(mean_achieve_steps, achieve_rate)
    
    valid_steps_list.append(mean_achieve_steps)
    achieve_rate_list.append(achieve_rate)

In [None]:
print(valid_steps_list, achieve_rate_list)

## Baseline

In [None]:
model_name = 'facebook/blenderbot-3B'
model = BlenderbotForConditionalGeneration.from_pretrained(model_name).cuda()
tokenizer = BlenderbotTokenizer.from_pretrained(model_name)

In [None]:
topic_model = f"cardiffnlp/tweet-topic-21-multi"
topic_tokenizer = AutoTokenizer.from_pretrained(topic_model)
topic_model = AutoModelForSequenceClassification.from_pretrained(topic_model).cuda()

In [None]:
from tqdm import tqdm

topic_list = []
for starting in tqdm(random.sample(startings, 50)):
    episode = 0
    
    context = [starting]
    topics = []
    print(starting)
    while episode < 20:
        episode += 1
        concat_context = ' <s> ' + ' </s> <s> '.join(context[-3:]) + ' </s> '
        input_tokenized = tokenizer(concat_context, padding='max_length', truncation=True,
                                      max_length=128, return_tensors='pt')
        chatbot_response_tokens = model.generate(input_ids=input_tokenized.input_ids.cuda(),
                                                          attention_mask=input_tokenized.attention_mask.cuda())
        chatbot_response_text = tokenizer.decode(chatbot_response_tokens[0], skip_special_tokens=True)
        context.append(chatbot_response_text)
        
        concat_context = ' <s> ' + ' </s> <s> '.join(context[-3:]) + ' </s> '
        input_tokenized = tokenizer(concat_context, padding='max_length', truncation=True,
                                      max_length=128, return_tensors='pt')
        user_response_tokens = model.generate(input_ids=input_tokenized.input_ids.cuda(),
                                                          attention_mask=input_tokenized.attention_mask.cuda())
        user_response_text = tokenizer.decode(user_response_tokens[0], skip_special_tokens=True)
        
        context.append(user_response_text)
        
    
        topic_tokens = topic_tokenizer(user_response_text, return_tensors='pt')
        topic_output = topic_model(input_ids=topic_tokens.input_ids.cuda(), 
                                        attention_mask=topic_tokens.attention_mask.cuda())

        topic_scores = topic_output[0][0].cpu().detach().numpy()
        topic_scores = expit(topic_scores)
        
        print("[Chatbot]", chatbot_response_text)
        print("[Userbot]", user_response_text)
        
        topics.append(topic_scores)
    topic_list.append(topics)
    print()

In [None]:
import numpy as np

topic_words = np.array(['arts', 'business', 'celebrity', 'diaries', 'family', 'fashion', 'film', 'fitness', 'food', 'gaming', 'education', 'music', 
               'news', 'hobbies', 'relationships', 'science', 'sports', 'travel', 'student'])

print(len(topic_list), len(topic_list[0]), topic_list[0][0].shape)


topic_achivement = {}
for k in range(len(topic_list)):
    for t in topic_list[k]:
        tws = topic_words[t > 0.5]
        for tw in tws:
            if tw not in topic_achivement:
                topic_achivement[tw] = 0
            topic_achivement[tw] += 1

[(k, topic_achivement[k]/1107) for k in topic_achivement]