In [None]:
#https://towardsdatascience.com/named-entity-recognition-with-bert-in-pytorch-a454405e0b6a

In [None]:
pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.23.1-py3-none-any.whl (5.3 MB)
[K     |████████████████████████████████| 5.3 MB 4.7 MB/s 
Collecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.10.1-py3-none-any.whl (163 kB)
[K     |████████████████████████████████| 163 kB 67.5 MB/s 
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 43.5 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.10.1 tokenizers-0.13.1 transformers-4.23.1


In [None]:
###############################
##### importing libraries #####
###############################

import os
import random
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.dataset import Dataset   
torch.backends.cudnn.benchmark=True

import pyarrow.parquet as pq
import pandas as pd
import random
import logging
import os
import csv

from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm, trange
import torch.nn.functional as F

from transformers import BertTokenizerFast, BertForTokenClassification
from torch.optim import SGD

In [None]:
##### Hyperparameters for federated learning #########
num_clients = 10
num_selected = 5
num_rounds = 100
epochs = 5
batch_size = 2
datapath='train-processed-sample.csv'
model_name="gpt2"
tokenizer_name="gpt2"
device= 'cuda'
epochs=5
lr=3e-5
output_dir=""

np.random.seed(112)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [None]:
df = pd.read_csv('ner.csv')
df.head()

Unnamed: 0,text,labels
0,Thousands of demonstrators have marched throug...,O O O O O O B-geo O O O O O B-geo O O O O O B-...
1,Iranian officials say they expect to get acces...,B-gpe O O O O O O O O O O O O O O B-tim O O O ...
2,Helicopter gunships Saturday pounded militant ...,O O B-tim O O O O O B-geo O O O O O B-org O O ...
3,They left after a tense hour-long standoff wit...,O O O O O O O O O O O
4,U.N. relief coordinator Jan Egeland said Sunda...,B-geo O O B-per I-per O B-tim O B-geo O B-gpe ...


In [None]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')

label_all_tokens = False #True

def align_label(texts, labels):
    tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)

    word_ids = tokenized_inputs.word_ids()

    previous_word_idx = None
    label_ids = []

    for word_idx in word_ids:

        if word_idx is None:
            label_ids.append(-100)

        elif word_idx != previous_word_idx:
            try:
                label_ids.append(labels_to_ids[labels[word_idx]])
            except:
                label_ids.append(-100)
        else:
            try:
                label_ids.append(labels_to_ids[labels[word_idx]] if label_all_tokens else -100)
            except:
                label_ids.append(-100)
        previous_word_idx = word_idx

    return label_ids

class DataSequence(torch.utils.data.Dataset):

    def __init__(self, data):

        x=[]
        y=[]
        for tempx,tempy in data:
          x.append(tempx)
          y.append(tempy)
          
        lb = [i.split() for i in y]
        txt = x
        self.texts = [tokenizer(str(i),
                               padding='max_length', max_length = 512, truncation=True, return_tensors="pt") for i in txt]
        self.labels = [align_label(i,j) for i,j in zip(txt, lb)]

    def __len__(self):

        return len(self.labels)

    def get_batch_data(self, idx):

        return self.texts[idx]

    def get_batch_labels(self, idx):

        return torch.LongTensor(self.labels[idx])

    def __getitem__(self, idx):

        batch_data = self.get_batch_data(idx)
        batch_labels = self.get_batch_labels(idx)

        return batch_data, batch_labels

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/436k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [None]:
df = df[0:5000]

labels = [i.split() for i in df['labels'].values.tolist()]
unique_labels = set()

for lb in labels:
        [unique_labels.add(i) for i in lb if i not in unique_labels]
labels_to_ids = {k: v for v, k in enumerate(unique_labels)}
ids_to_labels = {v: k for v, k in enumerate(unique_labels)}

df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=42),
                            [int(.8 * len(df)), int(.9 * len(df))])

In [None]:
class client_data():

    def __init__(self, df, idces):

        self.texts = []
        self.labels = []
        for id in idces:
          self.texts.append(df.loc[id].text)
          self.labels.append(df.loc[id].labels)

    def __len__(self):

        return len(self.labels)

    def __getitem__(self, idx):

        return self.texts[idx], self.labels[idx]

In [None]:
# Dividing the training data into num_clients, with each client having equal number of images
traindata_split = torch.utils.data.random_split(list(df_train.index), [int(len(df_train) / num_clients) for _ in range(num_clients)])

train_split=[]
for ts in traindata_split:
  train_split.append(client_data(df_train, list(ts)))


# Creating a pytorch loader for a Deep Learning model
train_loaders = [torch.utils.data.DataLoader(DataSequence(d), batch_size=batch_size, shuffle=True) for d in train_split]


val_loaders = torch.utils.data.DataLoader(DataSequence(client_data(df_val, list(df_val.index))), batch_size=batch_size, shuffle=True)

In [None]:
class BertModel(torch.nn.Module):

    def __init__(self):

        super(BertModel, self).__init__()

        self.bert = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=len(unique_labels))

    def forward(self, input_id, mask, label):

        output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False)

        return output

