## Setup

Add any imports needed here

Mount Google Drive

In [14]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

!pip -q install datasets
!pip -q install transformers
!pip -q install -U peft
!pip install bitsandbytes --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui
!pip -q install rouge

Looking in indexes: https://pypi.org/simple, https://jllllll.github.io/bitsandbytes-windows-webui


In [2]:
import numpy as np
from datasets import load_dataset
import torch
from torch.utils.data import Dataset

from transformers import AutoTokenizer, AutoModelForCausalLM
from google.colab import drive
from bs4 import BeautifulSoup

drive.mount('/content/drive')

#%cd /content/drive/MyDrive/CS6220 Folder #add direct access to the folder if you get an error in this cell
import os
os.chdir("/content/drive/MyDrive/CS6220 Folder")

import helpers
import importlib

Mounted at /content/drive


## **IBM Python Dataset**

## Dataset

We use a different code dataset which we then partition into artificial "repositories" using PyTorch dataloader

In [3]:
## TODO: Diana
## Download Dataset
%cd /content/drive/.shortcut-targets-by-id/1ZKIbzAJY4RXvlu64yVQt4OckCxyNqRJ6/CS6220 Folder

/content


The next approach uses the Scripts_Project_CodeNet_Python800 folder. It contains .py scripts. Each script has multiple chunks of code (combined from the previous structure, where in each folder we had multiple scripts). The delimiter between chunk and chunk is a string of hashtags. (#####################################################)

In [4]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
import random


class CustomDataset(Dataset):

    def __init__(self, folder_path, partition, delimiter="#", chunk_size=100): #initializes the dataset with the folder path, delimiter, and chunk size
        self.folder_path = folder_path
        self.delimiter = delimiter
        self.chunk_size = chunk_size

        self.script_files = [f for f in os.listdir(folder_path) if f.endswith(".py")] # Get a list of all Python scripts in the specified folder with a ".py" extension
        if partition == 1:
          self.script_files = self.script_files[:10]
        else:
          self.script_files = self.script_files[10:20]

        self.data = []

        for idx in range(len(self.script_files)):
          script_file_path = os.path.join(self.folder_path, self.script_files[idx])
          with open(script_file_path, 'r', encoding='utf-8') as file:
              code_samples = file.read().split(self.delimiter)
              code_samples = [i for i in code_samples if len(i) > 5]
              self.data.extend(code_samples)
        print(f'Code samples in partition {partition}: {len(self.data)}')

    def __len__(self):
        return len(self.data)


    def __getitem__(self, idx): # method to get a specific item from the dataset (returns a chunk of code)
        return self.data[idx]


In [5]:
folder_path = "/content/drive/.shortcut-targets-by-id/1ZKIbzAJY4RXvlu64yVQt4OckCxyNqRJ6/CS6220 Folder/Scripts_Project_CodeNet_Python800"
delimiter = "#####################################################"
chunk_size = 1
batch_size = 16

# Create instances of the dataset with specified file indices
dataset_1 = CustomDataset(folder_path, 1, delimiter, chunk_size)
dataset_2 = CustomDataset(folder_path, 2, delimiter, chunk_size)

# Create DataLoader instances for each dataset
dataloader_1 = DataLoader(dataset_1, batch_size=batch_size, shuffle=True)
dataloader_2 = DataLoader(dataset_2, batch_size=batch_size, shuffle=True)

client_dataloaders = [dataloader_1, dataloader_2]

total = 0
for batch in dataloader_1:
    total += 1
print(total)

#for batch in data_loader:
#    print(batch)

Code samples in partition 1: 3002
Code samples in partition 2: 3004
188


## Pretrained model

This model [name] from Hugging Face is pretrained on code.

In [6]:
## TODO: Jun
## Download model
# Load the CodeParrot small model and tokenizer

from helpers.train import generic_training_runner
from transformers import BitsAndBytesConfig
from peft import LoraConfig

model_name = "codeparrot/codeparrot-small"
run_title = "codeparrotsm"


SAVE_PATH = f"./ronit_ibm_outputs/checkpoints/{run_title}"
LOSSES_PATH = f"./ronit_ibm_outputs/logs/losses/{run_title}"
TIMES_PATH = f"./ronit_ibm_outputs/logs/times/{run_title}_elapsed_time"

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
tokenizer.pad_token = tokenizer.eos_token

model_info = {
    'model_name': model_name,
    'tokenizer': tokenizer,
    'client_dataloaders': client_dataloaders,
    'quant_config':
        BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
        ),
    'lora_config':
        LoraConfig(
            r=8,
            target_modules=["c_attn", "c_proj", "c_fc"],
            lora_alpha=32,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
        ),
}

training_args = {
    'clients': 2,
    'data_per_client': 1e3,
    'MAX_LENGTH': 1024,
    'conduct_logging': True,
    'EPOCHS': 1,
    'lr': 4e-4
}


tokenizer_config.json:   0%|          | 0.00/259 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/497k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/277k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/840k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

## Partition Code [outdated]

In [None]:
# ## Partition Dataset

# import random

# # Assuming you have 'n_partitions' as the number of partitions you want
# n_partitions = 5  # for example, change this based on your requirement

# # Shuffle the dataset
# random.shuffle(custom_dataset.script_files)

# # Calculate the size of each partition
# partition_size = len(custom_dataset) // n_partitions

# # Create partitions
# partitions = [custom_dataset.script_files[i * partition_size:(i + 1) * partition_size] for i in range(n_partitions)]

# # Now create a DataLoader for each partition
# dataloaders = []
# for partition in partitions:
#     # Creating a new CustomDataset instance for each partition
#     partition_dataset = CustomDataset(folder_path, delimiter, chunk_size)
#     partition_dataset.script_files = partition  # replace the script_files with the partition

#     # Create a DataLoader for the partitioned dataset
#     partition_dataloader = DataLoader(partition_dataset, batch_size=1, shuffle=True)
#     dataloaders.append(partition_dataloader)

# # Now 'dataloaders' contains separate DataLoader for each partition

# import hashlib
# from itertools import islice

# def hash_code(data):
#     # Create a hash for the given data
#     return hashlib.sha256(str(data).encode()).hexdigest()

# # Iterate through each DataLoader
# for i, each_data_loader in enumerate(dataloaders):
#     print(f"\nDataLoader {i + 1}:")

#     # Iterate through the first 5 batches of each DataLoader
#     for j, batch in enumerate(islice(each_data_loader, 5), 1):
#         # Compute the hash of the batch
#         batch_hash = hash_code(batch)

#         print(f"  Batch {j} Hash: {batch_hash}")





DataLoader 1:


TypeError: ignored

In [None]:
# # Validation Script

# # Check for equal size
# partition_sizes = [len(partition) for partition in partitions]
# print("Sizes of each partition:", partition_sizes)

# # Check if sizes are approximately equal
# # Note: They might not be exactly equal if the total size is not perfectly divisible by n_partitions
# if all(abs(size - partition_size) <= 1 for size in partition_sizes):
#     print("All partitions are of approximately equal size.")
# else:
#     print("Warning: Partitions sizes vary more than expected.")

# # Consistency Check
# # Example: Checking if each partition still contains script files and not something else
# for i, partition_dataset in enumerate(dataloaders):
#     is_consistent = all(isinstance(script, str) and script.endswith('.py') for script in partition_dataset.dataset.script_files)
#     print(f"Partition {i} consistency: {'Consistent' if is_consistent else 'Inconsistent'}")

# # Note: Replace the consistency check as per your criteria


# def hash_code(code_snippet):
#     # Using SHA-256 for hashing
#     return hashlib.sha256(code_snippet.encode()).hexdigest()

# # Check for duplicates within each partition
# for i, partition in enumerate(partitions):
#     hash_set = set()
#     duplicates = set()
#     for snippet in partition:
#         snippet_hash = hash_code(snippet)
#         if snippet_hash in hash_set:
#             duplicates.add(snippet_hash)
#         else:
#             hash_set.add(snippet_hash)

#     if duplicates:
#         print(f"Partition {i} has duplicates: {duplicates}")
#     else:
#         print(f"Partition {i} has no duplicates.")

# # Check for duplicates across partitions
# global_hash_set = set()
# cross_partition_duplicates = set()
# for partition in partitions:
#     for snippet in partition:
#         snippet_hash = hash_code(snippet)
#         if snippet_hash in global_hash_set:
#             cross_partition_duplicates.add(snippet_hash)
#         else:
#             global_hash_set.add(snippet_hash)

# if cross_partition_duplicates:
#     print(f"Cross-partition duplicates found: {cross_partition_duplicates}")
# else:
#     print("No cross-partition duplicates found.")




Sizes of each partition: [160, 160, 160, 160, 160]
All partitions are of approximately equal size.
Partition 0 consistency: Consistent
Partition 1 consistency: Consistent
Partition 2 consistency: Consistent
Partition 3 consistency: Consistent
Partition 4 consistency: Consistent
Partition 0 has no duplicates.
Partition 1 has no duplicates.
Partition 2 has no duplicates.
Partition 3 has no duplicates.
Partition 4 has no duplicates.
No cross-partition duplicates found.


## **Deduplication** [outdated]

In [None]:
#dataset = load_dataset("codeparrot/codeparrot-clean")

Downloading readme:   0%|          | 0.00/1.30k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/54 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/246M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/248M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/247M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/247M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/247M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/246M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/246M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/248M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/245M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/245M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/244M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/243M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/245M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/242M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/243M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/240M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/242M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/242M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/241M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/242M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/236M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/238M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/240M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/237M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/238M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/237M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/238M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/239M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/238M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/239M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/237M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/239M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/236M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/237M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/235M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/236M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/234M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/235M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/234M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/234M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/235M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/236M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/236M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/234M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/237M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/234M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/232M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/232M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/233M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/234M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/233M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/234M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/230M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/142M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

In [None]:
#########not enough RAM -- try process in chunks ??????????
path = '/content/drive/MyDrive/CS6220 Folder/codeparrot-clean/extracted/'

import json
import pandas as pd

all_code_content = []

for i in range(1, 55):
    file_name = f'file-{i:012}.json'
    full_path = f'{path}{file_name}'
    try:
        with open(full_path, 'r') as file:
            for line in file:
                try:
                    data = json.loads(line)
                    code_content = data.get('content', '')
                    if code_content:
                        all_code_content.append({'content': code_content})
                except json.JSONDecodeError as e:
                    print(f"JSON decode error in file {file_name}, line: {line}")
    except FileNotFoundError as e:
        print(f"File not found: {e}")

df = pd.DataFrame(all_code_content)



KeyboardInterrupt: ignored

## **Big Code/ The-Stack** [outdated]

