In [None]:
!pip install uv
!uv pip install evaluate optuna

In [1]:

import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.optim import AdamW
import torch
import os
import re
from transformers import Trainer, TrainingArguments
import evaluate
from transformers import EarlyStoppingCallback
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def set_seed(seed=42):
    # 1. Python & Libraries
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)

    # 2. PyTorch Standard Seeds
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # 3. SPEED SETTINGS (Keep these to stay at 3.1 it/s)
    torch.backends.cudnn.deterministic = False # Setting to True slows you down
    torch.backends.cudnn.benchmark = True      # Setting to False slows you down

set_seed(42)

In [3]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
# 1. Get the specific GPU name
gpu_name = torch.cuda.get_device_name(0)
print(f"Detected GPU: {gpu_name}")

# 2. Logic: If it's a T4, force FP16. Otherwise, check for BF16 support.
if "T4" in gpu_name:
    print("T4 Detected: Forcing FP16 for hardware acceleration.")
    use_bf16 = False
    use_fp16 = True
     # T4 often struggles with Flash Attn 2, standard is safe
else:
    # Check for native BF16 support (A100, L4, 3090, 4090 etc.
    is_bf16_supported = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    use_bf16 = is_bf16_supported
    use_fp16 = not is_bf16_supported
    attn_impl = "flash_attention_2" if is_bf16_supported else "eager"
attn_impl = "sdpa"
print(f"Final Configuration -> BF16: {use_bf16} | FP16: {use_fp16}")
print(f"Attention Implementation: {attn_impl}")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

Detected GPU: NVIDIA GeForce RTX 4060 Laptop GPU
Final Configuration -> BF16: True | FP16: False
Attention Implementation: sdpa
cuda


In [5]:
!hf auth login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    A token is already saved on your machine. Run `hf auth whoami` to get more information or `hf auth logout` if you want to log out.
    Setting a new token will erase the existing one.
    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): Traceback (most recent call last):
  File "/home/amlan/legal/joshi/bail/.venv/bin/hf", 

In [4]:
def model_init():
    seed_value = 42
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    model = AutoModelForSequenceClassification.from_pretrained(
        "ai4bharat/indic-bert",
        num_labels=2,
        attn_implementation=attn_impl
    )
    return model


In [5]:
def is_running_on_colab():
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        return True
    except ImportError:
        return False

# Set a global flag
IN_COLAB = is_running_on_colab()



In [6]:
# class LegalDataset(Dataset):
#     def __init__(self, df, tokenizer):
#         self.labels = torch.tensor(df['label'].values, dtype=torch.long)

#         print("Batch tokenizing... (this will be much faster)")
#         # Tokenize everything at once
#         self.encodings = tokenizer(
#             df['text'].tolist(),
#             add_special_tokens=True,
#             max_length=512,
#             padding=False, #'max_length',
#             truncation=True,
#             return_tensors='pt'
#         )

#     def __len__(self):
#         return len(self.labels)

#     def __getitem__(self, idx):
#         return {
#             'input_ids': self.encodings['input_ids'][idx],
#             'attention_mask': self.encodings['attention_mask'][idx],
#             'label': self.labels[idx]
#         }
from torch.utils.data import Dataset
class LegalDataset(Dataset):
    def __init__(self, df, tokenizer):
        # Store labels as a simple list first
        self.labels = df['label'].tolist()

        print("Batch tokenizing... (this will be much faster)")
        # Tokenize everything at once
        self.encodings = tokenizer(
            df['text'].tolist(),
            add_special_tokens=True,
            max_length=512,
            padding=False,     # Dynamic padding enabled
            truncation=True,
            # return_tensors='pt'  <-- REMOVED THIS (It causes the crash!)
        )

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # Convert to tensor HERE, for just this one item
        return {
            'input_ids': torch.tensor(self.encodings['input_ids'][idx]),
            'attention_mask': torch.tensor(self.encodings['attention_mask'][idx]),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long) # Note: Trainer expects 'labels' (plural)
        }

In [7]:


# Define paths on your Drive
load_dir = '/content/drive/MyDrive/bail_prediction_datasets/' if IN_COLAB else 'pt_datasets'

paths = {
    "train": os.path.join(load_dir, "train_dataset.pt"),
    "val": os.path.join(load_dir, "val_dataset.pt"),
    "hp_train": os.path.join(load_dir, "hp_train_dataset.pt"),
    "hp_val": os.path.join(load_dir, "hp_val_dataset.pt")
}

