In [None]:
from huggingface_hub import login
login(token="")

In [9]:
import torch

if torch.cuda.is_available():
    print(f"CUDA is available! GPU: {torch.cuda.get_device_name(0)}")
else:
    print("CUDA is not available.")

CUDA is available! GPU: NVIDIA GeForce RTX 2080 Ti


In [10]:
#!/usr/bin/env python
# coding: utf-8

import networkx as nx
import random
from typing import List, Union
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import time
import csv
import os
import torch
import numpy as np
import re
import logging
import transformers
from transformers import AutoTokenizer, pipeline
import pandas as pd

In [11]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [12]:
model = "meta-llama/Llama-3.1-8B-Instruct"

class LlamaModel():
    def __init__(self):
        self.model_name = "meta-llama/Llama-3.1-8B-Instruct"
        self.tokenizer = None
        self.pipeline = None
        self.__load_model__()
        
    def __load_model__(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_auth_token=True)
        
        self.pipeline = transformers.pipeline(
            "text-generation",
            model=self.model_name,
            tokenizer=self.tokenizer,
            torch_dtype=torch.float16,
            device_map="auto",
            do_sample=True,
            top_p=0.9,
            temperature=0.7,
            max_new_tokens=30,
        )
        
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        
        if self.pipeline is None:
            raise RuntimeError("Failed to initialize generation pipeline.")
    
    def generate_response(self, query):
        try:
            query = f"<s>[INST] {query} [/INST]"
            sequences = self.pipeline(
                query,
                truncation=True,
                eos_token_id=self.tokenizer.eos_token_id,
                pad_token_id=self.tokenizer.pad_token_id,
            )
            generated_text = sequences[0]['generated_text']
            sequence_return = generated_text[len(query):].strip()
            sequence_return = sequence_return.replace("[INST]", "").replace("[/INST]", "").strip()
            return sequence_return
        except Exception as e:
            logger.error(f"Error generating response: {e}")
            return "Error"

In [13]:
def generate_fukushima_event():
    return """
    The Fukushima Nuclear Disaster was a 2011 nuclear accident at the Daiichi Nuclear Power Plant in Fukushima, Japan. The cause of the nuclear disaster was the Tōhoku earthquake on March 11, 2011, the most powerful earthquake ever recorded in Japan. The earthquake triggered a tsunami with waves up to 130 feet tall, with 45 foot tall waves causing direct damage to the nuclear power plant. The damage inflicted dramatic harm both locally and globally. The damage caused radioactive isotopes in reactor coolant to discharge into the sea, therefore Japanese authorities quickly implemented a 100-foot exclusion zone around the power plant. Large quantities of radioactive particles were found shortly after throughout the Pacific Ocean and reached the California coast.
    The exclusion zone resulted in the displacement of approximately 156,000 people in years to follow. Independent commissions continue to recognize that affected residents are still struggling and facing grave concerns. Indeed, a WHO report predicts that infant girls exposed to the radiation are 70% more likely to develop thyroid cancer.
    The resulting energy shortage inspired media campaigns to encourage Japanese households and businesses to cut back on electrical usage, which led to the national movement Setsuden ("saving electricity"). The movement caused a dramatic decrease in the country's energy consumption during the crisis and later inspired the Japanese government to pass a battery of policies focused on reducing the energy consumption of large companies and households.
    """

class TimeStep:
    def __init__(self, observation, reward=None, terminal=False):
        self.observation = observation
        self.reward = reward
        self.terminal = terminal

