In [37]:
import json
from io import open

def preprocess_data(file_path):
    train_data = {}

    with open(file_path, 'r') as file:
        i = 0
        for line in file:
            # Parse the line as JSON
            train_data[i] = json.loads(line)
            i+=1

    inputs, targets = [], []

    for key, instance in train_data.items():
        if isinstance(instance, dict) and 'messages' in instance:
            messages = instance['messages']
            game_scores = instance['game_score']

        for idx, (message, score) in enumerate(zip(messages, game_scores)):
            start_index = max(0, idx-5)
            end_index = min(len(messages)-1, idx)
            context = " ".join(messages[start_index:end_index])
            inputs.append(f"Given the following context, generate an appropriate repsonse: {context}")
            targets.append(message)
    return inputs, targets

training_inputs, training_targets = preprocess_data("2020_acl_diplomacy/data/train.jsonl")
validation_inputs, validation_targets = preprocess_data("2020_acl_diplomacy/data/validation.jsonl")
test_inputs, test_targets = preprocess_data("2020_acl_diplomacy/data/test.jsonl")

print(len(training_inputs)) 
print(len(training_targets))
print(training_inputs[3])
print(training_targets[3])

13132
13132
Given the following context, generate an appropriate repsonse: Germany!

Just the person I want to speak with. I have a somewhat crazy idea that I’ve always wanted to try with I/G, but I’ve never actually convinced the other guy to try it. And, what’s worse, it might make you suspicious of me. 

So...do I suggest it?

I’m thinking that this is a low stakes game, not a tournament or anything, and an interesting and unusual move set might make it more fun? That’s my hope anyway.

What is your appetite like for unusual and crazy? You've whet my appetite, Italy. What's the suggestion? 👍
It seems like there are a lot of ways that could go wrong...I don't see why France would see you approaching/taking Munich--while I do nothing about it--and not immediately feel skittish


In [38]:
%pip install transformers
%pip install torch

from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
import torch
from torch.utils.data import Dataset, DataLoader

class DiplomacyDataset(Dataset):
    def __init__(self, data, targets, tokenizer, max_len=512):
        self.data = data
        self.targets = targets
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        tokenized_input = self.tokenizer.encode(self.data[index], max_length=self.max_len, truncation=True, padding='max_length', return_tensors='pt')
        tokenized_target = self.tokenizer.encode(self.targets[index], max_length=self.max_len, truncation=True, padding='max_length', return_tensors='pt')
        item = {
            'input_ids': tokenized_input.flatten(),
            # 'decoder_input_ids': tokenized_target.flatten()[:-1], # remove the last token of the target
            'labels': tokenized_target.flatten()[1:],
        }
        return item

tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small')

train_dataset = DiplomacyDataset(training_inputs, training_targets, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

validation_dataset = DiplomacyDataset(validation_inputs, validation_targets, tokenizer)
validation_loader = DataLoader(validation_dataset, batch_size=4, shuffle=True)


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip3 install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip3 install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [39]:
%pip install accelerate -U
import accelerate


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-4)


# for epoch in range(3):
#     model.train()
#     for batch in train_loader:
#         input_ids = batch['input_ids'].to(device)
#         labels = batch['labels'].to(device)

#         outputs = model(input_ids=input_ids, labels=labels)
#         loss = outputs[0]

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         print(f"Epoch: {epoch}, Loss: {loss.item()}")
        


training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy='epoch',
    num_train_epochs=3,
    learning_rate=3e-4,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset
)



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip3 install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [40]:
trainer.train() 

  0%|          | 10/9849 [00:20<3:22:21,  1.23s/it]

{'loss': 2.3186, 'learning_rate': 5.999999999999999e-06, 'epoch': 0.0}


  0%|          | 20/9849 [00:31<3:02:11,  1.11s/it]

{'loss': 2.1954, 'learning_rate': 1.1999999999999999e-05, 'epoch': 0.01}


  0%|          | 30/9849 [00:43<3:07:30,  1.15s/it]

{'loss': 2.1135, 'learning_rate': 1.7999999999999997e-05, 'epoch': 0.01}


  0%|          | 40/9849 [00:54<3:09:26,  1.16s/it]

