## Get Modules

### Install Modules

In [None]:
!pip install transformers
!pip install datasets
!pip install arabert
!pip install sentencepiece
!pip install  arabicnlp

Collecting transformers
  Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m72.6 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.16.3-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m30.1 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m110.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m86.2 MB/s[0m eta [36m0:00:

## Import the used Modules

In [None]:
import torch
from torch import nn


## Preprocessing

In [None]:
import re
import string

def _remove_diacritics(text):
  """
  this method for remove diacritics

  Parameters:
    text: string text as input

  return:
    text without diacritics
  """
    text = re.sub(r'\s*[A-Za-z]+\b', '' , text)
    return  re.sub(r"[ًًٌٍَُِّْ]", "", text)

def _remove_extra_spaces(text):
    """
    this method for remove extra spaces

    Parameters:
      text: string text as input

    return:
      text without extra spaces
    """
    return re.sub(" +", " ", text)

def _add_spaces_to_all_special_chars(text):
    """
    this method for add white space after special chars like

    Parameters:
      text: string text as input

    return:
      text without diacritics
    """
    return re.sub(r"(?<=\w)([؟.,،])", r" \1", text)

def _remove_repeated_chars(text):
    """
    this method for remove repeatead chars in the word

    Parameters:
      text: string text as input

    return:
      text withoout repeatead chars
    """
    return re.sub(r"(.)\1+", r"\1\1", text)

def _remove_qoutes(text):
    """
    this method for remove quotes

    Parameters:
      text: string text as input

    return:
      text without quotes
    """
    #text = re.sub(r'"(.*?)"',"",text) # remove the text between two double qoutes
    #text = re.sub("https?:\/\/.*[\r\n]*", "", text) # remove the urls
    text = re.sub(r"[0-9]"," ",text)
    return re.sub(r"\[(.*?)\]", " " , text) # remove the text between two brackets

def _remove_puncs(text):
    """
    this method for preprocessing puncs

    Parameters:
      text: string text as input

    return:
      text after preprocessing puncs
    """
    puncs = string.punctuation.replace(",","").replace("،","").replace("?","؟")
    puncs = puncs.replace(")","").replace("(","")
    text = text.translate(str.maketrans(' ', ' ', puncs))
    text = text.replace("•","")
    text = text.replace(")"," ").replace("("," ").replace(":","")
    return text

def Sequential(fns):
    """
    this method for sequential functions

    Parameters:
      fns: list of functions

    return:
      new function in new sequential function
    """
    def new_fn(inputs):
      for fn in fns:
        inputs = fn(inputs)
      return inputs
    return new_fn

# some of stop words we will remove from dataset

WS = ['نشاط إثرائي اقرأ ثم استنتج','عرف ذلك من النص التالي','السخلة : هي ولد الفتم من الضأن والماعز ساعة وضعه',
      'مصر من الفتح الإسلامي حتى قيام الدول المستقلة',
      'مصر في عصر الولاة','اقرأ','لتتعرف','التالي','ذلك','النص',
      'لاحظ الخريطة التالية لتتعرفها','فكر وناقش']

WS2 = ['مرحلة التاسيس و الاستقرار',
       'كل من',
       'اقرأ الشكل التالي لتتعرف أشهرهم بشيء من التفصيل',
       'بشيء من التفصيل',
       'حمورابي بالأكدية تلفظ امورابي وتعني المعتلي',
       'والآن اقرأ الشكل التالي لتتعرف أدوار الجهاد ضد الصليبيين',
       '(بدر الجمالي)','هل تعرف لمـاذا','لاحظ الخريطة الزمنية التالية لتتعرفها',
       'ومن خلال','الخريطة الزمنية السابقة'
       ,'نستنتج أن',
       'رضي الله عنه',
       'نعم يا أحبابنا','ق . م',
       'التواصل الثقافي والفني مع أفريقيا',
       'الصـور المقابلة',
       'مـن خـلال',
       'الدرس الثاني',
       'الدرس الثالث',
       'الدرس الرابع',
       'الدرس الخامس',
       'الدرس',
       'ولنبدأ',
       'والان تعال',
       'والآن تعال معا',
       'استثمر المصرى',
       'من خلال ما يلي','والآن تعال معنا',"أولا","رابعا","خامسا","سادسا","سابعا","والآن لاحظ الشكل التالي"]

WORDS=['هل','اقرأ وناقش','معلومة إثرائية','علام يدل','للصف الثاني الثانوي','أ-','ب-','ج-','وهذا ما سوف نتعرفه في الدرس القادم','عزيزي الطالب / عزيزتي الطالبة',
         'اقرأ النص التالي','ولعلك تتساءل عزيزي الطالب','يمكنك الإجابة عن هذا التساؤل بعد قراءة النص التالي','وذلك ما سوف نتعرفه في الدرس القادم',
  'هل تعرف أسباب فشلها  اقرأ الشكل التالي لتتعرف أهمها','تعرف ذلك من النص التالي',"لاحظ الخريطة التالية",
          "وذلك ما سوف نتعرفه في الدرس القادم","والآن تعال معا",
       'الفتوحات الأموية في الغرب الإسلامي',"ثالثا","ثانيا",
       "اولا","اقرأ وناقش،المعاني والقيم الواردة",
       "الفتوحات الإسلامية في عصر الدولة الأموية",
       'لاحظ الخريطة التالية لتتعرف عليها',
       "بعد ان","وبعد ان القينا الضوء علي",
       'وهذا ما سوف نتعرفه',
       'الوحدة التالية',
       'رواه مسلم','العامل التاريخي',
       'لتتعرفها','النص التالي',"أ.","ب.","ج.","د.","والآن لاحـظ الشك","ثانيا:",".. الرسم والنقش:","سادسا:"]


WORDS += WS
WORDS += WS2
WORDS = list(set(WORDS)) # set of all words we collected
# generate the sequential function one after another
text_processor = Sequential([ _remove_diacritics,
                             _remove_qoutes,
                              _remove_puncs,
                              _remove_extra_spaces,
                             _remove_repeated_chars ])

## Modeling

### Pretrained model

In [None]:
def freeze(model):
  """
    this method to freeze the model
    Parameters:
      fns: model
    return:
      new freezed model
    """
    for p in model.parameters():
        p.requires_grad = False
    return model

In [None]:
# Load The Pretrained Model
from transformers import BertTokenizer,AutoModelForSeq2SeqLM, pipeline,AutoModel,AutoTokenizer,BartForConditionalGeneration


model_mini= "asafaya/bert-mini-arabic"
model_med = "asafaya/bert-medium-arabic"
model_bart = "moussaKam/AraBART"
model_bart_2 = "abdalrahmanshahrour/auto-arabic-summarization"
model_name="malmarjeh/mbert2mbert-arabic-text-summarization"

model = AutoModelForSeq2SeqLM.from_pretrained(model_bart_2)
tokenizer = AutoTokenizer.from_pretrained(model_bart_2)

#Number of  model parameters
print(f"num of parameters is :{sum([p.numel() for p in model.parameters()])}")

num of parameters is :139221504


In [None]:
# Load the Labeled Dataset
import pandas as pd
vals = pd.read_json('/content/drive/MyDrive/data/labeled_validation_dataset.jsonl',lines=True)
vals['document'] = vals['paragraph']
del vals['paragraph']
vals.head()

In [None]:
get_len = lambda x: len(tokenizer.tokenize(x)) # function to get the tokenization length
test_data = pd.read_json('/content/drive/MyDrive/data/validation_data.jsonl',lines=True)
test_data['document'] = test_data['paragraph']
del test_data['paragraph']
test_data['document_len'] = test_data['document'].apply(get_len)
test_data['document_len'].max()

778

In [None]:
vals['document_len'] = vals['document'].apply(get_len)
vals['summary_len'] = vals['summary'].apply(get_len)
vals['document_len'].min()

274

In [None]:
test_data['summary'] = pred['summary']
test_data['summary_len'] = test_data['summary'].apply(get_len)
test_data.head()

In [None]:
#Load the Filtered Dataset
data = pd.read_csv('/content/drive/MyDrive/data/first_9.csv')
data2 = pd.read_csv("/content/drive/MyDrive/data/all_data_25_45_words.csv")
del data['address']
del data2['Unnamed: 0']
data = pd.concat([data,data2]).drop_duplicates().reset_index(drop=True)
data['document'] = data['article']
del data['article']
len(data)

209933

In [None]:
# Get the summary length and document length into new columns
data['sum_len'] = data['summary'].apply(lambda x:len(x.split()))
data['para_len']= data['document'].apply(lambda x:len(x.split()))

In [None]:
# The length Ratio of summary length and document length
data['calc'] = data['sum_len']/data['para_len']
data = data[(data['calc']<0.40)& (data['calc']>0.3)].reset_index(drop=True).drop_duplicates()
len(data)

63604

In [None]:
# The length Ratio of summary length and document length for the validation

vals['sum_len'] = vals['summary'].apply(lambda x:len(x.split()))
vals['para_len']= vals['document'].apply(lambda x:len(x.split()))
test_data['sum_len'] = test_data['summary'].apply(lambda x:len(x.split()))
test_data['para_len']= test_data['document'].apply(lambda x:len(x.split()))

vals['calc'] = vals['sum_len']/vals['para_len']
test_data['calc'] = test_data['sum_len']/test_data['para_len']

In [None]:
# Combining The  dataset into one big Dataset
test_data = test_data[['document','summary']]
vals = vals[['document','summary']]
data = data[['document','summary']]
data = pd.concat([data,vals]).reset_index(drop=True).drop_duplicates()


63717

In [None]:

from torch.utils.data import Dataset,DataLoader

# crearting the Summarization Dataset for DataLoading
class SummarizationDataset(Dataset):
  """
  Constructor

     data : Dataframe to get the data from
  """
  def __init__(self,
               data,
               with_summary=True):
      self.data = data
      self.with_summary = with_summary
  """
  Dender method to get the data length
  """
  def __len__(self):
      return len(self.data)

  def __getitem__(self,idx):
    """
    Dender method to iterate over the dataframe
    Parameters:
      idx: int index for data
    return:
      the data point (document, summary)
    """
     document = self.data.iloc[idx]['document']
     document = text_processor(document)
     document = _remove_extra_spaces(document.replace("\n"," "))
     if self.with_summary:
        summary = self.data.iloc[idx]['summary']
        return document,summary
     return document

# To Create the Batches
def collate_fn(batch):
    """
    This method  Create the Batches
    Parameters:
      batch : list of data points

    return:
      transformed batch
    """
    inputs,labels = [],[]
    with_summary =  len(batch[0])>1
    for item in batch:
       inputs.append(item[0])
       if with_summary:
        labels.append(item[1])

    inputs = tokenizer(inputs,
                       return_tensors='pt',
                       padding='longest',
                       truncation=True,
                       max_length=500)
    if with_summary:
      labels =  tokenizer(labels,
                       return_tensors='pt',
                       padding='longest',
                       truncation=True,
                       max_length=250)
      dec_in = {k:v[:,:-1] for k,v in labels.items()}
      dec_out = {k:v[:, 1:] for k,v in labels.items()}
      dec_in = {'decoder_input_ids':dec_in['input_ids'],
                "decoder_attention_mask":dec_in['attention_mask']}
      inputs.update(dec_in)
      return inputs,dec_out
    return inputs


In [None]:
# Generate training and validation the dataset instance
train_ds = SummarizationDataset(data)
val_ds = SummarizationDataset(test_data)


# Generate DataLoader or Data iterator
train_iter = DataLoader(train_ds,batch_size=18,
                        shuffle=True,
                        num_workers=2,
                        pin_memory=True,
                        collate_fn=collate_fn)

# Generate DataLoader or Data iterator
val_iter = DataLoader(val_ds,
                      batch_size=4,
                      num_workers=2,
                      pin_memory=True,
                      collate_fn=collate_fn)


In [None]:
# model configration
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = tokenizer.vocab_size

In [None]:
loss_obj = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)   # loss objective
def accuracy(preds,labels,attn_mask):
  """
  accuracy function that simulate rouge-1 metric

  Parameters:
    preds: prediction
    labels:  Ground Truth
    attn_mask: attention mask

  return:
    accuracy
  """
    dim = preds.shape[-1]
    pred = preds.view(-1,dim)
    label = labels.reshape(-1)
    attn= attn_mask.reshape(-1)
    pred = pred.argmax(dim=-1)
    acc = (pred==label)*attn
    return acc.sum()/attn.sum()

def loss_fn(logits,targets):
    """
    loss function

    Parameters:
      logits: prediction
      targets:  Ground Truth

    return:
      loss
    """
    return loss_obj(logits.reshape(-1,logits.size(-1))
                     ,targets.reshape(-1))

In [None]:
from torch.cuda.amp import GradScaler # for mixed precision


mean = lambda x:sum(x)/len(x)
DEVICE = torch.device("cuda")
scaler = GradScaler()


def train_epoch(model,train_iter,opt,times=5):
    """
    Training function for one loop across the training dataset

    Parameters:
      model : PretrainedModel
      train_iter: train_iterator
      opt: Optimizer
      times: number of logs printing througout the training process

    return:
      mean_loss, mean_acc
    """
    losses,accs = [],[]
    model.train()
    for i,(enc_in,labels) in enumerate(train_iter):
      enc_in = {k:v.to(DEVICE) for k,v in enc_in.items()}
      labels = {k:v.to(DEVICE) for k,v in labels.items()}
      opt.zero_grad()
      with torch.autocast(device_type='cuda', dtype=torch.float16):
        out = model(**enc_in).logits
        loss = loss_fn(out,labels['input_ids'])
      scaler.scale(loss).backward()
      scaler.step(opt)
      scaler.update()
      loss = loss.item()
      acc = accuracy(out,
                     labels['input_ids'],
                     labels['attention_mask']).item()
      del out, enc_in, labels
      torch.cuda.empty_cache()
      if (i+1)%(len(train_iter)//times)==0:
        print(f"Finished Training on {(i+1)*100/len(train_iter):.2f} % of the data, loss:{loss:.3f}, acc:{acc:.3f}.")
      losses.append(loss)
      accs.append(acc)
    return mean(losses),mean(accs)



@torch.no_grad()
def val_epoch(model,val_iter):
   """
    Validation function for one loop across the validation dataset

    Parameters:
      model : PretrainedModel
      val_iter: train_iterator
    return:
      mean_loss, mean_acc
    """
    losses,accs = [],[]
    model.eval()
    for enc_in,labels in val_iter:
        enc_in = {k:v.to(DEVICE) for k,v in enc_in.items()}
        labels = {k:v.to(DEVICE) for k,v in labels.items()}
        with torch.autocast(device_type='cuda', dtype=torch.float16):
             out = model(**enc_in).logits
             loss = loss_fn(out,labels['input_ids'])
        loss = loss.item()
        acc = accuracy(out,
                      labels['input_ids'],
                      labels['attention_mask']).item()
        losses.append(loss)
        accs.append(acc)
        del out,enc_in,labels
        torch.cuda.empty_cache()
    return mean(losses),mean(accs)

In [None]:
import math
class LoraLayer(nn.Module):
  """
  Constructor
   Parameters:
    orig_layer: The layer to apply Lora on
    rank: A and B matrix rank
    drop: dropout rate
    alpha: scaling factor

  """
  def __init__(self,
               orig_layer,
               rank=32,
               drop=0.1,
               alpha=32.0):
      super().__init__()
      assert rank > 0
      self.scale = alpha/rank
      weight_shape = orig_layer.weight.shape

      self.drop = nn.Dropout(drop) if drop !=0 else nn.Idenity()
      self.lora_A = nn.Parameter(torch.zeros(rank,weight_shape[1])) # (rank, in_shape)
      self.lora_B = nn.Parameter(torch.zeros(weight_shape[0],rank)) # (out_shape, rank)
      nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
      nn.init.zeros_(self.lora_B)
      self.orig_layer = orig_layer
      self.merged = False
  """
  This method to reset the parameters
  """
  def reset_parameters(self):
      self.orig_layer.reset_parameters()
      nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
      nn.init.zeros_(self.lora_B)
  """
  This method merge the parameters
  """
  def merge(self):
      self.orig_layer.weight.data += (self.lora_B @ self.lora_A).transpose(0,1)*self.scale
      self.merged=True

  def forward(self,x):
      """
        This method forward pass
          Parameters: x the inputs
        return output
      """
      out =self.orig_layer(x)
      if not self.merged:
         out2 = self.drop(x) @ self.lora_A.transpose(0,1)
         out2 = (out2  @ self.lora_B.transpose(0,1))*self.scale
         return out2 + out
      return out

# freeze the whole model except for lora parameters
def freeze_lora(model):
  """Freeze the lora layer """
    for name,parameter in model.named_parameters():
        if "lora_" not in name:
            parameter.requires_grad=False
    return model
def merge_lora(ch):
    """
    Merge the lora layer
    """
    if not ch.merged:
       ch.merge()  # merge the weights if its not merged
    return ch.orig_layer


def get_apply(Module,
              instance,
              attr_names,
              map_fn,**map_fn_kwargs):


    """
    The this to apply specific function to specified layers in the model architecture

      Parameter:
       Module: model
       instance: layer instance
       attr_names:  attrs to change
       map_fn: The applying function
       map_fn_kwargs: the mapping function kwargs

     return Model transformed
    """
    for name, ch in Module.named_children():
      if isinstance(ch,LoraLayer):
           continue
      elif isinstance(ch,instance):
        for attr_name in attr_names:
          attr = map_fn(getattr(ch,attr_name),**map_fn_kwargs)
          setattr(ch,attr_name,attr)
      else:
           get_apply(ch,instance,attr_names,map_fn,**map_fn_kwargs)

def convert_lora(module,**kwargs):
  """ This function to convert specific layer to lora layer"""
    return LoraLayer(module,**kwargs)



In [None]:
#Get the trainable and untrainable parameters
def get_params(model):
    trainable = sum([p.numel() for p in model.parameters() if p.requires_grad])
    untrainable = sum([p.numel() for p in model.parameters() if not p.requires_grad])
    return {"trainable":trainable,"untrainable":untrainable,"all":(trainable+untrainable)}
get_params(model)

{'trainable': 139221504, 'untrainable': 0, 'all': 139221504}

In [None]:

model = model.cuda()
opt = torch.optim.AdamW(model.parameters(),lr=5e-5)

In [None]:
# Training and Validation Loop
import time
EPOCHS = 2
best_val = .355

for e in range(EPOCHS):
  print(f"Started Training on epoch:{e+1}/{EPOCHS}")
  st = time.time()
  train_loss, train_acc = train_epoch(model,train_iter,opt)
  val_loss, val_acc = val_epoch(model,val_iter)
  if val_acc>best_val:
    best_val = val_acc
    model.save_pretrained(f'model_{val_acc:.2f}')
  print(f"Finished Trainin in {(time.time()-st)/60:.2f} mins, train_loss:{train_loss:.3f}, train_acc:{train_acc:.3f},val_loss:{val_loss:.3f},val_acc:{val_acc:.3f}\n.")

Started Training on epoch:1/2
Finished Training on 20.00 % of the data, loss:2.850, acc:0.442.
Finished Training on 40.00 % of the data, loss:2.923, acc:0.451.
Finished Training on 60.00 % of the data, loss:3.041, acc:0.432.
Finished Training on 80.00 % of the data, loss:3.221, acc:0.405.
Finished Training on 100.00 % of the data, loss:3.039, acc:0.423.
Finished Trainin in 30.92 mins, train_loss:3.033, train_acc:0.430,val_loss:1.456,val_acc:0.767
.
Started Training on epoch:2/2
Finished Training on 20.00 % of the data, loss:2.912, acc:0.444.
Finished Training on 40.00 % of the data, loss:2.545, acc:0.506.


In [None]:
# install the rouge metrics
!pip install rouge
from rouge import Rouge

Collecting rouge
  Downloading rouge-1.0.1-py3-none-any.whl (13 kB)
Installing collected packages: rouge
Successfully installed rouge-1.0.1


In [None]:
rouge = Rouge()

In [None]:
class SemSim:
    def __init__(self,encoder_model,tokenizer):
        for p in encoder_model.parameters():
            p.requires_grad = False
        self.tok = tokenizer
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.encoder = encoder.to(self.device)

    def get_hiddens(self,
                    text: str):
        """
        text: str
        return hidden states that encodes the meaning
        """
        text = {k:v.to(self.device) for k,v  in self.tok(text,return_tensors='pt').items()}
        hidden = self.encoder(**text).last_hidden_state
        hidden = hidden.mean(dim=1)
        return hidden

    def __call__(self,
                text1: str,
                text2: str=None,
                hid2=None):
        """
        text1: str the first man
        """
        hid1 = self.get_hiddens(text1)
        hid2 =  self.get_hiddens(text2) if text2 is not None else hid2
        return torch.nn.functional.cosine_similarity(hid1,hid2).item()

def search(text,sent_length=12):
    """
    split the long document into sentences
    """
    texts = text.replace(".","،")
    texts = texts.split("،")
    if len(texts)<2:
        texts = texts[0]
        texts = texts.split()
        word_inds = []
        ind,lens = 0, 0
        for i,t in enumerate(texts):
            if (len(t)==4 or len(t)==3 or len(t) == 5):
                if t.startswith("ف") or t.startswith("و"):
                    if t not in ['وبعد',"والي","فتح",'فرنسا',',ولايه'] and lens >sent_length:
                        word_inds.append(ind)
                        lens = 1
            ind = ind + len(t) + 1
            lens += 1
        word_inds= [i+j for i,j in zip(word_inds, range(len(word_inds)))]
        word_inds = [0] + word_inds if word_inds[0] !=0 else word_inds
        word_inds =  word_inds + [len(text)-1] if (len(text)-1)!=word_inds[-1] else word_inds
        return [text[word_inds[i]:word_inds[i+1]] for i in range(len(word_inds)-1)]
    return texts

In [None]:
class DocumentProcessor:
    def __init__(self,
                 encoder_model,
                 tokenizer,):
        self.semsim = SemSim(encoder_model,tokenizer)
        self.preprocessor = Sequential([ _remove_diacritics,
                             _remove_qoutes,
                              _remove_puncs,
                              _remove_extra_spaces,
                             _remove_repeated_chars ])

    def split_int_sents(self,doc,max_sent_len=12):
        """
        split the document the according to the searching algorithm
        """
        return search(doc,max_sent_len)
    def smart_prepare(self,
                      text:str,
                      max_sub_len: int= 128,
                      semantic_thres:float=0.25,
                      return_hid:bool=True,):

        """Smart select the sentences according to the semantic meaning"""
        texts = self.split_int_sents(text)
        length = 0
        count = -1
        chunks = []
        chunk = ''
        split_on="،"
        hid = self.semsim.get_hiddens(text)

        texts = [sent for sent in texts if self.semsim(sent,hid2=hid)>=semantic_thres]
        for sent in texts:
            count += 1
            com_length = len(sent.split()) +length

            if com_length< max_sub_len:
              chunk = split_on.join([chunk,sent])
              length = com_length

              if count == len(texts)-1:
                chunks.append(chunk.strip()[1:])

            else:
              chunks.append(chunk.strip()[1:])
              chunk = ''
              chunk = split_on.join([chunk,sent])
              length = len(sent.split())
        chunks = [c for c in chunks if len(c)>5]
        if return_hid:
          return chunks,hid
        return chunks


    def preprocess_text(self,doc):
        doc = self.preprocessor(doc)
        return _remove_extra_spaces(doc.replace("\n"," "))

    def __call__(self,
                 doc,
                 max_sub_len:int=90,
                 semantic_thres:float=0.25,
                 return_hid=False):
        doc = self.preprocess_text(doc)
        return self.smart_prepare(doc,
                             max_sub_len,
                             semantic_thres,
                             return_hid)



In [None]:
@torch.no_grad()
def generate(model,tok,text,**kwargs):
    """
    Generate the text
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    inputs = tok(text,return_tensors='pt',truncation=True)
    if kwargs is not None:
       num_returns = kwargs.get("num_return_sequences",None)
    inputs = {k:v.to(device) for k,v in inputs.items()}
    gen = model.generate(**inputs,**kwargs)
    gen = gen.cpu().squeeze()
    if num_returns is not None:
      return tok.batch_decode(gen,skip_spectial_tokens=True)
    return tok.decode(gen,skip_special_tokens=True)

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("your model path")
tokenizer = AutoTokenizer.from_pretrained('your tokenizer path')
encoder = copy.deepcopy(model.model.encoder)
doc_prcocessor = DocumentProcessor(encoder,tokenizer)

In [None]:
# Generate a summart
docs = doc_prcocessor(test_data['document'][1])
sums= []
p
for t in docs:
     summary = generate(model,
             tokenizer,
             t,
             max_length=300,
             num_beams=5,
             no_repeat_ngram_size=2,
             repetition_penalty=2.0,
             length_penalty=1.0,
             top_p=0.92,
             top_k=5000)

     sums.append(summary)
summary = ("،"+" ").join([s.replace('.',"") for s in sums])
summary

0.05769230375739672