In [None]:
!pip install -qqq --upgrade optuna transformers==4.47 bitsandbytes peft accelerate datasets nvidia-ml-py3 matplotlib torchmetrics tensorboard

In [None]:
%load_ext tensorboard

# Utils Cell

In [None]:
import random
import numpy as np
import torch
import torch.nn.functional as F
import re
import sys
from typing import Iterable, Mapping, Union, Callable, Optional, List, Dict, Any
from types import NoneType
import matplotlib.pyplot as plt
from transformers.utils import logging
from huggingface_hub.utils import enable_progress_bars, disable_progress_bars
from dataclasses import dataclass, field
import time
import os

Numeric = Union[int, float, torch.Tensor]
class Identity():
  def __call__(self, *args, **kwargs):
    return Identity()
  def __getattribute__(self, name):
    return Identity()
class Logger:
  def __init__(
      self,
      verbosity=logging.INFO,
      enable_progress_bars=True,
      handle_enabled = True,
    ):
    self.set_verbosity(verbosity)
    self.set_progress_bars(enable_progress_bars)
    self.handle_enabled = handle_enabled
  def set_handle(self, handle=None, enable=None):
    if enable is not None:
      self.handle_enabled = enable
    if handle != None:
      self.handle = handle
  def set_verbosity(self, verbosity):
    self.verbosity = verbosity
  def set_progress_bars(self, enable):
    if enable:
      enable_progress_bars()
    else:
      disable_progress_bars()
  def is_verbosity(self, verbosity):
    return self.verbosity <= verbosity
  def log(self, verbosity, *args, **kwargs):
    if self.is_verbosity(verbosity):
      print(*args, **kwargs)
  def critical(self, *args, **kwargs):
    self.log(logging.CRITICAL, *args, **kwargs)
  def error(self, *args, **kwargs):
    self.log(logging.ERROR, *args, **kwargs)
  def warning(self, *args, **kwargs):
    self.log(logging.WARNING, *args, **kwargs)
  def info(self, *args, **kwargs):
    self.log(logging.INFO, *args, **kwargs)
  def debug(self, *args, **kwargs):
    self.log(logging.DEBUG, *args, **kwargs)
  def __getattr__(self, name):
    try:
      obj = self.handle
    except AttributeError:
      obj = super()
    return obj.__getattribute__(name) if self.handle_enabled else Identity()
# logging.CRITICAL or logging.FATAL: only report the most critical errors.
# logging.ERROR: only report errors.
# logging.WARNING or logging.WARN: only reports error and warnings. This is the default level used by the library.
# logging.INFO: reports error, warnings and basic information.
# logging.DEBUG: report all information.
# GENERAL
def exists(obj, attr):
  return hasattr(obj, attr) and getattr(obj, attr) != None
def notexists(obj, attr):
  return not hasattr(obj, attr) or getattr(obj, attr) == None
# DATASET DEBUG
def is_double_iter(data):
  return isinstance(data, Iterable) and all(isinstance(inner, Iterable) for inner in data)
def is_good_for_lens(data):
  return isinstance(data, Iterable) and all(isinstance(inner, Iterable) and getattr(inner, "__len__", None) is not None for inner in data)
def check_pretty_print_list_of_lists(data):
  if is_double_iter(data):
    logger.info("[")
    for inner in data:
        logger.info(f"  {repr(inner)},")
    logger.info("]")
  else:
    logger.info(data)
def pretty_print_dict(data):
  logger.info("[")
  for key, value in data.items():
      logger.info(f"{key}: ", end="")
      check_pretty_print_list_of_lists(value)
  logger.info("}")
def pretty_print(data):
  if isinstance(data, Mapping):
    pretty_print_dict(data)
  else:
    check_pretty_print_list_of_lists(data)
def detokenized(
    data,
    tokenizer,
    dontconvert=["attention_mask"],
    exclude=["labels"],
    ):
  if isinstance(data, Mapping):
    tmpdic = {}
    for k,v in data.items():
      if k in dontconvert:
        tmp = v
      elif isinstance(v, Iterable):
        tmp = detokenized(v,tokenizer)
      else:
        raise TypeError()
      if k not in exclude:
        tmpdic[k] = tmp
    return tmpdic
  if isinstance(data, Iterable):
    tmpobj = None
    if all(isinstance(inner, Iterable) for inner in data):
      if all(isinstance(innermost, Numeric) for inner in data for innermost in inner):
        if any(innermost < 0 for inner in data for innermost in inner):
          tmpobj = data
        else:
          tmpobj = list(map(tokenizer.convert_ids_to_tokens, data))
    if tmpobj != None and not any(isinstance(inner, str) for inner in data):
      return tmpobj
    elif tmpobj != None:
      return type(data)(tmpobj)
    else:
      raise TypeError(f"Unknown nested type: {type(next(data))}")
  else:
    raise TypeError(f"Unknown type: {type(data)}")
def print_detokenized(
    data,
    tokenizer,
    dontconvert=["attention_mask"],
    exclude=["labels"],
    ):
  pretty_print(detokenized(data, tokenizer, dontconvert, exclude))
def lens(data):
  return [len(a) for a in data]
def type_shape(data_dic):
  return {k:(v.dtype, v.shape, v.device) if isinstance(v,torch.Tensor) else ((type(v), lens(v)) if is_good_for_lens(v) else type(v)) for k,v in data_dic.items()}
def get_random_or_indexed(dataset, index=None, split="train"):
  tmpsplit = dataset[split]
  i = index if isinstance(index, int) else random.randrange(len(tmpsplit))
  return i, tmpsplit[i]

# SIZE DEBUG
def get_bit_scaling_from_type(_type=None):
  types = ["B", "KB", "MB", "GB", "TB"]
  if _type == "b":
    return 8
  elif _type in types:
    return 1024**(-types.index(_type))
  else:
    raise ValueError(f"Unknown type: {_type}")
def get_readable_size(size_bits, size_type_to_print=None, join=True):
  if size_type_to_print == None:
    size_type_to_print = "b"
  size_scaling = get_bit_scaling_from_type(size_type_to_print)
  _size = size_bits * size_scaling
  return f"{_size} {size_type_to_print}" if join else (_size, size_type_to_print)
def get_obj_size(obj, gpu_only=False):
  obj_handle = None
  if isinstance(obj, torch.Tensor):
    if (not gpu_only or (gpu_only and obj.is_cuda)):
      obj_handle = obj.storage()
  elif not gpu_only:
    obj_handle = obj
  obj_size = 0
  if obj_handle is not None:
    obj_size = sys.getsizeof(obj_handle)
  return obj_size
def get_obj_size_rec(obj, gpu_only=False, depth = None):
  marked = {id(obj)}
  obj_q = [obj]
  sz = 0
  while obj_q:
      sz += sum([get_obj_size(obj, gpu_only) for obj in obj_q])
      new_refr = {}

      for o in gc.get_referents(*obj_q):
        o_id = id(o)
        if o_id not in marked and not isinstance(o, type):
          new_refr[o_id] = o
      """
      all_refr = ((id(o), o) for o in gc.get_referents(*obj_q))
      new_refr = {o_id: o for o_id, o in all_refr if o_id not in marked and not isinstance(o, type)}
      """
      obj_q = new_refr.values()
      marked.update(new_refr.keys())
      if depth is not None:
        depth-=1
        if depth <= 0:
          break
  return sz
def get_readable_obj_size(
    obj,
    gpu_only=False,
    depth = None,
    size_type_to_print=None,
    join=False,
    ):
  obj_size = get_obj_size_rec(obj, gpu_only, depth)
  return get_readable_size(obj_size, size_type_to_print, join)
def get_objs_size(
    gpu_only=False,
    depth = None,
    size_type_to_print=None,
    increasing = False,
    old_obj_size = {}
    ):
  objs_size = {}
  for obj in gc.get_objects():
    obj_id = id(obj)
    obj_type = type(obj)
    if obj_id not in old_obj_size:
      obj_size = get_readable_obj_size(
          obj,
          gpu_only,
          depth,
          size_type_to_print
      )
      if obj_size[0] > 0:
        objs_size[obj_id] = (obj_type,)+obj_size
      elif obj_size[0] < 0:
        raise ValueError(f"Negative object size for {obj_size}")
  objs_size = dict(sorted(objs_size.items(), key=lambda item: item[1][1], reverse=not increasing))
  return {k: " ".join([str(v[0]), str(v[1]), v[2]]) for k,v in objs_size.items()}

#PARAM DEBUG
def get_info_from_param(param):
  tmp = {
      "device": param.device,
      "dtype": param.dtype,
      "shape": param.shape,
      "numel": param.numel(),
      "size": sys.getsizeof(param.storage()),
      "requires_grad": param.requires_grad,
  }
  if exists(param, "grad"):
    tmp["grad dtype"] = param.grad.dtype,
    tmp["grad mean"] = param.grad.mean(),
  return tmp
def info_str_param(name, param_info, size_type_to_print):
  return f"{name}: "+", ".join([f"{k}={get_readable_size(v, size_type_to_print)}" if k=="size" else f"{k}={v}" for k,v in param_info.items()])
class BinnedCounter():
    def __init__(self, bins=None, right=False, device="cpu"):
        """
        right=True <= <
        right=False < <=
        """
        self.right = right
        self.device = device
        self.bins, self._bin_diff = self.create_bins(bins)
        self.length = len(self.bins)+1
        self._counts = torch.zeros(self.length).long().to(self.device)
    def create_bins(self, binargs):
        if isinstance(binargs, (list, tuple)):
          if len(binargs) != 3:
            raise ValueError(f"binargs has the wrong length {len(binargs)}")
          num_bins = int(binargs[-1])
          start, end = tuple(map(float, binargs[:-1]))
          bins = torch.linspace(start, end, num_bins)
          bin_diff = (end - start) / num_bins
        elif isinstance(binargs, torch.Tensor):
          bins = binargs
          bin_diff = self.bins[1] - self.bins[0] if len(self.bins) > 1 else 0.1
        else:
          raise TypeError(f"binargs has the wrong type {type(bins)}")
        return bins.to(self.device), torch.tensor([bin_diff]).to(self.device)
    def update(self, data):
        bin_indices = torch.bucketize(data, self.bins, right=self.right)
        indices, counts = bin_indices.unique(return_inverse=False, return_counts = True)
        self._counts[indices] += counts
    def _labels(self):
        def label(i):
            start = float("-inf") if i == 0 else self.bins[i-1]
            end = float("inf") if i >= len(self.bins) else self.bins[i]
            return f"({start:.4f})-({end:.4f})"
        return [label(i) for i in torch.arange(self.length)]
    def labels(self):
        def label(i):
            start = float("-inf") if i == 0 else self.bins[i-1]
            end = float("inf") if i >= len(self.bins) else self.bins[i]
            return f"{start if self.right else end:.4f}"
        return [label(i) for i in torch.arange(self.length)]
    def xticks(self, start=True):
        res = None
        if start:
            _start = self.bins[0] - self._bin_diff
            res = torch.concat([_start, self.bins])
        else:
            end = self.bins[-1] + self._bin_diff
            res = torch.concat([self.bins, end])
        return res.cpu()
    def counts(self):
        return self._counts.cpu()
    def bin_diff(self):
        return self._bin_diff.item()