def generate_hashtag_llama(player_name, event, llama, previous_guesses=None, current_round=1):
    """Generate a hashtag for a player based on the event and previous guesses."""
    
    # Table containing previous hashtags
    previous_table = "Round, Your Guess, Neighbor's Guess\n"
    
    # Add previous rounds' data
    if previous_guesses and current_round > 1:
        previous_rows = [f"{round_num}, {player_hashtag}, {neighbor_hashtag}" 
                        for round_num, player_hashtag, neighbor_hashtag in previous_guesses]
        previous_table += "\n".join(previous_rows)
    
    # Create prompt based on round number
    if current_round > 1:
        prompt = (
            f"In the experiment, you are awarded 1 point if you guess the same hashtag as your randomly-assigned neighbor, "
            f"and 0 points if you don't guess the same hashtag. Your goal is to earn as many points as possible.\n\n"
            f"You are in round {current_round} of the experiment. Your guesses and your neighbor's guesses have been as follows, "
            f"as represented in the CSV below:\n\n"
            f"{previous_table}\n\n"
            f"Based on this information and the event provided to you in round 1:\n{event}\n\n"
            f"Please guess a short (max 5 words) hashtag for this event with the goal of matching your randomly-assigned neighbor in this round. The hashtag should not contain any extra words."
            f"You can use your hashtag from the last round, but don't always use it—especially if you do not believe your next neighbor will have that hashtag."
        )
    else:
        prompt = (
            f"In the experiment, you are awarded 1 point if you guess the same hashtag as your randomly-assigned neighbor, "
            f"and 0 points if you don't guess the same hashtag. Your goal is to earn as many points as possible.\n\n"
            f"The event is as follows:\n{event}\n\n"
            f"Please guess a short (max 5 words) hashtag for this event. The hashtag should not contain any extra words."
        )

    # Generate response with retries
    max_retries = 3
    retries = 0
    
    while retries < max_retries:
        try:
            response = llama.generate_response(prompt)
            
            # Extract the hashtag
            hashtags = [word.strip(",.:*)([]") for word in response.split() if word.startswith("#") and len(word) > 1 and word != '##']
            
            if hashtags:
                chosen_hashtag = hashtags[0]
                # Clean up any special characters
                special = re.search(r"[^a-zA-Z0-9#]", chosen_hashtag)
                if special:
                    chosen_hashtag = chosen_hashtag[:special.start()]
                return chosen_hashtag
            
            retries += 1
            
        except Exception as e:
            logger.error(f"Error generating hashtag: {e}")
            retries += 1
    
    # Fallback hashtag if generation fails
    return f"#Error"

