# basic train demo



### general import

In [None]:
# General Tools
import numpy as np
import scipy as sp
import pandas as pd

# Machine Learning
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import ParameterGrid
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import make_pipeline



# Deep Learning
import torch
import torch.nn            as nn
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchinfo
from torchmetrics.classification import MulticlassAccuracy
import torchvision
from torchvision.transforms import v2 as TorchVisionTrns

# Improve performance by benchmarking
torch.backends.cudnn.benchmark = True

# Miscellaneous
import copy
from enum import auto, Enum, unique
import math
import os
from platform import python_version
import random
import time

# Typing
from typing import Callable, Dict, Generator, List, Optional, Self, Set, Tuple, Union

# Visualization
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

# Jupyter
from IPython import get_ipython
from IPython.display import HTML, Image
from IPython.display import display
from ipywidgets import Dropdown, FloatSlider, interact, IntSlider, Layout, SelectionSlider
from ipywidgets import interact

### utils import

In [None]:
import sys
sys.path.append('../')
from DataVisualization import PlotLabelsHistogram, PlotMnistImages
from armoTrain import basicModel

### get the data

In [None]:
dsTrain, dsVal ,dsTest = ## get the data....



#### split methods


In [None]:
mXTrain, mXTest, vYTrain, vYTest = train_test_split(mX, vY, train_size = trainRatio, test_size = testRatio, random_state = seedNum)

# Split data to train and validation:
oDataSet            = torch.utils.data.TensorDataset(mX, vF)
dsTrain, dsVal  = torch.utils.data.random_split(dsData, [numSamplesTrain, numSamplesVal])


### data preprocess

In [None]:
# Calculate the Standardization Parameters
vMean = np.mean(dsTrain.data / 255.0, axis = (0, 1, 2))
vStd  = np.std(dsVal.data / 255.0, axis = (0, 1, 2))

print('µ =', vMean)
print('σ =', vStd)

### Transforms

In [None]:
## V1
oDataTrns = torchvision.transforms.Compose([ 
    torchvision.transforms.ToTensor(),        
    torchvision.transforms.Normalize(µ, σ),  
    ])
# Update the DS transformer
dsTrain.transform = oDataTrns
dsVal.transform     = oDataTrns
dsTest.transform  = oDataTrns

## V2
oDataTrns = TorchVisionTrns.Compose([
    TorchVisionTrns.ToImage(),
    TorchVisionTrns.ToDtype(torch.float32, scale = True),
    TorchVisionTrns.Normalize(mean = vMean, std = vStd),
])
# Update the DS transformer
dsTrain.transform   = oDataTrns
dsVal.transform     = oDataTrns
dsTest.transform     = oDataTrns

### data loaders

In [None]:
dlTrain = torch.utils.data.DataLoader(dsTrain, shuffle = True, batch_size = 1 * batchSize, num_workers = numWork, persistent_workers = True)
dlVal   = torch.utils.data.DataLoader(dsVal, shuffle = False, batch_size = 2 * batchSize, num_workers = numWork, persistent_workers = True)


In [None]:
# Iterate on the Loader
tX, vY = next(iter(dlTrain)) #<! PyTorch Tensors
print(f'The batch features dimensions: {tX.shape}')
print(f'The batch labels dimensions: {vY.shape}')

### layers to model

In [None]:
class CustomLAYER():
    def __init__( self ) -> None:
        self.mX = None #<! Required for the backward pass
        self.dGrads = {}
    
    def Forward( self: Self, mX: np.ndarray ) -> np.ndarray:
        self.mX = mX                 #<! Store for Backward
        mZ      = ????
        return mZ
    
    def Backward( self: Self, mDz: np.ndarray ) -> np.ndarray:
        mX    = self.mX
        mDx = ????
        return mDx

In [None]:
class CustomBlock( nn.Module ):
    def __init__( self, numChnl: int ) -> None:
        super(CustomBlock, self).__init__()
        
        self.oConv2D1       = nn.Conv2d(...
        self.oBatchNorm1    = nn.BatchNorm2d(...
        self.oReLU1         = nn.ReLU(...
        self.CustomLAYER    = CustomLAYER(...
            
    def forward( self: Self, tX: torch.Tensor ) -> torch.Tensor:
        
        tY = ...
        tY += ...
		
        return tY

In [None]:
def BuildModel( nC: int ) -> nn.Module:
    
    oModel = nn.Sequential(

        nn.Identity(),

        nn.Conv2d(...
        
        CustomBlock(...
        
    )

    return oModel

In [None]:
oModel = BuildModel(len(L_CLASSES))

### model info summary

In [None]:
torchinfo.summary(oModel, (batchSize, *(T_IMG_SIZE[::-1])), col_names = ['kernel_size', 'output_size', 'num_params'], device = 'cpu', row_settings = ['depth', 'var_names'])

### target HW

In [None]:
runDevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') #<! The 1st CUDA device
oModel = oModel.to(runDevice)

### loss

In [None]:
class ObjLocLoss( nn.Module ):
    def __init__( self, numCls: int, λ: float, ϵ: float = 0.0 ) -> None:
        super(ObjLocLoss, self).__init__()
        self.numCls     = numCls
    
    def forward( self: Self, mYHat: torch.Tensor, mY: torch.Tensor ) -> torch.Tensor:
        ## 
        return lossVal

hL = ObjLocLoss(numCls = len(L_CLASSES), λ = λ, ϵ = ϵ)
hL = hL.to(runDevice)


### score

In [None]:
class ObjLocScore( nn.Module ):
    def __init__( self, numCls: int ) -> None:
        super(ObjLocScore, self).__init__()
        self.numCls = numCls
    
    def forward( self: Self, mYHat: torch.Tensor, mY: torch.Tensor ) -> Tuple[float, float, float]:
        batchSize = mYHat.shape[0]
        ## 
        return valScore

hS = ObjLocScore(numCls = len(L_CLASSES))
hS = hS.to(runDevice)


### optimizer

In [None]:
oOpt = torch.optim.AdamW(oModel.parameters(), lr = 1e-5, betas = (0.9, 0.99), weight_decay = 1e-5) 

### scheduler

In [None]:
oSch = torch.optim.lr_scheduler.OneCycleLR(oOpt, max_lr = 5e-4, total_steps = numEpochs)


### train

In [None]:
_, lTrainLoss, lTrainScore, lValLoss, lValScore, lLearnRate = TrainModel(oModel, dlTrain, dlVal, oOpt, numEpochs, hL, hS, oSch = oSch)