# Helper function to load or create
def load_dataset(file_path):
    if os.path.exists(file_path):
        return torch.load(file_path, weights_only=False)
    else:
        print(f"⚠️ {file_path} not found.")


train_dataset = load_dataset(paths["train"])
val_dataset = load_dataset(paths["val"])
hp_train_dataset = load_dataset(paths["hp_train"])
hp_val_dataset = load_dataset(paths["hp_val"])

print("\nAll datasets ready!")


All datasets ready!


In [8]:
metric1 = evaluate.load("accuracy")
metric2 = evaluate.load("f1")

In [9]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = metric1.compute(predictions=predictions, references=labels)
    f1 = metric2.compute(predictions=predictions, references=labels, average="macro")
    return {'accuracy': accuracy["accuracy"], 'f1-score': f1["f1"]}

In [10]:
def my_hp_space(trial):
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
        "weight_decay":trial.suggest_float("weight_decay", 0.005, 0.05),
        "adam_beta1":trial.suggest_float("adam_beta1", 0.75, 0.95),
        "adam_beta2":trial.suggest_float("adam_beta2", 0.99, 0.9999),
        "adam_epsilon":trial.suggest_float("adam_epsilon", 1e-9, 1e-7, log=True)
    }

In [11]:
from transformers import DataCollatorWithPadding
from transformers import AlbertTokenizer

tokenizer = AlbertTokenizer.from_pretrained("ai4bharat/indic-bert")
# This does the dynamic padding specifically for each batch
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


In [12]:
cpu_count = os.cpu_count()
print(f"Detected CPUs: {cpu_count}")
training_args = TrainingArguments(
    seed=42,           # Random weights & internal shuffling
    data_seed=42,      # Specifically for data sampling
    full_determinism=False, # IMPORTANT: Keep this False for speed!
    group_by_length=True,
    
    output_dir='/content/bail_prediction_datasets/htf2_results' if IN_COLAB else 'hft2_results',          # output directory
    num_train_epochs=5,            # total number of training epochs
    per_device_train_batch_size=8,  # batch size per device during training
    per_device_eval_batch_size=8,   # batch size for evaluation

    # 2. Data Loading Workers (Prevent CPU bottlenecks)
    dataloader_num_workers=min(8, cpu_count),
    dataloader_pin_memory=True,
    dataloader_prefetch_factor=2,  # Prefetch batches for speed

    #lr_scheduler_type="cosine",
    #warmup_ratio = 0.1,
    warmup_steps=500,               # number of warmup steps for learning rate scheduler
    weight_decay=0.01,              # strength of weight decay
    logging_dir='/content/bail_prediction_datasets/htf2_logs' if IN_COLAB else 'hft2_logs',           # directory for storing logs
    eval_strategy="epoch",
    logging_steps=1000,
    save_strategy='epoch',
    save_total_limit = 1,
    learning_rate = 0.00001,
    load_best_model_at_end=True,
    metric_for_best_model ="eval_f1-score",

    optim="adamw_torch_fused",
    bf16=use_bf16,                 # Auto-enable BF16 if available
    fp16=use_fp16,

    gradient_checkpointing=False,

    # Speed optimizations
    torch_compile=True,
    torch_compile_mode="default" if "T4" not in gpu_name else "reduce-overhead",
    report_to="none", # Prevents wandb prompts if not needed
    tf32=False if gpu_name == "Tesla T4" else True
)

Detected CPUs: 16


