In [None]:
pip install fire

In [None]:
pip install transformers

In [None]:
###############################
##### importing libraries #####
###############################
import json 
import os
import random
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.dataset import Dataset   
torch.backends.cudnn.benchmark=True

import pyarrow.parquet as pq
import pandas as pd
import random
import fire
import logging
import os
import csv

from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm, trange
import torch.nn.functional as F
import math
from statistics import mean
import matplotlib.pyplot as plt
import pickle

In [None]:
##### Hyperparameters for federated learning #########
num_clients = 20
num_selected = 5
num_rounds = 100
epochs = 5
batch_size = 16
datapath='/content/drive/MyDrive/Reserach_Notebooks/data/reddit_clean_jokes/'
model_name="gpt2"
tokenizer_name="gpt2"
device= 'cuda'
epochs=5
lr=0.001
max_seq_len=200
warmup_steps=5000
output_dir=""
output_prefix="jokes"
save_model_on_epoch=False
ckpt_path="/content/drive/MyDrive/Reserach_Notebooks/ckpt/"
misc_path="/content/drive/MyDrive/Reserach_Notebooks/misc/"
#np.random.seed(112)
temp=0.5
ctrl_code="<|joke|>"

In [None]:

class DataComposer(torch.utils.data.Dataset):
    
    def __init__(self, control_code,texts, truncate=False, max_length=768):

        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.texts = []
      
        # This uses the same CSV of Sentiment140 that we created in Chapter 5
        
        
        for row in texts:
            self.texts.append(torch.tensor(
                self.tokenizer.encode(f"<|{control_code}|>{row[:max_length]}<|endoftext|>")
            ))
                
        if truncate:
            self.texts = self.texts[:20000]
        self.text_count = len(self.texts)
        
    def __len__(self):
        return self.text_count

    def __getitem__(self, item):
        return self.texts[item]

#Pack Tensors

In [None]:
def pack_tensor(new_tensor, packed_tensor, max_seq_len):
    if packed_tensor is None:
        return new_tensor, True, None
    if new_tensor.size()[1] + packed_tensor.size()[1] > max_seq_len:
        return packed_tensor, False, new_tensor
    else:
        packed_tensor = torch.cat([new_tensor, packed_tensor[:, 1:]], dim=1)
        return packed_tensor, True, None

#Read Data

In [None]:
#Dataset specific tasks
text_csv = pd.read_csv(datapath+'reddit-cleanjokes.csv')

text_raw=text_csv['Joke'].tolist()
text_raw=text_raw[:1620]

In [None]:
# Dividing the training data into num_clients, with each client having equal number of images
def make_client_dataset(text_raw):

  splits= torch.utils.data.random_split(text_raw, [int(len(text_raw) / num_clients) for _ in range(num_clients)])
  client_data=[]
  for s in splits:
    temp=[]
    for t in s:
      temp+=[t]
    client_data+=[temp]
  
  return client_data

traindata_split = make_client_dataset(text_raw)

#Make Victim dataset

In [None]:
# Creating a pytorch loader for a Deep Learning model
train_loaders = [torch.utils.data.DataLoader(DataComposer("<|joke|>",x, truncate=True), batch_size=1, shuffle=True) for x in traindata_split]

Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

In [None]:
def client_update(client_model, optimizer, train_dataloader, epoch=5):
    """
    This function updates/trains client model on client data
    """
    client_model = client_model.to(device)
    client_model.train()

    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=-1
    )

    #accumulating_batch_count = 0
    input_tensor = None
    
    for epoch in range(epochs):

        print(f"Training epoch {epoch}")
        for idx, entry in enumerate(train_dataloader):
            (input_tensor, carry_on, remainder) = pack_tensor(entry, input_tensor, max_seq_len)

            if carry_on and idx != len(train_dataloader) - 1:
                continue

            input_tensor = input_tensor.to(device)
            outputs = client_model(input_tensor, labels=input_tensor)
            loss = outputs[0]
            loss.backward()

            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            client_model.zero_grad()

            #accumulating_batch_count += 1
            input_tensor = None
    return loss.item()

