In [None]:
!pip install -U BitsandBytes
!pip install -U accelerate



In [None]:
import time
from tqdm.notebook import tqdm
import torch
from transformers import AutoTokenizer, LlamaForCausalLM, BitsAndBytesConfig
import pandas as pd
import os
import bitsandbytes
import accelerate

In [None]:
# Defining paths for dataset and base output directory
data_path = '/content/drive/MyDrive/Colab Notebooks/dataset_150k_with_summaries.csv'
output_csv_path = '/content/drive/MyDrive/Colab Notebooks/summarized_abstracts.csv'
output_base_path = '/content/drive/MyDrive/Colab Notebooks/summarized_abstracts_checkpoint'
model_save_path = '/content/drive/MyDrive/Colab Notebooks/Model/llama-3-8b-bnb-4bit'
checkpoint_path = '/content/drive/MyDrive/Colab Notebooks/checkpoint.txt'


In [None]:
# Loading the dataset
df = pd.read_csv(data_path)
df = df.iloc[:10000]

In [None]:
# Initializing or load a checkpoint
start_index = 0
if os.path.exists(checkpoint_path):
    with open(checkpoint_path, 'r') as f:
        start_index = int(f.read().strip())
    print(f"Resuming from checkpoint at index {start_index}")

Resuming from checkpoint at index 4020


In [None]:
# Loading 4-bit quantization configuration and model
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model_name = "unsloth/llama-3-8b-bnb-4bit"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = LlamaForCausalLM.from_pretrained(model_name, quantization_config=quantization_config)
model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.25k [00:00<?, ?B/s]

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
`low_cpu_mem_usage` was None, now set to True since model is quantized.


model.safetensors:   0%|          | 0.00/5.70G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/198 [00:00<?, ?B/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096, padding_idx=128255)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): Llama

In [None]:
# Function to generate a summary
def generate_summary_llama(abstract):
    inputs = tokenizer(abstract, return_tensors='pt', truncation=True, max_length=2048)
    with torch.no_grad():
        summary_ids = model.generate(
            inputs['input_ids'],
            max_new_tokens=150,
            min_length=40,
            do_sample=False
        )
    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)


start_time = time.time()

with tqdm(total=len(df), desc="Processing Abstracts") as pbar:
    summaries = []  # Storing summaries for the current batch only
    batch_start_index = start_index  # Tracking the start of each batch

    for index, row in df.iloc[start_index:].iterrows():
        abstract = row['abstract']
        summary = generate_summary_llama(abstract)
        summaries.append(summary)

        # Saving every 20 entries
        if (index + 1) % 20 == 0:
            # Add summaries for the current batch to the DataFrame
            df.loc[batch_start_index:index, 'summary'] = summaries

            # Saving the DataFrame, model, tokenizer, and checkpoint
            df.to_csv(output_csv_path, index=False)
            model.save_pretrained(model_save_path)
            tokenizer.save_pretrained(model_save_path)
            with open(checkpoint_path, 'w') as f:
                f.write(str(index + 1))

            # Resetting for the next batch
            summaries = []
            batch_start_index = index + 1  # Updating the start index for the next batch
            print(f"Checkpoint saved at index {index + 1}")

        # Updating progress bar
        pbar.update(1)

# Saving any remaining summaries and the final checkpoint
if summaries:
    df.loc[batch_start_index:index, 'summary'] = summaries
    df.to_csv(output_csv_path, index=False)
    model.save_pretrained(model_save_path)
    tokenizer.save_pretrained(model_save_path)
    with open(checkpoint_path, 'w') as f:
        f.write(str(index + 1))
    print(f"Final checkpoint saved at index {index + 1}")


total_time = time.time() - start_time
print(f"Total time taken: {total_time} seconds")


Processing Abstracts:   0%|          | 0/10000 [00:00<?, ?it/s]

Checkpoint saved at index 4040




Checkpoint saved at index 4060




Checkpoint saved at index 4080




Checkpoint saved at index 4100




Checkpoint saved at index 4120




Checkpoint saved at index 4140




Checkpoint saved at index 4160




Checkpoint saved at index 4180




Checkpoint saved at index 4200




Checkpoint saved at index 4220




Checkpoint saved at index 4240




Checkpoint saved at index 4260




Checkpoint saved at index 4280




Checkpoint saved at index 4300




Checkpoint saved at index 4320




Checkpoint saved at index 4340




Checkpoint saved at index 4360




Checkpoint saved at index 4380




Checkpoint saved at index 4400




Checkpoint saved at index 4420




Checkpoint saved at index 4440




Checkpoint saved at index 4460




Checkpoint saved at index 4480




Checkpoint saved at index 4500




Checkpoint saved at index 4520




Checkpoint saved at index 4540




Checkpoint saved at index 4560




Checkpoint saved at index 4580




Checkpoint saved at index 4600




Checkpoint saved at index 4620




Checkpoint saved at index 4640




Checkpoint saved at index 4660




Checkpoint saved at index 4680




Checkpoint saved at index 4700




Checkpoint saved at index 4720




Checkpoint saved at index 4740




Checkpoint saved at index 4760




Checkpoint saved at index 4780




Checkpoint saved at index 4800




Checkpoint saved at index 4820




Checkpoint saved at index 4840




Checkpoint saved at index 4860




Checkpoint saved at index 4880




Checkpoint saved at index 4900




Checkpoint saved at index 4920




Checkpoint saved at index 4940




Checkpoint saved at index 4960