In [None]:
!pip install huggingface-cli
!huggingface-cli login


In [None]:
from  datasets  import  load_dataset

# full dataset (3TB of data)
#ds = load_dataset("bigcode/the-stack", split="train")

# specific language (e.g. Dockerfiles)
ds = load_dataset("bigcode/the-stack", data_dir="data/python", split="train")

# dataset streaming (will only download the data as needed)
#ds = load_dataset("bigcode/the-stack", streaming=True, split="train")
#for sample in iter(ds): print(sample["content"])


##run script

In [10]:
generic_training_runner(
    SAVE_PATH,
    LOSSES_PATH,
    TIMES_PATH,
    model_info,
    training_args,
)

Device Type: cuda
Beginning training iteration
Loaded global model.
Number of batches: 188
Starting training


1it [00:01,  1.92s/it]

Step: 0, Loss: 8.2141


2it [00:04,  2.23s/it]

Step: 1, Loss: 6.0871


3it [00:06,  2.19s/it]

Step: 2, Loss: 1.8793


4it [00:10,  2.77s/it]

Step: 3, Loss: 0.8737


5it [00:13,  2.87s/it]

Step: 4, Loss: 0.8002


6it [00:14,  2.49s/it]

Step: 5, Loss: 0.8529


7it [00:16,  2.11s/it]

Step: 6, Loss: 1.0662


8it [00:17,  1.67s/it]

Step: 7, Loss: 1.2932


9it [00:22,  2.99s/it]

Step: 8, Loss: 0.4014


10it [00:24,  2.44s/it]

Step: 9, Loss: 0.7880


11it [00:27,  2.60s/it]

Step: 10, Loss: 0.5341


12it [00:28,  2.13s/it]

Step: 11, Loss: 0.7711


13it [00:30,  2.16s/it]

Step: 12, Loss: 0.6124


14it [00:32,  2.11s/it]

Step: 13, Loss: 0.5294


15it [00:34,  2.04s/it]

Step: 14, Loss: 0.5654


16it [00:37,  2.29s/it]

Step: 15, Loss: 0.3812


17it [00:39,  2.26s/it]

Step: 16, Loss: 0.5187


18it [00:40,  2.04s/it]

Step: 17, Loss: 0.6696


19it [00:43,  2.16s/it]

Step: 18, Loss: 0.4840


20it [00:47,  2.82s/it]

Step: 19, Loss: 0.2562


21it [00:49,  2.47s/it]

Step: 20, Loss: 0.5513


22it [00:55,  3.49s/it]

Step: 21, Loss: 0.2697


23it [00:56,  2.84s/it]

Step: 22, Loss: 0.5430


24it [00:57,  2.24s/it]

Step: 23, Loss: 0.7625


25it [00:58,  1.92s/it]

Step: 24, Loss: 0.5575


26it [01:00,  2.07s/it]

Step: 25, Loss: 0.4399


27it [01:02,  1.85s/it]

Step: 26, Loss: 0.3854


28it [01:04,  2.02s/it]

Step: 27, Loss: 0.4780


29it [01:07,  2.32s/it]

Step: 28, Loss: 0.4440


30it [01:13,  3.38s/it]

Step: 29, Loss: 0.4066


31it [01:15,  2.93s/it]

Step: 30, Loss: 0.3953


32it [01:16,  2.50s/it]

Step: 31, Loss: 0.4406


33it [01:18,  2.23s/it]

Step: 32, Loss: 0.4677


34it [01:19,  1.85s/it]

Step: 33, Loss: 0.6522


35it [01:21,  2.03s/it]

Step: 34, Loss: 0.3726


36it [01:24,  2.25s/it]

Step: 35, Loss: 0.3707


37it [01:25,  1.94s/it]

Step: 36, Loss: 0.4239


38it [01:27,  1.76s/it]

Step: 37, Loss: 0.6271


39it [01:31,  2.40s/it]

Step: 38, Loss: 0.3604


40it [01:33,  2.39s/it]

Step: 39, Loss: 0.3741


41it [01:34,  1.92s/it]

Step: 40, Loss: 0.6019


42it [01:36,  2.11s/it]

Step: 41, Loss: 0.2763


43it [01:38,  2.06s/it]

Step: 42, Loss: 0.3592


44it [01:42,  2.58s/it]

Step: 43, Loss: 0.3940


45it [01:45,  2.67s/it]

Step: 44, Loss: 0.4004


46it [01:46,  2.31s/it]

Step: 45, Loss: 0.6351


47it [01:47,  1.91s/it]

Step: 46, Loss: 0.7279


48it [01:52,  2.67s/it]

Step: 47, Loss: 0.3067


49it [01:54,  2.39s/it]

Step: 48, Loss: 0.4026


50it [02:00,  3.44s/it]

Step: 49, Loss: 0.5501


51it [02:01,  2.86s/it]

Step: 50, Loss: 0.4606


52it [02:04,  2.77s/it]

Step: 51, Loss: 0.3279


53it [02:05,  2.47s/it]

Step: 52, Loss: 0.5226


54it [02:06,  2.00s/it]

Step: 53, Loss: 0.6862


55it [02:08,  1.80s/it]

Step: 54, Loss: 0.6552


56it [02:10,  2.05s/it]

Step: 55, Loss: 0.3528


57it [02:13,  2.37s/it]

Step: 56, Loss: 0.3208


58it [02:15,  2.27s/it]

Step: 57, Loss: 0.3746


59it [02:17,  2.09s/it]

Step: 58, Loss: 0.3767


60it [02:19,  2.16s/it]

Step: 59, Loss: 0.3918


61it [02:21,  2.06s/it]

Step: 60, Loss: 0.5576


62it [02:24,  2.32s/it]

Step: 61, Loss: 0.3800


63it [02:26,  2.23s/it]

Step: 62, Loss: 0.4630


64it [02:28,  2.11s/it]

Step: 63, Loss: 0.3968


65it [02:31,  2.27s/it]

Step: 64, Loss: 0.3396


66it [02:33,  2.29s/it]

Step: 65, Loss: 0.4237


67it [02:35,  2.34s/it]

Step: 66, Loss: 0.4053


68it [02:39,  2.70s/it]

Step: 67, Loss: 0.2790


69it [02:41,  2.41s/it]

Step: 68, Loss: 0.3201


70it [02:42,  2.12s/it]

Step: 69, Loss: 0.4686


71it [02:44,  1.97s/it]

Step: 70, Loss: 0.3608


72it [02:45,  1.67s/it]

Step: 71, Loss: 0.5371


73it [02:46,  1.61s/it]

Step: 72, Loss: 0.4888


74it [02:47,  1.46s/it]

Step: 73, Loss: 0.5562


75it [02:50,  1.91s/it]

Step: 74, Loss: 0.2854


76it [02:52,  1.74s/it]

Step: 75, Loss: 0.6071


77it [02:55,  2.31s/it]

Step: 76, Loss: 0.4343


78it [02:58,  2.36s/it]

Step: 77, Loss: 0.3579


79it [03:04,  3.41s/it]

Step: 78, Loss: 0.2908


80it [03:05,  2.89s/it]

Step: 79, Loss: 0.3215


81it [03:06,  2.31s/it]

Step: 80, Loss: 0.5082


82it [03:09,  2.31s/it]

Step: 81, Loss: 0.5382


83it [03:10,  2.06s/it]

Step: 82, Loss: 0.4260


84it [03:13,  2.28s/it]

Step: 83, Loss: 0.2578


85it [03:18,  3.14s/it]

Step: 84, Loss: 0.1819


86it [03:19,  2.65s/it]

Step: 85, Loss: 0.3582


87it [03:21,  2.27s/it]

Step: 86, Loss: 0.5024


88it [03:22,  2.08s/it]

Step: 87, Loss: 0.3763


89it [03:24,  1.98s/it]

Step: 88, Loss: 0.2833


90it [03:27,  2.30s/it]

Step: 89, Loss: 0.2462


91it [03:29,  2.13s/it]

Step: 90, Loss: 0.4755


92it [03:30,  1.81s/it]

Step: 91, Loss: 0.4401


93it [03:33,  2.05s/it]

Step: 92, Loss: 0.3320


94it [03:35,  2.13s/it]

Step: 93, Loss: 0.2255


95it [03:36,  1.72s/it]

Step: 94, Loss: 0.6439


96it [03:37,  1.53s/it]

Step: 95, Loss: 0.4146


97it [03:43,  2.83s/it]

Step: 96, Loss: 0.2373


98it [03:44,  2.51s/it]

Step: 97, Loss: 0.4499


99it [03:47,  2.41s/it]

Step: 98, Loss: 0.3846


100it [03:48,  2.06s/it]

Step: 99, Loss: 0.5201


101it [03:51,  2.49s/it]

Step: 100, Loss: 0.2169


102it [03:54,  2.50s/it]

Step: 101, Loss: 0.4534


103it [03:56,  2.42s/it]

Step: 102, Loss: 0.2842


104it [03:59,  2.67s/it]

Step: 103, Loss: 0.3621


105it [04:02,  2.57s/it]

Step: 104, Loss: 0.3467


106it [04:03,  2.32s/it]

Step: 105, Loss: 0.3809


107it [04:05,  2.07s/it]

Step: 106, Loss: 0.4575


108it [04:06,  1.83s/it]

Step: 107, Loss: 0.5476


109it [04:07,  1.59s/it]

Step: 108, Loss: 0.6074


110it [04:09,  1.52s/it]

Step: 109, Loss: 0.5156


111it [04:10,  1.51s/it]

Step: 110, Loss: 0.4349


112it [04:13,  2.02s/it]

Step: 111, Loss: 0.2050


113it [04:16,  2.34s/it]

Step: 112, Loss: 0.2678


114it [04:19,  2.46s/it]

Step: 113, Loss: 0.2740


115it [04:21,  2.34s/it]

Step: 114, Loss: 0.3375


116it [04:23,  2.09s/it]

Step: 115, Loss: 0.4768


117it [04:24,  1.74s/it]

Step: 116, Loss: 0.4337


118it [04:25,  1.64s/it]

Step: 117, Loss: 0.2871


119it [04:27,  1.81s/it]

Step: 118, Loss: 0.2599


120it [04:30,  2.13s/it]

Step: 119, Loss: 0.3328


121it [04:32,  2.08s/it]

Step: 120, Loss: 0.3840


122it [04:35,  2.33s/it]

Step: 121, Loss: 0.2342


123it [04:41,  3.40s/it]

Step: 122, Loss: 0.3529


124it [04:43,  2.99s/it]

Step: 123, Loss: 0.3461


125it [04:44,  2.42s/it]

Step: 124, Loss: 0.4652


126it [04:45,  2.10s/it]

