In [None]:
from pathlib import Path  # Importing the Path class from the pathlib module
import numpy as np  # Importing NumPy, a library for numerical computing in Python
import math  # Importing the math module for basic mathematical functions
from itertools import groupby  # Importing groupby function from the itertools module to group elements
import h5py  # Importing the h5py package for reading and writing HDF5 files
import unicodedata  # Importing the unicodedata module to work with Unicode characters
import cv2  # Importing the OpenCV library for computer vision tasks
import torch  # Importing the PyTorch library for tensor computations and deep learning
from torch import nn  # Importing the neural network module from PyTorch
from torch.autograd import Variable  # Importing the Variable class for automatic differentiation
from torch.utils.data import Dataset  # Importing the Dataset class for creating custom datasets
import time  # Importing the time module to measure time
import albumentations  # Importing the albumentations library for image augmentation
import albumentations.pytorch  # Importing the PyTorch integration for albumentations
import timm  # Importing the timm library for using pre-trained models and architectures
import argparse  # Importing the argparse module for handling command-line arguments
import string  # Importing the string module to work with strings and text
import json  # Importing the json module for working with JSON data
#os module which provides a way of using operating system dependent functionality like reading or writing to the file system.        
import os
#the datetime module which provides classes for working with dates and times.
import datetime
#the string module which provides a collection of string constants and helper functions.
import string


# Setting the benchmark flag to True for cuDNN to optimize performance
torch.backends.cudnn.benchmark = True
# Setting the seed for PyTorch's random number generator
torch.manual_seed(13)
# Setting the seed for NumPy's random number generator
np.random.seed(13)

def parse_args():
    # Define a function called `parse_args` to parse command-line arguments

    # Create an instance of the ArgumentParser class with a description for the 'ocr' program
    parse = argparse.ArgumentParser(description='ocr')

    # Add command-line arguments to the ArgumentParser object
    parse.add_argument('--target_path', type=str, default="saved_models/", help='target folder')
    parse.add_argument('--name_file', type=str, default="resnest_backbone", help='name of the state dict')
    parse.add_argument('--file_path', type=str, default="", help='path of the hdf5 file')
    parse.add_argument('--epochs', type=int, default=200, help='Number of total epochs')
    parse.add_argument('--batch_size', type=int, default=36, help='Size of one batch')
    parse.add_argument('--optimizer', type=int, default=0, help='load the optimizer as well')
    parse.add_argument('--lr', type=float, default=0.00006, help='Initial learning rate')
    parse.add_argument('--charset_base', type=str, default=string.printable[:95], help='path to vocab')
    parse.add_argument('--device', type=int, default=0, help='cuda device')
    parse.add_argument('--finetune', type=str, default='', help='pretrain model path')

    # Call the `parse_args()` method to parse the command-line arguments and store the results in the `args` variable
    args = parse.parse_args()

    # Return the `args` variable containing the parsed command-line arguments
    return args


class PositionalEncoding(nn.Module):
    # Define a class called 'PositionalEncoding' that inherits from PyTorch's 'nn.Module'

    def __init__(self, d_model, dropout=0.1, max_len=46):
        # Initialize the class with a constructor that takes d_model, dropout, and max_len as arguments
        super(PositionalEncoding, self).__init__()  # Call the superclass constructor
        self.dropout = nn.Dropout(p=dropout)  # Initialize a dropout layer with the given dropout probability

        # Create a positional encoding matrix with dimensions (max_len, d_model)
        pe = torch.zeros(max_len, d_model)
        # Create a position tensor of size (max_len, 1) with values from 0 to max_len-1
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # Compute the divisor term for the sine and cosine functions
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        # Calculate the sine values for even indices and store them in the positional encoding matrix
        pe[:, 0::2] = torch.sin(position * div_term)
        # Calculate the cosine values for odd indices and store them in the positional encoding matrix
        pe[:, 1::2] = torch.cos(position * div_term)
        # Reshape the positional encoding matrix and transpose its first two dimensions
        pe = pe.unsqueeze(0).transpose(0, 1)
        # Register 'pe' as a buffer so that it is not considered as a learnable parameter
        self.register_buffer("pe", pe)

    def forward(self, x):
        # Define the forward method for the PositionalEncoding class
        x = x + self.pe[: x.size(0), :]  # Add the positional encoding to the input tensor 'x'
        return self.dropout(x)  # Apply dropout and return the resulting tensor
    