{'loss': 1.7398, 'learning_rate': 2.3999999999999997e-05, 'epoch': 0.01}


  1%|          | 50/9849 [01:06<3:07:53,  1.15s/it]

{'loss': 1.9124, 'learning_rate': 2.9999999999999997e-05, 'epoch': 0.02}


  1%|          | 60/9849 [01:17<3:06:15,  1.14s/it]

{'loss': 0.852, 'learning_rate': 3.5999999999999994e-05, 'epoch': 0.02}


  1%|          | 70/9849 [01:28<3:01:24,  1.11s/it]

{'loss': 0.4436, 'learning_rate': 4.2e-05, 'epoch': 0.02}


  1%|          | 80/9849 [01:40<3:03:38,  1.13s/it]

{'loss': 0.3931, 'learning_rate': 4.7999999999999994e-05, 'epoch': 0.02}


  1%|          | 90/9849 [01:51<3:00:51,  1.11s/it]

{'loss': 0.2479, 'learning_rate': 5.399999999999999e-05, 'epoch': 0.03}


  1%|          | 100/9849 [02:02<2:59:00,  1.10s/it]

{'loss': 0.8029, 'learning_rate': 5.9999999999999995e-05, 'epoch': 0.03}


  1%|          | 110/9849 [02:13<2:59:14,  1.10s/it]

{'loss': 0.3474, 'learning_rate': 6.599999999999999e-05, 'epoch': 0.03}


  1%|          | 120/9849 [02:24<3:00:00,  1.11s/it]

{'loss': 0.366, 'learning_rate': 7.199999999999999e-05, 'epoch': 0.04}


  1%|▏         | 130/9849 [02:36<2:57:35,  1.10s/it]

{'loss': 0.3028, 'learning_rate': 7.8e-05, 'epoch': 0.04}


  1%|▏         | 140/9849 [02:47<3:00:17,  1.11s/it]

{'loss': 0.3121, 'learning_rate': 8.4e-05, 'epoch': 0.04}


  2%|▏         | 150/9849 [02:58<3:00:41,  1.12s/it]

{'loss': 0.3043, 'learning_rate': 8.999999999999999e-05, 'epoch': 0.05}


  2%|▏         | 160/9849 [03:09<2:56:33,  1.09s/it]

{'loss': 0.2634, 'learning_rate': 9.599999999999999e-05, 'epoch': 0.05}


  2%|▏         | 170/9849 [03:21<3:00:39,  1.12s/it]

{'loss': 0.6878, 'learning_rate': 0.000102, 'epoch': 0.05}


  2%|▏         | 180/9849 [03:32<2:58:44,  1.11s/it]

{'loss': 0.242, 'learning_rate': 0.00010799999999999998, 'epoch': 0.05}


  2%|▏         | 190/9849 [03:43<2:55:59,  1.09s/it]

{'loss': 0.1979, 'learning_rate': 0.00011399999999999999, 'epoch': 0.06}


  2%|▏         | 200/9849 [03:54<2:55:31,  1.09s/it]

{'loss': 0.4965, 'learning_rate': 0.00011999999999999999, 'epoch': 0.06}


  2%|▏         | 210/9849 [04:05<2:56:15,  1.10s/it]

{'loss': 0.2829, 'learning_rate': 0.00012599999999999997, 'epoch': 0.06}


  2%|▏         | 220/9849 [04:17<2:57:51,  1.11s/it]

{'loss': 0.2175, 'learning_rate': 0.00013199999999999998, 'epoch': 0.07}


  2%|▏         | 230/9849 [04:28<2:57:08,  1.10s/it]

{'loss': 0.216, 'learning_rate': 0.000138, 'epoch': 0.07}


  2%|▏         | 240/9849 [04:39<2:58:43,  1.12s/it]

{'loss': 0.2191, 'learning_rate': 0.00014399999999999998, 'epoch': 0.07}


  3%|▎         | 250/9849 [04:50<2:56:18,  1.10s/it]

