In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from torch import nn
import torch
import lightning as L
from lightning.pytorch.callbacks import EarlyStopping
from timm.models.vision_transformer import VisionTransformer

from models.dinov2 import DINOLoss, DINOv2
from data_modules.seismic.image import TransformedImageDataset, SeismicImageDataModule
from transforms.dinov2_transforms import DINOTransform, DINOTransformPrime


# This function must save the weights of the pretrained model
def pretext_save_backbone_weights(pretext_model, checkpoint_filename):
    print(f"Saving backbone pretrained weights at {checkpoint_filename}")
    torch.save(pretext_model.backbone.state_dict(), checkpoint_filename)

# This function must instantiate and configure the datamodule for the pretext task
# with the best parameters found for the seismic/HAR task.
# You might change this code, but must ensure it returns a Lightning DataModule.
def build_pretext_datamodule() -> L.LightningDataModule:
    module = SeismicImageDataModule(
        'data', TransformedImageDataset,
        transform=DINOTransform(), transform_prime=DINOTransformPrime(),
        output_image_size=(255, 701)
    )
    return module

# This function must instantiate and configure the pretext model
# with the best parameters found for the seismic/HAR task.
# You might change this code, but must ensure it returns a Lightning model.
def build_pretext_model() -> L.LightningModule:
    # Build the backbone
    backbone = VisionTransformer(img_size=(255, 701), patch_size=16, embed_dim=768, depth=2, num_heads=2, num_classes=0)
    # Dinov2 DOES NOT NEED A PROJECTION HEAD
    projection_head = None
    # Build the pretext model
    model = DINOv2(backbone = backbone, projection_head = projection_head, teacher_backbone = backbone, out_dim=768)
    return model

# This function must instantiate and configure the lightning trainer
# with the best parameters found for the seismic/HAR task.
# You might change this code, but must ensure you return a Lightning trainer.
def build_lightning_trainer() -> L.Trainer:
    early_stop_callback = EarlyStopping(
        monitor='train_loss',  # Monitora a perda de treino
        min_delta=0.01,  # Mínima mudança necessária para contar como melhoria
        patience=1,  # Número de épocas sem melhoria antes de parar
        verbose=True,
        mode='min'  # Minimiza a perda de validação
    )

    trainer = L.Trainer(
        max_epochs=3,
        accelerator='gpu',
        callbacks=[early_stop_callback],
        log_every_n_steps=1
    )

    return trainer

# This function must not be changed.
def main(SSL_technique_prefix):

    # Build the pretext model, the pretext datamodule, and the trainer
    pretext_model = build_pretext_model()
    pretext_datamodule = build_pretext_datamodule()
    lightning_trainer = build_lightning_trainer()

    # Fit the pretext model using the pretext_datamodule
    lightning_trainer.fit(pretext_model, pretext_datamodule)

    # Save the backbone weights
    output_filename = f"./{SSL_technique_prefix}_pretrained_backbone_weights.pth"
    pretext_save_backbone_weights(pretext_model, output_filename)


SSL_technique_prefix = "DINOv2"
main(SSL_technique_prefix)

In [None]:
import torch
from torch import nn
import lightning as L
from lightning.pytorch.callbacks import EarlyStopping
from timm.models.vision_transformer import VisionTransformer

from models.segmentation_model import SegmentationHead, SegmentationModel

from data_modules.seismic import F3SeismicDataModule

# def save_downstream_model_weights(downstream_model, checkpoint_filename):
#     print(f"Saving downstream model weights at {checkpoint_filename}")
#     torch.save(downstream_model.state_dict(), checkpoint_filename)

# This function should load the backbone weights
def load_pretrained_backbone(pretrained_backbone_checkpoint_filename):
#    loaded_model = MyModel()`
#    loaded_model.model = AutoModel.from_pretrained("path/to/save/model")

    backbone = VisionTransformer(img_size=(255, 701), patch_size=16, embed_dim=768, depth=2, num_heads=2, num_classes=0)
    backbone.load_state_dict(torch.load(pretrained_backbone_checkpoint_filename))
    return backbone

# This function must instantiate and configure the datamodule for the downstream task.
# You must not change this function (Check with the professor if you need to change it).
def build_downstream_datamodule() -> L.LightningDataModule:
    return F3SeismicDataModule(root_dir="./data/", batch_size=8)

# This function must instantiate and configure the downstream model
# with the best parameters found for the seismic/HAR task.
# You might change this code, but must ensure it returns a Lightning model.
def build_downstream_model(backbone) -> L.LightningModule:

    prediction_head = SegmentationHead(in_channels=768, num_classes=6, img_size=(255, 701))

    model = SegmentationModel(backbone=backbone, prediction_head=prediction_head, num_classes=6, lr=0.001, freeze_backbone=True)
    return model

