In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import AdamW
from tqdm import tqdm
import json
import os
import time 
import psutil
from datasets import load_dataset
import logging
import math
from torch.utils.tensorboard import SummaryWriter
from transformers import get_linear_schedule_with_warmup
from transformers.examples.research_projects.distillation.lm_seqs_dataset import LmSeqsDataset
logger = logging.getLogger(__name__)


ModuleNotFoundError: No module named 'tensorboard'

In [5]:
def gopher_rules_pass(sample) -> bool:
    """ function returns True if the sample complies with Gopher rules """
    signals = json.loads(sample["quality_signals"])

    # rule 1: number of words between 50 and 10'000
    word_count = signals["rps_doc_word_count"][0][2]
    if word_count < 50 or word_count > 10_000:
        return False

    # rule 2: mean word length between 3 and 10
    mean_word_length = signals["rps_doc_mean_word_length"][0][2]
    if mean_word_length < 3 or mean_word_length > 10:
        return False

    # rule 2: symbol to word ratio below 0.1
    symbol_word_ratio = signals["rps_doc_symbol_to_word_ratio"][0][2]
    if symbol_word_ratio > 0.1:
        return False

    # rule 3: 90% of lines need to start without a bullet point
    n_lines = signals["ccnet_nlines"][0][2]
    n_lines_bulletpoint_start = sum(map(lambda ln: ln[2], signals["rps_lines_start_with_bulletpoint"]))
    if n_lines_bulletpoint_start / n_lines > 0.9:
        return False

    # rule 4: the ratio between characters in the most frequent 2-gram and the total number 
    # of characters must be below 0.2
    top_2_gram_frac = signals["rps_doc_frac_chars_top_2gram"][0][2]
    if top_2_gram_frac > 0.2:
        return False

    # rule 5: ...

    return True

In [6]:


ds_iterator = load_dataset("togethercomputer/RedPajama-Data-V2",
                  name="default",
                  partition="head_middle",
                  #snapshots=["2023-06", "2022-49"],
                  languages=["en", "de", "fr", "es", "it"],
                  streaming=True,) 

for sample in ds_iterator["train"]:

    if not gopher_rules_pass(sample):
        continue

    documents = json.loads(sample["documents"])

    print(documents)
    break
    

NameError: name 'json' is not defined

In [10]:
print(sample["raw_content"])

ABOUT AWB
KIDS ARE KIDS
JOIN THE CAST
<< Back to AWB News
Christine Rouse is honored on the “Today Show”
The executive director of Acting Without Boundaries (AWB), Christine Rouse, was featured on the NBC Today Show with “Kathie Lee and Hoda” on March 1, 2012. The monthly segment, called “Everyone Has A Story,” features one ordinary person that has had a life-changing experience in their own life. Christine submitted an essay describing her life’s mission of increasing awareness of and support for people with disabilities. She described the process of creating the two non-profits she manages – “Kids are Kids,” which provides disability awareness workshops and AWB which provides theater arts opportunities for children, youth and young adults with physical disabilities . Christine talked about the importance of both in increasing inclusion for people, especially young people, with physical disabilities.
The March “Everyone Has A Story” segment featured Christine, her mother, and her brot