{'loss': 0.2363, 'learning_rate': 0.00015, 'epoch': 0.08}


  3%|▎         | 260/9849 [05:01<2:56:23,  1.10s/it]

{'loss': 0.2465, 'learning_rate': 0.000156, 'epoch': 0.08}


  3%|▎         | 270/9849 [05:12<2:54:40,  1.09s/it]

{'loss': 0.2149, 'learning_rate': 0.000162, 'epoch': 0.08}


  3%|▎         | 280/9849 [05:24<2:57:08,  1.11s/it]

{'loss': 0.3345, 'learning_rate': 0.000168, 'epoch': 0.09}


  3%|▎         | 290/9849 [05:35<2:56:47,  1.11s/it]

{'loss': 0.2939, 'learning_rate': 0.00017399999999999997, 'epoch': 0.09}


  3%|▎         | 300/9849 [05:46<2:55:35,  1.10s/it]

{'loss': 0.2098, 'learning_rate': 0.00017999999999999998, 'epoch': 0.09}


  3%|▎         | 310/9849 [05:57<2:54:18,  1.10s/it]

{'loss': 0.1919, 'learning_rate': 0.000186, 'epoch': 0.09}


  3%|▎         | 320/9849 [06:08<2:57:10,  1.12s/it]

{'loss': 0.1739, 'learning_rate': 0.00019199999999999998, 'epoch': 0.1}


  3%|▎         | 330/9849 [06:19<2:53:58,  1.10s/it]

{'loss': 0.2772, 'learning_rate': 0.000198, 'epoch': 0.1}


  3%|▎         | 340/9849 [06:30<2:55:34,  1.11s/it]

{'loss': 0.2117, 'learning_rate': 0.000204, 'epoch': 0.1}


  4%|▎         | 350/9849 [06:42<2:54:38,  1.10s/it]

{'loss': 0.1953, 'learning_rate': 0.00020999999999999998, 'epoch': 0.11}


  4%|▎         | 360/9849 [06:53<2:58:54,  1.13s/it]

{'loss': 0.2334, 'learning_rate': 0.00021599999999999996, 'epoch': 0.11}


  4%|▍         | 370/9849 [07:04<2:54:11,  1.10s/it]

{'loss': 0.2551, 'learning_rate': 0.00022199999999999998, 'epoch': 0.11}


  4%|▍         | 380/9849 [07:15<2:56:34,  1.12s/it]

{'loss': 0.2751, 'learning_rate': 0.00022799999999999999, 'epoch': 0.12}


  4%|▍         | 390/9849 [07:26<2:53:13,  1.10s/it]

{'loss': 0.2233, 'learning_rate': 0.000234, 'epoch': 0.12}


  4%|▍         | 400/9849 [07:37<2:52:29,  1.10s/it]

{'loss': 0.2342, 'learning_rate': 0.00023999999999999998, 'epoch': 0.12}


  4%|▍         | 410/9849 [07:48<2:50:16,  1.08s/it]

{'loss': 0.2424, 'learning_rate': 0.00024599999999999996, 'epoch': 0.12}


  4%|▍         | 420/9849 [07:59<2:52:01,  1.09s/it]

{'loss': 0.2633, 'learning_rate': 0.00025199999999999995, 'epoch': 0.13}


  4%|▍         | 430/9849 [08:10<2:53:04,  1.10s/it]

{'loss': 0.1598, 'learning_rate': 0.000258, 'epoch': 0.13}


  4%|▍         | 440/9849 [08:22<2:53:20,  1.11s/it]

{'loss': 0.1801, 'learning_rate': 0.00026399999999999997, 'epoch': 0.13}


  5%|▍         | 450/9849 [08:32<2:51:39,  1.10s/it]

{'loss': 0.1594, 'learning_rate': 0.00027, 'epoch': 0.14}


  5%|▍         | 460/9849 [08:44<2:52:39,  1.10s/it]

{'loss': 0.2327, 'learning_rate': 0.000276, 'epoch': 0.14}


  5%|▍         | 470/9849 [08:55<2:54:06,  1.11s/it]

