# Google Colab - Summarization Demo

Run the notebook on `https://colab.research.google.com/`.<br>
Remember to set Runtime > Change runtime type > Hardware Accelerator > GPU.

## Initialize

In [None]:
# Python version can't be set on Google Colab
!python --version

In [None]:
%load_ext autoreload
%autoreload

In [None]:
# Clone repository
CODE_TEMP = './_temp'
CODE_BRANCH = 'main'

!git clone -b $CODE_BRANCH --recurse-submodules --single-branch https://github.com/Genisis2/nus_cs5246_project.git $CODE_TEMP

# Explode it in the workspace and remove temp
!mv -f $CODE_TEMP/* . && mv -f $CODE_TEMP/.* .
!rm -rf $CODE_TEMP

# Setup git redirect
!git config --global url.https://github.com/.insteadOf git://github.com/

# Install requirements
!pip install -U pip setuptools
!pip install -r requirements-cuda.txt

# Restart kernel so imported modules are available
print("Restarting kernel. Run the next cell manually.")
import time
time.sleep(2)
import os
os.kill(os.getpid(), 9)

## Dataset

In [None]:
from simplertimes import data

# Display information on the dataset
data.describe_cnn_dm_dataset()

In [None]:
# Get a few samples of the test dataset
three_samples = data.load_cnn_dm_dataset(split='test')[:3]
for idx in range(3):
    print(f"{idx+1}:\n    article: {three_samples['article'][idx]}\n    highlight: {three_samples['highlights'][idx]}")

## Summarize

In [None]:
from simplertimes import summarize

# Create a BART model
bart_summarizer = summarize.create_summarizer(summarize.BART_MODEL_ID)
!nvidia-smi

In [None]:
bart_summarizer.print_details()

In [None]:
# Perform inference using bart
bart_summaries = bart_summarizer.generate_summary(three_samples["article"])

# Remove BART model from memory
del bart_summarizer
import torch
torch.cuda.empty_cache()
!nvidia-smi

In [None]:
# Create a PEGASUS model
peg_summarizer = summarize.create_summarizer(summarize.PEGASUS_MODEL_ID)
!nvidia-smi

In [None]:
peg_summarizer.print_details()

In [None]:
# Perform inference using PEGASUS
peg_summaries = peg_summarizer.generate_summary(three_samples["article"])

# Remove PEGASUS model from memory
del peg_summarizer
import torch
torch.cuda.empty_cache()
!nvidia-smi

In [None]:
# Compare generated summaries
for idx in range(3):
    print(f"{idx+1}:\n    Article: {three_samples['article'][idx]}\n    GT Summary: {three_samples['highlights'][idx]}\n    BART: {bart_summaries[idx]['summary_text']}\n   PEGASUS: {peg_summaries[idx]['summary_text']}")