In [None]:
trainer = Trainer(
    model_init=model_init,                        # the instantiated Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=hp_train_dataset,         # training dataset
    eval_dataset=hp_val_dataset,           # evaluation dataset
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
               
)

  trainer = Trainer(
Some weights of AlbertForSequenceClassification were not initialized from the model checkpoint at ai4bharat/indic-bert and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
import optuna
# 1. Create the study with a name and a seeded sampler
study_name = "bail_prediction_study"
study = optuna.create_study(
    study_name=study_name,
    direction="maximize",
    sampler=optuna.samplers.TPESampler(seed=42),
    load_if_exists=True  # Good practice if you rerun cells
)

[I 2025-12-19 13:59:00,980] A new study created in memory with name: bail_prediction_study


In [16]:
import logging

# 1. Standard Python logging is enough to silence the compiler/inductor
logging.getLogger("torch._inductor").setLevel(logging.ERROR)

# 2. Silence Hugging Face Transformers
from transformers import logging as hf_logging
hf_logging.set_verbosity_error()

# (Optional) If you really want to use the torch helper, the correct syntax
# to silence EVERYTHING is 'all', not 'errors':
# import torch
# torch._logging.set_logs(all=logging.ERROR)

In [17]:
best_run = trainer.hyperparameter_search(
    n_trials=10,
    direction="maximize",
    hp_space=my_hp_space,
    backend="optuna",
    study_name=study_name, # Use this instead of study=study
    load_if_exists=True
)

[I 2025-12-19 13:59:05,667] A new study created in memory with name: bail_prediction_study
Online softmax is disabled on the fly since Inductor decides to
split the reduction. Cut an issue to PyTorch if this is an
important use case and you want to speed it up with online
softmax.



Epoch,Training Loss,Validation Loss,Accuracy,F1-score
1,0.6704,0.657247,0.627894,0.385709
2,0.6426,0.700239,0.627894,0.385709
3,0.6182,0.594696,0.683794,0.604996
4,0.553,0.573003,0.721626,0.689205
5,0.538,0.563026,0.726708,0.688955


W1219 13:59:35.446000 40436 torch/fx/experimental/symbolic_shapes.py:6833] [0/2] _maybe_guard_rel() was called on non-relation expression Eq(s16, 1) | Eq(s27, s16)
Online softmax is disabled on the fly since Inductor decides to
split the reduction. Cut an issue to PyTorch if this is an
important use case and you want to speed it up with online
softmax.

W1219 14:03:16.560000 40436 torch/fx/experimental/symbolic_shapes.py:6833] [0/3] _maybe_guard_rel() was called on non-relation expression Eq(s52, s92) | Eq(s92, 1)
W1219 14:03:16.563000 40436 torch/fx/experimental/symbolic_shapes.py:6833] [0/3] _maybe_guard_rel() was called on non-relation expression Eq(s16, 1) | Eq(s27, s16)
Online softmax is disabled on the fly since Inductor decides to
split the reduction. Cut an issue to PyTorch if this is an
important use case and you want to speed it up with online
softmax.

[I 2025-12-19 14:15:39,046] Trial 0 finished with value: 1.4156626684163396 and parameters: {'learning_rate': 4.982431726668

Epoch,Training Loss,Validation Loss,Accuracy,F1-score
1,0.6677,0.658963,0.627894,0.385709
2,0.6052,0.576251,0.710333,0.642378
3,0.5781,0.539568,0.73179,0.679034
4,0.5149,0.519041,0.740824,0.719339
5,0.4926,0.516904,0.749294,0.727105


W1219 14:15:45.458000 40436 torch/fx/experimental/symbolic_shapes.py:6833] [0/5] _maybe_guard_rel() was called on non-relation expression Eq(s52, s92) | Eq(s92, 1)
W1219 14:15:45.461000 40436 torch/fx/experimental/symbolic_shapes.py:6833] [0/5] _maybe_guard_rel() was called on non-relation expression Eq(s16, 1) | Eq(s27, s16)
[I 2025-12-19 14:30:23,642] Trial 1 finished with value: 1.4763991907818295 and parameters: {'learning_rate': 1.0003708904089926e-05, 'weight_decay': 0.015641228909192135, 'adam_beta1': 0.8434180620407123, 'adam_beta2': 0.9944919294498108, 'adam_epsilon': 2.3368324020433957e-08}. Best is trial 1 with value: 1.4763991907818295.
W1219 14:30:27.402000 40436 torch/fx/experimental/symbolic_shapes.py:6833] [0/6] _maybe_guard_rel() was called on non-relation expression Eq(s52, s92) | Eq(s92, 1)
W1219 14:30:27.405000 40436 torch/fx/experimental/symbolic_shapes.py:6833] [0/6] _maybe_guard_rel() was called on non-relation expression Eq(s16, 1) | Eq(s27, s16)


Epoch,Training Loss,Validation Loss,Accuracy,F1-score
1,0.6673,0.614756,0.639187,0.438956
2,0.6089,0.598089,0.678148,0.608947
3,0.5882,0.585586,0.688312,0.614014
4,0.5574,0.578874,0.688312,0.631957
5,0.548,0.574984,0.688312,0.636432


