In [None]:
# GET PARAMETERS
import os 
import sys
import pandas as pd
import torch 

# Get Parent folder : 
current_path = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_path, '..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

from constants.paths import SAVE_DIRECTORY
from examples.train_and_visu_non_recurrent import get_ds,evaluate_config,analysis_on_specific_training_mode
from high_level_DL_method import load_model,load_optimizer_and_scheduler
from examples.load_best_config import load_args_of_a_specific_trial
from trainer import Trainer


def apply_transfer_learning(model,current_path,save_folder,trial_id,add_name_id,fold_name):
    # Load trained weights:
    model_param = torch.load(f"{current_path}/{SAVE_DIRECTORY}/{save_folder}/best_models/{trial_id}{add_name_id}_f{fold_name}.pkl")

    # Dupplicate Output Weights if needed: 
    output_weight = f'core_model.output.fc2.weight'
    output_bias = f'core_model.output.fc2.bias'

    size_output_init = model_param['state_dict'][output_weight].size()
    size_output_current = model.state_dict()[output_weight].size()

    if not (size_output_current == size_output_init):
        model_param['state_dict'][output_weight] = model_param['state_dict'][output_weight].repeat(size_output_current[0],1)
        model_param['state_dict'][output_bias] = model_param['state_dict'][output_bias].repeat(size_output_current[0])
    # ...

    # Tranfer learning: 
    model.load_state_dict(model_param['state_dict'], strict=True)
    return model


def fine_tune_model(model,ds,args,reduce_lr,freeze,stations_to_plot,training_mode,epochs_fine_tune):
    # Modification: 
    args.epochs = epochs_fine_tune
    # Reduce LR: 
    if reduce_lr:
        args.lr =args.lr/5

    # Freeze weights, excepted the 'output' module : 
    if freeze:
        for name, param in model.named_parameters(): #model.core_model.named_parameters():
            param.requires_grad = False 
        for name, param in model.core_model.output.named_parameters():
            param.requires_grad = True  
        # ...


    optimizer,scheduler,loss_function = load_optimizer_and_scheduler(model,args)
    trainer = Trainer(ds,model,args,optimizer,loss_function,scheduler = scheduler)
    trainer.train_and_valid(normalizer =ds.normalizer)

    analysis_on_specific_training_mode(trainer,ds,training_mode,station = stations_to_plot)
    return trainer

## Select the config of a trained model: 

In [2]:
#Trained Model with Subway-in / Subway-out 
if True:
    save_folder = 'K_fold_validation/training_with_HP_tuning/re_validation'
    add_name_id = ''
    trial_id ='subway_in_STGCN_MSELoss_2025_01_20_14_27_20569'

if False:
    save_folder = 'K_fold_validation/training_with_HP_tuning/re_validation'
    add_name_id = 'concat_early'
    trial_id ='subway_in_subway_out_STGCN_VariableSelectionNetwork_MSELoss_2025_01_20_05_38_87836' 
    #trainer2,ds2,args2 = get_trainer_and_ds_from_saved_trial(trial_id2,add_name_id2,save_folder2,modification)

#### Change Objective function and the output-dim

In [None]:
modification = {'shuffle':True,
                'loss_function_type':'quantile',
                'alpha':0.05,
                'track_pi':True,
                'type_calib':'classic',
                #'data_augmentation':False
                }
fold_name = 'complete_dataset'

#args,_ = load_configuration(trial_id1,load_config=True)
args = load_args_of_a_specific_trial(trial_id,add_name_id,save_folder,fold_name)
fold_to_evaluate=[args.K_fold-1]
ds,args,_,_,_ =  get_ds(args_init=args,modification = modification,fold_to_evaluate=fold_to_evaluate)

#### Load weights of trained model and trasfer learning on the quantile regressor:
Il semblerait que les meilleurs résultats obtenues pour le moments soien avec:
- reduce_lr = True 
- freeze = False
Donc plutôt ne pas Freeze les couches entrainé. 

Détail des étapes pour load un modèle, transferer des poids déjà appris (et duppliquer les poids le long des nouveaux output-dim), et fine tuner le model: 
```
model = load_model(ds, args)
transfered_model = apply_transfer_learning(model,current_path,save_folder,trial_id,add_name_id,fold_name)
trainer = fine_tune_model(transfered_model,ds,args,reduce_lr,freeze,stations_to_plot,training_mode,epochs_fine_tune)
```

