In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import json
import os
from nltk.tokenize import word_tokenize
import nltk

In [2]:
def to_single_message_format(gamefile):
    messages = []
    with open(gamefile) as inh:
        conversation_counter = 0
        for ln in inh:
            conversation = json.loads(ln)
            # players = conversation['players']
            conversation_id = conversation_counter
            # game_id = conversation['game_id']
            
            for msg, sender_label, receiver_label, score_delta, speaker, receiver, rel_index, abs_index \
                in zip(conversation['messages'], conversation['sender_labels'], \
                       conversation['receiver_labels'], conversation['game_score_delta'], \
                       conversation['speakers'], conversation['receivers'], \
                       conversation['relative_message_index'], conversation['absolute_message_index']):
                messages.append({
                    'conversation_id': conversation_id,
                    'message': msg,
                    'receiver_annotation': receiver_label,
                    'sender_annotation': sender_label,
                    'score_delta': int(score_delta),
                    'speaker': speaker,
                    'receiver': receiver,
                    'relative_message_index': rel_index,
                    'absolute_message_index': abs_index
                })
            conversation_counter += 1
            
    return messages

In [3]:
def save_to_csv(gamefile, output_file):
    messages = to_single_message_format(gamefile)
    df = pd.DataFrame(messages)

    df['message'] = df['message'].apply(lambda x: x.replace('\n',' '))
    df['message'] = df['message'].apply(lambda x: x.replace('\t',' '))
    df["sender_annotation"] = df["sender_annotation"].astype(int)

    df = df.sort_values(by=['conversation_id', 'absolute_message_index'])
    df = df[~((df['receiver_annotation'] == 'NOANNOTATION') & (df['sender_annotation'] == 1))]
    
    df.to_csv(output_file, index=False)
    number = len(df)
    print(f"Saved {number} messages to {output_file}")
    
    return df

In [4]:
save_to_csv('data/train.jsonl', 'data/train_new_format.csv')

Saved 12071 messages to data/train_new_format.csv


Unnamed: 0,conversation_id,message,receiver_annotation,sender_annotation,score_delta,speaker,receiver,relative_message_index,absolute_message_index
0,0,Germany! Just the person I want to speak with...,True,1,0,italy,germany,0,74
1,0,"You've whet my appetite, Italy. What's the sug...",True,1,0,germany,italy,1,76
2,0,👍,True,1,0,italy,germany,2,86
3,0,It seems like there are a lot of ways that cou...,True,1,0,germany,italy,3,87
7,0,"Sorry Italy I've been away doing, um, German t...",True,1,0,germany,italy,7,117
...,...,...,...,...,...,...,...,...,...
13127,187,Is there any way of me actually ending this co...,True,1,-1,france,england,11,380
13128,187,Can we agree on peace? What are your demands?,True,1,-1,france,england,12,433
13129,187,"Neutrality in exchange for current holdings, S...",False,1,1,england,france,13,434
13130,187,"Thats a bit too much, can I keep Spain and i h...",True,1,-1,france,england,14,437


In [5]:
save_to_csv('data/validation.jsonl', 'data/val_new_format.csv')

Saved 1289 messages to data/val_new_format.csv


Unnamed: 0,conversation_id,message,receiver_annotation,sender_annotation,score_delta,speaker,receiver,relative_message_index,absolute_message_index
0,0,"Good afternoon to our friends in the south, ju...",True,1,0,germany,italy,0,6
1,0,Of course! I thank you very much. The future o...,True,1,0,italy,germany,1,11
2,0,"Well, should the French turn out to be a thorn...",True,1,0,germany,italy,2,14
3,0,Greetings from new Italy!,True,1,0,italy,germany,3,96
4,0,Would just like to reiterate that Italy would ...,True,1,0,italy,germany,4,97
...,...,...,...,...,...,...,...,...,...
1406,20,"Well, they’re not off to a good start!",True,1,-2,turkey,france,5,325
1407,20,true.,True,1,2,france,turkey,6,350
1408,20,Oh brother ... What's stopping Germany from wa...,True,1,-2,turkey,france,7,1041
1409,20,"Ok, I didn't think Germany would betray me but...",NOANNOTATION,0,2,france,turkey,8,1119


In [6]:
save_to_csv('data/test.jsonl', 'data/test_new_format.csv')

Saved 2508 messages to data/test_new_format.csv