{'loss': 0.1773, 'learning_rate': 0.00028199999999999997, 'epoch': 0.14}


  5%|▍         | 480/9849 [09:06<2:50:35,  1.09s/it]

{'loss': 0.3558, 'learning_rate': 0.00028799999999999995, 'epoch': 0.15}


  5%|▍         | 490/9849 [09:17<2:52:11,  1.10s/it]

{'loss': 0.22, 'learning_rate': 0.000294, 'epoch': 0.15}


  5%|▌         | 500/9849 [09:28<2:49:34,  1.09s/it]

{'loss': 0.1748, 'learning_rate': 0.0003, 'epoch': 0.15}


  5%|▌         | 510/9849 [09:40<2:50:50,  1.10s/it]

{'loss': 0.2819, 'learning_rate': 0.00029967911006524757, 'epoch': 0.16}


  5%|▌         | 520/9849 [09:51<2:50:58,  1.10s/it]

{'loss': 0.2096, 'learning_rate': 0.0002993582201304952, 'epoch': 0.16}


  5%|▌         | 530/9849 [10:02<2:44:14,  1.06s/it]

{'loss': 0.2804, 'learning_rate': 0.00029903733019574286, 'epoch': 0.16}


  5%|▌         | 540/9849 [10:13<2:48:25,  1.09s/it]

{'loss': 0.1887, 'learning_rate': 0.00029871644026099046, 'epoch': 0.16}


  6%|▌         | 550/9849 [10:24<2:50:32,  1.10s/it]

{'loss': 0.1975, 'learning_rate': 0.00029839555032623805, 'epoch': 0.17}


  6%|▌         | 560/9849 [10:35<2:48:17,  1.09s/it]

{'loss': 0.2207, 'learning_rate': 0.0002980746603914857, 'epoch': 0.17}


  6%|▌         | 570/9849 [10:46<2:47:11,  1.08s/it]

{'loss': 0.2298, 'learning_rate': 0.00029775377045673334, 'epoch': 0.17}


  6%|▌         | 580/9849 [10:57<2:46:40,  1.08s/it]

{'loss': 0.2341, 'learning_rate': 0.00029743288052198094, 'epoch': 0.18}


  6%|▌         | 590/9849 [11:08<2:46:23,  1.08s/it]

{'loss': 0.1785, 'learning_rate': 0.00029711199058722853, 'epoch': 0.18}


  6%|▌         | 600/9849 [11:19<2:46:48,  1.08s/it]

{'loss': 0.2593, 'learning_rate': 0.0002967911006524762, 'epoch': 0.18}


  6%|▌         | 610/9849 [11:30<2:47:35,  1.09s/it]

{'loss': 0.142, 'learning_rate': 0.00029647021071772377, 'epoch': 0.19}


  6%|▋         | 620/9849 [11:41<2:47:50,  1.09s/it]

{'loss': 0.2376, 'learning_rate': 0.0002961493207829714, 'epoch': 0.19}


  6%|▋         | 630/9849 [11:52<2:57:47,  1.16s/it]

{'loss': 0.2733, 'learning_rate': 0.000295828430848219, 'epoch': 0.19}


  6%|▋         | 640/9849 [12:04<2:48:13,  1.10s/it]

{'loss': 0.224, 'learning_rate': 0.00029550754091346666, 'epoch': 0.19}


  7%|▋         | 650/9849 [12:15<2:50:16,  1.11s/it]

{'loss': 0.2246, 'learning_rate': 0.00029518665097871426, 'epoch': 0.2}


  7%|▋         | 660/9849 [12:26<2:48:58,  1.10s/it]

{'loss': 0.2561, 'learning_rate': 0.0002948657610439619, 'epoch': 0.2}


  7%|▋         | 670/9849 [12:37<2:48:16,  1.10s/it]

{'loss': 0.1948, 'learning_rate': 0.0002945448711092095, 'epoch': 0.2}


  7%|▋         | 680/9849 [12:48<2:41:04,  1.05s/it]

{'loss': 0.3193, 'learning_rate': 0.00029422398117445714, 'epoch': 0.21}


  7%|▋         | 690/9849 [12:59<2:45:58,  1.09s/it]

