In [None]:
!mkdir mpi
!tar -xf "/content/drive/MyDrive/eye_w/MPIIGaze.tar.gz" -C mpi

In [None]:
!rm -r eye

In [None]:
!mkdir eye
!git clone https://github.com/Etzelkut/Eye-Tracking.git /content/eye

In [None]:
!pip install local-attention
!pip install axial-positional-embedding
!pip install adabelief-pytorch
!pip install ranger-adabelief
!pip install pytorch-lightning
!pip install comet-ml
!pip install einops

In [None]:
from comet_ml import Experiment
from pytorch_lightning.loggers import CometLogger
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

In [None]:
from eye.basem.basic_dependency import *
from eye.gaze_track.pl_model_train import Gaze_Track_pl

from eye.gaze_track.dataset_mpi import Dataset_mpi_pl
from eye.gaze_track.pl_mpi_train import MPI_Gaze_Track_pl

In [None]:
class CheckpointEveryNSteps(pl.Callback):
    """
    Save a checkpoint every N steps, instead of Lightning's default that checkpoints
    based on validation loss.
    """

    def __init__(
        self,
        save_step_frequency,
        prefix="N-Step-Checkpoint",
        use_modelcheckpoint_filename=False,
        pathh = False,
    ):
        """
        Args:
            save_step_frequency: how often to save in steps
            prefix: add a prefix to the name, only used if
                use_modelcheckpoint_filename=False
            use_modelcheckpoint_filename: just use the ModelCheckpoint callback's
                default filename, don't use ours.
        """
        self.save_step_frequency = save_step_frequency
        self.prefix = prefix
        self.use_modelcheckpoint_filename = use_modelcheckpoint_filename
        self.pathh = pathh

    def on_batch_end(self, trainer: pl.Trainer, _):
        """ Check if we should save a checkpoint after every train batch """
        epoch = trainer.current_epoch
        global_step = trainer.global_step
        if global_step % self.save_step_frequency == 0:
            if self.use_modelcheckpoint_filename:
                filename = trainer.checkpoint_callback.filename
            else:
                filename = f"{self.prefix}_{epoch}_{global_step}.ckpt"
            ckpt_path = os.path.join(self.pathh, filename)
            print("we are here!!!")
            trainer.save_checkpoint(ckpt_path)

In [None]:
mpi_training = {
    "optimizer": "adamW", # "belief", "ranger_belief", "adam", adamW
    "lr": 3e-4, #
    "epochs": 40, #
    #
    "add_sch": False,
    #
    #belief
    "eplison_belief": 1e-16,
    "beta": [0.9, 0.999], # not used
    "weight_decouple": True, 
    "weight_decay": 1e-4,
    "rectify": True,

}

mpii_set = {
    "size": (96, 160),
    "main_path": "./mpi/MPIIGaze",
    "batch_size": 64,
    "num_workers": 2,
    "dataloader_shuffle": True,
    "epochs": mpi_training["epochs"],
    "training": mpi_training,

    "lock_main_weights": True,

    "unlock_tokens": True,

    "new_gaze_weights": True,
    "mlp_drop": 0.1, #0.05

    "updated_gaze": False,

    "loss_function": "mse", # "mse", mae

    "add_encoder_for_gaze": True,
    "add_token": False, # can not be added for now! need to change ViT_pos_emb adding the Parameter
}

In [None]:
seed_v = 42
seed_everything(seed_v)

root_dir = "/content/drive/MyDrive/eye_w/weights"
naming = "mpi_3e4_40_lockedMain_unlockToken_newGaze_addedEncoder_mse_"
from datetime import datetime


pre_trained_name_file = "/content/drive/MyDrive/eye_w/weights/trans_2_3e4_att_256_1learnparam_noNorm_land_alt_MdataN-Step-Checkpoint_29_60088.ckpt"
pretrained_model = Gaze_Track_pl.load_from_checkpoint(pre_trained_name_file)

#
comet_logger = CometLogger(
  save_dir='/content/log/',
    api_key="23CU99n7TeyZdPeegNDlQ5aHf",
    project_name="gaze2",
    workspace="etzelkut",
  # rest_api_key=os.environ["COMET_REST_KEY"], # Optional
  experiment_name = naming, # Optional
)


#
dataset_mpi = Dataset_mpi_pl(mpii_set)
dataset_mpi.prepare_data()
dataset_mpi.setup()
steps_per_epoch = int(len(dataset_mpi.train_dataloader()))

