In [3]:
import importlib

### Use the following code to reload
'''
import models.model
importlib.reload(models.model) 
'''

import wandb
import torch
import gc

from models.MixerMLP import MixerMLP, initialize_weights
from models.n_model import DenseNet

from customs.focal_loss import FocalLoss

from scripts.train import train
from scripts.eval import eval
from scripts.test import test
from scripts.trainer import Trainer
from scripts.data_loader import AudioDatasetModule

from torchsummaryX import summary
import matplotlib.pyplot as plt

In [6]:
from perforatedai import pb_globals as PBG
from perforatedai import pb_models as PBM
from perforatedai import pb_utils as PBU

AttributeError: module 'torch' has no attribute 'version'

In [None]:
### PHONEME LIST
PHONEMES = [
            '[SIL]',   'AA',    'AE',    'AH',    'AO',    'AW',    'AY',
            'B',     'CH',    'D',     'DH',    'EH',    'ER',    'EY',
            'F',     'G',     'HH',    'IH',    'IY',    'JH',    'K',
            'L',     'M',     'N',     'NG',    'OW',    'OY',    'P',
            'R',     'S',     'SH',    'T',     'TH',    'UH',    'UW',
            'V',     'W',     'Y',     'Z',     'ZH',    '[SOS]', '[EOS]']

config = {
    'subset': 1.0, # Subset of dataset to use (1.0 == 100% of data)
    'context': 30,  # 30
    'activations': 'Swish',
    'learning_rate': 1e-3,
    'dropout': 0.3,
    'optimizers': 'AdamW',
    'scheduler': 'OneCycleLR',
    'epochs': 100,       # 30
    'batch_size': 2048, # 1024, 500
    'patience': 30,  
    'save_every': 1,
    'weight_decay': 0.01,
    'weight_initialization': 'xavier_normal', # e.g kaiming_normal, kaiming_uniform, uniform, xavier_normal or xavier_uniform
    'augmentations': 'FreqMask', # Options: ["FreqMask", "TimeMask", "Both", null]
    'freq_mask_param': 4, #4
    'time_mask_param': 8
 }

device = "cuda"

def clean_cache(device):
    if device == "mps":
        torch.mps.empty_cache()
    elif device == "cuda":
        torch.cuda.empty_cache()

In [None]:
dm = AudioDatasetModule(
    root="./data",
    phonemes=PHONEMES,
    train_partition="train-clean-100",
    val_partition="dev-clean",
    test_partition="test-clean",
    batch_size=config["batch_size"],
    config=config,
    num_workers=10,
    pin_memory=True
)

dm.initialize(mode="fit")
train_loader = dm.train_dataloader()
val_loader   = dm.val_dataloader()

dm.initialize(mode="test")
test_loader = dm.test_dataloader()

print("Batch size     : ", config['batch_size'])
print("Context        : ", config['context'])
print("Input size     : ", (2*config['context']+1)*28)
print("Output symbols : ", len(PHONEMES))

print("batches = {}".format(len(train_loader)))
print("batches = {}".format(len(val_loader)))
print("batches = {}".format(len(test_loader)))

In [None]:
model_n = DenseNet(arch=(4096, 2048, 1024, 1024, 750, 512), 
                   num_ouputs=round(len(PHONEMES)), 
                   dropout=(0.2, 0.15, 0.15, 0.15, 0.05, 0)).to(device)

model = model_n
model_name = "model"

In [None]:
torch.cuda.empty_cache()
inputs, _ = next(iter(train_loader))
model.apply_init(inputs.to(device), initialize_weights)

print(inputs.shape)

criterion = torch.nn.CrossEntropyLoss()
# criterion = FocalLoss(gamma=1.5)

optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config["weight_decay"])

# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=len(train_loader), T_mult=1, 
#                                                                  eta_min = 0.0001)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=round(len(train_loader) * 1.2), 
#                                                                  eta_min = 0.00005)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3 * len(train_loader), gamma=0.9)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=2e-3, 
    total_steps = 20 * len(train_loader), 
    pct_start = 0.15, 
    anneal_strategy="cos"
)

scaler = torch.amp.GradScaler('cuda', enabled=True) 

clean_cache(device)
gc.collect()

start=0
best_val_acc = 0

In [None]:
# wandb login 
wandb.login(key="......")

In [None]:
# For test
wandb.unwatch(model)
test_Trainer = Trainer(config["epochs"], criterion, optimizer, scheduler,
                       config["patience"], config["save_every"], model_name, device=device, scaler=scaler)
test_Trainer.fit(model, train_loader, val_loader, log_epoch=False, log_batch=False, save_best=False,
                 checkpoints=False)

In [None]:
# Create wandb run
run = wandb.init(
    name    = f"{model_name}_run_1", ### set run names
    reinit  = True, ### Allows reinitalizing runs when re-running this cell
    #id     = "", ### Insert specific run id here if resuming a previous run
    #resume = "must", ### need this to resume previous runs, but comment out reinit = True when using this
    project = "HW1P2", ### Project name
    group=f"{model_name}", 
    config=config
)

In [None]:
clean_cache(device)
gc.collect()
wandb.watch(model, log="all")

trainer = Trainer(config["epochs"], criterion, optimizer, scheduler, 
                       config["patience"], config["save_every"], model_name, 
                       start=start, best_val_acc=best_val_acc, 
                       device=device, scaler=scaler)
trainer.fit(model, train_loader, val_loader, save_best=True, checkpoints=True, log_freq=20)

In [None]:
"----------------------------------------------------------------------"
"----------------------------Resume Run--------------------------------"
"----------------------------------------------------------------------"

In [None]:
# Load checkpoint
checkpoint = torch.load("checkpoints/model/model_epoch_40.pth")

model.load_state_dict(checkpoint["model_state_dict"])
# optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# scheduler.load_state_dict(checkpoint["scheduler_state_dict"])

start=40
best_val_acc = 0.86

In [None]:
wandb.init(
    name    = f"{model_name}_run_1",
    reinit  = True,
    id      = '......',   ### ID for the run
    resume  = "must", 
    project = "HW1P2",
    group=f"{model_name}",
    config=config
)