## Imports

In [None]:
import os
import re
import sys
import typing
import gc
import pandas as pd
import pickle
import logging

sys.path.append(
    os.path.join('.','src')
)

from src.models import Pipeline
from src.nodes import *

In [None]:
logging.basicConfig(filename='logs/logs.log', level=logging.DEBUG)

# Tweets Model

In [None]:
pipeline = Pipeline('CONFIG_MODEL_TWEETS.json', load_model_data = True)

In [None]:
d = {
    'train_set_len' : len(pipeline.train_dataset),
    'train_set_tokens' : pipeline.train_dataset.token_len(),
    'val_set_len' : len(pipeline.val_dataset),
    'val_set_tokens' : pipeline.val_dataset.token_len(),
    'test_set_len' : len(pipeline.test_dataset),
    'test_set_tokens' : pipeline.test_dataset.token_len(),
}

In [None]:
d

In [None]:
pipeline.perplexity()

In [None]:
pipeline.train_model()

In [None]:
pipeline.perplexity()

In [None]:
pipeline = Pipeline('CONFIG_MODEL_TWEETS.json', load_model_data = False)

In [None]:
pipeline.load_model('models/tweets/tweets.pth')

In [None]:
pipeline.generate(start_text = 'all', num_words = 40)

In [None]:
pipeline.generate(start_text = 'what', num_words = 40)

In [None]:
pipeline.generate(start_text = 'i like', num_words = 40)

# Wiki Model

In [None]:
pipeline = Pipeline('CONFIG_MODEL_WIKI.json', load_model_data = True)

In [None]:
pipeline.perplexity(with_recall = True, with_tqdm = True)

In [None]:
pipeline.load_model('models/wiki103/wiki103.pth')

In [None]:
pipeline.perplexity(with_recall = True, with_tqdm = True)

In [None]:
pipeline.train_model()

In [None]:
pipeline.perplexity()

In [None]:
pipeline.generate(start_text = 'all', num_words=40)

In [None]:
pipeline.generate(start_text = 'what', num_words=40)

In [None]:
pipeline.generate(start_text = 'i like')

In [None]:
d = {
    'train_set_len' : len(pipeline.train_dataset),
    'train_set_tokens' : pipeline.train_dataset.token_len(),
    'val_set_len' : len(pipeline.val_dataset),
    'val_set_tokens' : pipeline.val_dataset.token_len(),
    'test_set_len' : len(pipeline.test_dataset),
    'test_set_tokens' : pipeline.test_dataset.token_len(),
}

In [None]:
d

## FedAVG

In [1]:
from src.federated_pipeline import Federated_AVG

import os
import pickle

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

In [2]:
federated = Federated_AVG(
    "CONFIG_MODEL_TWEETS.json",
    "CONFIG_FEDERATED_TWEETS.json",
    testing = True
)

100%|██████████| 5000/5000 [00:00<00:00, 7736.48it/s]
100%|██████████| 5000/5000 [00:00<00:00, 7556.60it/s]
100%|██████████| 1/1 [00:00<00:00, 178.76it/s]
100%|██████████| 50/50 [00:15<00:00,  3.14it/s]


In [3]:
a = np.array([len(n.data) for n in federated.nodes.values()])
a

array([2948, 2954, 2851, 4121, 4065, 3986, 1801, 4089, 1930, 3477, 2594,
       3342, 2995, 2657, 2675, 3628, 2156, 2323, 3309, 3575, 3177, 4176,
       2776, 2854, 3333, 3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000,
       3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000,
       3000, 3000, 3000, 3000, 3000, 3000])

In [4]:
federated.train(20, save_results = False)

100%|██████████| 50/50 [00:00<00:00, 5332.33it/s]
  2%|▏         | 3/193 [00:00<00:06, 28.53it/s]

round 0


