In [None]:
from torch import nn
import torch
import lightning as L
from lightning.pytorch.callbacks import EarlyStopping
from timm.models.vision_transformer import VisionTransformer

from models.dinov2 import DINOLoss, DINOv2
from data_modules.seismic.image import TransformedImageDataset, SeismicImageDataModule
from transforms.dinov2_transforms import DINOTransform, DINOTransformPrime


# This function must save the weights of the pretrained model
def pretext_save_backbone_weights(pretext_model, checkpoint_filename):
    print(f"Saving backbone pretrained weights at {checkpoint_filename}")
    torch.save(pretext_model.backbone.state_dict(), checkpoint_filename)

# This function must instantiate and configure the datamodule for the pretext task
# with the best parameters found for the seismic/HAR task.
# You might change this code, but must ensure it returns a Lightning DataModule.
def build_pretext_datamodule() -> L.LightningDataModule:
    module = SeismicImageDataModule(
        'data', TransformedImageDataset,
        transform=DINOTransform(), transform_prime=DINOTransformPrime(),
        output_image_size=(255, 701)
    )
    dataloader = module.train_dataloader()
    return dataloader

# This function must instantiate and configure the pretext model
# with the best parameters found for the seismic/HAR task.
# You might change this code, but must ensure it returns a Lightning model.
def build_pretext_model() -> L.LightningModule:
    # Build the backbone
    backbone = VisionTransformer(img_size=(255, 701), patch_size=16, embed_dim=768, depth=12, num_heads=12, num_classes=0)
    # Dinov2 DOES NOT NEED A PROJECTION HEAD
    projection_head = None
    # Build the pretext model
    model = DINOv2(backbone = backbone, projection_head = projection_head, teacher_backbone = backbone, out_dim=768)
    return model

# This function must instantiate and configure the lightning trainer
# with the best parameters found for the seismic/HAR task.
# You might change this code, but must ensure you return a Lightning trainer.
def build_lightning_trainer() -> L.Trainer:
    early_stop_callback = EarlyStopping(
        monitor='train_loss',  # Monitora a perda de treino
        min_delta=0.01,  # Mínima mudança necessária para contar como melhoria
        patience=10,  # Número de épocas sem melhoria antes de parar
        verbose=True,
        mode='min'  # Minimiza a perda de validação
    )

    trainer = L.Trainer(
        max_epochs=1000,
        accelerator='gpu',
        callbacks=[early_stop_callback],
        log_every_n_steps=1
    )

    return trainer

# This function must not be changed.
def main(SSL_technique_prefix):

    # Build the pretext model, the pretext datamodule, and the trainer
    pretext_model = build_pretext_model()
    pretext_datamodule = build_pretext_datamodule()
    lightning_trainer = build_lightning_trainer()

    # Fit the pretext model using the pretext_datamodule
    lightning_trainer.fit(pretext_model, pretext_datamodule)

    # Save the backbone weights
    output_filename = f"./{SSL_technique_prefix}_pretrained_backbone_weights.pth"
    pretext_save_backbone_weights(pretext_model, output_filename)

if __name__ == "__main__":
    SSL_technique_prefix = "DINOv2"
    main(SSL_technique_prefix)