In [None]:
class Distiller:
	def __init__(self, params, student_model, teacher_model, dataloader):
		self.alpha_ce = params.get("alpha_ce", 0.5)
		self.alpha_mlm = params.get("alpha_mlm", 0.5)
		self.alpha_clm = params.get("alpha_clm", 0.5)
		self.alpha_mse = params.get("alpha_mse", 0.0)
		self.alpha_cos = params.alpha_cos

		self.temperature = params.get("temperature", 2.0)

		self.mlm_mask_prob = params.get("mask_prob", 0.15)
		self.word_rand = params.get("word_rand", 0.1)
		self.word_keep = params.get("word_keep", 0.1)
		self.word_mask = params.get("word_mask", 0.8)
		assert self.word_rand + self.word_keep + self.word_mask == 1.0

		self.n_epoch = params.n_epoch
		self.batch_size = params.batch_size
		self.gradient_accumulation_steps = params.get("gradient_accumulation_steps", 50)

		self.warmup_prop = params.get("warmup_prop", 0.05)
		self.weight_decay = params.get("weight_decay", 0.0)
		self.learning_rate = params.get("learning_rate", 5e-4)
		self.adam_epsilon = params.get("adam_epsilon", 1e-6)
		self.max_grad_norm = params.get("max_grad_norm", 5.0)
		self.initializer_range = params.get("initializer_range", 0.02)

		self.student = student_model
		self.teacher = teacher_model

		self.dataloader = dataloader

		self.epoch = 0
		self.n_iter = 0
		self.n_total_iter = 0
		self.n_sequences_epoch = 0
		self.total_loss_epoch = 0
		self.last_loss = 0
		self.last_loss_ce = 0
		self.last_loss_clm = 0
		if self.alpha_mse > 0.0:
			self.last_loss_mse = 0
		if self.alpha_cos > 0.0:
			self.last_loss_cos = 0
		self.last_log = 0

		self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
		self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
		if self.alpha_mse > 0.0:
			self.mse_loss_fct = nn.MSELoss(reduction="sum")
		if self.alpha_cos > 0.0:
			self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")

		logger.info("--- Initializing model optimizer")
		assert params.gradient_accumulation_steps >= 1
		self.num_steps_epoch = len(self.dataloader)
		num_train_optimization_steps = (
			int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1
		)

		no_decay = ["bias", "LayerNorm.weight"]
		optimizer_grouped_parameters = [
			{
				"params": [
					p for n, p in student_model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
				],
				"weight_decay": params.weight_decay,
			},
			{
				"params": [
					p for n, p in student_model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
				],
				"weight_decay": 0.0,
			},
		]
		logger.info(
			"------ Number of trainable parameters (student): %i"
			% sum([p.numel() for p in self.student.parameters() if p.requires_grad])
		)
		logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
		self.optimizer = AdamW(
			optimizer_grouped_parameters, lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98)
		)

		warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
		self.scheduler = get_linear_schedule_with_warmup(
			self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps
		)

		logger.info("--- Initializing Tensorboard")
		self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, "log", "train"))
		self.tensorboard.add_text(tag="config/training", text_string=str(self.params), global_step=0)
		self.tensorboard.add_text(tag="config/student", text_string=str(self.student_config), global_step=0)


	def round_batch(self, x: torch.tensor, lengths: torch.tensor):
		"""
		For float16 only.
		Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.

		Input:
		------
			x: `torch.tensor(bs, seq_length)` - The token ids.
			lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.

		Output:
		-------
			x:  `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
			lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
		"""
		if not self.fp16 or len(lengths) < 8:
			return x, lengths

		# number of sentences == 0 [8]
		bs1 = len(lengths)
		bs2 = 8 * (bs1 // 8)
		assert bs2 > 0 and bs2 % 8 == 0
		if bs1 != bs2:
			idx = torch.randperm(bs1)[:bs2]
			lengths = lengths[idx]
			slen = lengths.max().item()
			x = x[idx, :slen]
		else:
			idx = None

		# sequence length == 0 [8]
		ml1 = x.size(1)
		if ml1 % 8 != 0:
			pad = 8 - (ml1 % 8)
			ml2 = ml1 + pad
			if self.mlm:
				pad_id = self.params.special_tok_ids["pad_token"]
			else:
				pad_id = self.params.special_tok_ids["unk_token"]
			padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id)
			x = torch.cat([x, padding_tensor], 1)
			assert x.size() == (bs2, ml2)

		assert x.size(0) % 8 == 0
		assert x.size(1) % 8 == 0
		return x, lengths
	
	def prepare_batch(self, batch):
		"""
		Prepare the batch: from the token_ids and the lengths, compute the attention mask and the labels for CLM.

		Input:
		------
			batch: `Tuple`
				token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
				lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

		Output:
		-------
			token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
			attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
			clm_labels: `torch.tensor(bs, seq_length)` - The causal language modeling labels. There is a -100 where there is nothing to predict.
		"""
		token_ids, lengths = batch
		token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
		assert token_ids.size(0) == lengths.size(0)

		attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
		clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
		clm_labels[~attn_mask] = -100  # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility

		# sanity checks
		assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

		return token_ids, attn_mask, clm_labels
	
	def train(self):
		self.student.train()
		self.teacher.eval()

		for _ in range(self.n_epoch):
			iter_bar = tqdm(self.dataloader, desc="-Iter", disable=self.params.local_rank not in [-1, 0])
			for batch in iter_bar:
				input_ids, attn_mask, lm_labels = self.prepare_batch(batch) 
				self.step(input_ids, attn_mask, lm_labels)

				iter_bar.update()
				iter_bar.set_postfix(
					{"Last_loss": f"{self.last_loss:.2f}", "Avg_cum_loss": f"{self.total_loss_epoch/self.n_iter:.2f}"}
				)
			iter_bar.close()
			logger.info(f"--- Ending epoch {self.epoch}/{self.params.n_epoch-1}")
			self.end_epoch()

		logger.info("Save very last checkpoint as `pytorch_model.bin`.")
		self.save_checkpoint(checkpoint_name="pytorch_model.bin")
		logger.info("Training is finished")


	def step(self, input_ids  : torch.tensor,  attention_mask : torch.tensor, lm_labels : torch.tensor):
		"""
		Performs a single training step
		"""
		s_logits = self.student(input_ids=input_ids, attention_mask=None)
		with torch.no_grad():
			t_logits = self.teacher(input_ids=input_ids, attention_mask=None)

		assert s_logits.size() == t_logits.size()

		mask = attention_mask.unsqueeze(-1).expand_as(s_logits)  # (bs, seq_length, voc_size)
		s_logits_slct = torch.masked_select(s_logits, mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
		s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1))  # (bs * seq_length, voc_size) modulo the 1s in mask
		t_logits_slct = torch.masked_select(t_logits, mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
		t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1))  # (bs * seq_length, voc_size) modulo the 1s in mask
		assert t_logits_slct.size() == s_logits_slct.size()

		loss_ce = (
			self.ce_loss_fct(
				nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
				nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
			)
			* (self.temperature) ** 2
		)
		loss = self.alpha_ce * loss_ce

		if self.alpha_mlm > 0.0:
			loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
			loss += self.alpha_mlm * loss_mlm

		if self.alpha_clm > 0.0:
			shift_logits = s_logits[..., :-1, :].contiguous()
			shift_labels = lm_labels[..., 1:].contiguous()
			loss_clm = self.lm_loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
			loss += self.alpha_clm * loss_clm

		if self.alpha_mse > 0.0:
			loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct) / s_logits_slct.size(
				0
			)  # Reproducing batchmean reduction
			loss += self.alpha_mse * loss_mse

		self.total_loss_epoch += loss.item()
		self.last_loss = loss.item()
		self.last_loss_ce = loss_ce.item()
		if self.alpha_mlm > 0.0:
			self.last_loss_mlm = loss_mlm.item()
		if self.alpha_clm > 0.0:
			self.last_loss_clm = loss_clm.item()
		if self.alpha_mse > 0.0:
			self.last_loss_mse = loss_mse.item()

		self.optimize(loss)

		self.n_sequences_epoch += input_ids.size(0)

	def optimize(self, loss):
		"""
		Normalization on the loss (gradient accumulation or distributed training), followed by
		backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
		Also update the metrics for tensorboard.
		"""
		# Check for NaN
		if (loss != loss).data.any():
			logger.error("NaN detected")
			exit()

		if self.params.gradient_accumulation_steps > 1:
			loss = loss / self.params.gradient_accumulation_steps

	
		loss.backward()

		self.iter()
		if self.n_iter % self.params.gradient_accumulation_steps == 0:
			
			nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
			self.optimizer.step()
			self.optimizer.zero_grad()
			self.scheduler.step()

	def iter(self):
		"""
		Update global counts, write to tensorboard and save checkpoint.
		"""
		self.n_iter += 1
		self.n_total_iter += 1

		if self.n_total_iter % self.params.log_interval == 0:
			self.log_tensorboard()
			self.last_log = time.time()
		if self.n_total_iter % self.params.checkpoint_interval == 0:
			self.save_checkpoint()

	def log_tensorboard(self):
		"""
		Log into tensorboard. Only by the master process.
		"""
		if not self.is_master:
			return

		for param_name, param in self.student.named_parameters():
			self.tensorboard.add_scalar(
				tag="parameter_mean/" + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter
			)
			self.tensorboard.add_scalar(
				tag="parameter_std/" + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter
			)
			if param.grad is None:
				continue
			self.tensorboard.add_scalar(
				tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(), global_step=self.n_total_iter
			)
			self.tensorboard.add_scalar(
				tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter
			)

		self.tensorboard.add_scalar(
			tag="losses/cum_avg_loss_epoch",
			scalar_value=self.total_loss_epoch / self.n_iter,
			global_step=self.n_total_iter,
		)
		self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter)
		self.tensorboard.add_scalar(
			tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter
		)
		if self.alpha_mlm > 0.0:
			self.tensorboard.add_scalar(
				tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter
			)
		if self.alpha_clm > 0.0:
			self.tensorboard.add_scalar(
				tag="losses/loss_clm", scalar_value=self.last_loss_clm, global_step=self.n_total_iter
			)
		if self.alpha_mse > 0.0:
			self.tensorboard.add_scalar(
				tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter
			)
		
		self.tensorboard.add_scalar(
			tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter
		)

		self.tensorboard.add_scalar(
			tag="global/memory_usage",
			scalar_value=psutil.virtual_memory()._asdict()["used"] / 1_000_000,
			global_step=self.n_total_iter,
		)
		self.tensorboard.add_scalar(
			tag="global/speed", scalar_value=time.time() - self.last_log, global_step=self.n_total_iter
		)

	def end_epoch(self):
		"""
		Finally arrived at the end of epoch (full pass on dataset).
		Do some tensorboard logging and checkpoint saving.
		"""
		logger.info(f"{self.n_sequences_epoch} sequences have been trained during this epoch.")

		if self.is_master:
			self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth")
			self.tensorboard.add_scalar(
				tag="epoch/loss", scalar_value=self.total_loss_epoch / self.n_iter, global_step=self.epoch
			)

		self.epoch += 1
		self.n_sequences_epoch = 0
		self.n_iter = 0
		self.total_loss_epoch = 0

	def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
		"""
		Save the current state. Only by the master process.
		"""
		if not self.is_master:
			return
		mdl_to_save = self.student.module if hasattr(self.student, "module") else self.student
		mdl_to_save.config.save_pretrained(self.dump_path)
		state_dict = mdl_to_save.state_dict()
		torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