In [None]:
def client_update(client_model, optimizer, train_dataloader, epoch=5):
    """
    This function updates/trains client model on client data
    """
    if use_cuda:
        client_model = client_model.cuda()

    best_acc = 0
    best_loss = 1000

    for epoch_num in range(epoch):

        total_acc_train = 0
        total_loss_train = 0

        client_model.train()

        for train_data, train_label in train_dataloader:

            train_label = train_label.to(device)
            mask = train_data['attention_mask'].squeeze(1).to(device)
            input_id = train_data['input_ids'].squeeze(1).to(device)

            optimizer.zero_grad()
            loss, logits = client_model(input_id, mask, train_label)

            for i in range(logits.shape[0]):

              logits_clean = logits[i][train_label[i] != -100]
              label_clean = train_label[i][train_label[i] != -100]

              predictions = logits_clean.argmax(dim=1)
              acc = (predictions == label_clean).float().mean()
              total_acc_train += acc
              total_loss_train += loss.item()

            loss.backward()
            optimizer.step()
    return loss.item()

In [None]:
def server_aggregate(global_model, client_models):
    """
    This function has aggregation method 'mean'
    """
    ### This will take simple mean of the weights of models ###
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k].float() for i in range(len(client_models))], 0).mean(0)
    global_model.load_state_dict(global_dict)
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

In [None]:
def test(global_model, val_dataloader):
    """This function test the global model on test data and returns test loss and test accuracy """
    global_model.eval()

    total_acc_val = 0
    total_loss_val = 0

    for val_data, val_label in val_dataloader:

        val_label = val_label.to(device)
        mask = val_data['attention_mask'].squeeze(1).to(device)
        input_id = val_data['input_ids'].squeeze(1).to(device)

        loss, logits = model(input_id, mask, val_label)

        for i in range(logits.shape[0]):

          logits_clean = logits[i][val_label[i] != -100]
          label_clean = val_label[i][val_label[i] != -100]

          predictions = logits_clean.argmax(dim=1)
          acc = (predictions == label_clean).float().mean()
          total_acc_val += acc
          total_loss_val += loss.item()

    val_accuracy = total_acc_val / len(df_val)
    val_loss = total_loss_val / len(df_val)

    return val_loss, val_accuracy

In [None]:
############################################
#### Initializing models and optimizer  ####
############################################

#### global model ##########
gen_model=BertModel()
global_model = gen_model

############## client models ##############
client_models = [ gen_model for _ in range(num_selected)]
for model in client_models:
    model.load_state_dict(global_model.state_dict()) ### initial synchronizing with global model 

############### optimizers ################
opt = [SGD(model.parameters(), lr=lr) for model in client_models]

Downloading:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas

In [None]:
###### List containing info about learning #########
losses_train = []
losses_test = []
acc_train = []
acc_test = []

# Runnining FL

for r in range(num_rounds):
    # select random clients
    client_idx = np.random.permutation(num_clients)[:num_selected]
    loss = 0
    # client update
    for i in tqdm(range(num_selected)):
        loss+=client_update(client_models[i], opt[i], train_loaders[client_idx[i]], epoch=epochs)

    losses_train.append(loss)
    # server aggregate
    server_aggregate(global_model, client_models)
    test_loss, acc = test(global_model, val_loaders)
    losses_test.append(test_loss)
    acc_test.append(acc)
    print('%d-th round' % r)
    print('average train loss %0.3g | test loss %0.3g | test acc: %0.3f' % (loss / num_selected, test_loss, acc))

    print('after round ',r, 'saving global ckpt', 'global_'+str(r)+'.pt')

    torch.save(global_model.state_dict(), 'global_'+str(r)+'.pt')


In [None]:
global_model.load_state_dict(torch.load('global_2.pt'))


<All keys matched successfully>

In [None]:
def align_word_ids(texts):
  
    tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)

    word_ids = tokenized_inputs.word_ids()

    previous_word_idx = None
    label_ids = []

    for word_idx in word_ids:

        if word_idx is None:
            label_ids.append(-100)

        elif word_idx != previous_word_idx:
            try:
                label_ids.append(1)
            except:
                label_ids.append(-100)
        else:
            try:
                label_ids.append(1 if label_all_tokens else -100)
            except:
                label_ids.append(-100)
        previous_word_idx = word_idx

    return label_ids


def evaluate_one_text(model, sentence):

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    if use_cuda:
        model = model.cuda()

    text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors="pt")

    mask = text['attention_mask'].to(device)
    input_id = text['input_ids'].to(device)
    label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device)

    logits = model(input_id, mask, None)
    logits_clean = logits[0][label_ids != -100]

    predictions = logits_clean.argmax(dim=1).tolist()
    prediction_label = [ids_to_labels[i] for i in predictions]
    print(sentence)
    print(prediction_label)
            
evaluate_one_text(global_model, 'Bill Gates is the founder of Microsoft')