Unnamed: 0,conversation_id,message,receiver_annotation,sender_annotation,score_delta,speaker,receiver,relative_message_index,absolute_message_index
0,0,"Hi Italy! Just opening up communication, and I...",True,1,0,germany,italy,0,87
1,0,"Well....that's a great question, and a lot of ...",True,1,0,italy,germany,1,132
2,0,"Well, if you want to attack France in the Medi...",False,1,0,germany,italy,2,138
3,0,"Hello, I'm just asking about your move to Tyro...",True,1,1,germany,italy,3,207
4,0,Totally understandable - but did you notice th...,False,1,-1,italy,germany,4,221
...,...,...,...,...,...,...,...,...,...
2736,41,"Interesting, I didn't mean to take Naples- I f...",True,0,8,france,turkey,82,1498
2737,41,Interesting choice to move to Albania as it me...,True,1,8,france,turkey,83,1500
2738,41,*Austria can retreat to Greece,True,1,8,france,turkey,84,1501
2739,41,"This game is over, spending more than 2 minute...",True,1,-8,turkey,france,85,1518


In [2]:
nltk.download('punkt')

def load_glove_embeddings(glove_file_path):
    print("Loading Glove embdeddings rn")
    embeddings = {}
    with open(glove_file_path, 'r', encoding='utf-8') as f:
        for line in f:
            values = line.split()
            word = values[0]
            vector = np.asarray(values[1:], dtype='float32')
            embeddings[word] = vector
    return embeddings

def message_to_embedding(message, glove_embeddings, embedding_dim=300):
    tokens = word_tokenize(message.lower())
    embeddings = []
    for token in tokens:
        if token in glove_embeddings:
            embeddings.append(glove_embeddings[token])
    if not embeddings:  # If no valid embeddings found
        return np.zeros(embedding_dim)
    return np.mean(embeddings, axis=0)

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\vimal\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [20]:
def process_diplomacy_data(input_csv, glove_embeddings, output_file, context_window=2):
    # Load the data
    df = pd.read_csv(input_csv)
    
    # Load GloVe embeddings

    embedding_dim = len(next(iter(glove_embeddings.values())))
    
    # Prepare output data structure
    processed_data = {
        'message_embedding': [],
        'context_embeddings': [],
        'label': []
    }
    
    # Group by conversation_id
    grouped = df.groupby('conversation_id')
    count = 0
    for conv_id, group in grouped:
        # Sort by absolute_message_index and reset index
        group = group.sort_values('absolute_message_index').reset_index(drop=True)
        
        # Process each message in the conversation
        for idx in range(len(group)):
            # Convert current message to embedding
            msg_embedding = message_to_embedding(group.loc[idx, 'message'], glove_embeddings)
            
            # Get context (previous messages)
            context_indices = range(max(0, idx - context_window), idx)
            context_msgs = group.iloc[context_indices]['message'].tolist()
            
            # Convert context messages to embeddings
            context_embeddings = []
            for context_msg in context_msgs:
                embedding = message_to_embedding(context_msg, glove_embeddings)
                context_embeddings.append(embedding)
            
            # Pad context if less than context_window
            while len(context_embeddings) < context_window:
                context_embeddings.insert(0, np.zeros(embedding_dim))
            
            processed_data['message_embedding'].append(msg_embedding)
            processed_data['context_embeddings'].append(context_embeddings)
            processed_data['label'].append(group.loc[idx, 'sender_annotation'])
            # processed_data['conversation_id'].append(conv_id)
            # processed_data['absolute_message_index'].append(group.loc[idx, 'absolute_message_index'])
            if count % 100 == 0:
                print(f"Processed {count} messages in conversation {conv_id}")
            count += 1
    
    # Convert to DataFrame
    result_df = pd.DataFrame(processed_data)
    
    # Save to HDF5 format
    result_df.to_hdf(output_file, key='diplomacy_data', mode='w')
    
    print(f"Processed data saved to {output_file}")
    print(f"Shape of message embeddings: {np.array(processed_data['message_embedding']).shape}")
    print(f"Shape of context embeddings: {np.array(processed_data['context_embeddings']).shape}")
    print(f"Total samples: {len(processed_data['label'])}")
    
    return result_df

In [21]:
glove_file = "glove.6B/glove.6B.300d.txt"
context_window = 2   
glove_embeddings = load_glove_embeddings(glove_file)

Loading Glove embdeddings rn


In [None]:
input_csv = "data/train_new_format.csv"
output_file = "data/processed/train_processed.h5"


