In [None]:
import torch
from torch import nn
import lightning as L
from lightning.pytorch.callbacks import EarlyStopping
from timm.models.vision_transformer import VisionTransformer

from models.segmentation_model import SegmentationHead, SegmentationModel

from data_modules.seismic import F3SeismicDataModule

# def save_downstream_model_weights(downstream_model, checkpoint_filename):
#     print(f"Saving downstream model weights at {checkpoint_filename}")
#     torch.save(downstream_model.state_dict(), checkpoint_filename)

# This function should load the backbone weights
def load_pretrained_backbone(pretrained_backbone_checkpoint_filename):
#    loaded_model = MyModel()`
#    loaded_model.model = AutoModel.from_pretrained("path/to/save/model")

    backbone = VisionTransformer(img_size=(255, 701), patch_size=16, embed_dim=768, depth=12, num_heads=12, num_classes=0)
    backbone.load_state_dict(torch.load(pretrained_backbone_checkpoint_filename))
    return backbone

# This function must instantiate and configure the datamodule for the downstream task.
# You must not change this function (Check with the professor if you need to change it).
def build_downstream_datamodule() -> L.LightningDataModule:
    return F3SeismicDataModule(root_dir="../../data/", batch_size=8)

# This function must instantiate and configure the downstream model
# with the best parameters found for the seismic/HAR task.
# You might change this code, but must ensure it returns a Lightning model.
def build_downstream_model(backbone) -> L.LightningModule:

    prediction_head = SegmentationHead(in_channels=768, num_classes=6, img_size=(255, 701))

    model = SegmentationModel(backbone=backbone, prediction_head=prediction_head, num_classes=6, lr=0.001, freeze_backbone=True)
    return model

# This function must instantiate and configure the lightning trainer
# with the best parameters found for the seismic/HAR task.
# You might change this code, but must ensure you return a Lightning trainer.
def build_lightning_trainer(SSL_technique_prefix) -> L.Trainer:
    from lightning.pytorch.callbacks import ModelCheckpoint
    # Configure the ModelCheckpoint object to save the best model
    # according to validation loss
    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath=f'./',
        filename=f'{SSL_technique_prefix}-downstream-model',
        save_top_k=1,
        mode='min',
    )

    early_stop_callback = EarlyStopping(
        monitor='val_loss',  # Monitora a perda de treino
        min_delta=0.001,  # Mínima mudança necessária para contar como melhoria
        patience=10,  # Número de épocas sem melhoria antes de parar
        verbose=True,
        mode='min'  # Minimiza a perda de validação
    )

    return L.Trainer(
        accelerator="gpu",
        max_epochs=1000,
        logger=False,
        callbacks=[checkpoint_callback, early_stop_callback])

# This function must not be changed.
def main(SSL_technique_prefix):

    # Load the pretrained backbone
    pretrained_backbone_checkpoint_filename = f"./{SSL_technique_prefix}_pretrained_backbone_weights.pth"
    backbone = load_pretrained_backbone(pretrained_backbone_checkpoint_filename)

    # Build the downstream model, the downstream datamodule, and the trainer
    downstream_model = build_downstream_model(backbone)
    downstream_datamodule = build_downstream_datamodule()
    lightning_trainer = build_lightning_trainer(SSL_technique_prefix)

    # Fit the pretext model using the pretext_datamodule
    lightning_trainer.fit(downstream_model, downstream_datamodule.train_dataloader(), downstream_datamodule.val_dataloader())

    # Save the downstream model
    # output_filename = f"./{SSL_technique_prefix}_downstream_model.pth"
    # save_downstream_model_weights(downstream_model , output_filename)
    # print(f"Pretrained weights saved at: {output_filename}")

if __name__ == "__main__":
    SSL_technique_prefix = "DINOv2"
    main(SSL_technique_prefix)