OPTIONS = {
      "a": ("a", "all"),
      "t": ("t", "train", "trainable"),
      "nt": ("nt", "nontrain", "nontrainable"),
      "n": ("n", "none", None),
  }
def map_named_params(callback, model):
  for name, param in model.named_parameters():
    callback(name, param)
def map_named_modules(callback, model):
  for name, param in model.named_modules():
    callback(name, param)
def print_model_compare_any(model, compare, allOrAny, threshold = 1, grads = False):
  def compfuncs(attr):
    return {
      "gt": lambda x: getattr(x, attr) > threshold,
      "lt": lambda x: getattr(x, attr) < threshold,
      "eq": lambda x: getattr(x, attr) == threshold,
      "neq": lambda x: getattr(x, attr) != threshold,
      "geq": lambda x: getattr(x, attr) >= threshold,
      "leq": lambda x: getattr(x, attr) <= threshold,
  }
  _attr = "grad" if grads else "data"
  def print_if_satisfies(name, param):
    if exists(param, _attr) and getattr(compfuncs(_attr)[compare](param), allOrAny)():
      logger.debug(name)
  map_named_params(print_if_satisfies, model)

def print_info_params_and_ret_hist(
    model,
    size_type_to_print=None,
    print_trainable="t",
    binargs = (-1, 1, 100),
    calc_param_counts = True,
    calc_grad_counts = True,
  ):
  if isinstance(size_type_to_print, Union[str, type(None)]):
    size_type_to_print_all = size_type_to_print
    size_type_to_print_trainable = size_type_to_print
    size_type_to_print_nontrainable = size_type_to_print
  elif isinstance(size_type_to_print, Union[tuple, list]):
    if len(size_type_to_print) == 2:
      size_type_to_print_all, size_type_to_print_trainable = size_type_to_print
      size_type_to_print_nontrainable = size_type_to_print_all
    elif len(size_type_to_print) == 3:
      size_type_to_print_all, size_type_to_print_trainable, size_type_to_print_nontrainable = size_type_to_print
    else:
      raise ValueError(f"size_type_to_print is tuple but has wrong length: {len(size_type_to_print)}")
  else:
    raise ValueError(f"Unknown size_type_to_print: {size_type_to_print}")

  all_params = 0
  all_size = 0
  trainable_params = 0
  trainable_size = 0
  nontrainable_params = 0
  nontrainable_size = 0
  grad_calculable = False
  param_counter, grad_counter = None, None
  is_trainable_set = print_trainable.lower() in (OPTIONS["a"] + OPTIONS["t"])
  is_nontrainable_set = print_trainable.lower() in (OPTIONS["a"] + OPTIONS["nt"])
  def update_counts_sizes_then_print(name, param):
    nonlocal all_params, all_size, trainable_params, trainable_size, nontrainable_params, nontrainable_size, grad_calculable, param_counter, grad_counter
    param_info = get_info_from_param(param)
    isTrainable = (param.requires_grad and is_trainable_set)
    isNontrainable = (not param.requires_grad and is_nontrainable_set)
    isAll = isTrainable and isNontrainable

    size_type_to_print_param = None
    if isAll:
      size_type_to_print_param = size_type_to_print_all
    elif isTrainable:
      size_type_to_print_param = size_type_to_print_trainable
    elif isNontrainable:
      size_type_to_print_param = size_type_to_print_nontrainable
    all_params += param_info["numel"]
    all_size += param_info["size"]
    if isTrainable:
      trainable_params += param_info["numel"]
      trainable_size += param_info["size"]
    if isNontrainable:
      nontrainable_params += param_info["numel"]
      nontrainable_size += param_info["size"]
    if (isTrainable or isNontrainable) and binargs != None:
      logger.debug(info_str_param(name, param_info, size_type_to_print_param))
      if calc_param_counts:
        if param_counter == None:
          param_counter = BinnedCounter(binargs, device=model.device)
        param_counter.update(param.view(-1))
      if calc_grad_counts and exists(param, "grad"):
        if grad_counter == None:
          grad_counter = BinnedCounter(binargs, device=model.device)
        grad_counter.update(param.grad.view(-1))

  map_named_params(update_counts_sizes_then_print, model)

  all_size = get_readable_size(all_size, size_type_to_print_all)
  trainable_size = get_readable_size(trainable_size, size_type_to_print_trainable)
  nontrainable_size = get_readable_size(nontrainable_size, size_type_to_print_nontrainable)

  print_str_list = []
  print_str_list.append(f"all params: {all_params}")
  print_str_list.append(f"all size: {all_size}")
  if is_trainable_set:
      print_str_list.append(f"trainable params: {trainable_params}")
      print_str_list.append(f"trainable size: {trainable_size}")
      print_str_list.append(f"trainable%: {100 * trainable_params / all_params}")
  if is_nontrainable_set:
      print_str_list.append(f"nontrainable params: {nontrainable_params}")
      print_str_list.append(f"nontrainable size: {nontrainable_size}")
      print_str_list.append(f"nontrainable%: {100 * nontrainable_params / all_params}")
  logger.info(" || ".join(print_str_list))
  return param_counter, grad_counter

def plot_hist_params(
    param_counter,
    grad_counter,
    binargs = (-1, 1, 100),
    calc_param_counts = True,
    calc_grad_counts = True,
  ):
  def plot_hist_params_ax(ax, counter, title):
    logger.set_handle(ax)
    xticks = counter.xticks()
    yticks = counter.counts()
    logger.set_title(title)
    logger.set_xticks(xticks, counter.labels(), rotation=45)
    logger.set_xlabel(f"Values {'[)' if counter.right else '(]'}")
    logger.set_ylabel("Number of elements")
    logger.bar(xticks, yticks, width=counter.bin_diff())

  if (calc_param_counts or calc_grad_counts) and binargs != None:
    fig, (ax1,ax2) = plt.subplots(2,1)
    if calc_param_counts:
      plot_hist_params_ax(ax1, param_counter, "Parameters")
    if calc_grad_counts and grad_counter != None:
      plot_hist_params_ax(ax2, grad_counter, "Gradients")
    logger.set_handle(plt, True)
    logger.gcf().set_size_inches(40,10)
    logger.show()
    logger.clf()
def print_info_params(
    model,
    size_type_to_print=None,
    print_trainable="t",
    binargs = (-1, 1, 100),
    calc_param_counts = True,
    calc_grad_counts = True,
):
  param_counter, grad_counter = print_info_params_and_ret_hist(
    model,
    size_type_to_print,
    print_trainable,
    binargs,
  )
  plot_hist_params(
    param_counter,
    grad_counter,
    binargs,
    calc_param_counts,
    calc_grad_counts,
  )
def compare_models(model1, model2):
  allEqual = True
  for (name1, param1), (name2, param2) in zip(model1.named_parameters(), model2.named_parameters()):
    isEqual = torch.equal(param1, param2)
    if not isEqual:
      logger.debug(f"Weights differ: {name1}")
    allEqual = allEqual and isEqual
  if allEqual:
    logger.info("All params equal")
  else:
    logger.info("Some weights differ")
  return allEqual

def print_model(model):
  logger.info("Model architecture")
  logger.info(model.model)
def find_all_linear_names(model):
  cls = torch.nn.Linear
  lora_module_names = set()
  for name, module in model.named_modules():
      if isinstance(module, cls):
          names = name.split('.')
          lora_module_names.add(names[-1])

  if 'lm_head' in lora_module_names: # needed for 16-bit
      lora_module_names.remove('lm_head')
  return list(lora_module_names)
# LOSS DEBUG
def error_report(x, y):
  mae = F.l1_loss(x, y)
  mse = F.mse_loss(x, y)
  logger.info(
      f"Mean absolute error: {mae:>8.5f}\n"
      f"Mean squared error:  {mse:>8.5f}"
  )

In [None]:
#gc.set_debug(gc.DEBUG_LEAK)
logger = Logger(logging.DEBUG)
# logging.CRITICAL or logging.FATAL: only report the most critical errors.
# logging.ERROR: only report errors.
# logging.WARNING or logging.WARN: only reports error and warnings. This is the default level used by the library.
# logging.INFO: reports error, warnings and basic information.
# logging.DEBUG: report all information.

# Dataset Processing

In [None]:
#from trl.extras.dataset_formatting import get_formatting_func_from_dataset
from datasets import DatasetDict
from transformers import AutoTokenizer
import string
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
#MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
DATA_LENGTHS = 8