train_processed = process_diplomacy_data(input_csv, glove_embeddings, output_file, context_window)

Processed 0 messages in conversation 0
Processed 100 messages in conversation 0
Processed 200 messages in conversation 0
Processed 300 messages in conversation 0
Processed 400 messages in conversation 1
Processed 500 messages in conversation 2
Processed 600 messages in conversation 3
Processed 700 messages in conversation 3
Processed 800 messages in conversation 3
Processed 900 messages in conversation 3
Processed 1000 messages in conversation 4
Processed 1100 messages in conversation 5
Processed 1200 messages in conversation 7
Processed 1300 messages in conversation 8
Processed 1400 messages in conversation 8
Processed 1500 messages in conversation 8
Processed 1600 messages in conversation 8
Processed 1700 messages in conversation 8
Processed 1800 messages in conversation 8
Processed 1900 messages in conversation 10
Processed 2000 messages in conversation 10
Processed 2100 messages in conversation 11
Processed 2200 messages in conversation 13
Processed 2300 messages in conversation 15

your performance may suffer as PyTables will pickle object types that it cannot
map directly to c-types [inferred_type->mixed,key->block0_values] [items->Index(['message_embedding', 'context_embeddings'], dtype='object')]

  result_df.to_hdf(output_file, key='diplomacy_data', mode='w')


Processed data saved to data/train_processed.h5
Shape of message embeddings: (12071, 300)
Shape of context embeddings: (12071, 2, 300)
Total samples: 12071


In [23]:
# Example of how to access the data
print("Sample of processed data:")
print("First message embedding shape:", train_processed['message_embedding'][0].shape)
print("First context embeddings shape:", np.array(train_processed['context_embeddings'][0]).shape)
print("First label:", train_processed['label'][0])

Sample of processed data:
First message embedding shape: (300,)
First context embeddings shape: (2, 300)
First label: 1


In [24]:
input_csv = "data/test_new_format.csv"
output_file = "data/processed/test_processed.h5"


test_processed = process_diplomacy_data(input_csv, glove_embeddings, output_file, context_window)

Processed 0 messages in conversation 0
Processed 100 messages in conversation 5
Processed 200 messages in conversation 15
Processed 300 messages in conversation 16
Processed 400 messages in conversation 16
Processed 500 messages in conversation 16
Processed 600 messages in conversation 17
Processed 700 messages in conversation 17
Processed 800 messages in conversation 17
Processed 900 messages in conversation 19
Processed 1000 messages in conversation 20
Processed 1100 messages in conversation 20
Processed 1200 messages in conversation 27
Processed 1300 messages in conversation 29
Processed 1400 messages in conversation 32
Processed 1500 messages in conversation 34
Processed 1600 messages in conversation 36
Processed 1700 messages in conversation 36
Processed 1800 messages in conversation 37
Processed 1900 messages in conversation 38
Processed 2000 messages in conversation 38
Processed 2100 messages in conversation 38
Processed 2200 messages in conversation 40
Processed 2300 messages i

your performance may suffer as PyTables will pickle object types that it cannot
map directly to c-types [inferred_type->mixed,key->block0_values] [items->Index(['message_embedding', 'context_embeddings'], dtype='object')]

  result_df.to_hdf(output_file, key='diplomacy_data', mode='w')


In [25]:
input_csv = "data/val_new_format.csv"
output_file = "data/processed/val_processed.h5"


val_processed = process_diplomacy_data(input_csv, glove_embeddings, output_file, context_window)

Processed 0 messages in conversation 0
Processed 100 messages in conversation 1
Processed 200 messages in conversation 1
Processed 300 messages in conversation 4
Processed 400 messages in conversation 5
Processed 500 messages in conversation 6
Processed 600 messages in conversation 6
Processed 700 messages in conversation 7
Processed 800 messages in conversation 9
Processed 900 messages in conversation 10
Processed 1000 messages in conversation 13
Processed 1100 messages in conversation 13
Processed 1200 messages in conversation 13
Processed data saved to data/processed/val_processed.h5
Shape of message embeddings: (1289, 300)
Shape of context embeddings: (1289, 2, 300)
Total samples: 1289


your performance may suffer as PyTables will pickle object types that it cannot
map directly to c-types [inferred_type->mixed,key->block0_values] [items->Index(['message_embedding', 'context_embeddings'], dtype='object')]

  result_df.to_hdf(output_file, key='diplomacy_data', mode='w')
