<a href="https://colab.research.google.com/github/CyberMaryVer/llm-notebooks/blob/master/llama_fine_tuning_with_peft.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

----

Note: **Use the best GPU available**

(go to Runtime -> change runtime type)

In [None]:
! nvidia-smi

Tue Aug 15 17:55:01 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   74C    P0    33W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

<h1><strong> Fine-tuning of LLaMA 2 with PEFT </strong></h1>

▶️ Sources: [Fine-Tuning](https://mlabonne.github.io/blog/posts/Fine_Tune_Your_Own_Llama_2_Model_in_a_Colab_Notebook.html),
[Dataset](https://github.com/mshumer/gpt-llm-trainer)

▶️ Author: [Maria Startseva](https://www.linkedin.com/in/maria-startseva/)

----

The main idea behind prompt tuning, and parameter-efficient finetuning methods in general, is to add a small number of new parameters to a pretrained LLM and only finetune the newly added parameters to make the LLM perform better on (a) a target dataset (for example, a domain-specific dataset like medical or legal documents) and (b) a target task (for example, sentiment classification).

<img src="https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fd9c855ba-814f-4f95-9b1b-c97a46eb2f42_1646x880.png" width="900">



<h1><strong> Content: </strong></h1>

>[⚗️ Install necessary libraries](#scrollTo=AbrFgrhG_xYi)

>[📅 Data generation](#scrollTo=Way3_PuPpIuE)

>[💫 Fine-Tuning](#scrollTo=moVo0led-6tu)

>[🏃 Run Inference](#scrollTo=F6fux9om_c4-)

>[🫙Merge & store the model](#scrollTo=Ko6UkINu_qSx)

>[⭐ Load a fine-tuned model from Drive and run inference *](#scrollTo=do-dFdE5zWGO)



# ⚗️ **Install necessary libraries**

In [None]:
#@markdown ☑️ Install necessary libraries
from IPython.display import clear_output
import ipywidgets as widgets
import os

def inf(msg, style, wdth):
    inf = widgets.Button(description=msg,
                         disabled=True,
                         button_style=style,
                         layout=widgets.Layout(min_width=wdth))
    display(inf)

! pip install -q accelerate==0.21.0 peft==0.4.0 bitsandbytes==0.40.2 transformers==4.31.0 trl==0.4.7
! pip install -q openai
! pip install -q backoff

import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer

import backoff
import openai
from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
    retry_if_exception_type
)  # for exponential backoff

@retry(
    retry=retry_if_exception_type((openai.error.APIError, openai.error.APIConnectionError, openai.error.RateLimitError, openai.error.ServiceUnavailableError, openai.error.Timeout)),
    wait=wait_random_exponential(multiplier=1, max=60),
    stop=stop_after_attempt(10)
)
def completions_with_backoff(**kwargs):
    return openai.ChatCompletion.create(**kwargs)

clear_output()
inf('\u2714 Done','success', '50px')

Button(button_style='success', description='✔ Done', disabled=True, layout=Layout(min_width='50px'), style=But…

# 📅 **Data generation**

<h2>🔑 Secret key </h2>

In [None]:
#@markdown ☑️ Enter OpenAI Key
from getpass import getpass
secret = getpass('Enter the secret value: ')
clear_output()
inf('\u2714 Done','success', '50px')

Button(button_style='success', description='✔ Done', disabled=True, layout=Layout(min_width='50px'), style=But…

1. Write your prompt here. Make it as descriptive as possible!

2. Choose the temperature (between 0 and 1) to use when generating data. Lower values are great for precise tasks, like writing code, whereas larger values are better for creative tasks, like writing stories.

3. Choose how many examples you want to generate. The more you generate, a) the longer it takes and b) the more expensive data generation will be. But generally, more examples will lead to a higher-quality model. 100 is usually the minimum to start.

In [None]:
prompt = "A model that takes in a list of words in English and returns only a sentence with all these words included."
example_input = ["left", "rock", "play"]
temperature = .4
number_of_examples = 100

Define the wrapper prompt for your prompt here

In [None]:
%%writefile prompt.txt
You are generating data which will be used to train a machine learning model.

You will be given a high-level description of the model we want to train, and from that, you will generate data samples, each with a prompt/response pair.

You will do so in this format:
```
prompt
-----------
$prompt_goes_here
-----------

response
-----------
$response_goes_here
-----------
```

Only one prompt/response pair should be generated per turn.

For each turn, make the example slightly more complex than the last, while ensuring diversity.

Make sure your samples are unique and diverse, yet high-quality and complex enough to train a well-performing model.

Make sure to return only prompt response without any additional comments, notes ot text.

Here is the type of model we want to train:


Overwriting prompt.txt


Run this to generate the dataset.

In [None]:
import os
import openai
import random

openai.api_key = secret
DEBUG = False

with open("prompt.txt", "r") as f:
    prompt_string = f.read()

def generate_example(prompt, prev_examples, temperature=.5):
    messages=[
        {
            "role": "system",
            "content": f"{prompt_string}`{prompt}`"
        }
    ]

    if len(prev_examples) > 0:
        if len(prev_examples) > 10:
            prev_examples = random.sample(prev_examples, 10)
        for example in prev_examples:
            messages.append({
                "role": "assistant",
                "content": example
            })

    response = completions_with_backoff(
        model="gpt-3.5-turbo",
        messages=messages,
        temperature=temperature,
        max_tokens=1354,
    )

    return response.choices[0].message['content']

# Generate examples
prev_examples = []
for i in range(number_of_examples):
    print(f'Generating example {i}')
    example = generate_example(prompt, prev_examples, temperature)
    prev_examples.append(example)

    ### DEBUG ###
    if i == 10 and DEBUG:
        break

clear_output()
inf('\u2714 Done','success', '50px')
prev_examples

Button(button_style='success', description='✔ Done', disabled=True, layout=Layout(min_width='50px'), style=But…

['prompt\n-----------\nGive me a sentence with the word "cat".\n-----------\n\nresponse\n-----------\nI have a pet cat named Whiskers.\n-----------',
 'prompt\n-----------\nPlease provide a sentence that includes the words "cat" and "dog".\n-----------\n\nresponse\n-----------\nThe cat and the dog are playing in the backyard.\n-----------',
 'prompt\n-----------\nCan you create a sentence that contains the words "cat", "dog", and "bird"?\n-----------\n\nresponse\n-----------\nThe cat chased the dog while the bird watched from the tree.\n-----------',
 'prompt\n-----------\nI need a sentence that includes the words "cat", "dog", "bird", and "fish".\n-----------\n\nresponse\n-----------\nThe cat and the dog were chasing each other, while the bird flew overhead and the fish swam in the pond.\n-----------',
 'prompt\n-----------\nCould you give me a sentence that includes the words "cat", "dog", "bird", "fish", and "rabbit"?\n-----------\n\nresponse\n-----------\nThe cat and the dog were p

We also need to generate a system message.

In [None]:
def generate_system_message(prompt):

    response = openai.ChatCompletion.create(
        model="gpt-4",
        messages=[
          {
            "role": "system",
            "content": "You will be given a high-level description of the model we are training, and from that, you will generate a simple system prompt for that model to use. Remember, you are not generating the system message for data generation -- you are generating the system message to use for inference. A good format to follow is `Given $INPUT_DATA, you will $WHAT_THE_MODEL_SHOULD_DO.`.\n\nMake it as concise as possible. Include nothing but the system prompt in your response.\n\nFor example, never write: `\"$SYSTEM_PROMPT_HERE\"`.\n\nIt should be like: `$SYSTEM_PROMPT_HERE`."
          },
          {
              "role": "user",
              "content": prompt.strip(),
          }
        ],
        temperature=temperature,
        max_tokens=500,
    )

    return response.choices[0].message['content']

system_message = generate_system_message(prompt)

print(f'The system message is: `{system_message}`. Feel free to re-run this cell if you want a better result.')

The system message is: `Given a list of English words, generate a sentence incorporating all these words.`. Feel free to re-run this cell if you want a better result.


Now let's put our examples into a dataframe and turn them into a final pair of datasets.

In [None]:
# len(prompts), len(responses)  # check if needed

In [None]:
import pandas as pd

# Initialize lists to store prompts and responses
prompts = []
responses = []

# Parse out prompts and responses from examples
for example in prev_examples:
  try:
    split_example = example.split('-----------')
    prompts.append(split_example[1].strip())
    responses.append(split_example[3].strip())  # check your idx
  except:
    pass

# Create a DataFrame
df = pd.DataFrame({
    'prompt': prompts,
    'response': responses
})

# Remove duplicates
df = df.drop_duplicates()

print('There are ' + str(len(df)) + ' successfully-generated examples. Here are the first few:')

df.head()

There are 44 successfully-generated examples. Here are the first few:


Unnamed: 0,prompt,response
0,"Give me a sentence with the word ""cat"".",I have a pet cat named Whiskers.
1,Please provide a sentence that includes the wo...,The cat and the dog are playing in the backyard.
2,Can you create a sentence that contains the wo...,The cat chased the dog while the bird watched ...
3,"I need a sentence that includes the words ""cat...","The cat and the dog were chasing each other, w..."
4,Could you give me a sentence that includes the...,The cat and the dog were playing in the garden...


In [None]:
df.loc[df['response'].apply(lambda x: "note" in str(x).lower())] # check if there is redundand text

Unnamed: 0,prompt,response


Split into train and test sets.

In [None]:
# Split the data into train and test sets, with 90% in the train set
train_df = df.sample(frac=0.9, random_state=42)
test_df = df.drop(train_df.index)

# Save the dataframes to .jsonl files
train_df.to_json('train.jsonl', orient='records', lines=True)
test_df.to_json('test.jsonl', orient='records', lines=True)

# 💫 **Fine-Tuning**

Define Hyperparameters

In [None]:
model_name = "NousResearch/llama-2-7b-chat-hf" # use this if you have access to the official LLaMA 2 model "meta-llama/Llama-2-7b-chat-hf", though keep in mind you'll need to pass a Hugging Face key argument
dataset_name = "/content/train.jsonl"
new_model = "llama-2-7b-custom"
lora_r = 64
lora_alpha = 16
lora_dropout = 0.1
use_4bit = True
bnb_4bit_compute_dtype = "float16"
bnb_4bit_quant_type = "nf4"
use_nested_quant = False
output_dir = "./results"
num_train_epochs = 2
fp16 = False
bf16 = False
per_device_train_batch_size = 4
per_device_eval_batch_size = 4
gradient_accumulation_steps = 1
gradient_checkpointing = True
max_grad_norm = 0.3
learning_rate = 2e-4
weight_decay = 0.001
optim = "paged_adamw_32bit"
lr_scheduler_type = "constant"
max_steps = -1
warmup_ratio = 0.03
group_by_length = True
save_steps = 25
logging_steps = 5
max_seq_length = None
packing = False
device_map = {"": 0}

Load Datasets and Train

In [None]:
# # Reload after restart if needed
# system_message = "Given a list of English words, generate a sentence incorporating all these words."
# example_input = ["left", "rock", "play"]

In [None]:
# Load datasets
train_dataset = load_dataset('json', data_files='/content/train.jsonl', split="train")
valid_dataset = load_dataset('json', data_files='/content/test.jsonl', split="train")

# Preprocess datasets
train_dataset_mapped = train_dataset.map(lambda examples: {'text': [f'[INST] <<SYS>>\n{system_message.strip()}\n<</SYS>>\n\n' + prompt + ' [/INST] ' + response for prompt, response in zip(examples['prompt'], examples['response'])]}, batched=True)
valid_dataset_mapped = valid_dataset.map(lambda examples: {'text': [f'[INST] <<SYS>>\n{system_message.strip()}\n<</SYS>>\n\n' + prompt + ' [/INST] ' + response for prompt, response in zip(examples['prompt'], examples['response'])]}, batched=True)

compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=device_map
)
model.config.use_cache = False
model.config.pretraining_tp = 1
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
)
# Set training parameters
training_arguments = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    fp16=fp16,
    bf16=bf16,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    report_to="all",
    evaluation_strategy="steps",
    eval_steps=5  # Evaluate every 20 steps
)
# Set supervised fine-tuning parameters
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset_mapped,
    eval_dataset=valid_dataset_mapped,  # Pass validation dataset here
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
    packing=packing,
)
trainer.train()
trainer.model.save_pretrained(new_model)

# Clear output
clear_output()
inf('\u2714 Done','success', '50px')

Button(button_style='success', description='✔ Done', disabled=True, layout=Layout(min_width='50px'), style=But…

Test the model

In [None]:
# Test the model
logging.set_verbosity(logging.CRITICAL)
example_input = ["girl", "check", "rock", "politics"]
prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n{example_input} [/INST]"
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=200)
result = pipe(prompt)
print(result[0]['generated_text'])



[INST] <<SYS>>
Given a list of English words, generate a sentence incorporating all these words.
<</SYS>>

['girl', 'check', 'rock', 'politics'] [/INST]  The girl checked the rock in her pocket and thought about the latest political news.


# 🏃 **Run Inference**

In [None]:
#@markdown ☑️ Helper functions
from IPython.display import display, HTML
from transformers import pipeline

def display_html(text):
    html_code = f"""
    <div style="border: 4px solid #4CAF50; border-radius: 10px; padding: 10px; font-size: 20px; color: #333;">
        <blockquote style="quotes: '“' '”';">
            <strong>{text}</strong>
        </blockquote>
    </div>
    """
    display(HTML(html_code))

def extract_result(result_object):
    raw = result_object[0]['generated_text']

    # fix incorrect prompt deletion
    try:
      res = raw.replace(prompt, '').split('[/INST] ')[-1]
    except Exception:
      res = raw.replace(prompt, '')

    # fix redundand comments:
    try:
        if ":" in res and "\"" in res:
            res = res.split(":")[-1]
            res = res.replace("\"", "").replace("\n", "")
    except Exception:
        pass

    return res

def format_result(result_text, words, color='lightgreen'):
    for word in words:
        result_text = result_text.replace(word, f'<span style="font-weight: bold; background-color: {color};">{word}</span>')
    return result_text

def pprint_response(result_object, words):
    res = extract_result(result_object)
    html = format_result(res, words)
    display_html(html)

# # Example usage
# result = [{'generated_text':"This is a sample text with some important words."}]
# words = ["sample", "important"]
# pprint_response(result, words)

def run_inference(example_input, system_message=system_message):
    prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n{example_input} [/INST]" # replace the command here with something relevant to your task
    num_new_tokens = 200  # change to the number of new tokens you want to generate
    # Count the number of tokens in the prompt
    num_prompt_tokens = len(tokenizer(prompt)['input_ids'])
    # Calculate the maximum length for the generation
    max_length = num_prompt_tokens + num_new_tokens
    gen = pipeline('text-generation', model=model, tokenizer=tokenizer, max_length=max_length)
    result = gen(prompt)
    return result

In [None]:
user_input = ["happy", "sad", "erewwer"]  #@param {type: "raw"}
result = run_inference(user_input)
pprint_response(result, user_input)

In [None]:
user_input = ["sad", "creative", "data-scientist", "Israel", "USA"] #@param {type: "raw"}
result = run_inference(user_input)
pprint_response(result, user_input)

In [None]:
user_input = ["stochastic", "backpropagation", "happy"] #@param {type: "raw"}
result = run_inference(user_input)
pprint_response(result, user_input)

# 🫙**Merge & store the model**

We need to merge the weights from LoRA with the base model. Unfortunately, as far as I know, there is no straightforward way to do it: we need to reload the base model in FP16 precision and use the peft library to merge everything. Alas, it also creates a problem with the VRAM (despite emptying it), so we need to restart the notebook and then execute the cells below.

In [None]:
#@markdown ☑️ Reload necessary libraries
from IPython.display import clear_output
import ipywidgets as widgets
import os

def inf(msg, style, wdth):
    inf = widgets.Button(description=msg,
                         disabled=True,
                         button_style=style,
                         layout=widgets.Layout(min_width=wdth))
    display(inf)

import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer
import locale

def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

model_name = "NousResearch/llama-2-7b-chat-hf" # use this if you have access to the official LLaMA 2 model "meta-llama/Llama-2-7b-chat-hf", though keep in mind you'll need to pass a Hugging Face key argument
dataset_name = "/content/train.jsonl"
new_model = "llama-2-7b-custom"
lora_r = 64
lora_alpha = 16
lora_dropout = 0.1
use_4bit = True
bnb_4bit_compute_dtype = "float16"
bnb_4bit_quant_type = "nf4"
use_nested_quant = False
output_dir = "./results"
num_train_epochs = 2
fp16 = False
bf16 = False
per_device_train_batch_size = 4
per_device_eval_batch_size = 4
gradient_accumulation_steps = 1
gradient_checkpointing = True
max_grad_norm = 0.3
learning_rate = 2e-4
weight_decay = 0.001
optim = "paged_adamw_32bit"
lr_scheduler_type = "constant"
max_steps = -1
warmup_ratio = 0.03
group_by_length = True
save_steps = 25
logging_steps = 5
max_seq_length = None
packing = False
device_map = {"": 0}

clear_output()
inf('Done', 'success', '50px')

Button(button_style='success', description='Done', disabled=True, layout=Layout(min_width='50px'), style=Butto…

In [None]:
# Merge and save the fine-tuned model
from google.colab import drive
drive.mount('/content/drive')

model_path = "/content/drive/MyDrive/llama-2-7b-custom"  # change to your preferred path

# Reload model in FP16 and merge it with LoRA weights
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map=device_map,
)
model = PeftModel.from_pretrained(base_model, new_model)
model = model.merge_and_unload()

# Reload tokenizer to save it
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Save the merged model
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

('/content/drive/MyDrive/llama-2-7b-custom/tokenizer_config.json',
 '/content/drive/MyDrive/llama-2-7b-custom/special_tokens_map.json',
 '/content/drive/MyDrive/llama-2-7b-custom/tokenizer.json')

# ⭐ Load a fine-tuned model from Drive and run inference *

If you have Colab Pro+

In [None]:
#@markdown ☑️ Reload necessary libraries
from IPython.display import clear_output
import ipywidgets as widgets
import os

def inf(msg, style, wdth):
    inf = widgets.Button(description=msg,
                         disabled=True,
                         button_style=style,
                         layout=widgets.Layout(min_width=wdth))
    display(inf)

import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer
import locale

def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

model_name = "NousResearch/llama-2-7b-chat-hf" # use this if you have access to the official LLaMA 2 model "meta-llama/Llama-2-7b-chat-hf", though keep in mind you'll need to pass a Hugging Face key argument
dataset_name = "/content/train.jsonl"
new_model = "llama-2-7b-custom"
lora_r = 64
lora_alpha = 16
lora_dropout = 0.1
use_4bit = True
bnb_4bit_compute_dtype = "float16"
bnb_4bit_quant_type = "nf4"
use_nested_quant = False
output_dir = "./results"
num_train_epochs = 2
fp16 = False
bf16 = False
per_device_train_batch_size = 4
per_device_eval_batch_size = 4
gradient_accumulation_steps = 1
gradient_checkpointing = True
max_grad_norm = 0.3
learning_rate = 2e-4
weight_decay = 0.001
optim = "paged_adamw_32bit"
lr_scheduler_type = "constant"
max_steps = -1
warmup_ratio = 0.03
group_by_length = True
save_steps = 25
logging_steps = 5
max_seq_length = None
packing = False
device_map = {"": 0}

clear_output()
inf('Done', 'success', '50px')

Button(button_style='success', description='Done', disabled=True, layout=Layout(min_width='50px'), style=Butto…

In [None]:
from google.colab import drive
from transformers import AutoModelForCausalLM, AutoTokenizer

drive.mount('/content/drive')

model_path = "/content/drive/MyDrive/llama-2-7b-custom"  # change to the path where your model is saved

model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [None]:
from transformers import pipeline

prompt = "What is 2 + 2?"  # change to your desired prompt
gen = pipeline('text-generation', model=model, tokenizer=tokenizer)
result = gen(prompt)
print(result[0]['generated_text'])

In [None]:
# ! export 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32'

import gc
torch.cuda.empty_cache()
gc.collect()

0