In [None]:
from typing import Any, Optional
import torch
import pytorch_lightning as pl
from pytorch_lightning.utilities.types import STEP_OUTPUT, TRAIN_DATALOADERS
from project.model.CNN1D import simple_cnn1d
from torchmetrics.functional import accuracy
class TrainModel(pl.LightningModule):
    def __init__(self, net,lossFunction,optimizer,*args, **kwargs):
        super().__init__()
        self.Model=net
        self.LossFunction=lossFunction
        self.configure_optimizers()
    def forward(self,x, *args: Any, **kwargs: Any) -> Any:
        y=self.Model(x)
        return y
    def training_step(self, batch,batch_idx,*args: Any, **kwargs: Any) -> STEP_OUTPUT:
        #forward
        self.Optimizer.zero_grad()
        x,y=batch
        y=y.long()
        y_hat=self.Model(x.to(torch.float32))
        loss=self.LossFunction(y_hat,y)
        acc=accuracy(y_hat,y,task='multiclass',num_classes=3)

        #log
        self.log("train_acc", acc, on_step=False, on_epoch=True)
        self.log("lr", self.Optimizer.state_dict()['param_groups'][0]['lr'], on_step=True, on_epoch=True)

        return {'loss':loss}
    def configure_optimizers(self,name) :
         
        self.lr_scheduler=torch.optim.lr_scheduler.MultiStepLR(self.Optimizer,milestones=[10,20,30,200],gamma=0.5)
        return [self.Optimizer],[self.lr_scheduler]
    def validation_step(self,batch,batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT | None:
        x,y=batch
        y=y.long()
        x=x.to(torch.float32)
        y_hat=self.Model(x)
        loss=self.LossFunction(y_hat,y)
        acc=accuracy(y_hat,y,task='multiclass',num_classes=3)
        self.log("test_acc", acc, on_step=False, on_epoch=True)
        return {'loss':loss}

        