class FukushimaHashtagGame:
    def __init__(self, num_players=20, total_rounds=40):
        """Initialize the game with players and the LLM."""
        self.llama = LlamaModel()
        self.event = generate_fukushima_event()
        self.num_players = num_players
        self.players = [f"player_{i+1}" for i in range(num_players)]
        
        # Initialize the game
        self.total_rounds = total_rounds
        self.cur_round = 1
        self.turn = 0
        self.selected_hashtags = {}
        self.selected_hashtag_history = []
        self.scores = {player: 0 for player in self.players}
        self.round_scores = {}
        self.player_names = self.players.copy()
        self.previous_hashtags = []
        self.previous_neighbors = []
        self.cur_neighbor = {}
        self._terminal = False
        
        # Track convergence between players
        self.convergence_data = {
            "player_agreements": {player: {} for player in self.players},
            "hashtag_frequency": {},
            "round_agreement_rates": []
        }
        
        # Create the network graph
        self.network_graph = self._create_watts_strogatz_network()
        
        self.reset()

    def _create_watts_strogatz_network(self):
        """Create a Watts-Strogatz small-world network for players."""
        G = nx.watts_strogatz_graph(self.num_players, k=4, p=0.6)
        node_to_player = {i: player for i, player in enumerate(self.players)}
        player_graph = nx.Graph()
        
        # Add nodes with player attributes
        for i in range(self.num_players):
            player = node_to_player[i]
            player_graph.add_node(player, name=player)
        
        # Add edges from the Watts-Strogatz graph
        for i, j in G.edges():
            player_i = node_to_player[i]
            player_j = node_to_player[j]
            player_graph.add_edge(player_i, player_j)
        
        # Ensure the graph is connected
        if not nx.is_connected(player_graph):
            components = list(nx.connected_components(player_graph))
            for i in range(len(components) - 1):
                node1 = random.choice(list(components[i]))
                node2 = random.choice(list(components[i+1]))
                player_graph.add_edge(node1, node2)
        
        return player_graph

    def reset(self):
        """Reset the game for a new round."""
        self.player_names = self.players.copy()
        random.shuffle(self.player_names)
        self.turn = 0
        self._terminal = False
        self.round_scores = {player: 0 for player in self.players}
        self.scored_pairs = set()
        self.selected_hashtags.clear()
        
        # Store previous neighbor pairings
        if self.cur_neighbor:
            self.previous_neighbors.append(self.cur_neighbor.copy())
        
        # Create new neighbor pairings
        self.cur_neighbor = {}
        available_neighbors = self.player_names[:]
        while len(available_neighbors) > 1:
            player = available_neighbors.pop()
            neighbor = available_neighbors.pop()
            self.cur_neighbor[player] = neighbor
            self.cur_neighbor[neighbor] = player
            logger.info(f"{player} is paired with {neighbor}")
        
        # Handle odd number of players
        if available_neighbors:
            player = available_neighbors[0]
            neighbor = random.choice([p for p in self.player_names if p != player])
            self.cur_neighbor[player] = neighbor
            logger.info(f"{player} is paired with {neighbor} (repeat pairing)")
        
        logger.info(f"Round {self.cur_round}: Hashtag generation started")
        observation = self.get_observation(self.get_next_player())
        return TimeStep(observation=observation, reward=None, terminal=False)

    def get_observation(self, player_name=None):
        """Get a simplified observation for a player."""
        if player_name is None:
            return None
        
        neighbor = self.cur_neighbor.get(player_name)
        previous_guesses = self.get_neighbor_hashtags(player_name)
        
        return {
            "player": player_name,
            "neighbor": neighbor,
            "round": self.cur_round,
            "previous_guesses": previous_guesses,
            "event": self.event
        }

    def get_next_player(self):
        """Get the next player in the turn order."""
        return self.player_names[-1] if self.player_names else None

    def get_neighbor_hashtags(self, player_name):
        """Get the previous hashtags of the player and their neighbor."""
        previous_hashtags = []
        if self.cur_round > 1:
            for round_idx in range(self.cur_round - 1):
                if round_idx < len(self.previous_hashtags) and round_idx < len(self.previous_neighbors):
                    player_previous_hashtag = self.previous_hashtags[round_idx].get(player_name)
                    previous_neighbor = self.previous_neighbors[round_idx].get(player_name)
                    neighbor_previous_hashtag = self.previous_hashtags[round_idx].get(previous_neighbor)
                    if player_previous_hashtag and neighbor_previous_hashtag:
                        previous_hashtags.append((round_idx + 1, player_previous_hashtag, neighbor_previous_hashtag))
        return previous_hashtags

    def export_csv(self, filename="fukushima_hashtags_llama.csv"):
        """Export the game data to a CSV file."""
        write_header = not os.path.exists(filename)
        with open(filename, 'a', newline='') as file:
            writer = csv.writer(file)
            if write_header:
                writer.writerow(["Round", "Player", "Hashtag", "Matched", "Round Points", "Total Points"])
            
            if self.selected_hashtags:
                for player, hashtag in self.selected_hashtags.items():
                    neighbor = self.cur_neighbor.get(player)
                    matched = "Yes" if neighbor and self.selected_hashtags.get(neighbor) == hashtag else "No"
                    round_points = self.round_scores.get(player, 0)
                    total_points = self.scores.get(player, 0)
                    
                    writer.writerow([
                        self.cur_round,
                        player,
                        hashtag,
                        matched,
                        round_points,
                        total_points
                    ])

    def step(self, player_name, action):
        """Advance the game by one step."""
        neighbor = self.cur_neighbor.get(player_name)
        previous_guesses = self.get_neighbor_hashtags(player_name)
        
        # Generate hashtag for the player
        chosen_hashtag = generate_hashtag_llama(
            player_name,
            self.event, 
            self.llama, 
            previous_guesses, 
            self.cur_round
        )
        
        self.selected_hashtags[player_name] = chosen_hashtag
        logger.info(f"Player: {player_name}, Hashtag: {chosen_hashtag}")
        
        self.turn += 1
        if player_name in self.player_names:
            self.player_names.remove(player_name)
        
        # Check if all players have selected hashtags
        if not self.player_names:
            # Score all pairs at the end of the round
            for player, hashtag in self.selected_hashtags.items():
                neighbor = self.cur_neighbor.get(player)
                if neighbor and self.selected_hashtags.get(neighbor) == hashtag:
                    pair = tuple(sorted([player, neighbor]))
                    if pair not in self.scored_pairs:
                        round_points = 1
                        self.scores[player] += round_points
                        self.scores[neighbor] += round_points
                        self.round_scores[player] = round_points
                        self.round_scores[neighbor] = round_points
                        self.scored_pairs.add(pair)
            
            # Store hashtags for this round
            self.previous_hashtags.append(self.selected_hashtags.copy())
            self.selected_hashtag_history.append(self.selected_hashtags.copy())
            
            # Update convergence data
            self._update_convergence_data()
            
            # Export data and visualize network
            self.export_csv()
            logger.info(f"End of round {self.cur_round}")
            for player, score in self.scores.items():
                logger.info(f"{player}'s points this round: {self.round_scores[player]}, Total: {score}")
            self.show_network()
            
            # Check if we should continue or end the game
            if self.cur_round < self.total_rounds:
                self.cur_round += 1
                self.reset()
            else:
                self._terminal = True
                self.export_convergence_data()
                self.analyze_convergence()
                self.generate_entropy_graph()
        
        observation = self.get_observation(self.get_next_player())
        return TimeStep(observation=observation, reward=None, terminal=self._terminal)

    def _update_convergence_data(self):
        """Track how often players agree on hashtags and hashtag popularity."""
        # Track hashtag frequency
        for hashtag in self.selected_hashtags.values():
            if hashtag in self.convergence_data["hashtag_frequency"]:
                self.convergence_data["hashtag_frequency"][hashtag] += 1
            else:
                self.convergence_data["hashtag_frequency"][hashtag] = 1
        
        # Track agreements between players
        agreements_this_round = 0
        total_pairs = 0
        
        for player1 in self.players:
            for player2 in self.players:
                if player1 >= player2:  # Avoid counting pairs twice
                    continue
                    
                hashtag1 = self.selected_hashtags.get(player1)
                hashtag2 = self.selected_hashtags.get(player2)
                
                if hashtag1 and hashtag2:
                    total_pairs += 1
                    
                    # Initialize agreement tracking if needed
                    if player2 not in self.convergence_data["player_agreements"][player1]:
                        self.convergence_data["player_agreements"][player1][player2] = {
                            "agreement_rounds": [],
                            "total_rounds": 0
                        }
                    
                    # Update agreement data
                    self.convergence_data["player_agreements"][player1][player2]["total_rounds"] += 1
                    
                    if hashtag1 == hashtag2:
                        agreements_this_round += 1
                        self.convergence_data["player_agreements"][player1][player2]["agreement_rounds"].append(self.cur_round)
        
        # Calculate agreement rate for this round
        if total_pairs > 0:
            agreement_rate = agreements_this_round / total_pairs
        else:
            agreement_rate = 0
            
        self.convergence_data["round_agreement_rates"].append({
            "round": self.cur_round,
            "agreement_rate": agreement_rate,
            "agreements": agreements_this_round,
            "total_pairs": total_pairs
        })

    def export_convergence_data(self, filename="fukushima_convergence_data_llama.csv"):
        """Export convergence data including agreement rates between players."""
        # Export player-to-player agreement data
        with open(filename, 'w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["Player1", "Player2", "Agreement_Rounds", "Total_Rounds", "Agreement_Rate"])
            
            for player1 in self.players:
                for player2 in self.convergence_data["player_agreements"][player1]:
                    data = self.convergence_data["player_agreements"][player1][player2]
                    agreement_rounds = data["agreement_rounds"]
                    total_rounds = data["total_rounds"]
                    agreement_rate = len(agreement_rounds) / total_rounds if total_rounds > 0 else 0.0
                    
                    writer.writerow([
                        player1,
                        player2,
                        ", ".join(map(str, agreement_rounds)),
                        total_rounds,
                        f"{agreement_rate:.4f}"
                    ])
        
        # Export round-by-round agreement rates
        with open("fukushima_round_agreement_rates_llama.csv", 'w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["Round", "Agreement_Rate", "Agreements", "Total_Pairs"])
            
            for data in self.convergence_data["round_agreement_rates"]:
                writer.writerow([
                    data["round"],
                    f"{data['agreement_rate']:.4f}",
                    data["agreements"],
                    data["total_pairs"]
                ])
        
        # Export hashtag frequency data
        with open("fukushima_hashtag_frequency_llama.csv", 'w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["Hashtag", "Frequency"])
            
            # Sort hashtags by frequency
            sorted_hashtags = sorted(
                self.convergence_data["hashtag_frequency"].items(),
                key=lambda x: x[1],
                reverse=True
            )
            
            for hashtag, count in sorted_hashtags:
                writer.writerow([hashtag, count])
        
        logger.info(f"Exported convergence data to {filename}")

    def analyze_convergence(self):
        """Analyze convergence trends and create visualizations."""
        # Calculate overall agreement rate
        total_agreements = 0
        total_pairs = 0
        
        for data in self.convergence_data["round_agreement_rates"]:
            total_agreements += data["agreements"]
            total_pairs += data["total_pairs"]
        
        overall_agreement_rate = total_agreements / total_pairs if total_pairs > 0 else 0
        
        logger.info(f"Overall agreement rate: {overall_agreement_rate:.4f}")
        
        # Plot agreement rate over time
        plt.figure(figsize=(10, 6))
        rounds = [data["round"] for data in self.convergence_data["round_agreement_rates"]]
        rates = [data["agreement_rate"] for data in self.convergence_data["round_agreement_rates"]]
        
        plt.plot(rounds, rates, marker='o')
        plt.axhline(y=overall_agreement_rate, color='r', linestyle='--', label=f'Overall: {overall_agreement_rate:.4f}')
        plt.xlabel('Round')
        plt.ylabel('Agreement Rate')
        plt.title('Hashtag Agreement Rate Over Time')
        plt.legend()
        plt.grid(True)
        plt.savefig("fukushima_agreement_rate_llama.png")
        plt.close()
        
        # Plot hashtag popularity
        plt.figure(figsize=(12, 8))
        
        # Sort hashtags by frequency
        sorted_hashtags = sorted(
            self.convergence_data["hashtag_frequency"].items(),
            key=lambda x: x[1],
            reverse=True
        )
        
        # Take top 10 hashtags for clarity
        top_hashtags = sorted_hashtags[:10]
        
        labels = [hashtag for hashtag, _ in top_hashtags]
        values = [count for _, count in top_hashtags]
        
        plt.bar(labels, values)
        plt.xlabel('Hashtag')
        plt.ylabel('Frequency')
        plt.title('Top 10 Hashtag Frequencies')
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.savefig("fukushima_hashtag_frequency_llama.png")
        plt.close()

    def generate_entropy_graph(self):
        """Generate an entropy graph to visualize the diversity of hashtags over time."""
        entropy_values = []
        for round_hashtags in self.selected_hashtag_history:
            hashtag_counts = {}
            total_hashtags = 0
            for hashtag in round_hashtags.values():
                if hashtag in hashtag_counts:
                    hashtag_counts[hashtag] += 1
                else:
                    hashtag_counts[hashtag] = 1
                total_hashtags += 1
            entropy = 0
            for count in hashtag_counts.values():
                probability = count / total_hashtags
                entropy -= probability * np.log2(probability)
            entropy_values.append(entropy)
        
        plt.figure(figsize=(10, 6))
        plt.plot(range(1, self.total_rounds + 1), entropy_values, marker='o')
        plt.xlabel('Round')
        plt.ylabel('Entropy (bits)')
        plt.title('Entropy of Hashtags Over Time')
        plt.grid(True)
        plt.savefig("fukushima_entropy_graph_llama.png")
        plt.close()

    def show_network(self):
        """Visualize the network graph with hashtag matching."""
        pos = nx.spring_layout(self.network_graph, seed=42)
        
        matched_edges = []
        edge_colors = []
        
        # Color edges based on hashtag matching
        for edge in self.network_graph.edges():
            player1, player2 = edge
            if (player1 in self.selected_hashtags and 
                player2 in self.selected_hashtags and 
                self.selected_hashtags[player1] == self.selected_hashtags[player2]):
                matched_edges.append(edge)
                edge_colors.append('red')
            else:
                edge_colors.append('gray')
        
        # Color nodes based on their hashtag
        unique_hashtags = list(set(self.selected_hashtags.values()))
        color_map = plt.cm.get_cmap('tab20', len(unique_hashtags))
        hashtag_to_color = {hashtag: color_map(i) for i, hashtag in enumerate(unique_hashtags)}
        
        node_colors = [hashtag_to_color.get(self.selected_hashtags.get(node), 'lightgray') for node in self.network_graph.nodes()]
        
        plt.figure(figsize=(12, 10))
        nx.draw_networkx_nodes(self.network_graph, pos, node_color=node_colors, node_size=300, alpha=0.8)
        nx.draw_networkx_edges(self.network_graph, pos, width=1.0, alpha=0.5, edge_color=edge_colors)
        
        # Add labels showing player and hashtag
        labels = {node: f"{node}\n{self.selected_hashtags.get(node, '')}" 
                 for node in self.network_graph.nodes()}
        nx.draw_networkx_labels(self.network_graph, pos, labels=labels, font_size=8)
        
        # Create legend for hashtags
        legend_elements = []
        for hashtag, color in hashtag_to_color.items():
            legend_elements.append(plt.Line2D([0], [0], marker='o', color='w', 
                          label=f"{hashtag}", markerfacecolor=color, markersize=10))
                          
        # Add legend for matching edges
        legend_elements.append(plt.Line2D([0], [0], color='red', lw=2, label='Matching Hashtags'))
        legend_elements.append(plt.Line2D([0], [0], color='gray', lw=2, label='Different Hashtags'))
        
        plt.legend(handles=legend_elements, title="Hashtags", loc="upper left", bbox_to_anchor=(1, 1))
        plt.title(f"Fukushima Hashtag Network - Round {self.cur_round}")
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(f"fukushima_network_round_{self.cur_round}_llama.png", bbox_inches="tight")
        plt.close()

In [14]:
def run_simulation():
    game = FukushimaHashtagGame(num_players=20, total_rounds=40)
    timestep = game.reset()
    
    while not timestep.terminal:
        player = game.get_next_player()
        if player:
            timestep = game.step(player, None)
    
    logger.info("Simulation complete!")

if __name__ == "__main__":
    run_simulation()



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

INFO:__main__:player_4 is paired with player_1
INFO:__main__:player_18 is paired with player_17
INFO:__main__:player_10 is paired with player_3
INFO:__main__:player_12 is paired with player_8
INFO:__main__:player_9 is paired with player_14
INFO:__main__:player_2 is paired with player_5
INFO:__main__:player_16 is paired with player_13
INFO:__main__:player_19 is paired with player_15
INFO:__main__:player_20 is paired with player_7
INFO:__main__:player_11 is paired with player_6
INFO:__main__:Round 1: Hashtag generation started
INFO:__main__:player_7 is paired with player_3
INFO:__main__:player_13 is paired with player_15
INFO:__main__:player_12 is paired with player_1
INFO:__main__:player_10 is paired with player_9
INFO:__main__:player_14 is paired with player_18
INFO:__main__:player_17 is paired with player_19
INFO:__main__:player_16 is paired with player_4
INFO:__main__:player_5 is paired with player_2
INFO:__main__:player_8 is paired with player_6
INFO:__main__:player_20 is paired wit