100%|██████████| 193/193 [00:04<00:00, 39.17it/s]
100%|██████████| 50/50 [00:00<00:00, 627.30it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 1


100%|██████████| 50/50 [00:40<00:00,  1.24it/s]
100%|██████████| 193/193 [00:04<00:00, 42.24it/s]
100%|██████████| 50/50 [00:00<00:00, 646.03it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 2


100%|██████████| 50/50 [00:40<00:00,  1.24it/s]
100%|██████████| 193/193 [00:04<00:00, 42.14it/s]
100%|██████████| 50/50 [00:00<00:00, 643.73it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 3


100%|██████████| 50/50 [00:40<00:00,  1.24it/s]
100%|██████████| 193/193 [00:04<00:00, 41.73it/s]
100%|██████████| 50/50 [00:00<00:00, 634.90it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 4


100%|██████████| 50/50 [00:40<00:00,  1.24it/s]
100%|██████████| 193/193 [00:04<00:00, 41.83it/s]
100%|██████████| 50/50 [00:00<00:00, 647.14it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 5


100%|██████████| 50/50 [00:40<00:00,  1.23it/s]
100%|██████████| 193/193 [00:04<00:00, 42.09it/s]
100%|██████████| 50/50 [00:00<00:00, 651.47it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 6


100%|██████████| 50/50 [00:40<00:00,  1.24it/s]
100%|██████████| 193/193 [00:04<00:00, 41.92it/s]
100%|██████████| 50/50 [00:00<00:00, 636.03it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 7


100%|██████████| 50/50 [00:43<00:00,  1.15it/s]
100%|██████████| 193/193 [00:04<00:00, 38.77it/s]
100%|██████████| 50/50 [00:00<00:00, 644.41it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 8


100%|██████████| 50/50 [00:42<00:00,  1.17it/s]
100%|██████████| 193/193 [00:04<00:00, 42.67it/s]
100%|██████████| 50/50 [00:00<00:00, 623.81it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 9


100%|██████████| 50/50 [00:41<00:00,  1.22it/s]
100%|██████████| 193/193 [00:04<00:00, 42.74it/s]
100%|██████████| 50/50 [00:00<00:00, 680.90it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 10


100%|██████████| 50/50 [00:42<00:00,  1.19it/s]
100%|██████████| 193/193 [00:04<00:00, 40.57it/s]
100%|██████████| 50/50 [00:00<00:00, 645.59it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 11


100%|██████████| 50/50 [00:42<00:00,  1.19it/s]
100%|██████████| 193/193 [00:04<00:00, 42.15it/s]
100%|██████████| 50/50 [00:00<00:00, 656.17it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 12


100%|██████████| 50/50 [00:42<00:00,  1.19it/s]
100%|██████████| 193/193 [00:04<00:00, 40.83it/s]
100%|██████████| 50/50 [00:00<00:00, 652.61it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 13


100%|██████████| 50/50 [00:42<00:00,  1.18it/s]
100%|██████████| 193/193 [00:04<00:00, 42.58it/s]
100%|██████████| 50/50 [00:00<00:00, 592.10it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 14


100%|██████████| 50/50 [00:40<00:00,  1.24it/s]
100%|██████████| 193/193 [00:04<00:00, 42.70it/s]
100%|██████████| 50/50 [00:00<00:00, 679.29it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 15


100%|██████████| 50/50 [00:40<00:00,  1.24it/s]
100%|██████████| 193/193 [00:04<00:00, 42.61it/s]
100%|██████████| 50/50 [00:00<00:00, 675.77it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 16


100%|██████████| 50/50 [00:40<00:00,  1.23it/s]
100%|██████████| 193/193 [00:04<00:00, 42.28it/s]
100%|██████████| 50/50 [00:00<00:00, 646.26it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 17


100%|██████████| 50/50 [00:40<00:00,  1.24it/s]
100%|██████████| 193/193 [00:04<00:00, 42.26it/s]
100%|██████████| 50/50 [00:00<00:00, 634.70it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 18


100%|██████████| 50/50 [00:40<00:00,  1.24it/s]
100%|██████████| 193/193 [00:04<00:00, 41.98it/s]
100%|██████████| 50/50 [00:00<00:00, 649.72it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 19


100%|██████████| 50/50 [00:40<00:00,  1.24it/s]
100%|██████████| 193/193 [00:04<00:00, 40.00it/s]
100%|██████████| 50/50 [00:00<00:00, 576.69it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

round 20


100%|██████████| 50/50 [00:40<00:00,  1.24it/s]
100%|██████████| 193/193 [00:04<00:00, 42.31it/s]
100%|██████████| 50/50 [00:00<00:00, 652.68it/s]


In [5]:
pd.DataFrame(federated.results)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,11,12,13,14,15,16,17,18,19,20
perplexity,82.280398,189.175464,312.735829,313.627502,330.508332,378.797200,395.273117,377.667653,377.596313,511.838256,...,540.186345,7488.576066,8820.806284,490.042779,540.147629,16000.969242,505.977410,807.295065,1656.615388,424.982982
loss,4.406816,5.239529,5.742709,5.745860,5.798581,5.936045,5.978216,5.931737,5.932259,6.235802,...,6.291490,8.916497,9.082331,6.193844,6.290457,9.676824,6.225039,6.691768,7.411625,6.050663
f1_recall,0.295508,0.225968,0.190668,0.190580,0.186166,0.173482,0.171231,0.176204,0.179279,0.156943,...,0.155825,0.059152,0.018525,0.128427,0.136299,0.039346,0.134592,0.132076,0.072836,0.147261
f3_recall,0.402290,0.332956,0.298657,0.293904,0.292815,0.277483,0.269066,0.277571,0.282412,0.253572,...,0.248422,0.111432,0.043628,0.214138,0.222539,0.086785,0.220509,0.216624,0.150027,0.238313
attack_perplexity,598.074532,444.141736,1208.317450,863.370627,574.217410,487.995960,970.634404,667.914356,572.748604,936.209325,...,1166.273667,1.000026,1.000000,9.211939,59.540461,1.000003,1.714386,6.365241,1.000007,1.291710
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
attack_perplexity_25,598.074532,444.141736,1208.317450,863.370627,574.217410,487.995960,970.634404,667.914356,572.748604,936.209325,...,1166.273667,1.000026,1.000000,9.211939,59.540461,1.000003,1.714386,6.365241,1.000007,1.291710
train_len_25,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,...,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000
train_tokens_25,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,...,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000,3333.000000
len_25,711.000000,711.000000,711.000000,711.000000,711.000000,711.000000,711.000000,711.000000,711.000000,711.000000,...,711.000000,711.000000,711.000000,711.000000,711.000000,711.000000,711.000000,711.000000,711.000000,711.000000


In [6]:
federated.general_model.generate(federated.vocabulary, 'all')

'all have to go stories on his trumps deserves the pin is not having an silent at the one of the other dc as this thread is spread there is the role and which can he had many of ppl to rump that the fact pm leaking health room is one of trumps anything he was very living in the rally this is short of his things that the leaving divide where it almost not the debate in major weeks million the money and foundation had been endorsed in their camera campaign greater source jim tape in unless the nyt form'

In [None]:
plt.figure()
for k,v in pd.DataFrame(res[0]).T.to_dict().items():
    if 'perplexity_' in k:
        plt.plot(list(v.values()))

In [None]:
plt.figure()
for k,v in pd.DataFrame(res[0]).T.to_dict().items():
    if 'f1_recall_' in k:
        plt.plot(list(v.values()))

In [None]:
plt.figure()
for k,v in pd.DataFrame(res[0]).T.to_dict().items():
    if 'f3_recall_' in k:
        plt.plot(list(v.values()))

In [None]:
from src.federated_pipeline import Federated_LICCHAVI

import os
import pickle
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
federated = Federated_LICCHAVI(
    "CONFIG_MODEL_TWEETS.json",
    "CONFIG_FEDERATED_TWEETS.json",
    testing = True
)

In [None]:
federated.train(5)

In [None]:
dir_ = os.path.join('tweets', 'LICCHAVI')
res = {}
for i,name in enumerate(os.listdir(dir_)):
    with open(os.path.join(dir_, name, 'metrics.pickle'), 'rb') as f:
        res[i] = pickle.load(f)

In [None]:
f, ax = plt.subplots(1)
ax.set_ylim([0, 300])
for k,v in pd.DataFrame(res[0]).T.to_dict().items():
    if 'perplexity' == k:
        plt.plot(list(v.values()))
        

In [None]:
from src.federated_pipeline import Federated_LICCHAVI_avg

import os
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import torch

In [None]:
torch.autograd.set_detect_anomaly(True)

In [None]:
federated = Federated_LICCHAVI_avg(
    "CONFIG_MODEL_TWEETS.json",
    "CONFIG_FEDERATED_TWEETS.json",
    testing = False
)

In [None]:
federated.train(5)

In [None]:
dir_ = os.path.join('tweets', 'LICCHAVI_avg')
res = {}
for i,name in enumerate(os.listdir(dir_)):
    with open(os.path.join(dir_, name, 'metrics.pickle'), 'rb') as f:
        res[i] = pickle.load(f)

In [None]:
res.keys()

In [None]:
f, ax = plt.subplots(1)
ax.set_ylim([0, 80])
for k,v in pd.DataFrame(res[0]).T.to_dict().items():
    if 'perplexity' == k:
        plt.plot(list(v.values()))

In [None]:
f, ax = plt.subplots(1)
ax.set_ylim([0, 1])
for k,v in pd.DataFrame(res[0]).T.to_dict().items():
    if 'f3_recall' == k:
        plt.plot(list(v.values()))

In [None]:
f, ax = plt.subplots(1)
ax.set_ylim([0, 1])
for k,v in pd.DataFrame(res[0]).T.to_dict().items():
    if 'f3' in k:
        plt.plot(list(v.values()))

In [None]:
federated.nodes[7].losses