# This function must instantiate and configure the lightning trainer
# with the best parameters found for the seismic/HAR task.
# You might change this code, but must ensure you return a Lightning trainer.
def build_lightning_trainer(SSL_technique_prefix) -> L.Trainer:
    from lightning.pytorch.callbacks import ModelCheckpoint
    # Configure the ModelCheckpoint object to save the best model
    # according to validation loss
    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath=f'./',
        filename=f'{SSL_technique_prefix}-downstream-model',
        save_top_k=1,
        mode='min',
    )

    early_stop_callback = EarlyStopping(
        monitor='val_loss',  # Monitora a perda de treino
        min_delta=0.001,  # Mínima mudança necessária para contar como melhoria
        patience=3,  # Número de épocas sem melhoria antes de parar
        verbose=True,
        mode='min'  # Minimiza a perda de validação
    )

    return L.Trainer(
        accelerator="gpu",
        max_epochs=1000,
        logger=False,
        callbacks=[checkpoint_callback, early_stop_callback])

# This function must not be changed.
def main(SSL_technique_prefix):

    # Load the pretrained backbone
    pretrained_backbone_checkpoint_filename = f"./{SSL_technique_prefix}_pretrained_backbone_weights.pth"
    backbone = load_pretrained_backbone(pretrained_backbone_checkpoint_filename)

    # Build the downstream model, the downstream datamodule, and the trainer
    downstream_model = build_downstream_model(backbone)
    downstream_datamodule = build_downstream_datamodule()
    lightning_trainer = build_lightning_trainer(SSL_technique_prefix)

    # Fit the pretext model using the pretext_datamodule
    lightning_trainer.fit(downstream_model, downstream_datamodule.train_dataloader(), downstream_datamodule.val_dataloader())

    # Save the downstream model
    # output_filename = f"./{SSL_technique_prefix}_downstream_model.pth"
    # save_downstream_model_weights(downstream_model , output_filename)
    # print(f"Pretrained weights saved at: {output_filename}")

SSL_technique_prefix = "DINOv2"
main(SSL_technique_prefix)


In [None]:
import sys
sys.path.append('../../')

import torch
import lightning as L

from models.segmentation_model import SegmentationModel, SegmentationHead
from data_modules.seismic import F3SeismicDataModule

from timm.models.vision_transformer import VisionTransformer

from torchmetrics import JaccardIndex

def evaluate_model(model, dataset_dl):
    # Inicialize JaccardIndex metric
    jaccard = JaccardIndex(task="multiclass", num_classes=6)

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # For each batch, compute the predictions and compare with the labels.
    for X, y in dataset_dl:
        # Move the model, data and metric to the GPU if available
        model.to(device)
        X = X.to(device)
        y = y.to(device)
        jaccard.to(device)

        logits = model(X.float())
        predictions = torch.argmax(logits, dim=1, keepdim=True)
        jaccard(predictions, y)
    # Return a tuple with the number of correct predictions and the total number of predictions
    return (float(jaccard.compute().to("cpu")))

def report_IoU(model, dataset_dl, prefix=""):
    iou = evaluate_model(model, dataset_dl)
    print(prefix + " IoU = {:0.4f}".format(iou))

### -------------------------------------------------------------------------------

# This function must instantiate and configure the datamodule for the downstream task.
# You must not change this function (Check with the professor if you need to change it).
def build_downstream_datamodule() -> L.LightningDataModule:
    return F3SeismicDataModule(root_dir="./data/", batch_size=8)

# This function must instantiate the downstream model and load its weights
# from checkpoint_filename.
# You might change this code, but must ensure it returns a Lightning model initialized with
# Weights saved by the *_train.py script.
def load_downstream_model(checkpoint_filename) -> L.LightningModule:
    backbone = VisionTransformer(img_size=(255, 701), patch_size=16, embed_dim=768, depth=2, num_heads=2, num_classes=0)

    prediction_head = SegmentationHead(in_channels=768, num_classes=6, img_size=(255, 701))

    downstream_model = SegmentationModel.load_from_checkpoint(checkpoint_filename, backbone=backbone, prediction_head=prediction_head)
    return downstream_model

# This function must not be changed.
def main(SSL_technique_prefix):

    # Load the pretrained model
    downstream_model = load_downstream_model(f'{SSL_technique_prefix}-downstream-model.ckpt')

    # Retrieve the train, validation and test sets.
    downstream_datamodule = build_downstream_datamodule()
    train_dl = downstream_datamodule.train_dataloader()
    val_dl   = downstream_datamodule.val_dataloader()
    test_dl  = downstream_datamodule.test_dataloader()

    # Compute and report the mIoU metric for each subset
    report_IoU(downstream_model, train_dl, prefix="   Training dataset")
    report_IoU(downstream_model, val_dl,   prefix=" Validation dataset")
    report_IoU(downstream_model, test_dl,  prefix="       Test dataset")


SSL_technique_prefix = "DINOv2"
main(SSL_technique_prefix)