Checkpoint saved at index 4980




Checkpoint saved at index 5000




Checkpoint saved at index 5020




Checkpoint saved at index 5040




Checkpoint saved at index 5060




Checkpoint saved at index 5080




Checkpoint saved at index 5100




Checkpoint saved at index 5120




Checkpoint saved at index 5140




Checkpoint saved at index 5160




Checkpoint saved at index 5180




Checkpoint saved at index 5200




Checkpoint saved at index 5220




Checkpoint saved at index 5240




Checkpoint saved at index 5260




Checkpoint saved at index 5280




Checkpoint saved at index 5300




Checkpoint saved at index 5320




Checkpoint saved at index 5340




Checkpoint saved at index 5360




Checkpoint saved at index 5380




Checkpoint saved at index 5400




Checkpoint saved at index 5420




Checkpoint saved at index 5440




Checkpoint saved at index 5460




Checkpoint saved at index 5480




Checkpoint saved at index 5500




Checkpoint saved at index 5520




Checkpoint saved at index 5540




Checkpoint saved at index 5560




Checkpoint saved at index 5580




Checkpoint saved at index 5600




Checkpoint saved at index 5620




Checkpoint saved at index 5640




Checkpoint saved at index 5660




Checkpoint saved at index 5680




Checkpoint saved at index 5700




Checkpoint saved at index 5720




Checkpoint saved at index 5740




Checkpoint saved at index 5760




Checkpoint saved at index 5780




Checkpoint saved at index 5800




Checkpoint saved at index 5820




Checkpoint saved at index 5840




Checkpoint saved at index 5860




Checkpoint saved at index 5880




Checkpoint saved at index 5900




Checkpoint saved at index 5920




Checkpoint saved at index 5940




Checkpoint saved at index 5960




Checkpoint saved at index 5980




Checkpoint saved at index 6000




Checkpoint saved at index 6020




Checkpoint saved at index 6040




Checkpoint saved at index 6060




Checkpoint saved at index 6080




Checkpoint saved at index 6100




Checkpoint saved at index 6120




Checkpoint saved at index 6140




Checkpoint saved at index 6160




Checkpoint saved at index 6180




Checkpoint saved at index 6200




Checkpoint saved at index 6220




Checkpoint saved at index 6240




Checkpoint saved at index 6260




Checkpoint saved at index 6280




Checkpoint saved at index 6300




Checkpoint saved at index 6320




Checkpoint saved at index 6340




Checkpoint saved at index 6360




Checkpoint saved at index 6380




Checkpoint saved at index 6400




Checkpoint saved at index 6420




Checkpoint saved at index 6440




Checkpoint saved at index 6460




Checkpoint saved at index 6480




Checkpoint saved at index 6500




Checkpoint saved at index 6520




Checkpoint saved at index 6540




Checkpoint saved at index 6560




Checkpoint saved at index 6580




Checkpoint saved at index 6600




Checkpoint saved at index 6620




Checkpoint saved at index 6640




Checkpoint saved at index 6660




Checkpoint saved at index 6680




Checkpoint saved at index 6700




Checkpoint saved at index 6720




Checkpoint saved at index 6740




Checkpoint saved at index 6760




Checkpoint saved at index 6780




Checkpoint saved at index 6800




Checkpoint saved at index 6820




Checkpoint saved at index 6840




Checkpoint saved at index 6860




Checkpoint saved at index 6880




Checkpoint saved at index 6900




Checkpoint saved at index 6920




Checkpoint saved at index 6940




Checkpoint saved at index 6960




Checkpoint saved at index 6980




Checkpoint saved at index 7000




Checkpoint saved at index 7020




Checkpoint saved at index 7040




Checkpoint saved at index 7060




Checkpoint saved at index 7080




Checkpoint saved at index 7100




Checkpoint saved at index 7120




Checkpoint saved at index 7140




Checkpoint saved at index 7160




Checkpoint saved at index 7180




Checkpoint saved at index 7200




Checkpoint saved at index 7220




Checkpoint saved at index 7240




Checkpoint saved at index 7260




Checkpoint saved at index 7280




Checkpoint saved at index 7300




Checkpoint saved at index 7320




Checkpoint saved at index 7340




Checkpoint saved at index 7360




Checkpoint saved at index 7380




Checkpoint saved at index 7400




Checkpoint saved at index 7420




Checkpoint saved at index 7440




Checkpoint saved at index 7460




Checkpoint saved at index 7480




Checkpoint saved at index 7500




Checkpoint saved at index 7520




Checkpoint saved at index 7540




Checkpoint saved at index 7560




Checkpoint saved at index 7580




Checkpoint saved at index 7600




Checkpoint saved at index 7620




Checkpoint saved at index 7640




Checkpoint saved at index 7660




Checkpoint saved at index 7680




Checkpoint saved at index 7700




Checkpoint saved at index 7720




Checkpoint saved at index 7740




Checkpoint saved at index 7760




Checkpoint saved at index 7780




Checkpoint saved at index 7800




Checkpoint saved at index 7820




Checkpoint saved at index 7840




Checkpoint saved at index 7860




Checkpoint saved at index 7880




Checkpoint saved at index 7900




Checkpoint saved at index 7920




Checkpoint saved at index 7940




Checkpoint saved at index 7960




Checkpoint saved at index 7980




Checkpoint saved at index 8000




Checkpoint saved at index 8020




Checkpoint saved at index 8040




KeyboardInterrupt: 