{'loss': 0.2225, 'learning_rate': 0.00029390309123970474, 'epoch': 0.21}


  7%|▋         | 700/9849 [13:10<2:47:03,  1.10s/it]

{'loss': 0.2354, 'learning_rate': 0.0002935822013049524, 'epoch': 0.21}


  7%|▋         | 710/9849 [13:21<2:48:14,  1.10s/it]

{'loss': 0.2372, 'learning_rate': 0.0002932613113702, 'epoch': 0.22}


  7%|▋         | 720/9849 [13:32<2:48:29,  1.11s/it]

{'loss': 0.2487, 'learning_rate': 0.0002929404214354476, 'epoch': 0.22}


  7%|▋         | 730/9849 [13:43<2:45:52,  1.09s/it]

{'loss': 0.2006, 'learning_rate': 0.0002926195315006952, 'epoch': 0.22}


  8%|▊         | 740/9849 [13:54<2:43:46,  1.08s/it]

{'loss': 0.2842, 'learning_rate': 0.00029229864156594287, 'epoch': 0.23}


  8%|▊         | 750/9849 [14:05<2:46:34,  1.10s/it]

{'loss': 0.2825, 'learning_rate': 0.00029197775163119046, 'epoch': 0.23}


  8%|▊         | 760/9849 [14:16<2:48:17,  1.11s/it]

{'loss': 0.188, 'learning_rate': 0.0002916568616964381, 'epoch': 0.23}


  8%|▊         | 770/9849 [14:27<2:47:44,  1.11s/it]

{'loss': 0.1891, 'learning_rate': 0.0002913359717616857, 'epoch': 0.23}


  8%|▊         | 780/9849 [14:38<2:46:45,  1.10s/it]

{'loss': 0.2091, 'learning_rate': 0.00029101508182693335, 'epoch': 0.24}


  8%|▊         | 790/9849 [14:49<2:53:51,  1.15s/it]

{'loss': 0.2544, 'learning_rate': 0.000290694191892181, 'epoch': 0.24}


  8%|▊         | 800/9849 [15:01<2:53:23,  1.15s/it]

{'loss': 0.2332, 'learning_rate': 0.0002903733019574286, 'epoch': 0.24}


  8%|▊         | 810/9849 [15:12<2:48:15,  1.12s/it]

{'loss': 0.3364, 'learning_rate': 0.0002900524120226762, 'epoch': 0.25}


  8%|▊         | 820/9849 [15:24<2:55:33,  1.17s/it]

{'loss': 0.2475, 'learning_rate': 0.00028973152208792383, 'epoch': 0.25}


  8%|▊         | 830/9849 [15:35<2:49:06,  1.13s/it]

{'loss': 0.1953, 'learning_rate': 0.0002894106321531714, 'epoch': 0.25}


  9%|▊         | 840/9849 [15:47<2:51:51,  1.14s/it]

{'loss': 0.2189, 'learning_rate': 0.0002890897422184191, 'epoch': 0.26}


  9%|▊         | 850/9849 [15:58<2:50:56,  1.14s/it]

{'loss': 0.1885, 'learning_rate': 0.00028876885228366667, 'epoch': 0.26}


  9%|▊         | 860/9849 [16:10<2:53:26,  1.16s/it]

{'loss': 0.221, 'learning_rate': 0.00028844796234891426, 'epoch': 0.26}


  9%|▉         | 870/9849 [16:21<2:56:01,  1.18s/it]

{'loss': 0.203, 'learning_rate': 0.0002881270724141619, 'epoch': 0.27}


  9%|▉         | 880/9849 [16:33<2:52:51,  1.16s/it]

{'loss': 0.1868, 'learning_rate': 0.00028780618247940956, 'epoch': 0.27}


  9%|▉         | 885/9849 [16:38<2:46:58,  1.12s/it]

KeyboardInterrupt: 

In [None]:
test_dataset = DiplomacyDataset(test_inputs, test_targets, tokenizer)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True)

model.eval()
input_ids = test_dataset[10]['input_ids'].unsqueeze(0).to(device)
outputs = model.generate(input_ids)
print(outputs)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])