class Processor:
  def __init__(self, instruction, model_name, device = {"": 0}):
    # 'device' is a dictionary that maps devices to model parts. In this case, it is set to {"": 0}, which means that the entire model will be loaded on GPU 0.
    self.instruction = Processor.make_one_line(instruction)
    self.model_name = model_name
    self.label_key = "labels"
    self.device = device
  @staticmethod
  def make_one_line(word):
    return " ".join(word.split("\n")).strip(" ,")
  def words_to_sentence(self, words: list[str]) -> str:
    if not words:
        return ""
    # Handle spacing logic based on punctuation
    sentence = ""
    for i, word in enumerate(words):
        # Avoid spaces before punctuation
        if i > 0 and sentence and word not in string.punctuation:
            sentence += " "
        sentence += word
    # Capitalize first letter if it's a lowercase word
    sentence = sentence.capitalize()
    return sentence
  def sample_n(self, dataset, n, shuffle=True, seed=None):
    ns = None
    if isinstance(n, Iterable):
      ns = n
    elif isinstance(n, int):
      ns = [n]*len(dataset.keys())
    else:
      raise ValueError(f"Unknown n: {n}")
    def select_n(dataset, n):
      return dataset.select(range(n))
    def select_split_n_shuffled(dataset, n, shuffle=True):
      return select_n(dataset.shuffle(seed=seed), n) if shuffle else select_n(dataset, n)
    return DatasetDict(
        {
            split_name: select_split_n_shuffled(split, ns[i], shuffle) for i, (split_name, split) in enumerate(dataset.items())
        }
    )
  def split(self, dataset):
    train_dataset = dataset["train"]
    val_dataset = dataset["validation"]
    test_dataset = dataset["test"]
    return train_dataset, val_dataset, test_dataset
  def get_tokenizer(self):
    if exists(self, "tokenizer"):
      return self.tokenizer
    self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True, add_eos_token=True, use_fast=True)
    #self.tokenizer.pad_token = self.tokenizer.unk_token
    #self.tokenizer.pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token)
    self.tokenizer.padding_side = 'left' # For decoder-only models
    return self.tokenizer
  def get_model(self):
    if exists(self, "model"):
      return self.model
    if torch.cuda.is_bf16_supported():
      compute_dtype = torch.bfloat16 # BrainFloat 16 bits, can represent 32 bit with less precision (Mixed precision)
    else:
      compute_dtype = torch.float16 # 16 bit

    max_seq_length = 4096

    bnb_config = None
    if self.device != "cpu":
      bnb_config = BitsAndBytesConfig(
        load_in_4bit=True, # num of bits when quantizing
        bnb_4bit_quant_type="nf4",# Quantization type 4 bits NormalFloat
        bnb_4bit_use_double_quant=True,# Use nested quantization
        bnb_4bit_compute_dtype=compute_dtype, # num of bits when computing
        max_seq_length=max_seq_length,
      )
    attn_implementation = 'eager' # Flash attention not supported on gpus older than ampere

    self.model = AutoModelForCausalLM.from_pretrained(
      self.model_name,
      torch_dtype=compute_dtype,
      trust_remote_code=True,
      device_map=self.device,
      attn_implementation=attn_implementation,
      quantization_config=bnb_config,
    )
    return self.model
  def create_sentence_triples_columns(self, row_raw):
    row = eval(row_raw["text"])
    sentence = self.words_to_sentence(row["token"])
    triples = str([(row["h"]["name"], row["relation"], row["t"]["name"])])
    return {"sentence": sentence, self.label_key: triples}
  def create_message_column(self, row, training_mode=True):
    messages = []
    user = {
        "content": f"{self.instruction}\n Sentence: {row['sentence']}",
        "role": "user"
    }
    messages.append(user)
    if training_mode:
      assistant = {
          "content": row[self.label_key],
          "role": "assistant"
      }
      messages.append(assistant)
    return {"messages": messages}
  def format_dataset_chatml_and_tokenize(self, row, training_mode=True):
    return self.tokenizer.apply_chat_template(
      row["messages"],
      add_generation_prompt=not training_mode,
      tokenize=True,
      return_dict=True,
      add_special_tokens=False,
      truncation=True,
      padding=False,
      max_length=512,
      return_overflowing_tokens=False,
      return_length=False,
    )
  def format_dataset_chatml(self, row, training_mode=True):
    return {"messages": self.tokenizer.apply_chat_template(
      row["messages"],
      add_generation_prompt=not training_mode,
      tokenize=False,
      add_special_tokens=False,
      truncation=True,
      padding=False,
      max_length=512,
      return_overflowing_tokens=False,
      return_length=False,
    )}
  def process(self, dataset):
    logger.info("Creating sentence-triples columns")
    dataset = dataset.map(self.create_sentence_triples_columns, remove_columns=dataset["train"].column_names)

    tokenizer = self.get_tokenizer()

    rndindex = random.randrange(DATA_LENGTHS)

    def get_remove_columns():
      remove_columns = list(dataset[data_split].column_names)
      if not training_mode:
        remove_columns.remove(self.label_key)
      return remove_columns

    for data_split in dataset.keys():
      logger.info(f"Split {data_split}")
      logger.debug(dataset[data_split][rndindex])
      training_mode = data_split=="train"
      remove_columns = get_remove_columns()
      logger.info(f"Mapping columns to 'messages'")
      dataset[data_split] = dataset[data_split].map(
        self.create_message_column,
        remove_columns=remove_columns,
        fn_kwargs={"training_mode": training_mode}
      )
      logger.debug(dataset[data_split][rndindex])
      remove_columns = get_remove_columns()
      logger.info("Formatting to chatml format")
      dataset[data_split] = dataset[data_split].map(
        self.format_dataset_chatml_and_tokenize,
        remove_columns=remove_columns,
        fn_kwargs={"training_mode": training_mode},
      )
      logger.debug(dataset[data_split][rndindex])
    return self.split(dataset)
INSTRUCTION_WITH_PREDICATES = """
List of predicates is ['org:founded', 'org:subsidiaries', 'per:date_of_birth', 'per:cause_of_death',
'per:age', 'per:stateorprovince_of_birth', 'per:countries_of_residence', 'per:country_of_birth',
'per:stateorprovinces_of_residence', 'org:website', 'per:cities_of_residence', 'per:parents',
'per:employee_of', 'NA', 'per:city_of_birth', 'org:parents', 'org:political/religious_affiliation',
'per:schools_attended', 'per:country_of_death', 'per:children', 'org:top_members/employees',
'per:date_of_death', 'org:members', 'org:alternate_names', 'per:religion', 'org:member_of',
'org:city_of_headquarters', 'per:origin', 'org:shareholders', 'per:charges', 'per:title',
'org:number_of_employees/members', 'org:dissolved', 'org:country_of_headquarters', 'per:alternate_names',
'per:siblings', 'org:stateorprovince_of_headquarters', 'per:spouse', 'per:other_family', 'per:city_of_death',
'per:stateorprovince_of_death', 'org:founded_by'].
What Subject-Predicate-Object triples are included in the following sentence?
"""


In [None]:
processor = Processor(INSTRUCTION_WITH_PREDICATES, MODEL_NAME)
model = processor.get_model()
tokenizer = processor.get_tokenizer()

In [None]:
print_info_params(
    model,
    size_type_to_print=("GB", "MB"),
    print_trainable="a",
    calc_param_counts = False,
    calc_grad_counts = False,
)

In [None]:
from transformers import set_seed

set_seed(42)

# DEBUG Cell 1

In [None]:
from transformers.generation.utils import (
    GenerationMode,
    GenerationConfig,
    StoppingCriteriaList,
    LogitsProcessorList,
    GenerateBeamDecoderOnlyOutput,
    GenerateBeamEncoderDecoderOutput,
    stack_model_outputs,
    _split_model_inputs,
)
from transformers.generation.beam_search import BeamScorer, BeamHypotheses
import inspect
from torch import nn
from collections import UserDict
from transformers.generation.logits_process import LogitsProcessor
from collections import Counter

class CounterPrint:
  def __init__(self, enable=True):
    self.counter = Counter()
    self.enable = enable
  def __call__(self, label, *args, **kwargs):
    self.counter[label] += 1
    if self.enable:
      logger.info(self.counter[label], label, *args,**kwargs)

class TemperatureLogitsWarper(LogitsProcessor):
    def __init__(self, temperature: float):
        if not isinstance(temperature, float) or not (temperature > 0):
            except_msg = (
                f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
                "scores will be invalid."
            )
            if isinstance(temperature, float) and temperature == 0.0:
                except_msg += " If you're looking for greedy decoding strategies, set `do_sample=False`."
            raise ValueError(except_msg)

        self.temperature = temperature
        self.counter = 0
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        scores_processed = scores / self.temperature
        if self.counter == 65:
          torch.set_printoptions(edgeitems=20000)
          print("scores")
          print(scores)
          print(scores_processed)
        self.counter += 1
        return scores_processed

class BeamSearchScorer(BeamScorer):
    def __init__(
        self,
        batch_size: int,
        num_beams: int,
        device: torch.device,
        length_penalty: Optional[float] = 1.0,
        do_early_stopping: Optional[Union[bool, str]] = False,
        num_beam_hyps_to_keep: Optional[int] = 1,
        num_beam_groups: Optional[int] = 1,
        max_length: Optional[int] = None,
    ):
        self.num_beams = num_beams
        self.device = device
        self.length_penalty = length_penalty
        self.do_early_stopping = do_early_stopping
        self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
        self.num_beam_groups = num_beam_groups
        self.group_size = self.num_beams // self.num_beam_groups

        self._is_init = False
        # self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch.
        # If group_beam_search is not used, the list consists of `batch_size` beam_hyps.
        self._beam_hyps = [
            BeamHypotheses(
                num_beams=self.group_size,
                length_penalty=self.length_penalty,
                early_stopping=self.do_early_stopping,
                max_length=max_length,
            )
            for _ in range(batch_size * self.num_beam_groups)
        ]
        # self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group
        # in the i-th mini-batch is complete.
        self._done = torch.tensor(
            [False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device
        )
        #print(f"(init) self._done: {self._done}, self._beam_hyps: {self._beam_hyps}")
        if not isinstance(num_beams, int) or num_beams <= 1:
            raise ValueError(
                f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
                " one should make use of `greedy_search` instead."
            )

        if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
            raise ValueError(
                "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
                f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
            )

    @property
    def is_done(self) -> bool:
        return self._done.all()

    def process(
        self,
        input_ids: torch.LongTensor,
        next_scores: torch.FloatTensor,
        next_tokens: torch.LongTensor,
        next_indices: torch.LongTensor,
        pad_token_id: Optional[Union[int, torch.Tensor]] = None,
        eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
        beam_indices: Optional[torch.LongTensor] = None,
        group_index: Optional[int] = 0,
        decoder_prompt_len: Optional[int] = 0,
    ) -> Dict[str, torch.Tensor]:
        # add up to the length which the next_scores is calculated on (including decoder prompt)
        cur_len = input_ids.shape[-1] + 1
        batch_size = len(self._beam_hyps) // self.num_beam_groups

        if not (batch_size == (input_ids.shape[0] // self.group_size)):
            if self.num_beam_groups > 1:
                raise ValueError(
                    f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
                    f"size of {self.group_size} is expected by the beam scorer."
                )
            else:
                raise ValueError(
                    f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
                    f"{self.group_size} is expected by the beam scorer."
                )

        device = input_ids.device
        next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
        next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
        next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)

        if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
            if isinstance(eos_token_id, int):
                eos_token_id = [eos_token_id]
            eos_token_id = torch.tensor(eos_token_id)

        for batch_idx in range(batch_size):
            batch_group_idx = batch_idx * self.num_beam_groups + group_index
            #print(batch_idx, batch_group_idx, batch_size)
            #print(f"(long before) input_ids.device: {input_ids.device}")
            #print(f"(long before) self._done.device: {self._done.device}")
            #print(f"(long before) self._done.shape: {self._done.shape}")
            #print(f"(long before) self._beam_hyps: {self._beam_hyps}")
            #print(self._done)
            if self._done[batch_group_idx]:
                if self.num_beams < len(self._beam_hyps[batch_group_idx]):
                    raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
                if eos_token_id is None or pad_token_id is None:
                    raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
                # pad the batch
                next_beam_scores[batch_idx, :] = 0
                next_beam_tokens[batch_idx, :] = pad_token_id
                next_beam_indices[batch_idx, :] = 0
                continue

            # next tokens for this sentence
            beam_idx = 0
            for beam_token_rank, (next_token, next_score, next_index) in enumerate(
                zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
            ):
                batch_beam_idx = batch_idx * self.group_size + next_index
                # add to generated hypotheses if end of sentence
                if (eos_token_id is not None) and (next_token.item() in eos_token_id):
                    # if beam_token does not belong to top num_beams tokens, it should not be added
                    is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
                    if is_beam_token_worse_than_top_num_beams:
                        continue
                    if beam_indices is not None:
                        beam_index = beam_indices[batch_beam_idx]
                        beam_index = beam_index + (batch_beam_idx,)
                    else:
                        beam_index = None

                    self._beam_hyps[batch_group_idx].add(
                        input_ids[batch_beam_idx].clone(),
                        next_score.item(),
                        beam_indices=beam_index,
                        generated_len=cur_len - decoder_prompt_len,
                    )
                else:
                    # add next predicted token since it is not eos_token
                    next_beam_scores[batch_idx, beam_idx] = next_score
                    next_beam_tokens[batch_idx, beam_idx] = next_token
                    next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
                    beam_idx += 1

                # once the beam for next step is full, don't add more tokens to it.
                if beam_idx == self.group_size:
                    break

            if beam_idx < self.group_size:
                raise ValueError(
                    f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
                    f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
                )

            # Check if we are done so that we can save a pad step if all(done)
            #print(f"(before) self._done: {self._done}, self._beam_hyps: {self._beam_hyps}")
            self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
                next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
            )
            #print(f"(after) self._done: {self._done}, self._beam_hyps: {self._beam_hyps}")

        return UserDict(
            {
                "next_beam_scores": next_beam_scores.view(-1),
                "next_beam_tokens": next_beam_tokens.view(-1),
                "next_beam_indices": next_beam_indices.view(-1),
            }
        )

    def finalize(
        self,
        input_ids: torch.LongTensor,
        final_beam_scores: torch.FloatTensor,
        final_beam_tokens: torch.LongTensor,
        final_beam_indices: torch.LongTensor,
        max_length: int,
        pad_token_id: Optional[Union[int, torch.Tensor]] = None,
        eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
        beam_indices: Optional[torch.LongTensor] = None,
        decoder_prompt_len: Optional[int] = 0,
    ):
        batch_size = len(self._beam_hyps) // self.num_beam_groups

        if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
            if isinstance(eos_token_id, int):
                eos_token_id = [eos_token_id]
            eos_token_id = torch.tensor(eos_token_id)

        # finalize all open beam hypotheses and add to generated hypotheses
        for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
            if self._done[batch_group_idx]:
                continue

            # all open beam hypotheses are added to the beam hypothesis
            # beam hypothesis class automatically keeps the best beams
            for index_per_group in range(self.group_size):
                batch_beam_idx = batch_group_idx * self.group_size + index_per_group
                final_score = final_beam_scores[batch_beam_idx].item()
                final_tokens = input_ids[batch_beam_idx]
                beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
                generated_len = final_tokens.shape[-1] - decoder_prompt_len
                beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)

        # select the best hypotheses
        sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
        best = []
        best_indices = []
        best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)

        # retrieve best hypotheses
        for i in range(batch_size):
            beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
            candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
            sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
            for j in range(self.num_beam_hyps_to_keep):
                best_hyp_tuple = sorted_hyps.pop()
                best_score = best_hyp_tuple[0]
                best_hyp = best_hyp_tuple[1]
                best_index = best_hyp_tuple[2]
                sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)

                # append hyp to lists
                best.append(best_hyp)

                # append indices to list
                best_indices.append(best_index)

                best_scores[i * self.num_beam_hyps_to_keep + j] = best_score

        # prepare for adding eos
        sent_lengths_max = sent_lengths.max().item() + 1
        sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
        decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)

        if len(best_indices) > 0 and best_indices[0] is not None:
            indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
        else:
            indices = None

        # shorter batches are padded if needed
        if sent_lengths.min().item() != sent_lengths.max().item():
            if pad_token_id is None:
                raise ValueError("`pad_token_id` has to be defined")
            decoded.fill_(pad_token_id)

        if indices is not None:
            indices.fill_(-1)

        # fill with hypotheses and eos_token_id if the latter fits in
        for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
            decoded[i, : sent_lengths[i]] = hypo

            if indices is not None:
                indices[i, : len(best_idx)] = torch.tensor(best_idx)

            if sent_lengths[i] < sent_max_len:
                # inserting only the first eos_token_id
                decoded[i, sent_lengths[i]] = eos_token_id[0]

        return UserDict(
            {
                "sequences": decoded,
                "sequence_scores": best_scores,
                "beam_indices": indices,
            }
        )