W1219 14:30:29.524000 40436 torch/fx/experimental/symbolic_shapes.py:6833] [0/7] _maybe_guard_rel() was called on non-relation expression Eq(s52, s92) | Eq(s92, 1)
W1219 14:30:29.526000 40436 torch/fx/experimental/symbolic_shapes.py:6833] [0/7] _maybe_guard_rel() was called on non-relation expression Eq(s16, 1) | Eq(s27, s16)
[I 2025-12-19 14:45:14,681] Trial 2 finished with value: 1.3247432137051607 and parameters: {'learning_rate': 1.544722409477431e-06, 'weight_decay': 0.025510954302690732, 'adam_beta1': 0.9406818845528502, 'adam_beta2': 0.9980877515689371, 'adam_epsilon': 8.52543923480876e-09}. Best is trial 1 with value: 1.4763991907818295.
W1219 14:45:18.458000 40436 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8)
W1219 14:45:18.458000 40436 torch/_dynamo/convert_frame.py:1358] [0/8]    function: 'forward' (/home/amlan/legal/joshi/bail/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py:818)
W1219 14:45:18.458000 40436 torch/_

Epoch,Training Loss,Validation Loss,Accuracy,F1-score
1,0.6671,0.659819,0.627894,0.385709
2,0.6128,0.599616,0.669678,0.627808
3,0.5805,0.575886,0.709204,0.680598
4,0.5224,0.552944,0.732355,0.707631
5,0.5011,0.549802,0.734613,0.711327


[I 2025-12-19 15:00:23,456] Trial 3 finished with value: 1.4459398541820794 and parameters: {'learning_rate': 4.6958881292396195e-06, 'weight_decay': 0.0459072750240278, 'adam_beta1': 0.8744053981677709, 'adam_beta2': 0.9967204619518366, 'adam_epsilon': 2.4747976658992355e-09}. Best is trial 1 with value: 1.4763991907818295.


Epoch,Training Loss,Validation Loss,Accuracy,F1-score
1,0.6688,0.659765,0.627894,0.385709
2,0.6583,0.631026,0.671372,0.537426
3,0.6101,0.558102,0.737436,0.697997
4,0.5196,0.530063,0.738001,0.713386
5,0.4924,0.518626,0.752682,0.726811


[I 2025-12-19 15:15:36,686] Trial 4 finished with value: 1.47949282496702 and parameters: {'learning_rate': 2.0360367787956355e-05, 'weight_decay': 0.04941991394269555, 'adam_beta1': 0.858338223781299, 'adam_beta2': 0.9994199535980312, 'adam_epsilon': 1.0027757291209822e-09}. Best is trial 4 with value: 1.47949282496702.


Epoch,Training Loss,Validation Loss,Accuracy,F1-score
1,0.6641,0.641605,0.636364,0.610665
2,0.5932,0.594515,0.690006,0.595508


[I 2025-12-19 15:21:44,022] Trial 5 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,F1-score
1,0.668,0.634942,0.629588,0.390643
2,0.6165,0.615694,0.666855,0.653714
3,0.5975,0.548399,0.72332,0.667247
4,0.5271,0.547748,0.736307,0.700812


[I 2025-12-19 15:33:57,428] Trial 6 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,F1-score
1,0.6739,0.660586,0.627894,0.385709
2,0.6689,0.660526,0.627894,0.385709


[I 2025-12-19 15:40:10,037] Trial 7 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,F1-score
1,0.6725,0.65911,0.627894,0.385709
2,0.6671,0.660297,0.627894,0.385709


[I 2025-12-19 15:47:40,810] Trial 8 pruned. 


Epoch,Training Loss,Validation Loss,Accuracy,F1-score
1,0.6681,0.613392,0.655562,0.610733
2,0.5914,0.577403,0.701863,0.660067
3,0.5588,0.601004,0.681536,0.675729


[I 2025-12-19 15:58:56,313] Trial 9 pruned. 


In [18]:
print("Best HyperParameters")
print(best_run)

Best HyperParameters
BestRun(run_id='4', objective=1.47949282496702, hyperparameters={'learning_rate': 2.0360367787956355e-05, 'weight_decay': 0.04941991394269555, 'adam_beta1': 0.858338223781299, 'adam_beta2': 0.9994199535980312, 'adam_epsilon': 1.0027757291209822e-09}, run_summary=None)


In [19]:
del trainer
del training_args
import gc
gc.collect()
torch.cuda.empty_cache()