#
every_epoch = CheckpointEveryNSteps(save_step_frequency = steps_per_epoch, 
                                    use_modelcheckpoint_filename = False, 
                                    pathh=root_dir, prefix=naming+"Step_")
#



mpi_model = MPI_Gaze_Track_pl(mpii_set, model = pretrained_model, steps_per_epoch = steps_per_epoch)



trainer = Trainer(callbacks=[every_epoch],
                  logger=comet_logger,
                  gpus=1,
                  profiler='simple',
                  #auto_lr_find=True, #set hparams
                  #gradient_clip_val=0.5,
                  check_val_every_n_epoch=1,
                  #early_stop_callback=True,
                  max_epochs = mpii_set["epochs"],
                  progress_bar_refresh_rate = 0,
                  deterministic=True,)

#~
trainer.fit(mpi_model, dataset_mpi)
trainer.test(mpi_model, dataset_mpi)


#if re_dict["training"]["swa"]:
#  proj_a.change_for_swa(dataset_pl.train_dataloader())
#  trainer.test() # will it work?
#.

checkpoint_name = os.path.join(root_dir, naming + '.ckpt')
trainer.save_checkpoint(checkpoint_name)
#copied with time for safety, not memory effec. but why bother I have infinite google drive
checkpoint_name = os.path.join(root_dir, naming + str(datetime.now()) + '.ckpt')
trainer.save_checkpoint(checkpoint_name)

In [None]:
checkpoint_name = os.path.join(root_dir, naming + '.ckpt')
trainer.save_checkpoint(checkpoint_name)

In [None]:
seed_v = 42
seed_everything(seed_v)

root_dir = "/content/drive/MyDrive/eye_w/weights"
naming = "mpi_3e4_40_lockedMain_unlockToken_newGaze_addedEncoder_mse_"
from datetime import datetime


pre_trained_name_file = "/content/drive/MyDrive/eye_w/weights/trans_2_3e4_att_256_1learnparam_noNorm_land_alt_MdataN-Step-Checkpoint_29_60088.ckpt"
pretrained_model = Gaze_Track_pl.load_from_checkpoint(pre_trained_name_file)

#
comet_logger = CometLogger(
  save_dir='/content/log/',
    api_key="23CU99n7TeyZdPeegNDlQ5aHf",
    project_name="gaze2",
    workspace="etzelkut",
  # rest_api_key=os.environ["COMET_REST_KEY"], # Optional
  experiment_name = naming, # Optional
)


load_name = "/content/drive/MyDrive/eye_w/weights/mpi_3e4_40_lockedMain_unlockToken_newGaze_addedEncoder_mse_Step__19_102163.ckpt"

mpi_model = MPI_Gaze_Track_pl.load_from_checkpoint(load_name, model = pretrained_model)

mpi_model.hparams.pop('model', None)

#
dataset_mpi = Dataset_mpi_pl(mpi_model.hparams)
dataset_mpi.prepare_data()
dataset_mpi.setup()
steps_per_epoch = int(len(dataset_mpi.train_dataloader()))

#
every_epoch = CheckpointEveryNSteps(save_step_frequency = steps_per_epoch, 
                                    use_modelcheckpoint_filename = False, 
                                    pathh=root_dir, prefix=naming+"Step_")
#


trainer = Trainer(callbacks=[every_epoch],
                  logger=comet_logger,
                  gpus=1,
                  profiler='simple',
                  #auto_lr_find=True, #set hparams
                  #gradient_clip_val=0.5,
                  check_val_every_n_epoch=1,
                  #early_stop_callback=True,
                  max_epochs = mpi_model.hparams["training"]["epochs"],
                  progress_bar_refresh_rate = 0,
                  deterministic=True,
                  resume_from_checkpoint=load_name)


#~
trainer.fit(mpi_model, dataset_mpi)
trainer.test(mpi_model, dataset_mpi)


#if re_dict["training"]["swa"]:
#  proj_a.change_for_swa(dataset_pl.train_dataloader())
#  trainer.test() # will it work?
#.

checkpoint_name = os.path.join(root_dir, naming + '.ckpt')
trainer.save_checkpoint(checkpoint_name)
#copied with time for safety, not memory effec. but why bother I have infinite google drive
checkpoint_name = os.path.join(root_dir, naming + str(datetime.now()) + '.ckpt')
trainer.save_checkpoint(checkpoint_name)