def _beam_search(
    self,
    input_ids: torch.LongTensor,
    beam_scorer: BeamScorer,
    logits_processor: LogitsProcessorList,
    stopping_criteria: StoppingCriteriaList,
    generation_config: GenerationConfig,
    synced_gpus: bool,
    **model_kwargs,
):
        print("generation_config:")
        print(generation_config.to_dict())
        print("logits_processor:")
        print(logits_processor)
        # init values
        pad_token_id = generation_config._pad_token_tensor
        eos_token_id = generation_config._eos_token_tensor
        output_attentions = generation_config.output_attentions
        output_hidden_states = generation_config.output_hidden_states
        output_scores = generation_config.output_scores
        output_logits = generation_config.output_logits
        return_dict_in_generate = generation_config.return_dict_in_generate
        sequential = generation_config.low_memory
        do_sample = generation_config.do_sample

        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams

        batch_beam_size, cur_len = input_ids.shape
        model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        raw_logits = () if (return_dict_in_generate and output_logits) else None
        beam_indices = (
            tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
        )
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )

        # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
        # of the first beam are considered to avoid sampling the exact same tokens across all beams.
        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view((batch_size * num_beams,))

        this_peer_finished = False

        decoder_prompt_len = input_ids.shape[-1]  # record the prompt length of decoder
        _tmp_counter=0
        while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
            counter_print = CounterPrint()
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
            print(f"input_ids: {input_ids.shape}, model_inputs: {type_shape(model_inputs)}, model_kwargs: {type_shape(model_kwargs)}")

            counter_print("beam_scorer._done", beam_scorer._done)#1

            # prepare variable output controls (note: some models won't accept all output controls)
            model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
            model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

            counter_print("beam_scorer._done", beam_scorer._done)#2

            # if sequential is True, split the input to batches of batch_size and run sequentially
            if sequential:
                if any(
                    model_name in self.__class__.__name__.lower()
                    for model_name in [
                        "fsmt",
                        "reformer",
                        "ctrl",
                        "gpt_bigcode",
                        "transo_xl",
                        "xlnet",
                        "cpm",
                        "jamba",
                    ]
                ):
                    raise RuntimeError(
                        f"Currently generation for {self.__class__.__name__} is not supported "
                        f"for `low_memory beam_search`. Please open an issue on GitHub if you need this feature."
                    )
                counter_print("if sequential")
                counter_print("beam_scorer._done", beam_scorer._done)#3
                inputs_per_sub_batches = _split_model_inputs(
                    model_inputs,
                    split_size=batch_size,
                    full_batch_size=batch_beam_size,
                    config=self.config.get_text_config(),
                )
                outputs_per_sub_batch = [
                    self(**inputs_per_sub_batch, return_dict=True) for inputs_per_sub_batch in inputs_per_sub_batches
                ]

                outputs = stack_model_outputs(outputs_per_sub_batch, self.config.get_text_config())
            else:  # Unchanged original behavior
                counter_print("else any")
                counter_print("beam_scorer._done", beam_scorer._done)#3
                outputs = self(**model_inputs, return_dict=True)

            counter_print("outputs", type_shape(outputs))#1
            counter_print("beam_scorer._done", beam_scorer._done)#4

            # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs,
                model_kwargs,
                is_encoder_decoder=self.config.is_encoder_decoder,
            )
            if synced_gpus and this_peer_finished:
                cur_len = cur_len + 1
                continue
            counter_print("beam_scorer._done", beam_scorer._done)#5
            # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
            # (the clone itself is always small)
            # .float() is needed to retain precision for later logits manipulations
            next_token_logits = outputs.logits[:, -1, :].clone().float()
            next_token_logits = next_token_logits.to(input_ids.device)

            counter_print("next_token_logits", next_token_logits)
            next_token_scores = nn.functional.log_softmax(
                next_token_logits, dim=-1
            )  # (batch_size * num_beams, vocab_size)
            def noninf(_tensor):
              if _tmp_counter == 65:
                indices = (_tensor != -torch.inf).nonzero()
                #vals = _tensor[indices]
                return indices, #vals
              return None
            def isnan(_tensor):
              return _tensor.isnan().any()
            def printtofile(filename, var):
              torch.set_printoptions(edgeitems=self.vocab_size//2)
              with open(filename, "w") as f:
                print(var, end="", file=f)
              torch.set_printoptions(edgeitems=3)
            counter_print(f"_tmp_counter: {_tmp_counter}")
            counter_print("next_token_scores", next_token_scores, next_token_scores.shape)#1
            counter_print("noninf next_token_scores", noninf(next_token_scores))#1
            counter_print("beam_scorer._done", beam_scorer._done)#6
            counter_print("logits_processor", logits_processor)#1
            counter_print("temperature", generation_config.temperature)#1
            counter_print("input_ids", input_ids, input_ids.shape)#1
            counter_print("noninf input_ids", noninf(input_ids))#1
            next_token_scores_processed = logits_processor(input_ids, next_token_scores)
            if _tmp_counter == 65:
              printtofile("next_token_scores", next_token_scores)
              printtofile("next_token_scores_processed", next_token_scores_processed)
            counter_print("next_token_scores_processed", next_token_scores_processed, next_token_scores_processed.shape)
            counter_print("noninf next_token_scores_processed", noninf(next_token_scores_processed))#2
            counter_print("beam_scores", beam_scores, beam_scores.shape)#1
            next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
                next_token_scores_processed
            )
            counter_print("next_token_scores", next_token_scores, next_token_scores.shape)#2
            counter_print("noninf next_token_scores", noninf(next_token_scores))#2
            counter_print("beam_scorer._done", beam_scorer._done)#7
            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                counter_print("if return_dict_in_generate")
                counter_print("beam_scorer._done", beam_scorer._done)
                if output_scores:
                    scores += (next_token_scores_processed,)
                if output_logits:
                    raw_logits += (next_token_logits,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)
                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )
            counter_print("beam_scorer._done", beam_scorer._done)#8
            # reshape for beam search
            vocab_size = next_token_scores.shape[-1]
            next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
            counter_print("next_token_scores", next_token_scores, next_token_scores.shape)#3
            counter_print("noninf next_token_scores", noninf(next_token_scores))#3
            counter_print("beam_scorer._done", beam_scorer._done)#9
            n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
            n_tokens_to_keep = max(2, 1 + n_eos_tokens) * num_beams
            if do_sample:
                counter_print("if do_sample")
                counter_print("beam_scorer._done", beam_scorer._done)#10
                probs = nn.functional.softmax(next_token_scores, dim=-1)
                counter_print("isnan probs", isnan(probs))#1
                counter_print("next_token_scores", next_token_scores, next_token_scores.shape)#4
                counter_print("noninf next_token_scores", noninf(next_token_scores))#4
                counter_print("n_tokens_to_keep", n_tokens_to_keep)#1
                next_tokens = torch.multinomial(probs, num_samples=n_tokens_to_keep)
                counter_print("next_tokens", next_tokens, next_tokens.shape)#1
                next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
                counter_print("next_token_scores", next_token_scores, next_token_scores.shape)#5
                next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
                counter_print("next_token_scores", next_token_scores, next_token_scores.shape)#6
                next_tokens = torch.gather(next_tokens, -1, _indices)
                counter_print("next_token_scores", next_token_scores, next_token_scores.shape)#7
                counter_print("next_tokens", next_tokens, next_tokens.shape)#2
                counter_print("beam_scorer._done", beam_scorer._done)#11
            else:
                counter_print("else do_sample")
                counter_print("beam_scorer._done", beam_scorer._done)#10
                next_token_scores, next_tokens = torch.topk(
                    next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True
                )
                counter_print("next_token_scores", next_token_scores, next_token_scores.shape)#4
                counter_print("noninf next_token_scores", noninf(next_token_scores))#4
                counter_print("next_tokens", next_tokens, next_tokens.shape)#1
                counter_print("beam_scorer._done", beam_scorer._done)#11
            counter_print("beam_scorer._done", beam_scorer._done)#12
            next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
            next_tokens = next_tokens % vocab_size
            #print(f"before process")
            #print(f"before process beam_scorer._done: {beam_scorer._done}")
            # stateless
            beam_outputs = beam_scorer.process(
                input_ids,
                next_token_scores,
                next_tokens,
                next_indices,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                beam_indices=beam_indices,
                decoder_prompt_len=decoder_prompt_len,
            )
            #print(f"after process")
            #counter_print("beam_scorer._done", beam_scorer._done)

            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]

            counter_print("beam_scorer._done", beam_scorer._done)

            input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

            # This is needed to properly delete outputs.logits which may be very large for first iteration
            # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
            # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
            # (that way the memory peak does not include outputs.logits)
            del outputs
            counter_print("beam_scorer._done", beam_scorer._done)
            if model_kwargs.get("past_key_values", None) is not None:
                model_kwargs["past_key_values"] = self._temporary_reorder_cache(
                    model_kwargs["past_key_values"], beam_idx
                )
            counter_print("beam_scorer._done", beam_scorer._done)
            if return_dict_in_generate and output_scores:
                beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

            # increase cur_len
            cur_len = cur_len + 1
            counter_print("beam_scorer._done", beam_scorer._done)
            if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
                this_peer_finished = True
            counter_print("beam_scorer._done", beam_scorer._done)
            _tmp_counter += 1

        sequence_outputs = beam_scorer.finalize(
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
            beam_indices=beam_indices,
            decoder_prompt_len=decoder_prompt_len,
        )

        if return_dict_in_generate:
            if not output_scores:
                sequence_outputs["sequence_scores"] = None

            if self.config.is_encoder_decoder:
                return GenerateBeamEncoderDecoderOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
                    logits=raw_logits,
                    beam_indices=sequence_outputs["beam_indices"],
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
                    cross_attentions=cross_attentions,
                    decoder_hidden_states=decoder_hidden_states,
                    past_key_values=model_kwargs.get("past_key_values"),
                )
            else:
                return GenerateBeamDecoderOnlyOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
                    logits=raw_logits,
                    beam_indices=sequence_outputs["beam_indices"],
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                    past_key_values=model_kwargs.get("past_key_values"),
                )
        else:
            return sequence_outputs["sequences"]