In [None]:
def server_aggregate(global_model, client_models):
    """
    This function has aggregation method 'mean'
    """
    ### This will take simple mean of the weights of models ###
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k].float() for i in range(len(client_models))], 0).mean(0)
    global_model.load_state_dict(global_dict)
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

In [None]:
############################################
#### Initializing models and optimizer  ####
############################################

#### global model ##########
gen_model=GPT2LMHeadModel.from_pretrained(model_name)

### Resume from last ckpt
#gen_model.load_state_dict(torch.load(ckpt_path+'global_50.pt'))

global_model = gen_model

############## client models ##############
client_models = [ gen_model for _ in range(num_selected)]
for model in client_models:
    model.load_state_dict(global_model.state_dict()) ### initial synchronizing with global model 

############### optimizers ################
opt = [AdamW(model.parameters(), lr=lr) for model in client_models]

Downloading:   0%|          | 0.00/548M [00:00<?, ?B/s]



In [None]:
text_raw=text_csv=None

## FL Training

In [None]:
for r in range(num_rounds):
    # select random clients
    client_idx = np.random.permutation(num_clients)[:num_selected]
    # client update
    for i in tqdm(range(num_selected)):
        client_update(client_models[i], opt[i], train_loaders[client_idx[i]], epoch=epochs)
    
    # server aggregate
    server_aggregate(global_model, client_models)

    print('after round ',r+1, 'saving global ckpt', 'global_'+str(r+1)+'.pt')
    
    torch.save(global_model.state_dict(), ckpt_path+'part_'+str(r+1)+'.pt')

In [None]:
#get perplexity of a text sample
def get_ppl(
    m_name,
    tokenizer,
    sample
):
    global_model.load_state_dict(torch.load(ckpt_path+m_name,map_location='cpu'))
    model=global_model
    model.eval()

    with torch.no_grad():
      generated = torch.tensor(tokenizer.encode(sample)).unsqueeze(0)
      outputs = model(generated, labels=generated)
      loss= outputs[0]
      ppl=torch.exp(loss)

      return ppl.item()

In [None]:
## Generating texts with seed

def generate_text(
    m_name,
    tokenizer,
    prompt,
    entry_length=5,
    entry_count=1,
    top_p=0.8,
    temperature=1.,
):
    global_model.load_state_dict(torch.load(ckpt_path+m_name,map_location='cpu'))
    model=global_model
    model.eval()

    generated_num = 0
    generated_list = []

    filter_value = -float("Inf")

    with torch.no_grad():

        for entry_idx in trange(entry_count):

            entry_finished = False

            generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)

            # Using top-p (nucleus sampling): https://github.com/huggingface/transformers/blob/master/examples/run_generation.py
            for i in range(entry_length):
                outputs = model(generated, labels=generated)
                loss, logits = outputs[:2]
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)

                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(
                    F.softmax(sorted_logits, dim=-1), dim=-1
                )

                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                    ..., :-1
                ].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = filter_value

                next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
                generated = torch.cat((generated, next_token), dim=1)

                if next_token in tokenizer.encode("<|endoftext|>"):
                    entry_finished = True

                if entry_finished:

                    generated_num = generated_num + 1

                    output_list = list(generated.squeeze().numpy())
                    output_text = tokenizer.decode(output_list)

                    generated_list.append(output_text)
                    break
            
            if not entry_finished:
                output_list = list(generated.squeeze().numpy())
                output_text = f"{tokenizer.decode(output_list)}<|endoftext|>" 
                generated_list.append(output_text)
    ppxls=[]
    for g in generated_list:
      ppxls.append(get_PPl(m_name, tokenizer, g))
    return generated_list, ppxls

In [None]:
saved_model='part_26.pt'
generated_texts, pp = generate_txt(saved_model, GPT2Tokenizer.from_pretrained('gpt2'),"Gallen was born in", 3, 20)

for i in range(len(generated_texts)):
  print(generated_texts[i])
  print(pp[i])