Step: 125, Loss: 0.4691


127it [04:47,  2.07s/it]

Step: 126, Loss: 0.4560


128it [04:51,  2.51s/it]

Step: 127, Loss: 0.2850


129it [04:53,  2.51s/it]

Step: 128, Loss: 0.3874


130it [04:55,  2.33s/it]

Step: 129, Loss: 0.3819


131it [04:57,  2.13s/it]

Step: 130, Loss: 0.3125


132it [04:59,  2.16s/it]

Step: 131, Loss: 0.3556


133it [05:01,  2.19s/it]

Step: 132, Loss: 0.3058


134it [05:03,  1.93s/it]

Step: 133, Loss: 0.4569


135it [05:09,  3.12s/it]

Step: 134, Loss: 0.2650


136it [05:11,  2.79s/it]

Step: 135, Loss: 0.4267


137it [05:12,  2.27s/it]

Step: 136, Loss: 0.4671


138it [05:15,  2.42s/it]

Step: 137, Loss: 0.3569


139it [05:20,  3.46s/it]

Step: 138, Loss: 0.1920


140it [05:22,  2.85s/it]

Step: 139, Loss: 0.3416


141it [05:23,  2.27s/it]

Step: 140, Loss: 0.5973


142it [05:25,  2.29s/it]

Step: 141, Loss: 0.4247


143it [05:27,  2.13s/it]

Step: 142, Loss: 0.4393


144it [05:28,  1.94s/it]

Step: 143, Loss: 0.3488


145it [05:31,  2.10s/it]

Step: 144, Loss: 0.3611


146it [05:33,  2.02s/it]

Step: 145, Loss: 0.3391


147it [05:35,  2.15s/it]

Step: 146, Loss: 0.3317


148it [05:38,  2.45s/it]

Step: 147, Loss: 0.2046


149it [05:43,  3.14s/it]

Step: 148, Loss: 0.2366


150it [05:48,  3.81s/it]

Step: 149, Loss: 0.2022


151it [05:53,  4.10s/it]

Step: 150, Loss: 0.2088


152it [05:55,  3.36s/it]

Step: 151, Loss: 0.3058


153it [05:56,  2.84s/it]

Step: 152, Loss: 0.4344


154it [05:58,  2.57s/it]

Step: 153, Loss: 0.3904


155it [06:04,  3.56s/it]

Step: 154, Loss: 0.1952


156it [06:07,  3.22s/it]

Step: 155, Loss: 0.2422


157it [06:09,  2.87s/it]

Step: 156, Loss: 0.3751


158it [06:12,  2.89s/it]

Step: 157, Loss: 0.3668


159it [06:16,  3.24s/it]

Step: 158, Loss: 0.2419


160it [06:20,  3.45s/it]

Step: 159, Loss: 0.1933


161it [06:21,  2.75s/it]

Step: 160, Loss: 0.5420


162it [06:27,  3.69s/it]

Step: 161, Loss: 0.2792


163it [06:29,  3.35s/it]

Step: 162, Loss: 0.4187


164it [06:35,  4.11s/it]

Step: 163, Loss: 0.1729


165it [06:37,  3.50s/it]

Step: 164, Loss: 0.3377


166it [06:40,  3.40s/it]

Step: 165, Loss: 0.3302


167it [06:42,  2.76s/it]

Step: 166, Loss: 0.4675


168it [06:44,  2.68s/it]

Step: 167, Loss: 0.2859


169it [06:47,  2.90s/it]

Step: 168, Loss: 0.3157


170it [06:49,  2.44s/it]

Step: 169, Loss: 0.3986


171it [06:51,  2.22s/it]

Step: 170, Loss: 0.3916


172it [06:56,  3.32s/it]

Step: 171, Loss: 0.2749


173it [06:59,  3.02s/it]

Step: 172, Loss: 0.2173


174it [07:00,  2.49s/it]

Step: 173, Loss: 0.5231


175it [07:01,  2.10s/it]

Step: 174, Loss: 0.4075


176it [07:03,  1.91s/it]

Step: 175, Loss: 0.3812


177it [07:06,  2.39s/it]

Step: 176, Loss: 0.3146


178it [07:08,  2.13s/it]

Step: 177, Loss: 0.3490


179it [07:09,  1.97s/it]

Step: 178, Loss: 0.5164


180it [07:11,  1.87s/it]

Step: 179, Loss: 0.4547


181it [07:12,  1.71s/it]

Step: 180, Loss: 0.4718


182it [07:15,  2.07s/it]

Step: 181, Loss: 0.2459


183it [07:17,  2.13s/it]

Step: 182, Loss: 0.4249


184it [07:20,  2.23s/it]

Step: 183, Loss: 0.3007


185it [07:21,  2.00s/it]

Step: 184, Loss: 0.4109


186it [07:23,  1.93s/it]

Step: 185, Loss: 0.3266


187it [07:24,  1.70s/it]

Step: 186, Loss: 0.4284


188it [07:26,  2.38s/it]

Step: 187, Loss: 0.4346
Epoch 0 Complete
Ending training
446.595023393631
MODEL SAVED
Beginning training iteration





Loaded global model.
Number of batches: 188
Starting training


1it [00:01,  1.46s/it]

Step: 0, Loss: 7.6228


2it [00:07,  4.05s/it]

Step: 1, Loss: 6.7135


3it [00:08,  2.81s/it]

Step: 2, Loss: 2.8488


4it [00:09,  2.19s/it]

Step: 3, Loss: 1.5082


5it [00:12,  2.22s/it]

Step: 4, Loss: 1.1315


6it [00:17,  3.42s/it]

Step: 5, Loss: 0.7766


7it [00:19,  2.69s/it]

Step: 6, Loss: 1.1592


8it [00:24,  3.70s/it]

Step: 7, Loss: 0.4418


9it [00:27,  3.23s/it]

Step: 8, Loss: 0.7658


10it [00:28,  2.55s/it]

Step: 9, Loss: 1.2916


11it [00:29,  2.17s/it]

Step: 10, Loss: 0.7950


12it [00:31,  2.21s/it]

Step: 11, Loss: 0.5272


13it [00:37,  3.24s/it]

Step: 12, Loss: 0.3984


14it [00:41,  3.53s/it]

Step: 13, Loss: 0.4255


15it [00:42,  2.80s/it]

Step: 14, Loss: 0.9261


16it [00:43,  2.25s/it]

Step: 15, Loss: 0.9083


17it [00:45,  2.13s/it]

Step: 16, Loss: 0.6996


18it [00:48,  2.25s/it]

Step: 17, Loss: 0.4509


19it [00:49,  2.01s/it]

Step: 18, Loss: 0.6074


20it [00:50,  1.68s/it]

Step: 19, Loss: 0.9234


21it [00:51,  1.42s/it]

Step: 20, Loss: 0.8263


22it [00:52,  1.51s/it]

Step: 21, Loss: 0.5134


23it [00:56,  2.03s/it]

Step: 22, Loss: 0.3971


24it [00:57,  1.68s/it]

Step: 23, Loss: 0.7521


25it [00:58,  1.72s/it]

Step: 24, Loss: 0.5140


26it [01:00,  1.67s/it]

Step: 25, Loss: 0.4823


27it [01:01,  1.56s/it]

Step: 26, Loss: 0.6217


28it [01:02,  1.35s/it]

Step: 27, Loss: 0.8976


29it [01:03,  1.24s/it]

Step: 28, Loss: 0.6453


30it [01:04,  1.18s/it]

Step: 29, Loss: 0.7886


31it [01:06,  1.41s/it]

Step: 30, Loss: 0.4737


32it [01:08,  1.47s/it]

Step: 31, Loss: 0.5889


33it [01:09,  1.38s/it]

Step: 32, Loss: 0.7233


34it [01:11,  1.49s/it]

Step: 33, Loss: 0.4968


35it [01:12,  1.57s/it]

Step: 34, Loss: 0.3567


36it [01:14,  1.55s/it]

Step: 35, Loss: 0.5077


37it [01:15,  1.37s/it]

Step: 36, Loss: 0.8623


38it [01:16,  1.36s/it]

Step: 37, Loss: 0.5031


39it [01:20,  2.16s/it]

Step: 38, Loss: 0.3990


40it [01:21,  1.75s/it]

Step: 39, Loss: 0.8629


41it [01:23,  1.78s/it]

Step: 40, Loss: 0.5311


42it [01:24,  1.60s/it]

Step: 41, Loss: 0.6585


43it [01:27,  2.09s/it]

Step: 42, Loss: 0.2993


44it [01:28,  1.83s/it]

Step: 43, Loss: 0.5944


45it [01:31,  1.94s/it]

Step: 44, Loss: 0.5897


46it [01:32,  1.68s/it]

Step: 45, Loss: 0.6057


47it [01:33,  1.58s/it]

Step: 46, Loss: 0.5070


48it [01:35,  1.79s/it]

Step: 47, Loss: 0.3230


49it [01:36,  1.48s/it]

Step: 48, Loss: 0.7545


50it [01:38,  1.49s/it]

Step: 49, Loss: 0.4236


51it [01:43,  2.81s/it]

Step: 50, Loss: 0.6092


52it [01:44,  2.22s/it]

Step: 51, Loss: 0.7173


53it [01:45,  1.88s/it]

Step: 52, Loss: 0.7532


54it [01:47,  1.84s/it]

Step: 53, Loss: 0.4835


55it [01:48,  1.63s/it]

Step: 54, Loss: 0.4525


56it [01:51,  1.81s/it]

Step: 55, Loss: 0.3931


57it [01:52,  1.59s/it]

Step: 56, Loss: 0.6630


58it [01:53,  1.51s/it]

Step: 57, Loss: 0.4888


59it [01:55,  1.68s/it]

Step: 58, Loss: 0.4658


60it [01:57,  1.70s/it]

Step: 59, Loss: 0.4851


61it [01:58,  1.67s/it]

Step: 60, Loss: 0.3690


62it [02:00,  1.54s/it]

Step: 61, Loss: 0.6831


63it [02:03,  2.02s/it]

Step: 62, Loss: 0.3128


64it [02:04,  1.67s/it]

Step: 63, Loss: 0.8700


65it [02:05,  1.60s/it]

Step: 64, Loss: 0.5195


66it [02:11,  2.88s/it]

Step: 65, Loss: 0.2842


67it [02:12,  2.35s/it]

Step: 66, Loss: 0.7322


68it [02:18,  3.40s/it]

Step: 67, Loss: 0.3689


69it [02:20,  3.12s/it]

Step: 68, Loss: 0.3582


70it [02:22,  2.83s/it]

Step: 69, Loss: 0.3684


71it [02:24,  2.51s/it]

Step: 70, Loss: 0.5413