@torch.no_grad()
def generate_debug(
    self,
    inputs = None,
    generation_config = None,
    logits_processor = None,
    stopping_criteria = None,
    prefix_allowed_tokens_fn = None,
    synced_gpus = None,
    assistant_model = None,
    streamer = None,
    negative_prompt_ids = None,
    negative_prompt_attention_mask = None,
    **kwargs,
):

        # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
        self._validate_model_class()
        tokenizer = kwargs.pop("tokenizer", None)  # Pull this out first, we only use it for stopping criteria
        assistant_tokenizer = kwargs.pop("assistant_tokenizer", None)  # only used for assisted generation

        generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
        self._validate_model_kwargs(model_kwargs.copy())
        self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)

        # 2. Set generation parameters if not already defined
        """
        if synced_gpus is None:
            synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
        """
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

        accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
        requires_attention_mask = "encoder_outputs" not in model_kwargs
        kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None

        # 3. Define model inputs
        inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
            inputs, generation_config.bos_token_id, model_kwargs
        )
        batch_size = inputs_tensor.shape[0]

        device = inputs_tensor.device
        self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)

        # decoder-only models must use left-padding for batched generation.
        """
        if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
            # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
            # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
            if (
                generation_config._pad_token_tensor is not None
                and batch_size > 1
                and len(inputs_tensor.shape) == 2
                and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
            ):
                logger.warning(
                    "A decoder-only architecture is being used, but right-padding was detected! For correct "
                    "generation results, please set `padding_side='left'` when initializing the tokenizer."
                )
        """
        # 4. Define other model kwargs
        # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
        # generating the first new token or not, and we only want to use the embeddings for the first new token)
        if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
            generation_config.use_cache = True

        if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
            model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
                inputs_tensor, generation_config, model_kwargs
            )
        elif kwargs_has_attention_mask:
            # TODO (joao): generalize this check with other types of inputs
            if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
                raise ValueError("`attention_mask` passed to `generate` must be 2D.")

        if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
            # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
            model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
                inputs_tensor, model_kwargs, model_input_name, generation_config
            )

        # 5. Prepare `input_ids` which will be used for auto-regressive generation
        if self.config.is_encoder_decoder:
            input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
                batch_size=batch_size,
                model_input_name=model_input_name,
                model_kwargs=model_kwargs,
                decoder_start_token_id=generation_config._decoder_start_token_tensor,
                device=inputs_tensor.device,
            )
        else:
            input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")

        if generation_config.token_healing:
            input_ids = self.heal_tokens(input_ids, tokenizer)

        if streamer is not None:
            streamer.put(input_ids.cpu())

        # 6. Prepare `max_length` depending on other stopping criteria.
        input_ids_length = input_ids.shape[-1]
        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
        has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
        generation_config = self._prepare_generated_length(
            generation_config=generation_config,
            has_default_max_length=has_default_max_length,
            has_default_min_length=has_default_min_length,
            model_input_name=model_input_name,
            inputs_tensor=inputs_tensor,
            input_ids_length=input_ids_length,
        )

        # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
        # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
        # dynamically overrides this value as it can need more than the last token logits
        #if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
        #    model_kwargs["logits_to_keep"] = 1

        self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

        # 7. Prepare the cache.
        # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
        # - different models have a different cache name expected by the model (default = "past_key_values")
        # - `max_length`, prepared above, is used to determine the maximum cache length
        max_cache_length = generation_config.max_length - 1
        if (
            inputs_tensor.shape[1] != input_ids_length
            and model_input_name == "inputs_embeds"
            and not self.config.is_encoder_decoder
        ):
            max_cache_length += inputs_tensor.shape[1]
        self._prepare_cache_for_generation(
            generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device
        )

        # 8. determine generation mode
        generation_mode = generation_config.get_generation_mode(assistant_model)

        if streamer is not None and (generation_config.num_beams > 1):
            raise ValueError(
                "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
            )
        """
        if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
            warnings.warn(
                "You are calling .generate() with the `input_ids` being on a device type different"
                f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
                f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
                " Please make sure that you have put `input_ids` to the"
                f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
                " running `.generate()`.",
                UserWarning,
            )
        """
        # 9. prepare logits processors and stopping criteria
        prepared_logits_processor = self._get_logits_processor(
            generation_config=generation_config,
            input_ids_seq_length=input_ids_length,
            encoder_input_ids=inputs_tensor,
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            logits_processor=logits_processor,
            device=inputs_tensor.device,
            model_kwargs=model_kwargs,
            negative_prompt_ids=negative_prompt_ids,
            negative_prompt_attention_mask=negative_prompt_attention_mask,
        )
        prepared_stopping_criteria = self._get_stopping_criteria(
            generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
        )

        # Set model_kwargs `use_cache` so we can use it later in forward runs
        model_kwargs["use_cache"] = generation_config.use_cache

        if generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
            # 11. prepare beam search scorer
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=generation_config.num_beams,
                device=inputs_tensor.device,
                length_penalty=generation_config.length_penalty,
                do_early_stopping=generation_config.early_stopping,
                num_beam_hyps_to_keep=generation_config.num_return_sequences,
                max_length=generation_config.max_length,
            )
            #print(f"before: {input_ids.shape}")
            #print(input_ids)
            # 12. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids=input_ids,
                expand_size=generation_config.num_beams,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )
            #print(f"after: {input_ids.shape}")
            #print(input_ids)
            # 13. run beam sample
            result = _beam_search(
                self,
                input_ids,
                beam_scorer,
                logits_processor=prepared_logits_processor,
                stopping_criteria=prepared_stopping_criteria,
                generation_config=generation_config,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )
        """
        # Convert to legacy cache format if requested
        if (
            generation_config.return_legacy_cache is True
            and not is_torchdynamo_compiling()
            and hasattr(result, "past_key_values")
            and getattr(result.past_key_values, "to_legacy_cache") is not None
        ):
            result.past_key_values = result.past_key_values.to_legacy_cache()
        """
        return result

In [None]:
%env CUDA_LAUNCH_BLOCKING=1
%env TORCH_USE_CUDA_DSA=1 #doesnt work anymore due to a bug
os.getenv("TORCH_USE_CUDA_DSA"), os.getenv("CUDA_LAUNCH_BLOCKING")

# DEBUG Cell 2

In [None]:
from transformers import pipeline, TextStreamer
from datasets import Dataset
def generate_response():
    #torch.set_printoptions(edgeitems=model.vocab_size//2)
    dic = {
      "query": [
        "Hi how are you doing. Please reason step by step, and put your final answer within \boxed{}",
        "What about solving an 2x + 3 = 7 equation? Please reason step by step, and put your final answer within \boxed{}",
        "What about solving an 3x + 5 = 7 equation? Please reason step by step, and put your final answer within \boxed{}",
      ]
    }
    #messages = [
    #    {"role": "user", "content": "What about solving an 2x + 3 = 7 equation? Please reason step by step, and put your final answer within \boxed{}"},
    #    {"role": "user", "content": "What about solving an 3x + 5 = 7 equation? Please reason step by step, and put your final answer within \boxed{}"},
    #]
    def create_message_column(row):
      messages = [
        {
          "role": "user",
          "content": row['query'],
        }
      ]
      return {"messages": messages}
    def format_dataset_chatml_and_tokenize(row):
      return tokenizer.apply_chat_template(
        row["messages"],
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        add_special_tokens=False,
        truncation=True,
        padding=False,
        max_length=512,
        return_overflowing_tokens=False,
        return_length=False,
      )
    def tokenize(row):
      return tokenizer(
          row["chat"],
          add_special_tokens=False,
          truncation=True,
          padding=False,
          max_length=512,
          return_overflowing_tokens=False,
          return_length=False,
          return_tensors="pt"
      )

    dataset = Dataset.from_dict(dic)
    print(1,dataset[:])
    dataset = dataset.map(create_message_column, remove_columns=dataset.column_names)
    print(2,dataset[:])
    dataset = dataset.map(format_dataset_chatml_and_tokenize, remove_columns=dataset.column_names)
    print(3,dataset[:])
    dataset = tokenizer.pad(dataset[:], return_tensors="pt").to("cuda")
    print(4,dataset[:])

    in_decoded = tokenizer.batch_decode(
      dataset["input_ids"],
      skip_special_tokens=True,
      clean_up_tokenization_spaces=True,
    )
    prompt_lens = lens(in_decoded)
    print("inputs decoded")
    pretty_print(in_decoded)
    """
    genargs = dict(
        max_new_tokens=256,
        do_sample=True,
        num_beams=4,
        temperature=0.3,
        top_p=0.95,
        max_time=180,
        generation_config=model.generation_config,
        use_cache=True, # Use caching for faster inference
    )
    """
    genargs = {
        'max_new_tokens': 256,
        'do_sample': True,
        'num_beams': 4,
        'temperature': 0.0012029803766213753,
        'top_p': 0.0012730068237399462,
        'top_k': None,
        'max_time': 180,
    }
    genargs.update(
      dict(
        generation_config=model.generation_config,
        use_cache=True, # Use caching for faster inference
      )
    )
    """
    outputs = model.generate(
        input_ids=dataset["input_ids"],
        attention_mask=dataset["attention_mask"],
        **genargs,
    )
    """
    outputs = generate_debug(
        model,
        input_ids=dataset["input_ids"],
        attention_mask=dataset["attention_mask"],
        **genargs,
    )

    out_decoded = tokenizer.batch_decode(
      outputs,
      skip_special_tokens=True,
      clean_up_tokenization_spaces=True,
    )
    print("outputs decoded")
    pretty_print(out_decoded)
    answers = [out[prompt_len:] for prompt_len, out in zip(prompt_lens, out_decoded)]
    print("answers")
    pretty_print(answers)
