# Memory-Optimized Chat Memory Model Training

This notebook trains a chat summarization model using the Llama architecture with memory optimizations for Google Colab.

## Setup Steps:
1. Mount Google Drive
2. Install dependencies
3. Configure memory settings
4. Train model with optimizations
5. Save results

In [None]:
# Install required packages
%%capture
!pip install torch transformers datasets accelerate bitsandbytes trl peft

In [None]:
# Mount Google Drive and setup directories
from google.colab import drive
drive.mount('/content/drive')

# Create project directories
!mkdir -p "/content/drive/MyDrive/chat_memory/models"
!mkdir -p "/content/drive/MyDrive/chat_memory/data"
!mkdir -p "/content/drive/MyDrive/chat_memory/notebooks"

In [None]:
# Configure memory management
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'

# Import required libraries
import json
from datasets import Dataset
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig
from trl import SFTTrainer
from peft import LoraConfig, get_peft_model
import gc

In [None]:
# Verify GPU Setup and Memory
def print_gpu_memory():
    """Print GPU memory usage"""
    torch.cuda.empty_cache()
    gc.collect()
    
    print("\nGPU Memory Summary:")
    print(f"Allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
    print(f"Cached: {torch.cuda.memory_reserved()/1e9:.2f} GB")
    
    !nvidia-smi

print("CUDA Available:", torch.cuda.is_available())
print("GPU Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")
print_gpu_memory()

In [None]:
# Configuration
max_seq_length = 2048  # Reduced for memory optimization
model_name = "unsloth/Meta-Llama-3.1-8B-Instruct"

def formatting_prompts_func(examples):
    texts = []
    
    for i in range(len(examples['messages'])):
        # Format the messages
        messages_text = "\n".join([
            f"{msg['timestamp']} | {msg['content']}"
            for msg in examples['messages'][i]
        ])
        
        # Create the instruction prompt
        instruction = """You are a chat summarization assistant. Given a conversation and its date range, create a concise yet comprehensive summary that captures the key points, emotional undertones, and progression of the relationship between participants."""
        
        # Format the input with context
        input_text = f"""Please summarize the following chat conversation that occurred between {examples['start_date'][i]} and {examples['end_date'][i]}.

[START DATE]
{examples['start_date'][i]}
[END DATE]
{examples['end_date'][i]}
[CHAT MESSAGES]
{messages_text}"""

        # Create the full prompt with instruction format
        prompt = f"""<s>[INST] {instruction}

{input_text} [/INST]
[SUMMARY]
{examples['summary'][i]}</s>{tokenizer.eos_token}"""
        
        texts.append(prompt)
        
        # Clear memory periodically
        if i % 10 == 0:
            gc.collect()
            torch.cuda.empty_cache()
    
    return {"text": texts}

def load_and_prepare_data(file_path, chunk_size=50):
    data = []
    # Process file in chunks to reduce memory usage
    with open(file_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            data.append(json.loads(line))
            if i % chunk_size == 0:
                gc.collect()
    
    # Create dataset
    dataset = Dataset.from_list(data)
    
    print("\nSample data point before formatting:")
    print(json.dumps(data[0], indent=2))
    
    return dataset