72it [02:26,  2.20s/it]

Step: 71, Loss: 0.4987


73it [02:27,  2.03s/it]

Step: 72, Loss: 0.6176


74it [02:28,  1.74s/it]

Step: 73, Loss: 0.6357


75it [02:30,  1.57s/it]

Step: 74, Loss: 0.5691


76it [02:32,  1.76s/it]

Step: 75, Loss: 0.4711


77it [02:33,  1.52s/it]

Step: 76, Loss: 0.7143


78it [02:34,  1.36s/it]

Step: 77, Loss: 0.6331


79it [02:35,  1.26s/it]

Step: 78, Loss: 0.6501


80it [02:36,  1.37s/it]

Step: 79, Loss: 0.5924


81it [02:39,  1.68s/it]

Step: 80, Loss: 0.3536


82it [02:41,  1.91s/it]

Step: 81, Loss: 0.4219


83it [02:42,  1.71s/it]

Step: 82, Loss: 0.5158


84it [02:48,  2.96s/it]

Step: 83, Loss: 0.2674


85it [02:50,  2.52s/it]

Step: 84, Loss: 0.4179


86it [02:52,  2.32s/it]

Step: 85, Loss: 0.4677


87it [02:53,  1.94s/it]

Step: 86, Loss: 0.5812


88it [02:54,  1.69s/it]

Step: 87, Loss: 0.5161


89it [02:55,  1.54s/it]

Step: 88, Loss: 0.5959


90it [02:58,  1.99s/it]

Step: 89, Loss: 0.3097


91it [03:01,  2.31s/it]

Step: 90, Loss: 0.3255


92it [03:03,  2.18s/it]

Step: 91, Loss: 0.4997


93it [03:04,  1.81s/it]

Step: 92, Loss: 0.6379


94it [03:06,  2.03s/it]

Step: 93, Loss: 0.3029


95it [03:08,  1.75s/it]

Step: 94, Loss: 0.5794


96it [03:09,  1.55s/it]

Step: 95, Loss: 0.5677


97it [03:10,  1.41s/it]

Step: 96, Loss: 0.5573


98it [03:11,  1.26s/it]

Step: 97, Loss: 0.6931


99it [03:12,  1.28s/it]

Step: 98, Loss: 0.5967


100it [03:13,  1.17s/it]

Step: 99, Loss: 0.7790


101it [03:14,  1.03s/it]

Step: 100, Loss: 0.7165


102it [03:15,  1.24s/it]

Step: 101, Loss: 0.3855


103it [03:16,  1.16s/it]

Step: 102, Loss: 0.6796


104it [03:19,  1.50s/it]

Step: 103, Loss: 0.3609


105it [03:20,  1.42s/it]

Step: 104, Loss: 0.5762


106it [03:21,  1.29s/it]

Step: 105, Loss: 0.6741


107it [03:26,  2.45s/it]

Step: 106, Loss: 0.2505


108it [03:27,  2.09s/it]

Step: 107, Loss: 0.6324


109it [03:32,  2.98s/it]

Step: 108, Loss: 0.2671


110it [03:34,  2.63s/it]

Step: 109, Loss: 0.4167


111it [03:35,  2.15s/it]

Step: 110, Loss: 0.6589


112it [03:36,  1.90s/it]

Step: 111, Loss: 0.5601


113it [03:42,  3.09s/it]

Step: 112, Loss: 0.5680


114it [03:44,  2.55s/it]

Step: 113, Loss: 0.5474


115it [03:45,  2.14s/it]

Step: 114, Loss: 0.5466


116it [03:47,  2.02s/it]

Step: 115, Loss: 0.4229


117it [03:48,  1.72s/it]

Step: 116, Loss: 0.6069


118it [03:49,  1.64s/it]

Step: 117, Loss: 0.4673


119it [03:50,  1.40s/it]

Step: 118, Loss: 0.6274


120it [03:51,  1.20s/it]

Step: 119, Loss: 0.7933


121it [03:52,  1.18s/it]

Step: 120, Loss: 0.5396


122it [03:53,  1.15s/it]

Step: 121, Loss: 0.5165


123it [03:54,  1.13s/it]

Step: 122, Loss: 0.5638


124it [03:55,  1.15s/it]

Step: 123, Loss: 0.6581


125it [03:57,  1.40s/it]

Step: 124, Loss: 0.4036


126it [03:58,  1.22s/it]

Step: 125, Loss: 0.6374


127it [03:59,  1.23s/it]

Step: 126, Loss: 0.3867


128it [04:02,  1.68s/it]

Step: 127, Loss: 0.3468


129it [04:04,  1.79s/it]

Step: 128, Loss: 0.4422


130it [04:06,  1.82s/it]

Step: 129, Loss: 0.4256


131it [04:08,  1.80s/it]

Step: 130, Loss: 0.4234


132it [04:10,  1.91s/it]

Step: 131, Loss: 0.3599


133it [04:11,  1.73s/it]

Step: 132, Loss: 0.5978


134it [04:12,  1.54s/it]

Step: 133, Loss: 0.6349


135it [04:13,  1.32s/it]

Step: 134, Loss: 0.6409


136it [04:14,  1.19s/it]

Step: 135, Loss: 0.5904


137it [04:15,  1.16s/it]

Step: 136, Loss: 0.6161


138it [04:16,  1.09s/it]

Step: 137, Loss: 0.5500


139it [04:17,  1.16s/it]

Step: 138, Loss: 0.4675


140it [04:20,  1.70s/it]

Step: 139, Loss: 0.3279


141it [04:21,  1.43s/it]

Step: 140, Loss: 0.6146


142it [04:24,  1.98s/it]

Step: 141, Loss: 0.3429


143it [04:26,  1.84s/it]

Step: 142, Loss: 0.4694


144it [04:28,  1.91s/it]

Step: 143, Loss: 0.3888


145it [04:29,  1.66s/it]

Step: 144, Loss: 0.5348


146it [04:30,  1.61s/it]

Step: 145, Loss: 0.5875


147it [04:31,  1.46s/it]

Step: 146, Loss: 0.5188


148it [04:34,  1.89s/it]

Step: 147, Loss: 0.3501


149it [04:37,  2.16s/it]

Step: 148, Loss: 0.3933


150it [04:38,  1.91s/it]

Step: 149, Loss: 0.6019


151it [04:40,  1.78s/it]

Step: 150, Loss: 0.4541


152it [04:41,  1.52s/it]

Step: 151, Loss: 0.6338


153it [04:43,  1.61s/it]

Step: 152, Loss: 0.3164


154it [04:44,  1.57s/it]

Step: 153, Loss: 0.4448


155it [04:46,  1.59s/it]

Step: 154, Loss: 0.4205


156it [04:48,  1.67s/it]

Step: 155, Loss: 0.4712


157it [04:51,  2.12s/it]

Step: 156, Loss: 0.2691


158it [04:52,  1.88s/it]

Step: 157, Loss: 0.5256


159it [04:53,  1.71s/it]

Step: 158, Loss: 0.4510


160it [04:54,  1.49s/it]

Step: 159, Loss: 0.4844


161it [04:55,  1.36s/it]

Step: 160, Loss: 0.4145


162it [04:57,  1.29s/it]

Step: 161, Loss: 0.5871


163it [04:58,  1.19s/it]

Step: 162, Loss: 0.5029


164it [05:00,  1.44s/it]

Step: 163, Loss: 0.3893


165it [05:01,  1.34s/it]

Step: 164, Loss: 0.4413


166it [05:02,  1.30s/it]

Step: 165, Loss: 0.5929


167it [05:03,  1.24s/it]

Step: 166, Loss: 0.6206


168it [05:04,  1.20s/it]

Step: 167, Loss: 0.4165


169it [05:07,  1.58s/it]

Step: 168, Loss: 0.2919


170it [05:11,  2.48s/it]

Step: 169, Loss: 0.2687


171it [05:12,  2.03s/it]

Step: 170, Loss: 0.6450


172it [05:18,  3.19s/it]

Step: 171, Loss: 0.4529


173it [05:19,  2.56s/it]

Step: 172, Loss: 0.4914


174it [05:23,  2.97s/it]

Step: 173, Loss: 0.2925


175it [05:24,  2.33s/it]

Step: 174, Loss: 0.5377


176it [05:25,  1.94s/it]

Step: 175, Loss: 0.5622


177it [05:26,  1.73s/it]

Step: 176, Loss: 0.4374


178it [05:27,  1.50s/it]

Step: 177, Loss: 0.4674


179it [05:28,  1.29s/it]

Step: 178, Loss: 0.5177


180it [05:29,  1.16s/it]

Step: 179, Loss: 0.6737


181it [05:30,  1.21s/it]

Step: 180, Loss: 0.4436


182it [05:31,  1.15s/it]

Step: 181, Loss: 0.5026


183it [05:32,  1.16s/it]

Step: 182, Loss: 0.4198


184it [05:33,  1.13s/it]

Step: 183, Loss: 0.5681


185it [05:36,  1.48s/it]

Step: 184, Loss: 0.3692


186it [05:37,  1.43s/it]

Step: 185, Loss: 0.4195


187it [05:38,  1.32s/it]

Step: 186, Loss: 0.4842


188it [05:40,  1.81s/it]

Step: 187, Loss: 0.3386
Epoch 0 Complete
Ending training
340.0854561328888
MODEL SAVED





In [31]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)

import torch

from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

import pickle

import time
from tqdm import tqdm