#outputs = generate_response()

In [None]:
import gc

def clean_up_memory(*objs):
  if len(objs) > 0:
    logger.info(f"Deleting objects:\n{','.join(map(str,objs))}")
  for obj in objs:
    del obj
  logger.info(f"Doing garbage collection")
  gc.collect()
  gc.collect()
  logger.info(f"Emptying CUDA cache")
  torch.cuda.empty_cache()
  torch.cuda.empty_cache()

In [None]:
"""
from google.colab import userdata
from huggingface_hub import login

HF_TOKEN=userdata.get('HF_TOKEN')
login(token=HF_TOKEN)
"""

In [None]:
from datasets import load_dataset

dataset_name = "xiaobendanyn/tacred"
#dataset_name = "./tacred"

dataset = load_dataset(dataset_name)

In [None]:
dataset.column_names

In [None]:
dataset_sample = processor.sample_n(dataset, n=DATA_LENGTHS)
train_dataset, val_dataset, test_dataset = processor.process(dataset_sample)

In [None]:
clean_up_memory(dataset)

In [None]:
pretty_print(train_dataset[:])

In [None]:
pretty_print(val_dataset[:])

In [None]:
pretty_print(test_dataset[:])

In [None]:
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from torch.utils.data import DataLoader

def merge_dicts(*dicts):
  return {k: [d[k] for d in dicts if k in d] for k in {k for d in dicts for k in d}}

class TorchData:
  ops = {
  "add": lambda a,b: a+b,
  "truediv": lambda a,b: a/b,
  }
  defaults = {
      "device": "cpu",
  }
  def __init__(self, _type, *args, **kwargs):
    self.attrs = {}
    for attr, val in TorchData.defaults.items():
      self.attrs[attr] = kwargs.pop(attr, val)
    if len(args) == 1:
      _args = args[0]
    elif _type == dict:
      _args = kwargs
    else:
      _args = args
    self._items = _type(_args)
  def check_to(self, obj):
    if isinstance(obj, (torch.Tensor, TorchData, BatchEncoding)):
      return obj.to(self.device)
    return obj
  def __len__(self):
    return len(self._items)
  def __repr__(self):
    return f"{self.__class__.__name__}({self._items}, {', '.join([f'{k}={v}' for k, v in self.attrs.items()])})"
  def __getattr__(self, attr):
    if hasattr(self._items, attr):
      return getattr(self._items, attr)
    elif attr in self.attrs:
      return self.attrs[attr]
    raise AttributeError()
  def __setattr__(self, attr, val):
    if attr in TorchData.defaults:
      self.attrs[attr] = val
    else:
      super().__setattr__(attr, val)
  def __iter__(self):
    return iter(self._items)
  def __getitem__(self, i):
    return self._items[i]
  def __setitem__(self, i, v):
    self._items[i] = v
  # Add custom operators here
  def __add__(self, other):
    return self.__class__.apply(TorchData.ops["add"], self, other)
  def __radd__(self, other):
    return self.__class__.apply(TorchData.ops["add"], other, self)
  def __truediv__(self, other):
    return self.__class__.apply(TorchData.ops["truediv"], self, other)
  def __rtruediv__(self, other):
    return self.__class__.apply(TorchData.ops["truediv"], other, self)
  def __eq__(self, other):
    if isinstance(other, TorchData):
      return self._items == other._items
    return self._items == other
  def remove(self, *args):
    self.remove_with(lambda k: k in args)
class TorchList(TorchData):
  def __init__(self, *args, **kwargs):
    super().__init__(list, *args, **kwargs)
  def to(self, device):
    self.device = device
    return TorchList(map(self.check_to, self._items), **self.attrs)
  @classmethod
  def apply(cls, operator, operand1, operand2):
    def aggregate(ops):
      return operator(*ops)
    if isinstance(operand1, cls) and isinstance(operand2, Numeric):
      operand2 = cls([operand2]*len(operand1))
    elif isinstance(operand1, Numeric) and isinstance(operand2, cls):
      operand1 = cls([operand1]*len(operand2))
    elif not (isinstance(operand1, (Numeric, cls)) or isinstance(operand2, (Numeric, cls))):
      raise NotImplementedError()
    return TorchList(map(aggregate, zip(operand1, operand2)), **operand1.attrs)
  def remove_with(self, condition):
    self._items = [k for k in self._items if not condition(k)]
  @classmethod
  def merge(cls, operand1, operand2):
    return TorchList(operand1._items + operand2._items, **operand1.attrs)
class TorchDict(TorchData):
  def __init__(self, *args, **kwargs):
    super().__init__(dict, *args, **kwargs)
  def to(self, device):
    self.device = device
    return TorchDict({k: self.check_to(v) for k,v in self.items()}, **self.attrs)
  @classmethod
  def apply(cls, operator, operand1, operand2):
    def aggregate(kvpairs):
      kvpair = tuple(zip(*kvpairs))
      return (kvpair[0][0], operator(*kvpair[1]))
    if isinstance(operand1, cls) and isinstance(operand2, Numeric):
      operand2 = cls(zip(operand1.keys(), [operand2]*len(operand1)))
    elif isinstance(operand1, Numeric) and isinstance(operand2, cls):
      operand1 = cls(zip(operand2.keys(), [operand1]*len(operand2)))
    elif not (isinstance(operand1, (Numeric, cls)) or isinstance(operand2, (Numeric, cls))):
      raise NotImplementedError()
    return TorchDict(map(aggregate, zip(operand1.items(), operand2.items())), **operand1.attrs)
  @classmethod
  def merge(cls, operand1, operand2):
    tmp = TorchDict(operand1._items, **operand1.attrs)
    tmp.update(operand2._items)
    return tmp
  def remove_with(self, condition):
    self._items = {k: v for k,v in self._items.items() if not condition(k)}
