In [None]:
from dataset.dataset import HENetdataset
from dataset.transformation import augmenter, to_tensor
from model.lightning_wraper import HENetWrapper
from model.henet import HENet

import os 
import random

import torch
from torch.utils.data import DataLoader
import pytorch_lightning as L
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
    ModelSummary,
    StochasticWeightAveraging
)
from model.configs import load_config

In [None]:
config = load_config("./configs/train_cfg.yaml")


In [None]:
dataset_path = "F:/tfont/"
img_list = os.listdir(dataset_path) 
img_paths = []
for img in img_list:
    img_paths.append(dataset_path + img)

In [None]:
# Shuffle and split data

train_ratio = config['training']['train_ratio']
test_ratio = config['training']['test_ratio']



random.shuffle(img_paths) 

total_size = len(img_paths)
train_size = int(train_ratio*total_size)
test_size  = int(test_ratio*total_size) 
valid_size =  total_size - train_size - test_size

train_data = img_paths[:train_size]
test_data  = img_paths[train_size:train_size + test_size]
valid_data = img_paths[train_size + test_size:]

In [None]:
# Torch dataset
train_dataset =  HENetdataset(img_paths  = train_data, transform=augmenter)
test_dataset  =  HENetdataset(img_paths  = test_data,  transform=to_tensor)
valid_dataset =  HENetdataset(img_paths  = valid_data, transform=to_tensor)


In [None]:
# Data loader 
batch_size = config['training']['batch_size']
pwf = False
pwt = True
train_loader =  DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers= 8,  persistent_workers= pwt)

valid_loader =  DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers= 8,  persistent_workers= pwf)

test_loader  =  DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers= 8,  persistent_workers= pwt)

In [None]:
torch.set_float32_matmul_precision("high")

In [None]:
# Define model
n_classes = config['training']['n_classes']
lr = config['training']['learning_rate']

model = HENetWrapper(model=HENet(n_classes=n_classes),
                     num_classes=n_classes,
                     learning_rate=lr
                     )
torch.cuda.empty_cache()

In [None]:
# Define trainer 

training_callbacks = [
        EarlyStopping(monitor="val_loss", mode="min"),
        StochasticWeightAveraging(swa_lrs=1e-2),
        LearningRateMonitor(logging_interval="step"),
        ModelCheckpoint(
            dirpath="./output",
            save_top_k=config['training']['k'],
            monitor="val_loss",
            filename="HENet-{epoch:02d}-{val_loss:.4f}-{val_accuracy:.4f}",
            save_last=True,
        ),
        ModelSummary(-1)
    ]




trainer = L.Trainer(max_epochs=40, callbacks=training_callbacks)


trainer.fit(
    model=model,
    train_dataloaders=train_loader,
    val_dataloaders=valid_loader,
    ckpt_path= None
)

In [None]:
trainer.test(model, test_loader) 