### Fine Tune the model :
- only on the output layers if `freeze = True`

In [4]:
epochs_fine_tune = 10

## Plot Quantificaiton of Uncertainty: 
stations_to_plot = ['CHA','PER','PAR']
training_mode = 'test'

In [None]:
freeze = False
reduce_lr = True 

model = load_model(ds, args)
transfered_model = apply_transfer_learning(model,current_path,save_folder,trial_id,add_name_id,fold_name)
trainer = fine_tune_model(transfered_model,ds,args,reduce_lr,freeze,stations_to_plot,training_mode,epochs_fine_tune)

## Comparaison des résultats avec un training à partir de 0:
#### Entrainement complet (100 epochs, plus long: )

In [None]:
modification_bis = {key:value for key,value in modification.items()}
trainer,ds,ds_no_shuffle,args = evaluate_config(args.model_name,args.dataset_names,args.dataset_for_coverage,
                                                station = stations_to_plot,
                                                modification=modification_bis,
                                                training_mode_to_visualise=[training_mode],
                                                args_init =args,
                                                fold_to_evaluate =fold_to_evaluate)

#### Autant d'epoch que pour fine-tune: 

In [None]:
modification_bis = {key:value for key,value in modification.items()}
modification_bis.update({'epochs':epochs_fine_tune})
trainer,ds,ds_no_shuffle,args = evaluate_config(args.model_name,args.dataset_names,args.dataset_for_coverage,
                                                station = stations_to_plot,
                                                modification=modification_bis,
                                                training_mode_to_visualise=[training_mode],
                                                args_init =args,
                                                fold_to_evaluate =fold_to_evaluate)

## Maintenant Fine-tune sur le modèle qui utilise 'subway-out', voir si il y a toujours des gains : 

In [None]:
# Init
save_folder = 'K_fold_validation/training_with_HP_tuning/re_validation'
add_name_id = 'concat_early'
trial_id ='subway_in_subway_out_STGCN_VariableSelectionNetwork_MSELoss_2025_01_20_05_38_87836' 
epochs_fine_tune = 10
stations_to_plot = ['CHA','PER','PAR']
training_mode = 'test'
freeze = False
reduce_lr = True 
# ...

modification = {'shuffle':True,
                'loss_function_type':'quantile',
                'alpha':0.05,
                'track_pi':True,
                'type_calib':'classic',
                #'data_augmentation':False
                }
fold_name = 'complete_dataset'

# Load config : 
args = load_args_of_a_specific_trial(trial_id,add_name_id,save_folder,fold_name)
fold_to_evaluate=[args.K_fold-1]
ds,args,_,_,_ =  get_ds(args_init=args,modification = modification,fold_to_evaluate=fold_to_evaluate)

# Load model and fine tune: 
model = load_model(ds, args)
transfered_model = apply_transfer_learning(model,current_path,save_folder,trial_id,add_name_id,fold_name)
trainer = fine_tune_model(transfered_model,ds,args,reduce_lr,freeze,stations_to_plot,training_mode,epochs_fine_tune)

## Check if the transfer learning worked well: 

In [None]:
stations_to_plot = ['CHA','PER','PAR']
training_mode = 'test'

plot_prediction(trainer,ds,stations_to_plot,training_mode)

# Load Trained Model without modificiation :

modification = {}
fold_name = 'complete_dataset'
#args,_ = load_configuration(trial_id1,load_config=True)
args_init = load_args_of_a_specific_trial(trial_id,add_name_id,save_folder,fold_name)
ds_init,args_init,_,_,_ =  get_ds(args_init=args_init,modification = modification,fold_to_evaluate=[args_init.K_fold-1])
model_init = load_model(ds_init, args_init)

transfered_model_init = apply_transfer_learning(model_init,current_path,save_folder,trial_id,add_name_id,fold_name)

trainer_init = get_trainer(ds_init,transfered_model_init,args_init)

plot_prediction(trainer_init,ds_init,stations_to_plot,training_mode)