# Hybrid RAG model

In [1]:
pip install gradio huggingface-hub accelerate tqdm langchain-pinecone xformers accelerate nomic 

Collecting urllib3~=2.0
  Using cached urllib3-2.2.1-py3-none-any.whl (121 kB)
Installing collected packages: urllib3
  Attempting uninstall: urllib3
    Found existing installation: urllib3 1.26.6
    Uninstalling urllib3-1.26.6:
      Successfully uninstalled urllib3-1.26.6
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gradient 2.0.5 requires attrs<=19, but you have attrs 23.2.0 which is incompatible.
gradient 2.0.5 requires marshmallow<3.0, but you have marshmallow 3.21.1 which is incompatible.
fastai 2.7.9 requires torch<1.14,>=1.7, but you have torch 2.2.2 which is incompatible.
botocore 1.27.27 requires urllib3<1.27,>=1.25.4, but you have urllib3 2.2.1 which is incompatible.[0m[31m
[0mSuccessfully installed urllib3-2.2.1
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
pip install huggingface_hub

[0mNote: you may need to restart the kernel to use updated packages.


In [3]:
pip install stack

[31mERROR: Could not find a version that satisfies the requirement stack (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for stack[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.


In [4]:
pip install transformers --upgrade

[0mNote: you may need to restart the kernel to use updated packages.


In [5]:
pip install -i https://pypi.org/simple/ bitsandbytes

Looking in indexes: https://pypi.org/simple/
[0mNote: you may need to restart the kernel to use updated packages.


In [6]:
pip install urllib3==1.26.6

Collecting urllib3==1.26.6
  Using cached urllib3-1.26.6-py2.py3-none-any.whl (138 kB)
Installing collected packages: urllib3
  Attempting uninstall: urllib3
    Found existing installation: urllib3 2.2.1
    Uninstalling urllib3-2.2.1:
      Successfully uninstalled urllib3-2.2.1
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gradio 4.28.3 requires urllib3~=2.0, but you have urllib3 1.26.6 which is incompatible.
gradient 2.0.5 requires attrs<=19, but you have attrs 23.2.0 which is incompatible.
gradient 2.0.5 requires marshmallow<3.0, but you have marshmallow 3.21.1 which is incompatible.
fastai 2.7.9 requires torch<1.14,>=1.7, but you have torch 2.2.2 which is incompatible.[0m[31m
[0mSuccessfully installed urllib3-1.26.6
[0mNote: you may need to restart the kernel to use updated packages.


# Setup

In [7]:
# Python version
import sys 
print(sys.version)

3.9.13 (main, May 23 2022, 22:01:06) 
[GCC 9.4.0]


In [8]:
# Environment Variables
#from dotenv import load_dotenv
import yaml
import os
import tiktoken

# Load env
#load_dotenv()

In [9]:
# Torch config
from torch import cuda, bfloat16, float16
import torch

# Torch options
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)

# Parameters

In [10]:
# Load parameters from YAML file
import os

# Change the current working directory to the directory containing the YAML file
os.chdir('/notebooks/TFM/TFM_LAW_LLM')

# Load parameters from YAML file
with open('config.yaml', 'r') as file:
    config = yaml.safe_load(file)

In [11]:
# Use optimum
use_optimum = config["use_optimum"]

# Show
use_optimum

False

# Reference

- https://colab.research.google.com/drive/1rt318Ew-5dDw21YZx2zK2vnxbsuDAchH?usp=sharing#scrollTo=YFw8HWIyTCnJ
- https://www.reddit.com/r/LocalLLaMA/comments/16j624z/some_questions_of_implementing_llm_to_generate_qa
- https://www.anyscale.com/blog/a-comprehensive-guide-for-building-rag-based-llm-applications-part-1
- https://towardsdatascience.com/rag-how-to-talk-to-your-data-eaf5469b83b0
- https://github.com/edumunozsala/question-answering-pinecone-sts
- https://medium.com/@pankaj_pandey/fine-tuning-rag-models-for-custom-content-generation-849d7ffce97f

# Directory

In [12]:
# Set directory to file location
from pathlib import Path
import sys
notebook_location = Path(os.path.abspath(''))
os.chdir(notebook_location)

# Get the current working directory
current_directory = os.getcwd()
current_directory

'/notebooks/TFM/TFM_LAW_LLM'

# Libraries

In [13]:
# Libraries for display and visualization
from IPython.display import Markdown, display
import gradio as gr

# Libraries for managing data and serialization
import pinecone
import yaml
import json
import numpy as np

# General utility libraries
import gc
import os
import time

# Libraries related to HuggingFace
from huggingface_hub import notebook_login

# Libraries related to Transformers
from transformers import BitsAndBytesConfig
from sentence_transformers import CrossEncoder
from typing import List
#import accelerate

# Libraries related to Langchain
from sentence_transformers import SentenceTransformer
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.schema import AIMessage, HumanMessage
from langchain.memory import ConversationBufferMemory
from langchain.chains import (
    SimpleSequentialChain, 
    RetrievalQA, 
    LLMChain,
    RetrievalQAWithSourcesChain,
    ConversationalRetrievalChain
)
from langchain.prompts.chat import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate
)
from langchain import HuggingFacePipeline
from langchain import PromptTemplate

# Libraries related to Pinecone
from langchain_pinecone import PineconeVectorStore  
from pinecone import Pinecone

# Libraries related to optimization
import xformers

# Other miscellaneous libraries
from tqdm.notebook import tqdm
from nomic import atlas
import nomic

# Local custom functions
from functions import *

In [14]:
# Warnings
import warnings
warnings.filterwarnings("ignore")

# Device

In [15]:
# Setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

# CUDA information
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    allocated_memory = torch.cuda.memory_allocated(0) / (1024**3)  # Convert bytes to GB
    cached_memory = torch.cuda.memory_reserved(0) / (1024**3)  # Convert bytes to GB
    total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)  # Convert bytes to GB
    available_memory = total_memory - cached_memory
    print('Allocated:   ', round(allocated_memory, 1), 'GB')
    print('Cached:      ', round(cached_memory, 1), 'GB')
    print('Available:  ', round(available_memory, 1), 'GB')
    print('Total:      ', round(total_memory, 1), 'GB')

Using device: cuda

NVIDIA RTX A4000
Memory Usage:
Allocated:    0.0 GB
Cached:       0.0 GB
Available:   15.7 GB
Total:       15.7 GB


In [16]:
# Clean memory
torch.cuda.empty_cache()
gc.collect()

260

# Pinecone

Let's get Pinecone vector store ready.

In [17]:
# Init pinecone
pinecone = Pinecone(api_key = "03b29f67-c297-4462-825b-13ce23b3d577")

pc = Pinecone(api_key = pinecone)
# Connect
index_name = 'lawllm-unstructured-database'
index = pinecone.Index(index_name)

# Index stats
index.describe_index_stats()

{'dimension': 768,
 'index_fullness': 0.0,
 'namespaces': {'': {'vector_count': 1068}},
 'total_vector_count': 1068}

# Embedding model

In [18]:
# Model ID
embed_model_id = config["embedding_model"]

# Show
embed_model_id

'sentence-transformers/multi-qa-mpnet-base-cos-v1'

In [19]:
# Embed model
embed_model = HuggingFaceEmbeddings(
    model_name = embed_model_id,
    model_kwargs = {'device': device},
    encode_kwargs = {'device': device, 'batch_size': 32}
) 

# Show
embed_model

HuggingFaceEmbeddings(client=SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: MPNetModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
), model_name='sentence-transformers/multi-qa-mpnet-base-cos-v1', cache_folder=None, model_kwargs={'device': device(type='cuda')}, encode_kwargs={'device': device(type='cuda'), 'batch_size': 32}, multi_process=False, show_progress=False)

In [20]:
# CUDA information
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    allocated_memory = torch.cuda.memory_allocated(0) / (1024**3)  # Convert bytes to GB
    cached_memory = torch.cuda.memory_reserved(0) / (1024**3)  # Convert bytes to GB
    total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)  # Convert bytes to GB
    available_memory = total_memory - cached_memory
    print('Allocated:   ', round(allocated_memory, 1), 'GB')
    print('Cached:      ', round(cached_memory, 1), 'GB')
    print('Available:  ', round(available_memory, 1), 'GB')
    print('Total:      ', round(total_memory, 1), 'GB')

NVIDIA RTX A4000
Memory Usage:
Allocated:    0.0 GB
Cached:       0.0 GB
Available:   15.7 GB
Total:       15.7 GB


# Load LLM model

In [21]:
# Model ID
use_quantization = config["use_quantization"]

# Show
use_quantization

True

In [22]:
# Select model
model_id = config["model"]

# Show
model_id

'mistralai/Mistral-7B-Instruct-v0.2'

In [23]:
from transformers import AutoTokenizer

# Your Hugging Face API token
api_token = "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"

# Model identifier
model_id = "mistralai/Mistral-7B-Instruct-v0.2"

# Load tokenizer with authentication
tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    token=api_token
)

In [24]:
# Tokenizer
import transformers
from transformers import AutoTokenizer

tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_id
)

In [25]:
# Set BNB configuration if quantization is enabled
bnb_config = transformers.BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type = 'nf4',
    bnb_4bit_use_double_quant = True,
    bnb_4bit_compute_dtype = torch.bfloat16
) if use_quantization else None

In [26]:
# Set model
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    token=api_token,
    trust_remote_code=True
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [27]:

# CUDA information
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    allocated_memory = torch.cuda.memory_allocated(0) / (1024**3)  # Convert bytes to GB
    cached_memory = torch.cuda.memory_reserved(0) / (1024**3)  # Convert bytes to GB
    total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)  # Convert bytes to GB
    available_memory = total_memory - cached_memory
    print('Allocated:   ', round(allocated_memory, 1), 'GB')
    print('Cached:      ', round(cached_memory, 1), 'GB')
    print('Available:  ', round(available_memory, 1), 'GB')
    print('Total:      ', round(total_memory, 1), 'GB')

NVIDIA RTX A4000
Memory Usage:
Allocated:    0.0 GB
Cached:       0.0 GB
Available:   15.7 GB
Total:       15.7 GB


# Command

We now import the pre_prompt and the prompt_context from the yaml file.

In [28]:
# Get pre-prompt
pre_prompt = config["pre_prompt"]

# Create prompt context
prompt_context = config["prompt_context"]

In [29]:
# General template
general_template = pre_prompt + prompt_context + "A continuación se proporciona el contexto: {context}" + " " + "pregunta: {query}"

In [30]:
# Mistral template
mistral_template = "<s>[INST]" + pre_prompt + prompt_context +  "A continuación se proporciona el contexto: [/INST] {context}" + "</s>" + "[INST] pregunta: {query} [/INST]"

In [31]:
# Google template
google_template = f"""
<start_of_turn>user
{pre_prompt}. {prompt_context} A continuación se proporciona el contexto: 
Contexto: {{context}} 
Pregunta: {{query}}
<end_of_turn>
<start_of_turn>model
Respuesta: """

In [32]:
# Define the final template selection logic
if "mistral" in model_id:
    template = mistral_template
    selected_template_message = "Mistral template selected."
elif "google" in model_id:
    template = google_template
    selected_template_message = "Google template selected."
else:
    template = general_template
    selected_template_message = "Default template selected."

# Print out the selected template message
print(selected_template_message)

Mistral template selected.


In [33]:
# Prompt Template
prompt = PromptTemplate(
    template = template, 
    input_variables = ["context", "query"]
)

We can now print the prompt.

In [34]:
# Show
prompt

PromptTemplate(input_variables=['context', 'query'], template='<s>[INST]Eres un asistente experto en derecho y leyes españolas y tu objetivo es proporcionar respuestas exhaustivas y precisas a las preguntas planteadas por tus clientes.\nAsegúrate de basar tus respuestas en el contexto proporcionado, utilizando todas las leyes y normativas relevantes para fundamentar tus argumentos.\nEs crucial que todas las respuestas estén redactadas en español y presentadas de forma clara y coherente.\nConsidera ofrecer ejemplos o casos hipotéticos para ilustrar tus puntos de vista.\nA continuación, se presenta la información relevante que debes usar para responder a las consultas de los clientes. \nEn caso de no encontrar la respuesta, debes indicarlo de forma explícita.\nA continuación se proporciona el contexto: [/INST] {context}</s>[INST] pregunta: {query} [/INST]')

# LLM Pipeline

Let's define the LLM Pipeline.

In [35]:
# Define pipeline with parameters from config file
generate_text = transformers.pipeline(
    model = model,
    tokenizer = tokenizer,
    task = 'text-generation',
    return_full_text = config["return_full_text"],
    max_new_tokens = config["max_new_tokens"],
    repetition_penalty = config["repetition_penalty"],
    temperature = config["temperature"],
    pad_token_id = tokenizer.eos_token_id,
    batch_size = 1
)

# HF pipeline
llm = HuggingFacePipeline(pipeline = generate_text)

# Create llm chain 
llm_chain = LLMChain(llm = llm, prompt = prompt)

# Vector store

In [36]:
# Text field
text_field = "text"  

# Vector store
vectorstore = PineconeVectorStore(index, embed_model, text_field)  

# Show
vectorstore

<langchain_pinecone.vectorstores.PineconeVectorStore at 0x7f2862ba2d60>

# Test models

In [37]:
# Simple context
context = "Eres una API con conocimientos legales. Debes responder a preguntas en Español. Si no conoces la respuesta, admítelo."

# Query
query = 'Explícame el Artículo 245 del Código Penal de España referente a ocupaciones ilegales de bienes inmuebles'

# Find closer docs

We can now see the closer docs to the query and it's scores.

In [38]:
# Similarity output
similarity_output = vectorstore.similarity_search_with_score(query, k = config['top_k_docs'])

# Context preprocessed
context_processed = [{"context": doc.page_content, "score": score} for doc, score in similarity_output]

# Show
context_processed[0:3]

[{'context': 'De la acusación y denuncia falsas y de la\nsimulación de delitos\n\nArtículo 456\n1. Los que, con conocimiento de su falsedad o temerario desprecio hacia la verdad, imputaren a alguna persona hechos\nque, de ser ciertos, constituirían infracción penal, si esta\n295',
  'score': 0.664225459},
 {'context': 'Trata de seres humanos, 177 bis.\nTratos degradantes, 173 a 177.\nTregua:\nViolación, 593.\nTribunal Constitucional:\nAtentado contra sus miembros, 550.\nCalumnias, injurias o amenazas contra, 504.\nTribunal de Cuentas:\nObstáculos a la investigación, 502.2.\nTribunal Superior de Justicia:\nCalumnias, injurias o amenazas contra, 504.\nTribunal Supremo:\nCalumnias, injurias o amenazas contra, 504.\nTribunales:\nPerturbación del orden, 558.\nTutela, 120, 192 y 440.\nInhabilitación para el ejercicio de, 46.\n\nU\nUltrajes:\nA España, Comunidades Autónomas, símbolos o emblemas, 543.\nUrbanismo:\nDelitos contra, 319 y 320.\nUsurpación:\nAlteración de términos o lindes, 246.\n

# Re-ranking

In [39]:
# Model ID
reranking_model = config["reranking_model"]

# Show
reranking_model 

'cross-encoder/ms-marco-MiniLM-L-6-v2'

In [40]:
# Model ID
top_reranked_docs = config["top_reranked_docs"]

# Show
top_reranked_docs

15

In [41]:
# Extracting 'title' keys
final_context = [entry['context'] for entry in context_processed]

In [42]:
# Cross encoder
cross_encoder = CrossEncoder(reranking_model)

# Show
cross_encoder

<sentence_transformers.cross_encoder.CrossEncoder.CrossEncoder at 0x7f29bc6215e0>

In [43]:
# CUDA information
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    allocated_memory = torch.cuda.memory_allocated(0) / (1024**3)  # Convert bytes to GB
    cached_memory = torch.cuda.memory_reserved(0) / (1024**3)  # Convert bytes to GB
    total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)  # Convert bytes to GB
    available_memory = total_memory - cached_memory
    print('Allocated:   ', round(allocated_memory, 1), 'GB')
    print('Cached:      ', round(cached_memory, 1), 'GB')
    print('Available:  ', round(available_memory, 1), 'GB')
    print('Total:      ', round(total_memory, 1), 'GB')

NVIDIA RTX A4000
Memory Usage:
Allocated:    0.4 GB
Cached:       0.5 GB
Available:   15.3 GB
Total:       15.7 GB


In [44]:
def rank_documents(cross_encoder, text_field, query:str, retrieved_documents:List[dict]):
    """
    Ranks retrieved documents based on their relevance to a given query using a cross-encoder model.

    Parameters:
    - cross_encoder (CrossEncoder): A cross-encoder model from the sentence-transformers library.
    - query (str): The query string for which the documents are to be ranked.
    - retrieved_documents (List[dict]): A list of dictionaries representing documents. Each dictionary should have a 'text' field
      containing the document text and any additional fields that you want to retain in the output dictionary.

    Returns:
    - dict: A dictionary where the key is the rank position (starting from 0 for the most relevant document)
      and the value is a dictionary containing the document text and any additional fields. The documents are ranked
      in descending order of relevance to the query.

    Usage:
    ranked_docs = rank_documents(cross_encoder, query, retrieved_documents)

    Note: This function requires the sentence-transformers library and a pretrained cross-encoder model.
    """
    pairs = [[query, doc[text_field]] for doc in retrieved_documents]
    scores = cross_encoder.predict(pairs)
    ranks = np.argsort(scores)[::-1]
    ranked_docs = {rank_num: {text_field: retrieved_documents[rank_num][text_field], **retrieved_documents[rank_num]} for rank_num in ranks}
    return ranked_docs

In [45]:
# Re-ranking
text_field = 'context'
ranked_context = rank_documents(cross_encoder, text_field, query, context_processed)

In [46]:
# Sort
sorted_ranked_context = dict(sorted(ranked_context.items())[:top_reranked_docs])

In [47]:
# Format
sorted_ranked_context = list(sorted_ranked_context.values())

In [48]:
# Show
sorted_ranked_context[0:3]

[{'context': 'De la acusación y denuncia falsas y de la\nsimulación de delitos\n\nArtículo 456\n1. Los que, con conocimiento de su falsedad o temerario desprecio hacia la verdad, imputaren a alguna persona hechos\nque, de ser ciertos, constituirían infracción penal, si esta\n295',
  'score': 0.664225459},
 {'context': 'Trata de seres humanos, 177 bis.\nTratos degradantes, 173 a 177.\nTregua:\nViolación, 593.\nTribunal Constitucional:\nAtentado contra sus miembros, 550.\nCalumnias, injurias o amenazas contra, 504.\nTribunal de Cuentas:\nObstáculos a la investigación, 502.2.\nTribunal Superior de Justicia:\nCalumnias, injurias o amenazas contra, 504.\nTribunal Supremo:\nCalumnias, injurias o amenazas contra, 504.\nTribunales:\nPerturbación del orden, 558.\nTutela, 120, 192 y 440.\nInhabilitación para el ejercicio de, 46.\n\nU\nUltrajes:\nA España, Comunidades Autónomas, símbolos o emblemas, 543.\nUrbanismo:\nDelitos contra, 319 y 320.\nUsurpación:\nAlteración de términos o lindes, 246.\n

# Get max docs

In [49]:
# Model ID
max_model_tokens = config["max_model_tokens"]

# Show
max_model_tokens

5120

In [50]:
def count_tokens(string: str, encoding_name: str = "cl100k_base") -> int:
    # Get encoding from tiktoken
    encoding = tiktoken.get_encoding(encoding_name)
    # Encode the string using the specified encoding
    encoded_string = encoding.encode(string)
    # Count the number of tokens
    num_tokens = len(encoded_string)
    return num_tokens

In [51]:
# Initialize cumulative token count
cumulative_tokens = 0

# Filtered list to store dictionaries
filtered_context = []

# Iterate through the list of dictionaries
for item in sorted_ranked_context:
    # Calculate number of tokens for 'context' value
    token_count = count_tokens(item['context'])
    
    # Cumulative sum of token counts
    cumulative_tokens += token_count
    
    # Check if cumulative tokens are still less than max_model_tokens
    if cumulative_tokens < max_model_tokens:
        filtered_context.append(item)
    else:
        break

In [52]:
# Show
filtered_context[0:3]

[{'context': 'De la acusación y denuncia falsas y de la\nsimulación de delitos\n\nArtículo 456\n1. Los que, con conocimiento de su falsedad o temerario desprecio hacia la verdad, imputaren a alguna persona hechos\nque, de ser ciertos, constituirían infracción penal, si esta\n295',
  'score': 0.664225459},
 {'context': 'Trata de seres humanos, 177 bis.\nTratos degradantes, 173 a 177.\nTregua:\nViolación, 593.\nTribunal Constitucional:\nAtentado contra sus miembros, 550.\nCalumnias, injurias o amenazas contra, 504.\nTribunal de Cuentas:\nObstáculos a la investigación, 502.2.\nTribunal Superior de Justicia:\nCalumnias, injurias o amenazas contra, 504.\nTribunal Supremo:\nCalumnias, injurias o amenazas contra, 504.\nTribunales:\nPerturbación del orden, 558.\nTutela, 120, 192 y 440.\nInhabilitación para el ejercicio de, 46.\n\nU\nUltrajes:\nA España, Comunidades Autónomas, símbolos o emblemas, 543.\nUrbanismo:\nDelitos contra, 319 y 320.\nUsurpación:\nAlteración de términos o lindes, 246.\n

In [53]:
# Sum tokens for all contents in filtered_data
total_tokens = sum(count_tokens(item["context"]) for item in filtered_context)

# Print total tokens
print("Total tokens for all contexts in filtered_context:", total_tokens)

Total tokens for all contexts in filtered_context: 3509


In [54]:
len(filtered_context)

15

# Enhanced model

Now, let's use the RAG model.

In [55]:
# Filter contexts, keeping only the context strings
filtered_context_ready = [item["context"] for item in filtered_context]

In [None]:
# Enhanced model
enhanced_model = llm_chain({"context": str(filtered_context_ready), "query": query})

In [None]:
# Output
enhanced_result = enhanced_model['text'].strip()

# Markdown
display(Markdown(f"<b>{query}</b>"))
display(Markdown(f"<p>{enhanced_result}</p>"))

# Clean

In [None]:
# CUDA information
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    allocated_memory = torch.cuda.memory_allocated(0) / (1024**3)  # Convert bytes to GB
    cached_memory = torch.cuda.memory_reserved(0) / (1024**3)  # Convert bytes to GB
    total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)  # Convert bytes to GB
    available_memory = total_memory - cached_memory
    print('Allocated:   ', round(allocated_memory, 1), 'GB')
    print('Cached:      ', round(cached_memory, 1), 'GB')
    print('Available:  ', round(available_memory, 1), 'GB')
    print('Total:      ', round(total_memory, 1), 'GB')

In [None]:
# Clean memory
torch.cuda.empty_cache()
gc.collect()