In [21]:
training_args = TrainingArguments(
    seed=42,           # Random weights & internal shuffling
    data_seed=42,      # Specifically for data sampling
    full_determinism=False, # IMPORTANT: Keep this False for speed!
    group_by_length=True,
    
    output_dir='/content/bail_prediction_datasets/tf2_results' if IN_COLAB else 'tf2_results',          # output directory
    
    num_train_epochs=15,            # total number of training epochs
    per_device_train_batch_size=8,  # batch size per device during training
    per_device_eval_batch_size=8,   # batch size for evaluation

    # Optimized data loading
    dataloader_num_workers=min(8, cpu_count),
    dataloader_pin_memory=True,
    dataloader_prefetch_factor=2,

    #lr_scheduler_type="cosine",
    #warmup_ratio = 0.1,
    warmup_steps=500,               # number of warmup steps for learning rate scheduler
    weight_decay=0.01,              # strength of weight decay
    logging_dir='/content/bail_prediction_datasets/tf2_logs' if IN_COLAB else 'tf2_logs',           # directory for storing logs
    eval_strategy="epoch",
    logging_steps=1000,
    save_strategy='epoch',
    save_total_limit = 1,
    learning_rate = 0.00001,
    load_best_model_at_end=True,
    metric_for_best_model ="eval_f1-score",


    optim="adamw_torch_fused",
    bf16=use_bf16,                 # Auto-enable BF16 if available
    fp16=use_fp16,
    gradient_checkpointing=False,

    # Speed optimizations
    torch_compile=True,
    torch_compile_mode="default" if "T4" not in gpu_name else "reduce-overhead",
    report_to="none", # Prevents wandb prompts if not needed
    tf32=False if gpu_name == "Tesla T4" else True

)

In [22]:
print("Starting Training...")

Starting Training...


In [23]:
trainer = Trainer(
    model_init=model_init,                        # the instantiated Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset,           # evaluation dataset
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)

  trainer = Trainer(


In [24]:
for n, v in best_run.hyperparameters.items():
    setattr(trainer.args, n, v)
print(trainer.args)

TrainingArguments(
_n_gpu=1,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},
adafactor=False,
adam_beta1=0.858338223781299,
adam_beta2=0.9994199535980312,
adam_epsilon=1.0027757291209822e-09,
auto_find_batch_size=False,
average_tokens_across_devices=True,
batch_eval_metrics=False,
bf16=True,
bf16_full_eval=False,
data_seed=42,
dataloader_drop_last=False,
dataloader_num_workers=8,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
dataloader_prefetch_factor=2,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=True,
do_eval=True,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_do_concat_batches=True,
eval_on_start=False,
eval_steps=None,
eval_strategy=epoch,
eval_use_

In [27]:
trainer.train()

{'loss': 0.6723, 'grad_norm': 20.196247100830078, 'learning_rate': 2.0316484652476957e-05, 'epoch': 0.06464959917248513}
{'loss': 0.6749, 'grad_norm': 4.625257968902588, 'learning_rate': 2.0228542497207625e-05, 'epoch': 0.12929919834497025}
{'loss': 0.6675, 'grad_norm': 0.7067912817001343, 'learning_rate': 2.0140600341938288e-05, 'epoch': 0.1939487975174554}
{'loss': 0.6645, 'grad_norm': 0.6750290989875793, 'learning_rate': 2.0052658186668952e-05, 'epoch': 0.2585983966899405}
{'loss': 0.6651, 'grad_norm': 0.4118005335330963, 'learning_rate': 1.996471603139962e-05, 'epoch': 0.32324799586242564}
{'loss': 0.663, 'grad_norm': 14.254481315612793, 'learning_rate': 1.9876773876130283e-05, 'epoch': 0.3878975950349108}
{'loss': 0.6234, 'grad_norm': 3.576551675796509, 'learning_rate': 1.978883172086095e-05, 'epoch': 0.4525471942073959}
{'loss': 0.6041, 'grad_norm': 0.504767119884491, 'learning_rate': 1.9700889565591613e-05, 'epoch': 0.517196793379881}
{'loss': 0.5885, 'grad_norm': 3.616740226745

TrainOutput(global_step=170148, training_loss=0.38773956216322975, metrics={'train_runtime': 19775.0535, 'train_samples_per_second': 93.862, 'train_steps_per_second': 11.733, 'train_loss': 0.38773956216322975, 'epoch': 11.0})

In [None]:
#trainer.train(resume_from_checkpoint=True)

In [28]:
save_dir="/content/drive/MyDrive/bail_prediction_datasets/model" if IN_COLAB else 'model'
trainer.save_model(save_dir)