In [5]:
import io
import os
import pandas as pd
import zstandard as zst
import json
import zstandard as zstd
import matplotlib.pyplot as plt
import time 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
import torch.distributed as dist

from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.distributed import init_process_group, destroy_process_group

from torch.cuda.amp import GradScaler, autocast
from copy import deepcopy

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import pipeline, Cache

In [8]:
TEXT_BATCH_SIZE = 4  # 32 here increased training x5 times
TEXT_LEN = 128
LR = 1e-2
L2 = 1e-8  # to regularize Adam with weight_decay
MIN_LR = 1e-5
LR_STEP_RATE = 1  # how many zst files to do lr rate decay
LATENT_SIZE = 2048
NUM_PROTOTYPES = 32000
GEN_TEXT_LEN = 128
TEXT_EVAL_GEN = ['Once upon a time,']
QA_TEXT_EVAL_GEN = ['<|system|>\nYou are a friendly chatbot who always responds in a helpful manner</s>\n<|user|>\nCan you give me some helpful financial advice?</s>\n<|assistant|>\n']
DEVICE = 'cuda:2' if torch.cuda.is_available() else 'cpu'
DIR = 'SlimPajama-627B'
k=10

In [7]:
def plot_training(plotting_acc, plotting_loss, plotting_topk_acc) -> None:
	folder_name = 'plots/'
	
	# Plot running accuracy and top-k accuracy
	plt.figure(figsize=(10, 5))
	plt.plot(plotting_acc, label='Running Accuracy', color='b')
	plt.plot(plotting_topk_acc, label='Top-k Accuracy', color='g')
	plt.xlabel('Iterations')
	plt.ylabel('Accuracy')
	plt.title('Training Accuracy')
	plt.legend()
	acc_plot_path = os.path.join(folder_name, 'plotting_acc_and_topk.png')
	plt.savefig(acc_plot_path)
	plt.close()
	
	# Plot running loss
	plt.figure(figsize=(10, 5))
	plt.plot(plotting_loss, label='Running Loss', color='r')
	plt.xlabel('Iterations')
	plt.ylabel('Loss')
	plt.title('Training Loss')
	plt.legend()
	loss_plot_path = os.path.join(folder_name, 'plotting_loss.png')
	plt.savefig(loss_plot_path)
	plt.close()
    

def load_data(compressed_file_path) -> pd.DataFrame():
	"""
	Downloads a url from hugging face and returns a df
	Made because dataset is too large for initial tests
	Can probably remove this later when we get a large AWS and download the whole dataset
	"""
	def read_jsonl_zst(file_path) -> None:
		"""
		Extracts jsonl into readable format for pandas
		"""
		with open(file_path, 'rb') as file:
			decompressor = zst.ZstdDecompressor()
			stream_reader = decompressor.stream_reader(file)
			stream = io.TextIOWrapper(stream_reader, encoding = "utf-8")
			for line in stream:
				yield json.loads(line)
			
	data = list(read_jsonl_zst(compressed_file_path))
	df = pd.DataFrame(data)
	return df 
	 

def generate_text(model, pwnet, tokenizer, max_new_tokens=128):
	pwnet.eval()
	for prompt in [TEXT_EVAL_GEN, QA_TEXT_EVAL_GEN]:
		with torch.no_grad():
			input_ids = tokenizer(prompt, padding=True, return_tensors="pt").input_ids
			input_ids = input_ids.to(model.device)
			
			# Generate tokens iteratively
			generated_ids = input_ids
			generated_pwnet_ids = input_ids.clone().detach()
			generate = True
			generate_pwnet = True
			
			for _ in range(max_new_tokens):
				# Get model outputs and z
				z1 = model.model(generated_ids).last_hidden_state[0]
				z2 = model.model(generated_pwnet_ids).last_hidden_state[0]
				logits = model.lm_head(z1)
				pwnet_logits = pwnet(z2)

				# Sample next token
				next_token_id = torch.argmax(logits, dim=1)[-1].to(model.device)
				next_pwnet_token_id = torch.argmax(pwnet_logits, dim=1)[-1].to(model.device)

				# Append the new token to the generated sequence
				if generate:
					generated_ids = torch.cat((generated_ids, next_token_id.view(1,1)), dim=-1)
				if generate_pwnet:
					generated_pwnet_ids = torch.cat((generated_pwnet_ids, next_pwnet_token_id.view(1,1)), dim=-1)

				# Stop if the end of sequence token is generated
				if next_token_id == tokenizer.eos_token_id:
					generate = False
				if next_pwnet_token_id == tokenizer.eos_token_id:
					generate_pwnet = False
					
				if not generate and not generate_pwnet:
					break

			# Decode the generated sequence
			generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=False)
			# print('\nBlack-box text: ============================')
			# print(tokenizer.decode(generated_ids[0], skip_special_tokens=False))
			print('\nPW-Net text ================================')
			print(tokenizer.decode(generated_pwnet_ids[0], skip_special_tokens=False))
			print(' ')

	pwnet.train()