@dataclass
class CustomDefaultCollator:
  tokenizer: PreTrainedTokenizerBase
  pad_to_multiple_of: Optional[int] = None
  def __call__(
      self,
      examples: List[Union[List[int], Any, Dict[str, Any]]],
      custom_fields = ["labels"],
    ) -> Dict[str, Any]:
    # Separate token fields and custom fields, custom fields wont be padded
    token_fields_rows_list = [{k: row[k] for k in row if k not in custom_fields} for row in examples]
    custom_fields_rows_list = [{k: row[k] for k in row if k in custom_fields} for row in examples]
    # Pad token fields
    batch = self.tokenizer.pad(token_fields_rows_list, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
    # Merge padded token fields and custom fields
    custom_fields_rows_list = merge_dicts(*custom_fields_rows_list)
    labels = batch["input_ids"].clone()
    if self.tokenizer.pad_token_id is not None:
        labels[labels == self.tokenizer.pad_token_id] = -100
    batch["labels"] = labels
    return batch if (custom_fields_rows_list == {}) else TorchList([batch, custom_fields_rows_list])

#loader = DataLoader(
#  val_dataset,
#  batch_size = 8,
#  collate_fn = CustomDefaultCollator(processor.tokenizer),
#  num_workers = 0,
#  pin_memory = True,
#  persistent_workers = False,
#)
#for elem in loader:
#  logger.debug(elem)

# Trainer API

In [None]:
from torch import nn, optim
from torch.utils.data import DataLoader
from transformers.utils import logging
from torchmetrics import Metric
from torchmetrics.text import BLEUScore, ROUGEScore
from transformers import get_scheduler, DataCollatorForLanguageModeling, BatchEncoding
from peft import get_peft_model, PeftModel, LoftQConfig, LoraConfig, TaskType, prepare_model_for_kbit_training, replace_lora_weights_loftq
from torch.utils.tensorboard import SummaryWriter
from peft.utils.integrations import gather_params_ctx
import copy
import multiprocessing as mp

class CustomMetric:
  def forward(self, raw_text_preds, raw_text_true):
    raise NotImplementedError()
  def __call__(self, preds, true_labels):
    length = len(preds)
    if len(true_labels) != length:
      raise ValueError(f"true labels and predictions have different lengths: {len(true_labels)} != {length}")

    def proc_pair(start, end):
      for index in range(start, end):
        raw_preds = preds[index]
        raw_true = true_labels[index]

      return self.forward(raw_text_preds, raw_text_true)
    max_procs = os.cpu_count()
    proc_length = length // max_procs + 1
    procs = [mp.Process(target=proc_pair, args=(i, max(i+proc_length, length))) for i in range(0,length,proc_length)]

    index = 0
    for index in range(length):
      raw_text_preds = preds[index]
      raw_text_true = true_labels[index]
      try:
        raw_text_preds, raw_text_true = self.forward(raw_text_preds, raw_text_true)
        fixed = False
      except Exception:
        #logger.debug(f"Sample {index}")
        #logger.debug(f"True text: {raw_text_true}")
        #logger.debug(f"Predicted text: {raw_text_preds}")
        pass
    return self.ret()
class MicroF1(CustomMetric):
  def __init__(self):
    self.tps = 0
    self.total = 0
  def forward(self, raw_text_preds, raw_text_true):
    raw_text_preds = raw_text_preds.strip("` \n")
    if raw_text_preds.startswith("python"):
      raw_text_preds = self.fix(raw_text_preds)
    preds_triples = eval(raw_text_preds)
    true_triples = eval(raw_text_true)
    preds_triples_set = set(preds_triples)
    true_triples_set = set(true_triples)
    true_positives = preds_triples_set & true_triples_set # intersection set
    lentp = len(true_positives)
    self.tps += lentp # add intersection cardinal
    self.total += (len(preds_triples) + len(true_triples) - lentp) # add union cardinal
    return raw_text_preds, raw_text_true
  def ret(self):
    if self.total == 0:
      return 0
    return self.tps / self.total
  def fix(self, raw_text_preds):
    return raw_text_preds[raw_text_preds.index("["):]
@dataclass
class DataArg:
  fullname: str
  alias: str
  inference_mode: str
  dataset: Mapping
  batch_func: Callable
  batch_size: int = 0
  loader: DataLoader = None
@dataclass
class CustomArg:
  name: str
  args: dict
@dataclass
class OptimArg(CustomArg):
  pass
@dataclass
class SchedulerArg(CustomArg):
  pass
class ReinforceTrainer:
  def __init__(
        self,
        model,
        tokenizer,
        train_dataset,
        val_dataset,
        test_dataset,
        **kwargs,
    ):
    self.init_params()
    self.adapter_name = "tacred"
    self.writer = SummaryWriter()
    self.tokenizer = tokenizer
    #self.data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm=False)
    self.data_collator = CustomDefaultCollator(tokenizer=self.tokenizer)
    self.update(
      model,
      train_dataset,
      val_dataset,
      test_dataset,
      **kwargs
    )
  def __getattr__(self, attr):
    if attr == "defaults":
      return {}
    elif attr in self.params:
      return self.params[attr]
    else:
      raise AttributeError
  def __setattr__(self, attr, val):
    if attr in self.defaults:
      self.params[attr] = val
    else:
      super().__setattr__(attr, val)
  def init_update(
        self,
        kwargs,
        update_params_dic,
        callback_update=Identity(),
        update_kwargs={},
        callback_init=Identity(),
        init_kwargs={},
      ):
    self.defaults.update(update_params_dic)
    #logger.debug(kwargs, self.params, self.defaults)
    def update_if_nondefault(attr, val):
      val2set = kwargs.get(attr, val)
      #if attr in update_params_dic and attr == "lora_alpha":
      #  cond = hasattr(self, attr)
      #  logger.debug("lora_alpha", cond, val2set)
      #  if cond:
      #    logger.debug(getattr(self, attr))
      if notexists(self, attr):
        self.params[attr] = val2set
        return True
      return False
    def update_via_list():
      tmplist = [update_if_nondefault(*attr_val) for attr_val in update_params_dic.items()]
      logger.debug(f"Updated params: {dict(zip(update_params_dic.keys(), tmplist))}")
      return any(tmplist)
    if update_via_list():
      logger.info(f"Running init {callback_init.__name__}")
      callback_init(**init_kwargs)
    else:
      logger.info(f"Running update {callback_update.__name__}")
      callback_update(**update_kwargs)
  def update_params(self, **kwargs):
    self.params.update(kwargs)
  def init_params(self, **kwargs):
    self.defaults = {
        "num_epochs": 1,
        "gradient_accumulation_steps": 1,
        # EVAL
        "calc_loss_on_eval": False,
        # LOSS FUNCTION
        "shift_labels": True,
        "baseline": True,
        "label_smoothing_factor": 0.1,#
        "train_metric": None,#
        "train_metric_key": None,#
        "eval_metric": [MicroF1()],
        "num_samples": 0,#
        "multiple": True,#
        # GEN KWARGS
        "max_new_tokens": 256,
        "do_sample": True,
        "num_beams": 1,#
        "temperature": 0.3,#
        "top_p": 0.95,#
        "max_time": 180,
    }
    self.params = {k: kwargs.get(k, v) for k, v in self.defaults.items()}
  def reset_model_params(self):
    # to reset all params
    #self.model.init_weights()
    # to reset only lora params
    def is_in_module(key, module):
      for name, _ in module.named_children():
        if key in name:
          return True
      return False
    def reset_lora_weights(module):
      init_lora_weights = self.lora_config.init_lora_weights
      if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
        with gather_params_ctx(module.get_base_layer().weight):
          module.pissa_init(self.adapter_name, init_lora_weights)
      elif isinstance(init_lora_weights, str) and init_lora_weights.startswith("corda"):
        with gather_params_ctx(module.get_base_layer().weight):
          module.corda_init(self.adapter_name, init_lora_weights)
      elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora":
        with gather_params_ctx(module.get_base_layer().weight):
          module.olora_init(self.adapter_name)
      elif init_lora_weights == "loftq":
        with gather_params_ctx(module.get_base_layer().weight):
          module.loftq_init(self.adapter_name)
      elif init_lora_weights == "eva":
        nn.init.zeros_(module.lora_B[self.adapter_name].weight)
      elif init_lora_weights:
        module.reset_lora_parameters(self.adapter_name, init_lora_weights)
    def reset_if_lora(name, module):
      if is_in_module("lora", module):
        logger.debug(name)
        reset_lora_weights(module)
    logger.debug("Reset parameters")
    map_named_modules(reset_if_lora, model)
  def init_model(self, model, **kwargs):
    if hasattr(self, "model") and isinstance(self.model, PeftModel):
      logger.info(f"Unloading the model")
      if model == None:
        model = self.model.unload()
    # Enable gradient checkpointing to save memory
    _model = prepare_model_for_kbit_training(
        model,
        use_gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": True}
    )
    self.lora_config = LoraConfig(
      r=self.lora_r,
      lora_alpha=self.lora_alpha,
      lora_dropout=self.lora_dropout,
      task_type=TaskType.CAUSAL_LM,
      target_modules="all-linear",# list of the modules that should be targeted by LoRA, all linear layers
      use_rslora=True, # Rank-Stabilized LoRA: scale by lora_r/sqrt(lora_alpha) instead of lora_r/lora_alpha
      use_dora=True,
    )

    self.model = get_peft_model(_model, self.lora_config, adapter_name = self.adapter_name)
    # Lora weights are only used for attention and mlp weights
    # Below takes too long, probably longer with more lora weights
    #replace_lora_weights_loftq(self.model)
  def init_optimizer(self, **kwargs):
    self.optimizer = getattr(optim, self.optim_arg.name)(self.model.parameters(), **self.optim_arg.args)
  def init_scheduler(self, **kwargs):
    self.scheduler = get_scheduler(
       self.scheduler_arg.name,
       optimizer=self.optimizer,
       **self.scheduler_arg.args,
    )
  def get_loader(self, dataarg):
    return DataLoader(
      dataarg.dataset,
      batch_size = dataarg.batch_size,
      collate_fn = self.data_collator,
      num_workers = 0,
      pin_memory = True,
      persistent_workers = False,
    )
  def init_loaders(self, train_dataset, val_dataset, test_dataset):
    self.dataargs = {
        "train": DataArg("training","train", "train", train_dataset, self.train_on_batch, self.train_batch_size),
        "val": DataArg("validation","val", "eval", val_dataset, self.eval_on_batch, self.val_batch_size),
        "test": DataArg("test","test", "eval", test_dataset, self.eval_on_batch, self.test_batch_size),
    }
    for name, dataarg in self.dataargs.items():
      dataarg.loader = self.get_loader(dataarg)
  def update(
      self,
      model=None,
      train_dataset=None,
      val_dataset=None,
      test_dataset=None,
      **kwargs,
    ):
    batch_size = kwargs.get("batch_size", 8)
    self.init_update(
      kwargs,
      {
          "train_batch_size": batch_size,
          "val_batch_size": batch_size,
          "test_batch_size": batch_size,
      },
      callback_init = self.init_loaders,
      init_kwargs = {
        "train_dataset": train_dataset,
        "val_dataset": val_dataset,
        "test_dataset": test_dataset,
      }
    )
    self.init_update(
      kwargs,
      {
        "lora_alpha": 32,
        "lora_r": 16,
        "lora_dropout": 0.05,
      },
      callback_init = self.init_model,
      callback_update = self.reset_model_params,
      init_kwargs = {"model": model},
    )
    self.init_update(
      kwargs,
      {
        "optim_arg": OptimArg(name="AdamW", args={"lr": 1e-8, "betas": (0.9, 0.999), "weight_decay": 1e-2}),
      },
      callback_init = self.init_optimizer,
      callback_update = self.init_optimizer,
    )
    self.init_update(
      kwargs,
      {
        "scheduler_arg": SchedulerArg(
          name="linear",
          args={
            "num_training_steps": kwargs.get("num_epochs", self.num_epochs) * len(self.dataargs["train"].loader),
            "num_warmup_steps": 0,
          }
        ),
      },
      callback_init = self.init_scheduler,
      callback_update = self.init_scheduler,
    )
    self.update_params(**kwargs)
    logger.debug(f"Default params:\n {self.defaults}")
    logger.info(f"Current params:\n {self.params}")
  def remove_padding_mask(self, labels):
    padding_mask = labels < 0 # (N, L)
    return padding_mask, labels.masked_fill_(padding_mask, self.tokenizer.pad_token_id)
  def inference(self, batch):
    labels = batch.pop("labels")
    if self.dataarg.inference_mode == "eval":
      with torch.inference_mode():
        outputs = self.model(
            batch.input_ids,
            batch.attention_mask,
            use_cache=True,
        )
    else:
      outputs = self.model(
          batch.input_ids,
          batch.attention_mask,
          use_cache=True,
      )
    logits = outputs.logits
    return logits, labels
  def loss_function(self, batch):
    """params should be (metric=="ce" and num_samples==0) or (metric!="ce" and num_samples>0)"""
    def calc_scale(candidates, references):
      candidates_detokenized = self.tokenizer.batch_decode(
        candidates,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True,
      )
      references_detokenized = self.tokenizer.batch_decode(
        references,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True,
      )
      def ret_metric(res):
        return res if self.train_metric_key==None else res[self.train_metric_key]
      if self.multiple:
        res = [ret_metric(self.train_metric([c], [references_detokenized])) for c in candidates_detokenized]
      else:
        res = [ret_metric(self.train_metric([c], [[r]])) for c, r in zip(candidates_detokenized, references_detokenized)]
      return torch.tensor(res).to(self.model.device)
    logits, labels = self.inference(batch)
    if self.shift_labels:
        logits = logits[..., :-1, :].contiguous()# (N, L, S)
        labels = labels[..., 1:].contiguous()# (N, L)
    # log probs
    log_probs = -F.log_softmax(logits, dim=-1) # (N, L, S)
    # padding mask
    padding_mask, _ = self.remove_padding_mask(labels) # (N, L)
    num_active_elements = padding_mask.numel() - padding_mask.long().sum()
    # smoothed loss
    smoothed_loss = log_probs.sum(dim=-1) # (N, L)
    smoothed_loss.masked_fill_(padding_mask, 0.0)
    smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
    lpa = None
    metric_adv = None
    if self.num_samples < 0:
      raise ValueError(f"Number of samples {self.num_samples} cannot be negative")
    cond1ce = isinstance(self.train_metric, (nn.CrossEntropyLoss, NoneType))
    cond2ce = isinstance(self.train_metric, str) and self.train_metric.lower() in ("ce", "crossentropyloss")
    if cond1ce or cond2ce:
      if self.num_samples > 0:
        raise ValueError("Cross entropy loss cannot be used with num_samples")
      metric_adv = 1 # N
      labels_index = labels.view(*labels.shape, 1) # (N, L, 1)
      lpa = torch.gather(log_probs, index=labels_index, dim=-1).squeeze(-1) # (N, L)
    elif isinstance(self.train_metric, ROUGEScore) and self.train_metric_key == None:
      raise ValueError("ROUGEScore requires metric_key")
    elif isinstance(self.train_metric, Metric):# Other metrics
      actions_amax = logits.argmax(dim=-1) # (N, L)
      scale_amax = calc_scale(actions_amax, labels) if self.baseline else 0 # N
      scale_sample = None
      if self.num_samples > 0:
        probs = F.softmax(logits, dim=-1) # (N, L, S)
        probs_view_2d = probs.view(-1, logits.shape[-1]) # (N*L, S)

        actions_sample_2d = torch.multinomial(probs_view_2d, self.num_samples, replacement=True) # (N*L, R)
        actions_sample_2d_T = actions_sample_2d.T # (R, N*L)
        actions_sample = actions_sample_2d_T.reshape(-1, labels.shape[-1]) # (R*N, L)

        labels_sample = labels.expand(self.num_samples, *labels.shape).reshape(-1, labels.shape[-1]) # (R*N, L)
        scale_sample = calc_scale(actions_sample, labels_sample).view(self.num_samples,-1) # (R, N)
        # -(Q - b)
        metric_adv = scale_amax - scale_sample # (R,N)
        actions_sample_index = actions_sample.view(self.num_samples, *labels.shape,1) # (R, N, L, 1)
        log_probs_expanded = log_probs.expand(self.num_samples, *log_probs.shape) # (R, N, L, S)
        lpa = torch.gather(log_probs_expanded, index=actions_sample_index, dim=-1).squeeze(-1) # (R, N, L)
        num_active_elements *= self.num_samples
      else:
        raise ValueError("Metrics other than cross-entropy require num_samples > 0")
    else:
      raise ValueError("Metrics other than cross-entropy should be from torchmetrics library")
    lpa.masked_fill_(padding_mask, 0.0) # (R, N, L) or (N, L)
    lpa_len_1 = len(lpa.shape)-1
    lpa_permuted = lpa.permute(lpa_len_1, *range(lpa_len_1)) # (L, R, N) or (L, N)
    loss = metric_adv * lpa_permuted # (L, R, N) or (L, N)
    loss = loss.sum() / num_active_elements
    return (1 - self.label_smoothing_factor) * loss + self.label_smoothing_factor * smoothed_loss
  def predict_and_score(self, batch, labels):
    true_labels = labels["labels"]
    preds, pred_time = self.exec_time(self.predict, batch)
    scores = self.exec_time(self.score, preds, true_labels, timer_name = "score time")
    scores["predict time"] = pred_time
    return scores
  def score(self, preds, true_labels):
    return TorchDict({metric.__class__.__name__.lower(): metric(preds, true_labels) for metric in self.eval_metric})
  def predict(self, batch):
    logger.debug(f"batch 2: {batch}")
    logger.debug(f"model device 2: {self.model.device}")

    input_ids = batch.input_ids
    attention_mask = batch.attention_mask
    in_decoded = tokenizer.batch_decode(
      input_ids,
      skip_special_tokens=True,
      clean_up_tokenization_spaces=True,
    )
    prompt_lens = lens(in_decoded)
    #logger.debug(f"inputs decoded:")
    #pretty_print(in_decoded)

    generate_kwargs = dict(
        max_new_tokens=self.max_new_tokens,
        do_sample=self.do_sample,
        num_beams=self.num_beams,
        temperature=self.temperature,
        top_p=self.top_p,
        max_time = self.max_time,
        generation_config=model.generation_config,
        use_cache=True, # Use caching for faster inference
    )
    output = self.model.generate(
      input_ids=input_ids,
      attention_mask=attention_mask,
      **generate_kwargs,
    )
    out_decoded = self.tokenizer.batch_decode(
      output,
      skip_special_tokens=True,
      clean_up_tokenization_spaces=True,
    )
    #logger.debug("outputs decoded")
    #pretty_print(out_decoded)
    return [out[prompt_len+1:] for prompt_len, out in zip(prompt_lens, out_decoded)]
  def backward(self, loss):
    if (self.batch_step + 1) % self.gradient_accumulation_steps == 0:
        torch.autograd.backward(loss)
        self.optimizer.step()
        self.scheduler.step()
        self.optimizer.zero_grad()
  def train_on_batch(self):
    loss, loss_time = self.exec_time(self.loss_function, self.batch)
    loss /= self.gradient_accumulation_steps
    _, backward_time = self.exec_time(self.backward, loss)
    return TorchDict({"loss": torch.tensor(loss.item()), "loss time": loss_time, "backward time": backward_time})
  def eval_on_batch(self):
    batch, labels = self.batch
    scores = self.predict_and_score(batch, labels)
    if self.calc_loss_on_eval:
      loss, loss_time = self.exec_time(self.loss_function, batch)
      scores["loss"] = loss
      scores["loss time"] = loss_time
    if scores == {}:
      raise ValueError(f"Scores cannot be empty")
    return scores
  def process_batch(self):
    #_model_old = copy.deepcopy(self.model)
    self.batch = self.batch.to(self.model.device)
    logger.debug(f"Batch: {self.batch}")
    batch_callback = self.dataarg.batch_func
    res = batch_callback()
    #logger.debug(self.scheduler.get_last_lr())
    #compare_models(_model_old, self.model)
    return res
  def exec_time(self, func, *args, **kwargs):
    timer_name = kwargs.pop("timer_name", "time")
    start = time.time()
    res = func(*args, **kwargs)
    duration = time.time() - start
    if isinstance(res, (dict, TorchDict)):
      res[timer_name] = duration
      return res
    return res, duration
  def get_iterator(self, dataarg):
    prefix = "" if notexists(self, "epoch") else f"Epoch {self.epoch} "
    desctxt = f"{prefix}{dataarg.fullname} batches"
    return logging.tqdm(dataarg.loader, desc=desctxt)
  def init_batches(self, data_split):
    if data_split not in self.dataargs:
      raise ValueError(f"Unknown value for data_split: {data_split}")
    self.dataarg = self.dataargs[data_split]
    self.num_batches = len(self.dataarg.loader)
    self.iterator = enumerate(self.get_iterator(self.dataarg), start=1)
    getattr(self.model, self.dataarg.inference_mode)()
    logger.info(f"Activated {self.dataarg.inference_mode} mode, training mode is {self.model.training}")
  def loop_batches(self, data_split):
    self.init_batches(data_split)

    #if logger.is_verbosity(logging.DEBUG) and data_split == "val":
    #  self.device_old = self.model.device
    #  self.model.to_full_precision()
    #  self.model = self.model.to("cpu")
    #  self.model.to_empty_cache()
    mean_scores = 0
    for self.batch_step, self.batch in self.iterator:
      scores = self.process_batch()
      self.print_results(scores)
      mean_scores += scores
    self.batch_step = None

    if exists(self, "device_old"):
      self.model.to_full_precision()
      self.model = self.model.to(self.device_old)
      self.model.to_empty_cache()
      del self.device_old
    return mean_scores / self.num_batches
  def print_results(self, scores):
    prefix = "Mean"
    results_dict = {}
    if exists(self, "epoch"):
      prefix = "Batch mean"
      results_dict["epoch"] = self.epoch
    if exists(self, "batch_step"):
      prefix = "Batch step"
      results_dict["batch"] = self.batch_step
      if exists(self, "epoch"):
        global_step = (self.epoch - 1) * self.num_batches + self.batch_step
        results_dict["global"] = global_step
    results_dict.update(scores)
    results_txt = f"{prefix} {self.dataarg.fullname} {', '.join([f'{k}: {v}' for k, v in results_dict.items()])}"
    logger.info(results_txt)
    scores.remove_with(lambda k: k.lower().endswith("time"))
  def process_batch_once(self, data_split):
    if data_split not in self.dataargs:
     raise ValueError(f"Unknown value for data split: {data_split}. Data split should be one of {tuple(self.dataargs.keys())}")
    if not hasattr(self, "batch_step") or self.batch_step in (1, None):
      self.init_batches(data_split)
    scores = None
    try:
      self.batch_step, self.batch = next(self.iterator)
      scores = self.process_batch()
      self.print_results(scores)
    except StopIteration:
      self.batch_step = None
    return scores
  def train_loop(self):
    mean_train_loss = 0
    for self.epoch in logging.tqdm(range(1,self.num_epochs+1), desc=f"Training epochs"):
      batch_train_loss = self.exec_time(self.loop_batches, "train", timer_name="batch time")
      self.print_results(batch_train_loss)
      mean_train_loss += batch_train_loss
    self.epoch = None
    mean_train_loss /= self.num_epochs
    return mean_train_loss
  def train(self):
    mean_train_loss = self.exec_time(self.train_loop, timer_name="total time")
    self.print_results(mean_train_loss)
    val_scores = self.exec_time(self.loop_batches, "val", timer_name="total time")
    self.print_results(val_scores)
    clean_up_memory()
    return mean_train_loss, val_scores
  def __del__(self):
    logger.info(f"Deleting {self.__class__.__name__} object")
    self.writer.close()

In [None]:
trainer = ReinforceTrainer(
    model,
    tokenizer,
    train_dataset,
    val_dataset,
    test_dataset,
)

In [None]:
print_info_params(
    trainer.model,
    size_type_to_print=("GB", "MB"),
    print_trainable="a",
    calc_param_counts = False,
    calc_grad_counts = False,
)

In [None]:
res = trainer.train()
res

In [None]:
EPSILON = 1e-4

def objective(trial):
  train_metric_key = trial.suggest_categorical("train_metric_key", [None, 'rouge1_fmeasure', 'rouge2_fmeasure', 'rougeL_fmeasure', 'rougeLsum_fmeasure'])
  if train_metric_key == None:
    num_samples = trial.suggest_int("num_samples", 0, 4)
    if num_samples == 0:
      train_metric = None
    else:
      train_metric = BLEUScore(n_gram=4, smooth=True)
  else:
    num_samples = trial.suggest_int("num_samples", 1, 4)
    train_metric = ROUGEScore(use_stemmer=True, accumulate='avg')

  beta1 = trial.suggest_float("beta1", 0.5, 1-EPSILON, log=True)
  beta2 = trial.suggest_float("beta2", beta1, 1-EPSILON, log=True)
  params = {
    # LOSS KWARGS
    "label_smoothing_factor": trial.suggest_float("label_smoothing_factor", EPSILON, 1, log=True),#
    "train_metric": train_metric,#
    "train_metric_key": train_metric_key,#
    "num_samples": num_samples,#
    "multiple": trial.suggest_categorical("multiple", [True, False]),#
    # GEN KWARGS
    "do_sample": trial.suggest_categorical("do_sample", [True, False]),#
    "num_beams": trial.suggest_int("num_samples", 1, 4),
    "temperature": trial.suggest_float("temperature", EPSILON, 2, log=True),
    "top_p": trial.suggest_float("top_p_1", EPSILON, 1, log=True),
    # MODEL KWARGS
    "lora_r": trial.suggest_int("lora_r", 1, 64),#
    "lora_dropout": trial.suggest_float("lora_dropout", EPSILON, 1, log=True),#
    # OPTIM KWARGS
    "optim_arg": OptimArg(
      name="AdamW",
      args={
        "lr": trial.suggest_float("weight_decay", 1e-20, 1e-1, log=True), #1e-8
        "betas": (beta1, beta2), #(0.9, 0.999)
        "weight_decay": trial.suggest_float("weight_decay", 1e-4, 1e-1, log=True), #1e-2
      },
    )#,
  }

  trainer.update(**params)
  _, res = trainer.train()
  score = res["scores"]["microf1"]
  logger.info(f"{trial.number}) {params} Score: {score}")
  return score


# Hyperparameter Tuning

In [None]:
import optuna
study = optuna.create_study(direction='maximize') # sampler=TPESampler() by default and maximize f1 score

In [None]:
study.optimize(objective, n_trials=15) # about 1 hour

In [None]:
study.best_params, study.best_value

In [None]:
#%tensorboard --logdir=runs