"""
Defines a class OCR that inherits from nn.Module.
Initializes the constructor with several arguments: vocab_len (the length of the vocabulary), 
max_len (the maximum length of the input sequence), hidden_dim (the number of hidden dimensions in 
the transformer model), nheads (the number of heads in the transformer model), num_encoder_layers (the number
of encoder layers in the transformer model), and num_decoder_layers (the number of decoder layers in the
transformer model).
Calls the constructor of the base class nn.Module.
"""    
    
# We define an OCR OPtical Character Recognition model using transformer-based architecture for text recognition.
#The OCR class inherits from 'nn.Module', which is abase class for all neural network modules in PyTorch. 
class OCR(nn.Module):
    def __init__(
        self,
        vocab_len,
        max_len,
        hidden_dim,
        nheads,
        num_encoder_layers,
        num_decoder_layers,
    ):
        super().__init__()
        #A backbone model that serves as the feature extractor. 
        #Here, a resnest101e model from the timm package is used, which is pre-trained on ImageNet.
        self.backbone = timm.create_model("resnest101e", pretrained=True)
        
        #Deletes the fully-connected layer (fc) of the backbone model.
        del self.backbone.fc
        
        #A 2D convolutional layer (conv) 
        #that takes the feature maps from the backbone and converts them from 2048 to hidden_dim feature planes.
        self.conv = nn.Conv2d(2048, hidden_dim, 1)
        
        #A transformer (transformer) that takes the feature maps and output positional encodings (object queries) 
        #to predict the text. It has num_encoder_layers layers for the encoder and num_decoder_layers for the decoder,
        #with hidden_dim number of hidden dimensions and nheads number of heads.
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers)
        
        # A linear layer (vocab) that produces the prediction heads with a length of vocab_len.
        self.vocab = nn.Linear(hidden_dim, vocab_len)
        
        # An embedding layer (decoder) that outputs the positional encodings (object queries) with a length of vocab_len.
        self.decoder = nn.Embedding(vocab_len, hidden_dim)
        
        #Two positional encoding parameters (row_embed and col_embed) that help add positional information to the 
        #feature maps. Initializes a query_pos object of type PositionalEncoding that adds positional encodings 
        #to the decoder queries.
        self.query_pos = PositionalEncoding(hidden_dim, 0.2, max_len)
        
        #Initializes two positional encoding parameters (row_embed and col_embed) that help add positional 
        #information to the feature maps.
        self.row_embed = nn.Parameter(torch.rand(15, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(15, hidden_dim // 2))
        #A transformer target mask (trg_mask) that is initially set to None.
        self.trg_mask = None

    #TDefines a method get_feature that applies the backbone model to extract features from the input image.
    def get_feature(self, x):
        #Applies a series of convolutional, batch normalization, activation, and pooling layers 
        #from the `backbone model to the input 'x'
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.act1(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)
        return x
    
   #Defines a method generate_square_subsequent_mask that creates a mask to prevent each position 
#from attending to subsequent positions in the decoder.
    def generate_square_subsequent_mask(self, sz):
        #Initializes a mask using the torch.ones method with shape (sz, sz) and applies the 
        #torch.triu function to set all elements above the diagonal to 0.
        mask = torch.triu(torch.ones(sz, sz), 1)
        #Masks all remaining elements in the mask with a value of -inf using the torch.masked_fill method.
        mask = mask.masked_fill(mask == 1, float("-inf"))
        return mask
    
#Defines a method make_len_mask that creates a mask for the target text sequence to exclude padding tokens 
#from the attention mechanism.
    def make_len_mask(self, inp):
        #Checks whether each element of the input inp is equal to 0.
        #Transposes the resulting boolean mask such that the shape is (sequence length, batch size).
        return (inp == 0).transpose(0, 1)
    
#Defines the forward method that takes the input image and target text as inputs.
    def forward(self, inputs, trg):
        # Applies the get_feature method to the input image, inputs to obtain the feature maps.
        x = self.get_feature(inputs)
        
        #Applies the conv layer to the feature maps to convert them from 1024 to hidden_dim feature planes.
        h = self.conv(x)
        
        # Initializes the positional encodings for the feature maps using the row_embed and col_embed parameters 
        #and concatenates them with the feature maps. Adds a small constant (0.1) to the positional encodings 
        #to help stabilize training.
        bs, _, H, W = h.shape
        pos = (
            torch.cat(
                [
                    self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
                    self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
                ],
                dim=-1,
            )
            .flatten(0, 1)
            .unsqueeze(1)
        )
        h = pos + 0.1 * h.flatten(2).permute(2, 0, 1)
        
        #If the trg_mask attribute is None or the length of the target text sequence has changed, creates a 
        #new target mask using the generate_square_subsequent_mask method and sets trg_mask to the new mask.
        if self.trg_mask is None or self.trg_mask.size(0) != len(trg):
            self.trg_mask = self.generate_square_subsequent_mask(trg.shape[1]).to(
                trg.device
            )
            
        # Creates a padding mask for the target text sequence using the make_len_mask method.
        trg_pad_mask = self.make_len_mask(trg)
        
        #Applies decoder layer to target text to output positional encodings (object queries) for the decoder.
        trg = self.decoder(trg)
        
        #Applies the query_pos object to the target text positional encodings.
        trg = self.query_pos(trg.permute(1, 0, 2))
        
        #Applies the transformer to the feature maps and the positional encodings for the decoder with 
        #the target mask and padding mask.
        output = self.transformer(h, trg, tgt_mask=self.trg_mask,
                                  tgt_key_padding_mask=trg_pad_mask.permute(1,0))
        
        #Passes the resulting output through the vocab linear layer to produce the final prediction.
        #Returns the output with the batch dimension first using the transpose method.
        return self.vocab(output.transpose(0, 1))

    
    
# defines a function make_model that takes in several hyperparameters and returns an instance of the HPR class 
#with those hyperparameters. The OCR class is defined in the code above and contains the actual model architecture.     
def make_model(
    #vocab_len: an integer specifying the number of possible characters in the output sequence (i.e., the size
    #of the vocabulary)
    vocab_len,
    
    #maxlen: an integer specifying the maximum length of the output sequence
    maxlen,
    
    #hidden_dim: an integer specifying the number of hidden units in the transformer layers
    hidden_dim=256,
    
    #nheads: an integer specifying the number of attention heads in the transformer layers
    nheads=6,
    
    #num_encoder_layers: an integer specifying the number of transformer encoder layers
    num_encoder_layers=2,
    
    #num_decoder_layers: an integer specifying the number of transformer decoder layers
    num_decoder_layers=6,
):
    #The function simply calls the OCR constructor with these hyperparameters and returns the resulting model.
    return OCR(
        vocab_len, maxlen, hidden_dim, nheads, num_encoder_layers, num_decoder_layers
    )

    
    
  
# Define a Tokenizer class for managing tokens and character set properties
class Tokenizer:
    """Manager tokens functions and charset/dictionary properties"""
    def __init__(self, chars, max_text_length=630):
        # Define special tokens
        self.PAD_TK, self.UNK_TK, self.SOS, self.EOS = "¶", "¤", "SOS", "EOS"
        # Add special tokens and characters to character set
        self.chars = (
            [self.PAD_TK] + [self.UNK_TK] + [self.SOS] + [self.EOS] + list(chars)
        )
        # Set PAD and UNK tokens to corresponding indices in character set
        self.PAD = self.chars.index(self.PAD_TK)
        self.UNK = self.chars.index(self.UNK_TK)

        # Set vocabulary size to size of character set
        self.vocab_size = len(self.chars)
        # Set maximum text length
        self.maxlen = max_text_length

    def encode(self, text):
        """Encode text to vector"""
        text = str(text)
        encoded = []
        # Add start-of-sequence and end-of-sequence tokens to text
        text = ["SOS"] + list(text.strip()) + ["EOS"]
        # Convert each character in text to corresponding index in character set
        for item in text:
            index = self.chars.index(item)
            # If character is not in character set, use UNK token index
            index = self.UNK if index == -1 else index
            encoded.append(index)
        # Return encoded text as NumPy array
        return np.asarray(encoded)

    def decode(self, text):
        """Decode vector to text"""
        # Convert each index in encoded text to corresponding character in character set
        decoded = "".join([self.chars[int(x)] for x in text if x > -1])
        # Remove padding and unknown tokens from decoded text
        decoded = self.remove_tokens(decoded)
        # Return decoded text as string
        return decoded

    def remove_tokens(self, text):
        """Remove tokens (PAD) from text"""
        # Remove padding and unknown tokens from text
        return text.replace(self.PAD_TK, "").replace(self.UNK_TK, "")

    

    

#Parses the command-line arguments and returns them as an object.
args = parse_args()
print(args)

#Assigns the batch size to the variable batch_size.
batch_size = args.batch_size

#Assigns the number of epochs to the variable epochs.
epochs = args.epochs

#Opens a file containing character set data in read-only mode, reads the data from it and loads it as a JSON object.
# with open(args.charset_base, 'r') as f:
#     data = json.load(f)    

#Sets the maximum text length to 635, with the longest image text and 10 special character    
max_text_length = 635

#Sets the character set to a predefined string containing alphanumeric characters and some special characters.
charset_base = '0123456789abcdefghijklmnopqrstuvwxyz!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ '

#Sets the device to be used for computation, based on the GPU availability.
device = torch.device("cuda:{}".format(args.device))

#Creates an instance of the Tokenizer class with the specified character set and maximum text length.
tokenizer = Tokenizer(charset_base, max_text_length)








#Defines two image transformation pipelines using the Albumentations library 
#for data augmentation and normalization for training and validation sets respectively.   

#transform_train: the transformation pipeline for training data, composed of:
transform_train = albumentations.Compose(
    [
        #OneOf: selects one of the following augmentations with equal probability:
        albumentations.OneOf(
            [
                #MotionBlur: applies motion blur with a limit of 7 pixels
                albumentations.MotionBlur(p=1, blur_limit=7),
                
                #OpticalDistortion: applies random optical distortion with a limit of 0.05
                albumentations.OpticalDistortion(p=1, distort_limit=0.05),
                
                #GaussNoise: applies Gaussian noise with a variance limit between 10.0 and 100.0
                albumentations.GaussNoise(p=1, var_limit=(10.0, 100.0)),
                
                #Equalize: applies histogram equalization for adjusting image intensities to enhance contrast
                albumentations.Equalize(p=1),
                
                #Solarize: inverts all pixels above a given threshold (50 in this case)
                albumentations.Solarize(p=1, threshold=50),
                
                #RandomBrightnessContrast: applies random brightness and contrast adjustments with a limit of 0.2
                albumentations.RandomBrightnessContrast(p=1, brightness_limit=0.2),
                
                #Downscale: scales the image down with a factor between 0.8 and 0.9
                albumentations.Downscale(p=1, scale_min=0.8, scale_max=0.9),
            ],
            p=0.5,
        ),
        #Normalize: normalizes the image pixel values to have a mean of 0.5 and standard deviation of 0.5
        albumentations.Normalize(),
        
        #pytorch.ToTensorV2(): converts the image to a PyTorch tensor
        albumentations.pytorch.ToTensorV2(),
    ]
)

#transform_valid: the transformation pipeline for validation data, composed of:
transform_valid = albumentations.Compose(
    [
        #Normalize: normalizes the image pixel values to have a mean of 0.5 and standard deviation of 0.5
        albumentations.Normalize(),
        
        #pytorch.ToTensorV2(): converts the image to a PyTorch tensor
        albumentations.pytorch.ToTensorV2(),
    ]
)



#define the number of layers in the encoder and decoder parts of the model, respectively.
num_encoder_layers = 2
num_decoder_layers = 6

#an instance of the transformer model make_model() with parameters for the vocabulary length, maximum text length, 
#hidden dimension size, number of attention heads, number of encoder layers, and number of decoder layers.
ddp_model = make_model(
    vocab_len=tokenizer.vocab_size,
    maxlen=tokenizer.maxlen,
    hidden_dim=384,
    nheads=6,
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
)


# takes a batch of image and label data and pads the labels to the maximum length of the batch,
#stacks the labels and images, and returns the batch.
def collate_fn(batch):
    imgs, labels = zip(*batch)
    labels = [label for label in labels]
    imgs = [img for img in imgs]
    max_len = max(len(label) for label in labels)
    labels = [torch.nn.functional.pad(label, (0, max_len - len(label)), 'constant', 0) for label in labels]
    labels = torch.stack(labels)
    imgs = torch.stack(imgs)
    return imgs, labels

#a path to the directory containing the image and label data.
file_path = args.file_path

#a PyTorch data loader for validation data, using a DataGenerator to generate the data from the file path, 
val_loader = torch.utils.data.DataLoader(
    DataGenerator("{}".format(file_path), "valid", transform_valid, tokenizer),
    #the transform_valid transformation, and the tokenizer,
    batch_size=batch_size * 4,
    # and using the collate_fn function to collate the data into batches.
    collate_fn=collate_fn,
    num_workers=1,
)

# PyTorch data loader for training data, using a DataGenerator to generate the data from the file path, 
train_loader = torch.utils.data.DataLoader(
    
    #the transform_train transformation, and the tokenizer, 
    DataGenerator("{}".format(file_path), "train", transform_train, tokenizer),
    batch_size=batch_size,
    num_workers=1,
    
    #using the collate_fn function to collate the data into batches.
    collate_fn=collate_fn,
    shuffle=True,
)

#the ddp_model transferred to the device specified in args.device.
model = ddp_model.to(device)


#The nn.Module is a base class for all neural network modules in PyTorch.
class LabelSmoothing(nn.Module):
    "Implement label smoothing."
    
    # takes three parameters: size, padding_idx and smoothing
    def __init__(self, size, padding_idx=0, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        
        #the criterion is initialized as a nn.KLDivLoss object. 
        self.criterion = nn.KLDivLoss(size_average=False)
        
        #The padding_idx is the index of the padding token used in the tokenization process.
        self.padding_idx = padding_idx
        
        #The confidence is the value of confidence that is added to the true distribution
        self.confidence = 1.0 - smoothing
        
        #smoothing is the value that is subtracted from the padded and unknown tokens
        self.smoothing = smoothing
        
        #size represents the number of tokens in the vocabulary 
        self.size = size
        
        #true_dist is initialized as None.
        self.true_dist = None
        

    #takes two arguments x and target, which are the predicted and the target probabilities of the model respectively. 
    def forward(self, x, target):
        assert x.size(1) == self.size
        
        #The true_dist is computed using the smoothing and confidence values and is 
        #used to compute the KLDivLoss between x and true_dist.
        true_dist = x.data.clone()
        true_dist.fill_(self.smoothing / (self.size - 2))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        
        #Padding tokens are masked by setting their value to zero in true_dist
        true_dist[:, self.padding_idx] = 0
        mask = torch.nonzero(target.data == self.padding_idx)
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
        self.true_dist = true_dist
        
        #returns a criterion object which is a Kullback-Leibler Divergence loss function used to measure the distance 
        #between two probability distributions. the gradient is set to zero using requires_grad=False.
        return self.criterion(x, Variable(true_dist, requires_grad=False))
    

# Set the value for smoothing
smoothing = 0.1

# Initialize a new instance of the LabelSmoothing class with size equal to the vocabulary size, 
# padding_idx equal to 0, and smoothing equal to the value set above
criterion = LabelSmoothing(
    size=tokenizer.vocab_size, padding_idx=0, smoothing=smoothing
)

# Move the criterion to the device specified in the args
criterion.to(device)

# Set the learning rate to the value specified in the args
lr = args.lr

# Set the factor for the scheduler to be used later in training
scheduler_factor = 0.8


# Load pre-trained model from checkpoint if finetune flag is set
if args.finetune:
    checkpoint = torch.load("saved_models/{}".format(args.finetune), map_location="cpu")
    model.load_state_dict(checkpoint['model_state_dict'])

# Define optimizer with AdamW and set learning rate and weight decay
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.0004)

# Load optimizer state from checkpoint if optimizer flag is set
if args.optimizer:
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Define learning rate scheduler with StepLR and set step size and decay factor
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=scheduler_factor)


def train(model, criterion, optimizer, dataloader, scaler):
    # set the model to training mode
    model.train()
    # initialize the total loss to zero
    total_loss = 0
    # enumerate over the dataloader which contains batches of data
    for batch, (imgs, labels_y,) in enumerate(dataloader):
        # move the images and labels to the device (GPU)
        imgs = imgs.to(device).float()
        labels_y = labels_y.to(device).long()
        # zero out the gradients
        optimizer.zero_grad()
        # enable mixed-precision training context
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            # compute the model output given the input images and labels
            output = model(imgs, labels_y[:, :-1])
            # compute the loss using the criterion
            loss = criterion(
                output.log_softmax(-1).contiguous().view(-1, tokenizer.vocab_size),
                labels_y[:, 1:].contiguous().view(-1).long(),
            )
        # backpropagate the loss
        scaler.scale(loss).backward()
        # update the optimizer
        scaler.step(optimizer)
        # update the scaler
        scaler.update()
        # add the loss to the total loss
        total_loss += loss.item()
    # return the average loss across all batches
    return total_loss / len(dataloader)



# Define the 'evaluate' function which takes in three arguments:
# 1. model - the neural network model to be evaluated
# 2. criterion - the loss function used to compute the error between the model predictions and the ground truth
# 3. dataloader - an iterable object that provides batches of data and their corresponding labels
def evaluate(
    model,
    criterion,
    dataloader,):
    # Set the model to evaluation mode, which disables dropout and batch normalization layers
    model.eval()

    # Initialize a variable to keep track of the total loss across all batches in the epoch
    epoch_loss = 0

    # Initialize a variable to keep track of the character error rate (CER) across all batches in the epoch (note: this variable is not used in the current implementation)
    cer = 0

    # Begin a context manager to perform evaluation without tracking gradients for memory efficiency
    with torch.no_grad():
        # Loop through each batch of data and its corresponding labels in the dataloader
        for batch, (
            imgs,
            labels_y,
        ) in enumerate(dataloader):            
            # Move the images and labels to the current device (CPU or GPU)
            imgs = imgs.to(device)
            labels_y = labels_y.to(device)

            # Forward pass the images and labels through the model, excluding the last element of each label
            output = model(imgs.float(), labels_y.long()[:, :-1])

            # Compute the loss using the criterion by comparing the model output and the ground truth labels
            # 1. Apply a log_softmax function to the output along the last dimension
            # 2. Reshape the output tensor to be 2D with dimensions (-1, tokenizer.vocab_size)
            # 3. Remove the first element from each label sequence and reshape the tensor to be 1D with dimensions (-1)
            loss = criterion(
                output.log_softmax(-1).contiguous().view(-1, tokenizer.vocab_size),
                labels_y[:, 1:].contiguous().view(-1).long(),
            )

            # Add the current batch's loss to the epoch_loss variable
            epoch_loss += loss.item()

    # Calculate the average loss across all batches by dividing the total loss by the number of batches
    return epoch_loss / len(dataloader)




# Define a function to calculate the elapsed time in minutes and seconds
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs


# Get the target path and name file from the command-line arguments
target_path = args.target_path
name_file = args.name_file + args.file_path.replace(".hdf5","")

# Set initial values for the best CER and best validation loss
best_CER = np.inf
best_valid_loss = np.inf

# Create a GradScaler object for mixed precision training
scaler = torch.cuda.amp.GradScaler()

# Print a message to indicate the start of training and the name of the current file
print("Started training")
print(name_file)

# Initialize a counter to keep track of the number of epochs since the last improvement in validation loss
c = 0

# Loop over the specified number of epochs
for epoch in range(epochs):

    # Record the start time of the epoch
    start_time = time.time()

    # Train the model for one epoch and calculate the training loss
    train_loss = train(model, criterion, optimizer, train_loader, scaler)

    # Evaluate the model on the validation set and calculate the validation loss
    valid_loss = evaluate(model, criterion, val_loader)

    # Calculate the elapsed time for the epoch and print the current epoch number, time, and losses
    epoch_mins, epoch_secs = epoch_time(start_time, time.time())
    print(f"Epoch: {epoch+1:02}", "learning rate{}".format(lr_scheduler.get_last_lr()))
    print(f"Time: {epoch_mins}m {epoch_secs}s")
    print(f"Train Loss: {train_loss:.3f}")
    print(f"Val   Loss: {valid_loss:.3f}")

    # Increment the counter and set a flag for whether to save the model
    c += 1
    save = 1

    # If the current validation loss is better than the previous best, save the model and update the best loss
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        print("saving it", best_valid_loss)
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": valid_loss,
                "best_loss": best_valid_loss,
            },
            target_path + name_file + "best_loss.pt",
        )

        # Reset the counter and turn off the flag to save the last checkpoint
        save=0
        c = 0
        
    # If the flag to save the model is still on and we are not in fine-tuning mode, save the last checkpoint
    if save and not args.finetune:        
        print("saving for last loss")
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": valid_loss,
                "best_loss": best_valid_loss,
            },
            target_path + name_file + "last.pt",
        )

    # If the counter has reached its maximum value, decrease the learning rate and reset the counter
    if c > 4:
        lr_scheduler.step()
        c = 0