In [9]:
def train_proto_llm(tiny_llama, tokenizer, pwnet, k=10):

	criterion = nn.CrossEntropyLoss()  # Example loss function
	lr_decay_factor = 0.9  # Factor to decrease the learning rate
	optimizer = optim.Adam(pwnet.parameters(), lr=LR, weight_decay=L2)
	scaler = GradScaler()

	# Directory containing the CSV files
	root_data_dir = DIR + '/train/'

	plotting_acc = list()
	plotting_loss = list()
	plotting_topk_acc = list()

	optimizer.zero_grad()

	count = 0
	file_count = 0  # To keep track of the number of .zst files processed
	start_time = time.time()
		
	# Iterate chunk directories
	for sub_dir in os.listdir(root_data_dir):
		sub_dir_path = os.path.join(root_data_dir, sub_dir)
		if os.path.isdir(sub_dir_path):  # Check if it's a directory
			
			# Iterate .zst files in chunk directories
			for zst_file in os.listdir(sub_dir_path):
				if zst_file.endswith('.zst'):
					file_path = os.path.join(sub_dir_path, zst_file)
					df = load_data(file_path)
					text_data = df.text.values.tolist()
					num_text_batches = len(text_data) // TEXT_BATCH_SIZE

					print(
						"DF Shape:", df.shape, 
						"  --len(text data):", len(text_data),
						"  --num text batches:", num_text_batches
						)


					# Here we start iterating the CSV in chunks
					for text_batch_idx in range(num_text_batches):
						text_batch_data = text_data[text_batch_idx * TEXT_BATCH_SIZE: (text_batch_idx+1) * TEXT_BATCH_SIZE]

						with torch.autocast(device_type='cuda'):
							with torch.no_grad():
								input_ids = tokenizer(text_batch_data, max_length=TEXT_LEN, return_tensors="pt", padding=True, truncation=True).input_ids
								input_ids = input_ids.to(DEVICE)
								z = tiny_llama.model(input_ids).last_hidden_state
								bb_logits = tiny_llama.lm_head(z)

						z = z.view(-1, LATENT_SIZE)
						bb_logits = bb_logits.view(-1, NUM_CLASSES)
						labels = torch.argmax(bb_logits, dim=1)

						with torch.autocast(device_type='cuda'):
							logits = pwnet(z)
							loss = criterion(logits, labels.to(DEVICE))

						scaler.scale(loss).backward()
						scaler.step(optimizer)
						scaler.update()
						optimizer.zero_grad()
						

						acc = sum(torch.argmax(logits.cpu(), dim=1) == labels.cpu()) / len(labels.cpu())
						
						# Calculate top-k accuracy
						top_k = torch.topk(logits, k, dim=1).indices
						topk_correct = (labels.cpu().unsqueeze(1).expand_as(top_k) == top_k.cpu()).any(dim=1).float().sum().item()
						topk_acc = topk_correct / len(labels)

						plotting_loss.append(loss.item())
						plotting_acc.append(acc.item())
						plotting_topk_acc.append(topk_acc)

						count += 1

						plot_training(plotting_acc, plotting_loss, plotting_topk_acc)

					file_count += 1

					# Decrease learning rate every .zst files
					if file_count % LR_STEP_RATE == 0:
						for param_group in optimizer.param_groups:
							new_lr = param_group['lr'] * lr_decay_factor
							param_group['lr'] = max(new_lr, MIN_LR)  # Ensure lr doesn't drop below MIN_LR
						print(f"Decreased learning rate to {optimizer.param_groups[0]['lr']} after processing {file_count} files")

					print(
						  '\nRun Loss:', round(  sum(plotting_loss[-50:]) / len(plotting_loss[-50:])  , 2),
						  ' -- Acc:',      round(  sum(plotting_acc[-50:]) / len(plotting_acc[-50:])  , 2), 
						  ' -- Top-'+str(k)+' Accuracy:', round(  sum(plotting_topk_acc[-50:]) / len(plotting_topk_acc[-50:])  , 2),
						  ' -- Iter:', count,
						  ' -- Dir:', file_path, 
						  )

					print("\nTime Taken:", time.time() - start_time)

					generate_text(tiny_llama, pwnet, tokenizer, max_new_tokens=GEN_TEXT_LEN)

					torch.save(pwnet.module.state_dict(), 'weights/pwnet_'+str(count)+'.pth')

In [None]:
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tiny_llama = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

pwnet = PWNet(LATENT_SIZE, PROTOTYPE_SIZE, NUM_PROTOTYPES, DEVICE)
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    pwnet = nn.DataParallel(pwnet)
    pwnet.to(DEVICE)

# Initial evaluation
generate_text(tiny_llama, pwnet, tokenizer, max_new_tokens=GEN_TEXT_LEN)	

train_proto_llm(tiny_llama, tokenizer, pwnet)
