In [1]:
#pip install torch
#pip install bert-extractive-summarizer

# !pip install transformers
# !pip install datasets
# !pip install -U accelerate
# !pip install -U bertviz
# !pip install -U umap-learn
# !pip install -U sentencepiece

In [2]:
# !git clone https://github.com/vgupta123/sumpubmed.git

In [3]:
import os
from datasets import Dataset
import torch
import pandas as pd
import re
from transformers import pipeline
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments

In [4]:
def load_text_files(directory):
    files = os.listdir(directory)
    files = [file for file in files if not file.startswith('.DS')]
    files = sorted(files, key=lambda x: int(re.split('\.|_', x)[1]))
    print(len(files))
    data = []
    for file in files:
        if file.endswith('.txt'):
            with open(os.path.join(directory, file), 'r', encoding='utf-8',errors='ignore') as f:
                text = f.read()
                data.append(text)
    return data


In [5]:
#print(re.split('\.|_', 'text_1010.txt'))

In [6]:
# Load text files from directories
original_text = load_text_files('/content/sumpubmed/abstract')
summaries = load_text_files('/content/sumpubmed/shorter_abstract')
# original_text = load_text_files('/content/sumpubmed/text')
# summaries = load_text_files('/content/sumpubmed/abstract')
print(len(original_text))
print(len(summaries))

# Create a dataset from the loaded text files
data = {'original_text': original_text, 'summary': summaries}

32689
32689
32689
32689


In [7]:
# Create a dataset from the loaded text files
df =pd.DataFrame({'original_text': original_text, 'summary': summaries})

# dataset = Dataset.from_dict(data)

print(df.shape)



(32689, 2)


In [8]:
# Load pre-trained model and tokenizer
# model_name = "gpt2"  # or specify another GPT model suitable for text generation
# model = GPT2LMHeadModel.from_pretrained(model_name)
# tokenizer = GPT2Tokenizer.from_pretrained(model_name)

device = 'gpu'
model_ckpt = 'facebook/bart-large-cnn'
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt)

In [9]:
# Tokenize the dataset
# def tokenize_function(ds):
#     return tokenizer(ds["original_text"] + " [SEP] " + ds["summary"])

# tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Save the tokenized dataset to a text file
# tokenized_datasets.save_to_disk("/content/")

In [10]:
df['input_text'] = df["original_text"].iloc[:1000] + " TL;DR " + df["summary"].iloc[:1000]
df["input_text"].to_csv("/content/combined_text.txt", index=False, header=False)

In [11]:
def get_feature(batch):
  encodings = tokenizer(batch['original_text'], text_target=batch['summary'],
                        max_length=1024, truncation=True)

  encodings = {'input_ids': encodings['input_ids'],
               'attention_mask': encodings['attention_mask'],
               'labels': encodings['labels']}

  return encodings

In [12]:
data = Dataset.from_pandas(df.iloc[:1000])

In [13]:
data_pt = data.map(get_feature, batched=True)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [14]:
data_pt