def generic_training_runner(
    SAVE_PATH,
    LOSSES_PATH,
    TIMES_PATH,
    model_info,
    training_args,
):
    """
    Function to perform finetuning on provided model_name from Hugging Face

    Models are checkpointed and saved at provided locations in Drive.

    :param str SAVE_PATH: Location to save model checkpoint
    :param str LOSSES_PATH: Location to log loss results
    :param str TIMES_PATH: Location to log elapsed time results
    :param model_info: Object containing information about model
    :param training_args: Object containing training arguments
    :return: None
    """
    # Args
    clients, MAX_LENGTH, conduct_logging, EPOCHS, lr = (
        training_args['clients'],
        training_args['MAX_LENGTH'],
        training_args['conduct_logging'],
        training_args['EPOCHS'],
        training_args['lr'],
    )
    model_name, tokenizer, client_dataloaders, quant_config, lora_config = (
        model_info['model_name'],
        model_info['tokenizer'],
        model_info['client_dataloaders'],
        model_info['quant_config'],
        model_info['lora_config'],
    )

    assert len(client_dataloaders) == clients

    # Device setup
    torch.manual_seed(42)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device Type: {device}")


    # Loggings of losses setup.
    def write_to_file(epoch_number, loss_values, avg_loss, file_path):
        # Open the file in binary append mode
        with open(file_path + ".txt", "ab") as file:
            # Serialize and write the epoch number
            pickle.dump(f"Epoch {epoch_number}", file)
            # Serialize and write the list of loss values
            pickle.dump(loss_values, file)
            # Average loss for the epoch
            pickle.dump(avg_loss, file)

    # Simple test to log.
    conduct_logging = True
    if conduct_logging:
        write_to_file(-1, [-1, -2, -3], [-2], LOSSES_PATH)

    # Begin training process
    # Traditional training loop requires modifications for FL...
    # To do this, we perform the following. Assume only 1 epoch is done.
    """
    Create a variable to aggregate gradients

    For Each client
        Copy the global model
        Fetch the client's dataloader
        Train across the data in one iteration.
        Log losses carefully
        Save model
        Aggregate change in weights to variable
    """

    def train_loop(client_idx):
        """
        Main training loop

        :param int client_idx: The index of the client being trained
        :return: None
        """
        global_model = AutoModelForCausalLM.from_pretrained(
            model_name, quantization_config=quant_config, device_map={"": 0}
        )
        global_model.gradient_checkpointing_enable()

        global_model = prepare_model_for_kbit_training(global_model)
        global_model = get_peft_model(global_model, lora_config).to(device)
        client_model = global_model
        print("Loaded global model.")

        # Set up loss and optimization unique to model.
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(client_model.parameters(), lr=lr)

        losses = []
        # Log a 10K losses
        dataloader_client = client_dataloaders[client_idx]
        log_every = max(1, (EPOCHS * len(dataloader_client)) // 10_000)
        print("Number of batches:", len(dataloader_client))
        client_model.train()
        start_time = time.time()
        print("Starting training")
        for epoch in range(EPOCHS):
            total_loss = 0
            for c, batch in tqdm(enumerate(dataloader_client)):
                # Batch size x Max Seq LEn
                sample = tokenizer(
                    batch,
                    padding=True,
                    truncation=True,
                    return_tensors="pt",
                    max_length=MAX_LENGTH,
                )["input_ids"].to(device)
                target = sample.detach()[:, 1:]
                sample = sample[:, :-1]

                # Batch size x Max Seq Len x Vocab Size
                optimizer.zero_grad()
                prediction = client_model(sample).logits

                # Ensure swapping of axes
                loss = criterion(prediction.transpose(1, 2), target)
                loss.backward()

                # Loss logging
                total_loss += loss.item()
                if c % log_every == 0:
                    print(f"Step: {c}, Loss: {loss.item():.4f}")
                    losses.append(loss.item())

                # Change model weights
                optimizer.step()

                # Explicit destruction (may not be needed after previous debugging)
                del loss, prediction, sample, target
                torch.cuda.empty_cache()

            print(f"Epoch {epoch} Complete")
            avg_loss = total_loss / len(dataloader_client)
            if conduct_logging:
                write_to_file(epoch, losses, avg_loss, LOSSES_PATH + f"_{client_idx}")
            losses.clear()

        print("Ending training")

        end_time = time.time()

        optimizer.zero_grad()

        elapsed_time = end_time - start_time
        if conduct_logging:
            with open(TIMES_PATH + f"_{client_idx}.txt", "a") as file:
                file.write(f"{elapsed_time}")

        print(elapsed_time)
        if conduct_logging:
            client_model.save_pretrained(SAVE_PATH + f"_{client_idx}")
            print("MODEL SAVED")
            return client_model
            del client_model
            torch.cuda.empty_cache()

    # Execute training for all clients
    client_rets = []
    for i in range(clients):
        print("Beginning training iteration")
        client_rets.append(train_loop(i))
    return client_rets

client_rets = generic_training_runner(
    SAVE_PATH,
    LOSSES_PATH,
    TIMES_PATH,
    model_info,
    training_args,
)

Device Type: cuda
Beginning training iteration
Loaded global model.
Number of batches: 188
Starting training


1it [00:01,  1.89s/it]

Step: 0, Loss: 8.2141


2it [00:04,  2.18s/it]

Step: 1, Loss: 6.0871


3it [00:06,  2.14s/it]

Step: 2, Loss: 1.8793


4it [00:09,  2.68s/it]

Step: 3, Loss: 0.8737


5it [00:12,  2.77s/it]

Step: 4, Loss: 0.8002


6it [00:14,  2.41s/it]

Step: 5, Loss: 0.8529


7it [00:15,  2.04s/it]

Step: 6, Loss: 1.0662


8it [00:16,  1.61s/it]

Step: 7, Loss: 1.2932


9it [00:22,  2.89s/it]

Step: 8, Loss: 0.4014


10it [00:23,  2.36s/it]

Step: 9, Loss: 0.7880


11it [00:26,  2.50s/it]

Step: 10, Loss: 0.5341


12it [00:27,  2.05s/it]

Step: 11, Loss: 0.7711


13it [00:29,  2.07s/it]

Step: 12, Loss: 0.6124


14it [00:31,  2.02s/it]

Step: 13, Loss: 0.5294


15it [00:33,  1.95s/it]

Step: 14, Loss: 0.5654


16it [00:35,  2.20s/it]

Step: 15, Loss: 0.3812


17it [00:37,  2.17s/it]

Step: 16, Loss: 0.5187


18it [00:39,  1.97s/it]

Step: 17, Loss: 0.6696


19it [00:41,  2.09s/it]

Step: 18, Loss: 0.4840


20it [00:46,  2.74s/it]

Step: 19, Loss: 0.2562


21it [00:47,  2.41s/it]

Step: 20, Loss: 0.5513


22it [00:53,  3.41s/it]

Step: 21, Loss: 0.2697


23it [00:54,  2.78s/it]

Step: 22, Loss: 0.5430


24it [00:55,  2.19s/it]

Step: 23, Loss: 0.7625


25it [00:56,  1.88s/it]

Step: 24, Loss: 0.5575


26it [00:59,  2.04s/it]

Step: 25, Loss: 0.4399


27it [01:00,  1.82s/it]

Step: 26, Loss: 0.3854


28it [01:02,  2.00s/it]

Step: 27, Loss: 0.4780


29it [01:05,  2.30s/it]

Step: 28, Loss: 0.4440


30it [01:11,  3.35s/it]

Step: 29, Loss: 0.4066


31it [01:13,  2.90s/it]

Step: 30, Loss: 0.3953


32it [01:14,  2.48s/it]

Step: 31, Loss: 0.4406


33it [01:16,  2.21s/it]

Step: 32, Loss: 0.4677


34it [01:17,  1.84s/it]

Step: 33, Loss: 0.6522


35it [01:19,  2.01s/it]

Step: 34, Loss: 0.3726


36it [01:22,  2.24s/it]

Step: 35, Loss: 0.3707


37it [01:23,  1.93s/it]

Step: 36, Loss: 0.4239


38it [01:25,  1.74s/it]

Step: 37, Loss: 0.6271


39it [01:29,  2.39s/it]

Step: 38, Loss: 0.3604


40it [01:31,  2.38s/it]

Step: 39, Loss: 0.3741


41it [01:32,  1.91s/it]

Step: 40, Loss: 0.6019


42it [01:34,  2.10s/it]

Step: 41, Loss: 0.2763


43it [01:36,  2.05s/it]

Step: 42, Loss: 0.3592


44it [01:40,  2.57s/it]

Step: 43, Loss: 0.3940


45it [01:43,  2.66s/it]

Step: 44, Loss: 0.4004


46it [01:44,  2.30s/it]

Step: 45, Loss: 0.6351


47it [01:45,  1.89s/it]

Step: 46, Loss: 0.7279


48it [01:50,  2.65s/it]

Step: 47, Loss: 0.3067


49it [01:51,  2.38s/it]

Step: 48, Loss: 0.4026


50it [01:57,  3.43s/it]

Step: 49, Loss: 0.5501


51it [01:59,  2.86s/it]

Step: 50, Loss: 0.4606


52it [02:01,  2.77s/it]

Step: 51, Loss: 0.3279


53it [02:03,  2.47s/it]

Step: 52, Loss: 0.5226


54it [02:04,  2.01s/it]

Step: 53, Loss: 0.6862


55it [02:06,  1.81s/it]

Step: 54, Loss: 0.6552


56it [02:08,  2.07s/it]

Step: 55, Loss: 0.3528


57it [02:11,  2.40s/it]

Step: 56, Loss: 0.3208


58it [02:13,  2.31s/it]

Step: 57, Loss: 0.3746


59it [02:15,  2.12s/it]

Step: 58, Loss: 0.3767


60it [02:18,  2.21s/it]

Step: 59, Loss: 0.3918


61it [02:19,  2.11s/it]

Step: 60, Loss: 0.5576


62it [02:22,  2.37s/it]

Step: 61, Loss: 0.3800


63it [02:25,  2.29s/it]

Step: 62, Loss: 0.4630


64it [02:26,  2.17s/it]

Step: 63, Loss: 0.3968


65it [02:29,  2.34s/it]

Step: 64, Loss: 0.3396


66it [02:32,  2.36s/it]

Step: 65, Loss: 0.4237


67it [02:34,  2.42s/it]

Step: 66, Loss: 0.4053


68it [02:38,  2.78s/it]

Step: 67, Loss: 0.2790


69it [02:40,  2.49s/it]

Step: 68, Loss: 0.3201


70it [02:41,  2.19s/it]

Step: 69, Loss: 0.4686


71it [02:43,  2.02s/it]

Step: 70, Loss: 0.3608


72it [02:44,  1.71s/it]

Step: 71, Loss: 0.5371


73it [02:45,  1.66s/it]

Step: 72, Loss: 0.4888


74it [02:46,  1.50s/it]

Step: 73, Loss: 0.5562


75it [02:49,  1.95s/it]

Step: 74, Loss: 0.2854


76it [02:51,  1.78s/it]

Step: 75, Loss: 0.6071


77it [02:54,  2.37s/it]

Step: 76, Loss: 0.4343


78it [02:57,  2.41s/it]

Step: 77, Loss: 0.3579


79it [03:03,  3.48s/it]

Step: 78, Loss: 0.2908


80it [03:05,  2.94s/it]

Step: 79, Loss: 0.3215


81it [03:06,  2.35s/it]

Step: 80, Loss: 0.5082


82it [03:08,  2.35s/it]

Step: 81, Loss: 0.5382


83it [03:09,  2.10s/it]

Step: 82, Loss: 0.4260


84it [03:12,  2.31s/it]

Step: 83, Loss: 0.2578


85it [03:17,  3.19s/it]

Step: 84, Loss: 0.1819


86it [03:19,  2.69s/it]

Step: 85, Loss: 0.3582


87it [03:20,  2.30s/it]

Step: 86, Loss: 0.5024


88it [03:22,  2.11s/it]

Step: 87, Loss: 0.3763


89it [03:24,  2.01s/it]

Step: 88, Loss: 0.2833


90it [03:27,  2.33s/it]

Step: 89, Loss: 0.2462


91it [03:29,  2.16s/it]

Step: 90, Loss: 0.4755


92it [03:30,  1.84s/it]

Step: 91, Loss: 0.4401


93it [03:32,  2.09s/it]

Step: 92, Loss: 0.3320


94it [03:35,  2.17s/it]

Step: 93, Loss: 0.2255


95it [03:36,  1.75s/it]

Step: 94, Loss: 0.6439


96it [03:37,  1.56s/it]

Step: 95, Loss: 0.4146


97it [03:43,  2.88s/it]

Step: 96, Loss: 0.2373


98it [03:44,  2.55s/it]

Step: 97, Loss: 0.4499


99it [03:47,  2.45s/it]

Step: 98, Loss: 0.3846


100it [03:48,  2.10s/it]

Step: 99, Loss: 0.5201


101it [03:51,  2.54s/it]

Step: 100, Loss: 0.2169


102it [03:54,  2.55s/it]

Step: 101, Loss: 0.4534


103it [03:56,  2.47s/it]

Step: 102, Loss: 0.2842


104it [04:00,  2.73s/it]

Step: 103, Loss: 0.3621


105it [04:02,  2.62s/it]

Step: 104, Loss: 0.3467


106it [04:04,  2.37s/it]

Step: 105, Loss: 0.3809


107it [04:05,  2.12s/it]

Step: 106, Loss: 0.4575


108it [04:07,  1.87s/it]

Step: 107, Loss: 0.5476


109it [04:08,  1.63s/it]

Step: 108, Loss: 0.6074


110it [04:09,  1.55s/it]

Step: 109, Loss: 0.5156


111it [04:11,  1.54s/it]

Step: 110, Loss: 0.4349


112it [04:14,  2.06s/it]

Step: 111, Loss: 0.2050


113it [04:17,  2.39s/it]

Step: 112, Loss: 0.2678


114it [04:20,  2.51s/it]

Step: 113, Loss: 0.2740


115it [04:22,  2.39s/it]

Step: 114, Loss: 0.3375


116it [04:23,  2.13s/it]

Step: 115, Loss: 0.4768


117it [04:24,  1.77s/it]

Step: 116, Loss: 0.4337


118it [04:26,  1.67s/it]

Step: 117, Loss: 0.2871


119it [04:28,  1.85s/it]

Step: 118, Loss: 0.2599


120it [04:31,  2.17s/it]

Step: 119, Loss: 0.3328


121it [04:33,  2.12s/it]

Step: 120, Loss: 0.3840


122it [04:36,  2.37s/it]

Step: 121, Loss: 0.2342


123it [04:42,  3.45s/it]

Step: 122, Loss: 0.3529


124it [04:44,  3.05s/it]

Step: 123, Loss: 0.3461


125it [04:45,  2.46s/it]

Step: 124, Loss: 0.4652


126it [04:47,  2.13s/it]

Step: 125, Loss: 0.4691


127it [04:49,  2.11s/it]

Step: 126, Loss: 0.4560


128it [04:52,  2.55s/it]

Step: 127, Loss: 0.2850


129it [04:55,  2.55s/it]

Step: 128, Loss: 0.3874


130it [04:57,  2.37s/it]

Step: 129, Loss: 0.3819


131it [04:58,  2.16s/it]

Step: 130, Loss: 0.3125


132it [05:01,  2.19s/it]

Step: 131, Loss: 0.3556


133it [05:03,  2.23s/it]

Step: 132, Loss: 0.3058


134it [05:04,  1.97s/it]

Step: 133, Loss: 0.4569


135it [05:10,  3.17s/it]

Step: 134, Loss: 0.2650


136it [05:12,  2.84s/it]

Step: 135, Loss: 0.4267


137it [05:13,  2.31s/it]

Step: 136, Loss: 0.4671


138it [05:16,  2.47s/it]

Step: 137, Loss: 0.3569


139it [05:22,  3.52s/it]

Step: 138, Loss: 0.1920


140it [05:24,  2.90s/it]

Step: 139, Loss: 0.3416


141it [05:25,  2.31s/it]

Step: 140, Loss: 0.5973


142it [05:27,  2.34s/it]

Step: 141, Loss: 0.4247


143it [05:29,  2.17s/it]

Step: 142, Loss: 0.4393


144it [05:30,  1.97s/it]

Step: 143, Loss: 0.3488


145it [05:33,  2.13s/it]

Step: 144, Loss: 0.3611


146it [05:35,  2.05s/it]

Step: 145, Loss: 0.3391


147it [05:37,  2.19s/it]

Step: 146, Loss: 0.3317


148it [05:40,  2.50s/it]

Step: 147, Loss: 0.2046


149it [05:45,  3.20s/it]

Step: 148, Loss: 0.2366


150it [05:51,  3.89s/it]

Step: 149, Loss: 0.2022


151it [05:56,  4.18s/it]

Step: 150, Loss: 0.2088


152it [05:57,  3.43s/it]

Step: 151, Loss: 0.3058


153it [05:59,  2.90s/it]

Step: 152, Loss: 0.4344


154it [06:01,  2.62s/it]

Step: 153, Loss: 0.3904


155it [06:07,  3.63s/it]

Step: 154, Loss: 0.1952


156it [06:09,  3.29s/it]

Step: 155, Loss: 0.2422


157it [06:11,  2.93s/it]

Step: 156, Loss: 0.3751


158it [06:14,  2.94s/it]

Step: 157, Loss: 0.3668


159it [06:19,  3.29s/it]

Step: 158, Loss: 0.2419


160it [06:23,  3.51s/it]

Step: 159, Loss: 0.1933


161it [06:24,  2.81s/it]

Step: 160, Loss: 0.5420


162it [06:30,  3.76s/it]

Step: 161, Loss: 0.2792


163it [06:32,  3.41s/it]

Step: 162, Loss: 0.4187


164it [06:38,  4.19s/it]

Step: 163, Loss: 0.1729


165it [06:40,  3.57s/it]

Step: 164, Loss: 0.3377


166it [06:44,  3.47s/it]

Step: 165, Loss: 0.3302


167it [06:45,  2.81s/it]

Step: 166, Loss: 0.4675


168it [06:47,  2.73s/it]

Step: 167, Loss: 0.2859


169it [06:51,  2.95s/it]

Step: 168, Loss: 0.3157


170it [06:52,  2.49s/it]

Step: 169, Loss: 0.3986


171it [06:54,  2.27s/it]

Step: 170, Loss: 0.3916


172it [07:00,  3.38s/it]

Step: 171, Loss: 0.2749


173it [07:02,  3.07s/it]

Step: 172, Loss: 0.2173


174it [07:04,  2.53s/it]

Step: 173, Loss: 0.5231


175it [07:05,  2.14s/it]

Step: 174, Loss: 0.4075


176it [07:06,  1.95s/it]

Step: 175, Loss: 0.3812


177it [07:10,  2.44s/it]

Step: 176, Loss: 0.3146


178it [07:12,  2.18s/it]

Step: 177, Loss: 0.3490


179it [07:13,  2.02s/it]

Step: 178, Loss: 0.5164


180it [07:15,  1.91s/it]

Step: 179, Loss: 0.4547


181it [07:16,  1.74s/it]

Step: 180, Loss: 0.4718


182it [07:19,  2.11s/it]

Step: 181, Loss: 0.2459


183it [07:22,  2.18s/it]

Step: 182, Loss: 0.4249


184it [07:24,  2.28s/it]

Step: 183, Loss: 0.3007


185it [07:26,  2.04s/it]

Step: 184, Loss: 0.4109


186it [07:27,  1.96s/it]

Step: 185, Loss: 0.3266


187it [07:29,  1.74s/it]

Step: 186, Loss: 0.4284


188it [07:30,  2.40s/it]

Step: 187, Loss: 0.4346
Epoch 0 Complete
Ending training
450.8140070438385
MODEL SAVED
Beginning training iteration





Loaded global model.
Number of batches: 188
Starting training


1it [00:01,  1.49s/it]

Step: 0, Loss: 7.6228


2it [00:07,  4.12s/it]

Step: 1, Loss: 6.7135


3it [00:08,  2.86s/it]

Step: 2, Loss: 2.8488


4it [00:10,  2.24s/it]

Step: 3, Loss: 1.5082


5it [00:12,  2.27s/it]

Step: 4, Loss: 1.1315


6it [00:18,  3.48s/it]

Step: 5, Loss: 0.7766


7it [00:19,  2.75s/it]

Step: 6, Loss: 1.1592


8it [00:25,  3.78s/it]

Step: 7, Loss: 0.4418


9it [00:27,  3.30s/it]

Step: 8, Loss: 0.7658


10it [00:28,  2.61s/it]

Step: 9, Loss: 1.2916


11it [00:30,  2.22s/it]

Step: 10, Loss: 0.7950


12it [00:32,  2.25s/it]

Step: 11, Loss: 0.5272


13it [00:38,  3.30s/it]

Step: 12, Loss: 0.3984


14it [00:42,  3.61s/it]

Step: 13, Loss: 0.4255


15it [00:43,  2.86s/it]

Step: 14, Loss: 0.9261


16it [00:44,  2.30s/it]

Step: 15, Loss: 0.9083


17it [00:46,  2.18s/it]

Step: 16, Loss: 0.6996


18it [00:49,  2.30s/it]

Step: 17, Loss: 0.4509


19it [00:50,  2.05s/it]

Step: 18, Loss: 0.6074


20it [00:51,  1.72s/it]

Step: 19, Loss: 0.9234


21it [00:52,  1.45s/it]

Step: 20, Loss: 0.8263


22it [00:54,  1.55s/it]

Step: 21, Loss: 0.5134


23it [00:57,  2.08s/it]

Step: 22, Loss: 0.3971


24it [00:58,  1.73s/it]

Step: 23, Loss: 0.7521


25it [01:00,  1.77s/it]

Step: 24, Loss: 0.5140


26it [01:01,  1.71s/it]

Step: 25, Loss: 0.4823


27it [01:03,  1.60s/it]

Step: 26, Loss: 0.6217


28it [01:03,  1.38s/it]

Step: 27, Loss: 0.8976


29it [01:04,  1.27s/it]

Step: 28, Loss: 0.6453


30it [01:06,  1.21s/it]

Step: 29, Loss: 0.7886


31it [01:08,  1.45s/it]

Step: 30, Loss: 0.4737


32it [01:09,  1.51s/it]

Step: 31, Loss: 0.5889


33it [01:10,  1.42s/it]

Step: 32, Loss: 0.7233


34it [01:12,  1.53s/it]

Step: 33, Loss: 0.4968


35it [01:14,  1.61s/it]

Step: 34, Loss: 0.3567


36it [01:16,  1.59s/it]

Step: 35, Loss: 0.5077


37it [01:17,  1.40s/it]

Step: 36, Loss: 0.8623


38it [01:18,  1.40s/it]

Step: 37, Loss: 0.5031


39it [01:22,  2.21s/it]

Step: 38, Loss: 0.3990


40it [01:23,  1.80s/it]

Step: 39, Loss: 0.8629


41it [01:25,  1.82s/it]

Step: 40, Loss: 0.5311


42it [01:26,  1.64s/it]

Step: 41, Loss: 0.6585


43it [01:29,  2.14s/it]

Step: 42, Loss: 0.2993


44it [01:30,  1.87s/it]

Step: 43, Loss: 0.5944


45it [01:33,  1.98s/it]

Step: 44, Loss: 0.5897


46it [01:34,  1.72s/it]

Step: 45, Loss: 0.6057


47it [01:35,  1.61s/it]

Step: 46, Loss: 0.5070


48it [01:38,  1.83s/it]

Step: 47, Loss: 0.3230


49it [01:38,  1.51s/it]

Step: 48, Loss: 0.7545


50it [01:40,  1.52s/it]

Step: 49, Loss: 0.4236


51it [01:46,  2.86s/it]

Step: 50, Loss: 0.6092


52it [01:47,  2.27s/it]

Step: 51, Loss: 0.7173


53it [01:48,  1.92s/it]

Step: 52, Loss: 0.7532


54it [01:50,  1.89s/it]

Step: 53, Loss: 0.4835


55it [01:51,  1.66s/it]

Step: 54, Loss: 0.4525


56it [01:53,  1.85s/it]

Step: 55, Loss: 0.3931


57it [01:54,  1.62s/it]

Step: 56, Loss: 0.6630


58it [01:56,  1.54s/it]

Step: 57, Loss: 0.4888


59it [01:58,  1.71s/it]

Step: 58, Loss: 0.4658


60it [01:59,  1.74s/it]

Step: 59, Loss: 0.4851


61it [02:01,  1.71s/it]

Step: 60, Loss: 0.3690


62it [02:02,  1.58s/it]

Step: 61, Loss: 0.6831


63it [02:06,  2.06s/it]

Step: 62, Loss: 0.3128


64it [02:06,  1.71s/it]

Step: 63, Loss: 0.8700


65it [02:08,  1.64s/it]

Step: 64, Loss: 0.5195


66it [02:14,  2.94s/it]

Step: 65, Loss: 0.2842


67it [02:15,  2.40s/it]

Step: 66, Loss: 0.7322


68it [02:21,  3.47s/it]

Step: 67, Loss: 0.3689


69it [02:23,  3.18s/it]

Step: 68, Loss: 0.3582


70it [02:26,  2.89s/it]

Step: 69, Loss: 0.3684


71it [02:27,  2.56s/it]

Step: 70, Loss: 0.5413


72it [02:29,  2.25s/it]

Step: 71, Loss: 0.4987


73it [02:31,  2.07s/it]

Step: 72, Loss: 0.6176


74it [02:32,  1.78s/it]

Step: 73, Loss: 0.6357


75it [02:33,  1.60s/it]

Step: 74, Loss: 0.5691


76it [02:35,  1.80s/it]

Step: 75, Loss: 0.4711


77it [02:36,  1.56s/it]

Step: 76, Loss: 0.7143


78it [02:37,  1.39s/it]

Step: 77, Loss: 0.6331


79it [02:38,  1.29s/it]

Step: 78, Loss: 0.6501


80it [02:40,  1.40s/it]

Step: 79, Loss: 0.5924


81it [02:42,  1.72s/it]

Step: 80, Loss: 0.3536


82it [02:45,  1.95s/it]

Step: 81, Loss: 0.4219


83it [02:46,  1.75s/it]

Step: 82, Loss: 0.5158


84it [02:52,  3.02s/it]

Step: 83, Loss: 0.2674


85it [02:54,  2.58s/it]

Step: 84, Loss: 0.4179


86it [02:56,  2.37s/it]

Step: 85, Loss: 0.4677


87it [02:57,  1.98s/it]

Step: 86, Loss: 0.5812


88it [02:58,  1.73s/it]

Step: 87, Loss: 0.5161


89it [02:59,  1.57s/it]

Step: 88, Loss: 0.5959


90it [03:02,  2.03s/it]

Step: 89, Loss: 0.3097


91it [03:05,  2.36s/it]

Step: 90, Loss: 0.3255


92it [03:07,  2.22s/it]

Step: 91, Loss: 0.4997


93it [03:08,  1.85s/it]

Step: 92, Loss: 0.6379


94it [03:11,  2.07s/it]

Step: 93, Loss: 0.3029


95it [03:12,  1.79s/it]

Step: 94, Loss: 0.5794


96it [03:13,  1.58s/it]

Step: 95, Loss: 0.5677


97it [03:14,  1.44s/it]

Step: 96, Loss: 0.5573


98it [03:15,  1.29s/it]

Step: 97, Loss: 0.6931


99it [03:16,  1.31s/it]

Step: 98, Loss: 0.5967


100it [03:17,  1.20s/it]

Step: 99, Loss: 0.7790


101it [03:18,  1.05s/it]

Step: 100, Loss: 0.7165


102it [03:20,  1.27s/it]

Step: 101, Loss: 0.3855


103it [03:21,  1.19s/it]

Step: 102, Loss: 0.6796


104it [03:23,  1.53s/it]

Step: 103, Loss: 0.3609


105it [03:24,  1.45s/it]

Step: 104, Loss: 0.5762


106it [03:25,  1.32s/it]

Step: 105, Loss: 0.6741


107it [03:31,  2.49s/it]

Step: 106, Loss: 0.2505


108it [03:32,  2.13s/it]

Step: 107, Loss: 0.6324


109it [03:37,  3.03s/it]

Step: 108, Loss: 0.2671


110it [03:39,  2.68s/it]

Step: 109, Loss: 0.4167


111it [03:40,  2.19s/it]

Step: 110, Loss: 0.6589


112it [03:41,  1.94s/it]

Step: 111, Loss: 0.5601


113it [03:47,  3.15s/it]

Step: 112, Loss: 0.5680


114it [03:49,  2.60s/it]

Step: 113, Loss: 0.5474


115it [03:50,  2.18s/it]

Step: 114, Loss: 0.5466


116it [03:52,  2.06s/it]

Step: 115, Loss: 0.4229


117it [03:53,  1.76s/it]

Step: 116, Loss: 0.6069


118it [03:54,  1.68s/it]

Step: 117, Loss: 0.4673


119it [03:55,  1.43s/it]

Step: 118, Loss: 0.6274


120it [03:56,  1.23s/it]

Step: 119, Loss: 0.7933


121it [03:57,  1.20s/it]

Step: 120, Loss: 0.5396


122it [03:58,  1.18s/it]

Step: 121, Loss: 0.5165


123it [03:59,  1.16s/it]

Step: 122, Loss: 0.5638


124it [04:00,  1.17s/it]

Step: 123, Loss: 0.6581


125it [04:02,  1.43s/it]

Step: 124, Loss: 0.4036


126it [04:03,  1.25s/it]

Step: 125, Loss: 0.6374


127it [04:04,  1.26s/it]

Step: 126, Loss: 0.3867


128it [04:07,  1.72s/it]

Step: 127, Loss: 0.3468


129it [04:09,  1.83s/it]

Step: 128, Loss: 0.4422


130it [04:11,  1.86s/it]

Step: 129, Loss: 0.4256


131it [04:13,  1.83s/it]

Step: 130, Loss: 0.4234


132it [04:15,  1.95s/it]

Step: 131, Loss: 0.3599


133it [04:17,  1.77s/it]

Step: 132, Loss: 0.5978


134it [04:18,  1.57s/it]

Step: 133, Loss: 0.6349


135it [04:18,  1.35s/it]

Step: 134, Loss: 0.6409


136it [04:19,  1.21s/it]

Step: 135, Loss: 0.5904


137it [04:21,  1.19s/it]

Step: 136, Loss: 0.6161


138it [04:21,  1.11s/it]

Step: 137, Loss: 0.5500


139it [04:23,  1.18s/it]

Step: 138, Loss: 0.4675


140it [04:26,  1.74s/it]

Step: 139, Loss: 0.3279


141it [04:27,  1.46s/it]

Step: 140, Loss: 0.6146


142it [04:30,  2.03s/it]

Step: 141, Loss: 0.3429


143it [04:32,  1.88s/it]

Step: 142, Loss: 0.4694


144it [04:34,  1.95s/it]

Step: 143, Loss: 0.3888


145it [04:35,  1.70s/it]

Step: 144, Loss: 0.5348


146it [04:36,  1.65s/it]

Step: 145, Loss: 0.5875


147it [04:37,  1.49s/it]

Step: 146, Loss: 0.5188


148it [04:40,  1.93s/it]

Step: 147, Loss: 0.3501


149it [04:43,  2.21s/it]

Step: 148, Loss: 0.3933


150it [04:45,  1.96s/it]

Step: 149, Loss: 0.6019


151it [04:46,  1.82s/it]

Step: 150, Loss: 0.4541


152it [04:47,  1.56s/it]

Step: 151, Loss: 0.6338


153it [04:49,  1.65s/it]

Step: 152, Loss: 0.3164


154it [04:50,  1.61s/it]

Step: 153, Loss: 0.4448


155it [04:52,  1.63s/it]

Step: 154, Loss: 0.4205


156it [04:54,  1.70s/it]

Step: 155, Loss: 0.4712


157it [04:57,  2.16s/it]

Step: 156, Loss: 0.2691


158it [04:59,  1.92s/it]

Step: 157, Loss: 0.5256


159it [05:00,  1.75s/it]

Step: 158, Loss: 0.4510


160it [05:01,  1.52s/it]

Step: 159, Loss: 0.4844


161it [05:02,  1.39s/it]

Step: 160, Loss: 0.4145


162it [05:03,  1.31s/it]

Step: 161, Loss: 0.5871


163it [05:04,  1.22s/it]

Step: 162, Loss: 0.5029


164it [05:06,  1.47s/it]

Step: 163, Loss: 0.3893


165it [05:07,  1.37s/it]

Step: 164, Loss: 0.4413


166it [05:09,  1.33s/it]

Step: 165, Loss: 0.5929


167it [05:10,  1.28s/it]

Step: 166, Loss: 0.6206


168it [05:11,  1.24s/it]

Step: 167, Loss: 0.4165


169it [05:13,  1.62s/it]

Step: 168, Loss: 0.2919


170it [05:18,  2.54s/it]

Step: 169, Loss: 0.2687


171it [05:19,  2.08s/it]

Step: 170, Loss: 0.6450


172it [05:25,  3.25s/it]

Step: 171, Loss: 0.4529


173it [05:26,  2.61s/it]

Step: 172, Loss: 0.4914


174it [05:30,  3.02s/it]

Step: 173, Loss: 0.2925


175it [05:31,  2.38s/it]

Step: 174, Loss: 0.5377


176it [05:32,  1.98s/it]

Step: 175, Loss: 0.5622


177it [05:33,  1.77s/it]

Step: 176, Loss: 0.4374


178it [05:34,  1.53s/it]

Step: 177, Loss: 0.4674


179it [05:35,  1.32s/it]

Step: 178, Loss: 0.5177


180it [05:36,  1.18s/it]

Step: 179, Loss: 0.6737


181it [05:37,  1.24s/it]

Step: 180, Loss: 0.4436


182it [05:38,  1.18s/it]

Step: 181, Loss: 0.5026


183it [05:40,  1.19s/it]

Step: 182, Loss: 0.4198


184it [05:41,  1.16s/it]

Step: 183, Loss: 0.5681


185it [05:43,  1.51s/it]

Step: 184, Loss: 0.3692


186it [05:44,  1.46s/it]

Step: 185, Loss: 0.4195


187it [05:45,  1.35s/it]

Step: 186, Loss: 0.4842


188it [05:47,  1.85s/it]

Step: 187, Loss: 0.3386
Epoch 0 Complete
Ending training
347.58215522766113
MODEL SAVED





In [25]:
importlib.reload(helpers)

from helpers.fl_impl import get_fed_avg_model
from peft import prepare_model_for_kbit_training, get_peft_model

SAVE_PATH = f"./ronit_ibm_outputs/checkpoints/{run_title}"

global_aggregated_model = get_fed_avg_model(SAVE_PATH, model_info, num_clients=2)
torch.cuda.empty_cache()

Device Type: cuda
--------------------------------
Processing client 0
Loading model from ./ronit_ibm_outputs/checkpoints/codeparrotsm_0
Loaded model


AttributeError: ignored

In [34]:
# Dealing with above bug in peft library
weight_dict = {}
for client_idx, client in enumerate(client_rets):
    for k,v in client.state_dict().items():
        if isinstance(v, str):
                weight_dict[k] = v
        else:
            if client_idx == 0:       # For the first client, we need to initialize the value.
                weight_dict[k] = 0
            weight_dict[k] += v/torch.tensor(2)
aggregated_model = AutoModelForCausalLM.from_pretrained(
    model_info['model_name'], quantization_config=model_info['quant_config'], device_map={"": 0}
)
aggregated_model = prepare_model_for_kbit_training(aggregated_model)
global_aggregated_model = get_peft_model(aggregated_model, model_info['lora_config']).to('cuda')

global_aggregated_model.load_state_dict(weight_dict)

<All keys matched successfully>

In [35]:
import importlib

import helpers
import helpers.evaluation
importlib.reload(helpers)
importlib.reload(helpers.evaluation)

from helpers.evaluation import (
    generate_batch_completions,
    get_humaneval_dataloader,
    generate_model_on_problems,
    get_prompt_from_descr,
    clean_samples,
    compute_average_rouge_scores,
    print_gens_on_problem_all_models,
)

In [36]:
import human_eval
importlib.reload(human_eval)
from human_eval.data import write_jsonl, read_problems

problems = read_problems()
probs_dataloader = get_humaneval_dataloader(problems, batch_size=10)

In [37]:
samples = generate_model_on_problems(global_aggregated_model, probs_dataloader, tokenizer)

write_jsonl("human_eval/samples-sf.json", samples)

Device Type: cuda
Generated completions for batch 0 of 17
Device Type: cuda
Generated completions for batch 1 of 17
Device Type: cuda
Generated completions for batch 2 of 17
Device Type: cuda
Generated completions for batch 3 of 17
Device Type: cuda
Generated completions for batch 4 of 17
Device Type: cuda
Generated completions for batch 5 of 17
Device Type: cuda
Generated completions for batch 6 of 17
Device Type: cuda
Generated completions for batch 7 of 17
Device Type: cuda
Generated completions for batch 8 of 17
Device Type: cuda
Generated completions for batch 9 of 17
Device Type: cuda
Generated completions for batch 10 of 17
Device Type: cuda
Generated completions for batch 11 of 17
Device Type: cuda
Generated completions for batch 12 of 17
Device Type: cuda
Generated completions for batch 13 of 17
Device Type: cuda
Generated completions for batch 14 of 17
Device Type: cuda
Generated completions for batch 15 of 17
Device Type: cuda
Generated completions for batch 16 of 17


In [45]:
!python3 human_eval/evaluate_functional_correctness.py human_eval/samples-sf.json

Reading samples...
0it [00:00, ?it/s]164it [00:00, 31829.43it/s]
Running test suites...
100% 164/164 [00:01<00:00, 126.86it/s]
Writing results to human_eval/samples-sf.json_results.jsonl...
100% 164/164 [00:00<00:00, 25607.40it/s]
{'pass@1': 0.0}


In [46]:
references = [problems[f'HumanEval/{i}']['prompt'] for i in range(len(samples))]
hypotheses = [samples[i]['completion'] for i in range(len(samples))]

# Compute average ROUGE scores
average_rouge_scores = compute_average_rouge_scores(references, hypotheses)
print(average_rouge_scores)

{'rouge-1': {'f': 0.181439196656887, 'p': 0.19444999640380556, 'r': 0.19924353054708221}, 'rouge-2': {'f': 0.02186178882053857, 'p': 0.02247974187441289, 'r': 0.0255706528152988}, 'rouge-l': {'f': 0.16800162708168045, 'p': 0.18028544477288624, 'r': 0.18456777342031117}}


In [None]:
print_gens_on_problem_all_models(
    SAVE_PATH,
    "/content/drive/MyDrive/CS6220 Folder/problem_descriptions/p00003.html",
    model_info,
    global_aggregated_model
)

In [48]:
samples = generate_model_on_problems(client_rets[0], probs_dataloader, tokenizer)

write_jsonl("human_eval/samples-client0.json", samples)

!python3 human_eval/evaluate_functional_correctness.py human_eval/samples-client0.json

references = [problems[f'HumanEval/{i}']['prompt'] for i in range(len(samples))]
hypotheses = [samples[i]['completion'] for i in range(len(samples))]

# Compute average ROUGE scores
average_rouge_scores = compute_average_rouge_scores(references, hypotheses)
print(average_rouge_scores)

Device Type: cuda




Generated completions for batch 0 of 17
Device Type: cuda
Generated completions for batch 1 of 17
Device Type: cuda
Generated completions for batch 2 of 17
Device Type: cuda
Generated completions for batch 3 of 17
Device Type: cuda
Generated completions for batch 4 of 17
Device Type: cuda
Generated completions for batch 5 of 17
Device Type: cuda
Generated completions for batch 6 of 17
Device Type: cuda
Generated completions for batch 7 of 17
Device Type: cuda
Generated completions for batch 8 of 17
Device Type: cuda
Generated completions for batch 9 of 17
Device Type: cuda
Generated completions for batch 10 of 17
Device Type: cuda
Generated completions for batch 11 of 17
Device Type: cuda
Generated completions for batch 12 of 17
Device Type: cuda
Generated completions for batch 13 of 17
Device Type: cuda
Generated completions for batch 14 of 17
Device Type: cuda
Generated completions for batch 15 of 17
Device Type: cuda
Generated completions for batch 16 of 17
Reading samples...
164it 

In [52]:
prompt = get_prompt_from_descr("/content/drive/MyDrive/CS6220 Folder/problem_descriptions/p00003.html")
print(generate_batch_completions([prompt], client_rets[0], tokenizer)[0])
print(generate_batch_completions([prompt], client_rets[1], tokenizer)[0])
print(generate_batch_completions([prompt], global_aggregated_model, tokenizer)[0])

Device Type: cuda




"""

Is it a Right Triangle?

Write a program which judges wheather given length of three side form a right triangle. Print "YES" if the given sides (integers) form a right triangle, "NO" if not so.


"""

def main():
    n = int(input())
    if n == 1:
        print("NO")
    elif n == 2:
        print("YES")
    elif n == 3:
        print("NO")
    else:
        print("YES")



Device Type: cuda
"""

Is it a Right Triangle?

Write a program which judges wheather given length of three side form a right triangle. Print "YES" if the given sides (integers) form a right triangle, "NO" if not so.


"""

def main():
    n = int(input())
    if n < 10:
        print(input())
    else:
        print(input())


Device Type: cuda
"""

Is it a Right Triangle?

Write a program which judges wheather given length of three side form a right triangle. Print "YES" if the given sides (integers) form a right triangle, "NO" if not so.


"""

def main():
    print("Yes")




In [49]:
samples = generate_model_on_problems(client_rets[1], probs_dataloader, tokenizer)

write_jsonl("human_eval/samples-client0.json", samples)

!python3 human_eval/evaluate_functional_correctness.py human_eval/samples-client0.json

references = [problems[f'HumanEval/{i}']['prompt'] for i in range(len(samples))]
hypotheses = [samples[i]['completion'] for i in range(len(samples))]

# Compute average ROUGE scores
average_rouge_scores = compute_average_rouge_scores(references, hypotheses)
print(average_rouge_scores)

Device Type: cuda




Generated completions for batch 0 of 17
Device Type: cuda
Generated completions for batch 1 of 17
Device Type: cuda
Generated completions for batch 2 of 17
Device Type: cuda
Generated completions for batch 3 of 17
Device Type: cuda
Generated completions for batch 4 of 17
Device Type: cuda
Generated completions for batch 5 of 17
Device Type: cuda
Generated completions for batch 6 of 17
Device Type: cuda
Generated completions for batch 7 of 17
Device Type: cuda
Generated completions for batch 8 of 17
Device Type: cuda
Generated completions for batch 9 of 17
Device Type: cuda
Generated completions for batch 10 of 17
Device Type: cuda
Generated completions for batch 11 of 17
Device Type: cuda
Generated completions for batch 12 of 17
Device Type: cuda
Generated completions for batch 13 of 17
Device Type: cuda
Generated completions for batch 14 of 17
Device Type: cuda
Generated completions for batch 15 of 17
Device Type: cuda
Generated completions for batch 16 of 17
Reading samples...
164it 