# Imports

In [None]:
from lib.data import (
    H4Tokenizer, 
    LMDataset,
    verify_dataloader
)
from lib.model import (
    CausalMask,
    PadMask,
    PositionalEncoding,
    DecoderOnlyTransformer
)
from lib.utils import (
    create_optimizer,
    create_scheduler,
    plot_lr_schedule
)
from lib.trainers import (
    LMTrainer,
)
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import yaml
import gc
import torch
from torchinfo import summary
import os
import json
import tarfile
import shutil
import wandb
import yaml
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## Config

In [None]:
%%writefile config.yaml

Name                      : "Name"

###### Tokenization ------------------------------------------------------------
tokenization:
  token_type                : "char"       # [char, 1k, 5k, 10k]
  token_map :
      'char': 'lib/data/tokenizer_jsons/tokenizer_char.json'
      '1k'  : 'lib/data/tokenizer_jsons/tokenizer_1000.json'
      '5k'  : 'lib/data/tokenizer_jsons/tokenizer_5000.json'
      '10k' : 'lib/data/tokenizer_jsons/tokenizer_10000.json'

###### Dataset -----------------------------------------------------------------
data:                  
  root                 : "data/p1_data"
  train_partition      : "train" 
  val_partition        : "val"    
  test_partition       : "test"   
  subset               : 1.0     
  batch_size           : 256    
  NUM_WORKERS          : 2       

###### Network Specs -------------------------------------------------------------
model:
  d_model                   : 256
  d_ff                      : 1024
  num_layers                : 2
  num_heads                 : 2
  dropout                   : 0.0
  layer_drop_rate           : 0.0
  weight_tying              : False

###### Common Training Parameters ------------------------------------------------
training:
  use_wandb                   : False  
  wandb_run_id                : "none"  
  resume                      : False 
  epochs                      : 20
  gradient_accumulation_steps : 1
  wandb_project               : "Set-Project-Name-Here"  

###### Loss ----------------------------------------------------------------------
loss:  
  label_smoothing: 0.0

###### Optimizer -----------------------------------------------------------------
optimizer:
  name: "adam"  
  lr: 5.0e-4    

  weight_decay: 0.0001

  param_groups:
    - name: self_attn
      patterns: []   
      lr: 0.0001     
      layer_decay:
        enabled: False
        decay_rate: 0.8

    - name: ffn
      patterns: []  
      lr: 0.0001  
      layer_decay:
        enabled: False
        decay_rate: 0.8

  layer_decay:
    enabled: False
    decay_rate: 0.75

  sgd:
    momentum: 0.9
    nesterov: True
    dampening: 0

  adam:
    betas: [0.9, 0.999]
    eps: 1.0e-8
    amsgrad: False

  adamw:
    betas: [0.9, 0.999]
    eps: 1.0e-8
    amsgrad: False

###### Scheduler -----------------------------------------------------------------
scheduler:
  name: "cosine"   

  reduce_lr:
    mode: "min"  
    factor: 0.1  
    patience: 10  
    threshold: 0.0001   
    threshold_mode: "rel"  
    cooldown: 0   
    min_lr: 0.0000001  
    eps: 1.0e-8  

  cosine:
    T_max: 15  
    eta_min: 1.0e-8  
    last_epoch: -1

  cosine_warm:
    T_0: 4  
    T_mult: 4  
    eta_min: 0.0000001  
    last_epoch: -1

  warmup:
    enabled: True
    type: "exponential"  
    epochs: 5
    start_factor: 0.1
    end_factor: 1.0

In [None]:
with open('config.yaml', 'r') as file:
    config = yaml.safe_load(file)

## Tokenizer

In [None]:
Tokenizer = H4Tokenizer(
    token_map  = config['tokenization']['token_map'],
    token_type = config['tokenization']['token_type']
)

## Datasets

In [None]:
train_dataset  = LMDataset(
    partition  = config['data']['train_partition'],
    config     = config['data'],
    tokenizer  = Tokenizer
)

val_dataset    = LMDataset(
    partition  = config['data']['val_partition'],
    config     = config['data'],
    tokenizer  = Tokenizer
)

test_dataset   = LMDataset(
    partition  = config['data']['test_partition'],
    config     = config['data'],
    tokenizer  = Tokenizer
)

gc.collect()

## Dataloaders

In [None]:
train_loader    = DataLoader(
    dataset     = train_dataset,
    batch_size  = config['data']['batch_size'],
    shuffle     = True,
    num_workers = config['data']['NUM_WORKERS'] if device == 'cuda' else 0,
    pin_memory  = True,
    collate_fn  = train_dataset.collate_fn
)

val_loader      = DataLoader(
    dataset     = val_dataset,
    batch_size  = config['data']['batch_size'],
    shuffle     = False,
    num_workers = config['data']['NUM_WORKERS'] if device == 'cuda' else 0,
    pin_memory  = True,
    collate_fn  = val_dataset.collate_fn
)

test_loader     = DataLoader(
    dataset     = test_dataset,
    batch_size  = config['data']['batch_size'],
    shuffle     = False,
    num_workers = config['data']['NUM_WORKERS'] if device == 'cuda' else 0,
    pin_memory  = True,
    collate_fn  = test_dataset.collate_fn
)

## Calculate Max Transcript Length




In [None]:
max_transcript_length = max(train_dataset.text_max_len, val_dataset.text_max_len, test_dataset.text_max_len)
print("="*50)
print(f"{'Global Max Transcript Length':<30} : {max_transcript_length}")
print("="*50)

## Model

In [None]:
model_config = config['model']
model_config.update({
    'max_len': max_transcript_length,
    'num_classes': Tokenizer.vocab_size
})
model = DecoderOnlyTransformer(**model_config)

for batch in train_loader:
    shifted_transcripts, golden_transcripts, transcript_lengths = batch
    print("Shape of shifted_transcripts : ", shifted_transcripts.shape)
    print("Shape of golden_transcripts  : ", golden_transcripts.shape)
    print("Shape of transcript_lengths  : ", transcript_lengths.shape)
    break

model_stats = summary(model, input_data=[shifted_transcripts, transcript_lengths])
print(model_stats)

## Wandb

## Trainer

Every time you run the trainer, it will create a new directory in the `expts` folder with the following structure:
```
expts/
    └── {run_name}/
        ├── config.yaml
        ├── model_arch.txt
        ├── checkpoints/
        │   ├── checkpoint-best-metric-model.pth
        │   └── checkpoint-last-epoch-model.pth
        ├── attn/
        │   └── {attention visualizations}
        └── text/
            └── {generated text outputs}
```

In [None]:
trainer = LMTrainer(
    model=model,
    tokenizer=Tokenizer,
    config=config,
    run_name="test-lm",
    config_file="config.yaml",
    device=device
)

### Setup Optimizer and Scheduler


#### Setting up the optimizer

In [None]:
trainer.optimizer = create_optimizer(
    model=model,
    opt_config=config['optimizer']
)

#### Setting up the scheduler

In [None]:
trainer.scheduler = create_scheduler(
    optimizer=trainer.optimizer,
    scheduler_config=config['scheduler'],
    train_loader=train_loader,
    gradient_accumulation_steps=config['training']['gradient_accumulation_steps']
)

# Train

In [None]:
trainer.train(train_loader, val_loader, epochs=config['training']['epochs'])

# Evaluate


In [None]:
test_metrics, test_generation_results = trainer.evaluate(test_loader)
# Cleanup
trainer.cleanup()