Dataset({
    features: ['original_text', 'summary', 'input_text', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 1000
})

In [15]:
columns = ['input_ids', 'labels', 'attention_mask']
data_pt.set_format(type='torch', columns=columns)

In [16]:
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [17]:
# # Load your dataset
# train_file = "/content/combined_text.txt"
# train_dataset = TextDataset(
#     tokenizer=tokenizer,
#     file_path=train_file,
#     block_size=128,
# )

In [18]:
# data_collator = DataCollatorForLanguageModeling(
#     tokenizer=tokenizer,
#     mlm=False
# )

In [20]:
# # Fine-tuning arguments
# training_args = TrainingArguments(
#     output_dir="/content/gpt2-finetuned-summarization",
#     overwrite_output_dir=True,
#     num_train_epochs=1,
#     per_device_train_batch_size=4,
#     save_steps=10000,
#     save_total_limit=2,
# )

training_args = TrainingArguments(
    output_dir = '/content/CS532_bart',
    overwrite_output_dir=True,
    num_train_epochs=1,
    warmup_steps = 500,
    per_device_train_batch_size=2,
    weight_decay = 0.01,
    logging_steps = 100,
    evaluation_strategy = 'steps',
    eval_steps=1000,
    save_steps=1e6,
    gradient_accumulation_steps=8
    # num_train_epochs=1,
    # warmup_steps = 100,
    # per_device_train_batch_size=4,
    # weight_decay = 0.01,
    # logging_steps = 10,
    # evaluation_strategy = 'steps',
    # eval_steps=500,
    # save_steps=1e6,
    # gradient_accumulation_steps=16
)

trainer = Trainer(model=model, args=training_args, tokenizer=tokenizer, data_collator=data_collator,
                  train_dataset = data_pt)

In [21]:
trainer.train()

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss


TrainOutput(global_step=62, training_loss=0.42466957338394656, metrics={'train_runtime': 384.9989, 'train_samples_per_second': 2.597, 'train_steps_per_second': 0.161, 'total_flos': 1000965299109888.0, 'train_loss': 0.42466957338394656, 'epoch': 0.99})

In [None]:
# Initialize Trainer
# trainer = Trainer(
#     model=model,
#     args=training_args,
#     data_collator=data_collator,
#     train_dataset=train_dataset,
# )

In [22]:
# Fine-tune the model
# trainer.train()

# Save the fine-tuned model
model.save_pretrained('/content/CS532_bart')
tokenizer.save_pretrained('/content/CS532_bart')

('/content/CS532_bart/tokenizer_config.json',
 '/content/CS532_bart/special_tokens_map.json',
 '/content/CS532_bart/vocab.json',
 '/content/CS532_bart/merges.txt',
 '/content/CS532_bart/added_tokens.json',
 '/content/CS532_bart/tokenizer.json')

In [23]:
trainer.save_model('CS532_bart_model')

In [24]:
# Load the fine-tuned model and tokenizer
model_path = '/content/CS532_bart'  # Replace with the actual path to your fine-tuned model
# model = GPT2LMHeadModel.from_pretrained(model_path)
# tokenizer = GPT2Tokenizer.from_pretrained(model_path)
pipe = pipeline('summarization', model='CS532_bart_model')

In [25]:
input_text = df['original_text'][10000]
print(pipe(input_text))

[{'summary_text': 'in mammals, pheromones play an important role in social and innate reproductive behavior within species.\nin rodents, vomeronasal receptor type  <dig> , which is specifically expressed in the vomer onasal organ, is thought to detect pheramones.\nwe found that all of the goat and sheep v1r genes have orthologs in their cross-species counterparts among these three ruminant species and that the sequence identity of v 1r orthologous pairs among these ruminants is much higher than that of mouse-rat v1R orthologOUS pairs.\nthe v1 r gene repertoire differs dramatically between mammalian species, and'}]


In [26]:
df['summary'][10000]

'interestingly, goat male pheromone, which can induce out-of-season ovulation in anestrous females, causes the same pheromone response in sheep, and vice versa, suggesting that there may be mechanisms for detecting "inter-species" pheromones among ruminant species.\nwe found that all of the goat and sheep v1r genes have orthologs in their cross-species counterparts among these three ruminant species and that the sequence identity of v1r orthologous pairs among these ruminants is much higher than that of mouse-rat v1r orthologous pairs.\nwe isolated  <dig> goat and  <dig> sheep intact v1r genes based on sequence similarity with  <dig> cow v1r genes in the cow genome database.\nthe fact that ruminant and rodent v1rs have distinct features suggests that ruminant and rodent v1rs have evolved distinct functions.\nour results suggest that, compared with rodents, the repertoire of orthologous v1r genes is remarkably conserved among the ruminants cow, sheep and goat.\nfurthermore, all goat v1r