In [1]:
import os
import pickle
from itertools import chain

import numpy as np
import pytorch_lightning as pl
import sklearn.metrics as metrics
import torch
import torch.nn.functional as F
import torchmetrics
from pointcnn import PointCNN
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from sklearn.metrics import classification_report as clrp
from torch import nn
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
from utils import PointCloudsInFiles, folder_structure

In [2]:
if __name__ == "__main__":
    # Set model name
    MODEL_NAME = "PointCNN-8Class"  # change for current model

    # Create folder structure
    folder_structure()

    # Read and load datasets
    train_dataset = PointCloudsInFiles(
        "./input/train",
        "*.laz",
        "Class",
        max_points=1024,
        use_columns=["intensity"],
    )

    test_dataset = PointCloudsInFiles(
        "./input/test",
        "*.laz",
        "Class",
        max_points=1024,
        use_columns=["intensity"],
    )

    val_dataset = PointCloudsInFiles(
        "./input/val",
        "*.laz",
        "Class",
        max_points=1024,
        use_columns=["intensity"],
    )

    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0)

    # Define model checkpoint
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        save_top_k=1,
        dirpath="./lightning_logs/checkpoints",
    )

    # Define trainer
    trainer = pl.Trainer(
        gpus=1,
        enable_progress_bar=True,
        callbacks=[EarlyStopping(monitor="val_loss", patience=20)],
        enable_checkpointing=checkpoint_callback,
        max_epochs=100,
        log_every_n_steps=1,
    )

    # Define model
    model = PointCNN()

    # Fit model
    trainer.fit(model, train_loader, val_loader)

    # Get best model
    ckpts = [
        f
        for f in os.listdir("./lightning_logs/version_0/checkpoints")
        if f.endswith(".ckpt")
    ]
    if len(ckpts) != 1:
        raise ValueError("Should only be one file in directory")
    best_model = ckpts[0]
    # best_model = checkpoint_callback.best_model_path
    print("Done Learning: " + best_model)

    all_preds = []
    all_labels = []

    model = PointCNN.load_from_checkpoint(
        checkpoint_path=os.path.join(
            "./lightning_logs/version_0/checkpoints", best_model
        )
    )
    trainer.test(model, dataloaders=test_loader)
    # with open(f"./output/results/{MODEL_NAME}_res.pickle", "wb") as file:
    #     pickle.dump(res, file)

    logits = list(chain(*(p.exp().argmax(axis=1).tolist() for p in all_preds)))
    ground = list(chain(*(l.exp().arcmax(axis=1).tolist() for l in all_labels)))
    # ground = list(chain(*(tmp.y.tolist() for tmp in test_dataset)))

    classification_report = clrp(
        ground,
        logits,
        target_names=[
            "JackPine",
            "WhiteSpruce",
            "BlackSpruce",
            "BalsamFir",
            "EasternWhiteCedar",
            "AmericanLarch",
            "PaperBirch",
            "TremblingAspen",
        ],
        digits=3,
    )
    print(classification_report)
    with open(f"./output/results/{MODEL_NAME}_results.txt", "w") as file:
        file.writelines(classification_report)
        file.writelines(best_model)

    metrics.confusion_matrix(ground, logits).plot()
    plt.savefig(
        f"./output/results/{MODEL_NAME}_results.eps",
        bbox_inches="tight",
    )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type     | Params
---------------------------------------
0 | train_acc | Accuracy | 0     
1 | val_acc   | Accuracy | 0     
2 | test_acc  | Accuracy | 0     
3 | conv1     | XConv    | 8.3 K 
4 | conv2     | XConv    | 26.9 K
5 | conv3     | XConv    | 87.3 K
6 | conv4     | XConv    | 270 K 
7 | lin1      | Linear   | 98.6 K
8 | lin2      | Linear   | 32.9 K
9 | lin3      | Linear   | 1.0 K 
---------------------------------------
525 K     Trainable params
0         Non-trainable params
525 K     Total params
2.103     Total estimated model params size (MB)


Directory model cannot be created or is already created
Directory results cannot be created or is already created


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
                not been set for this class (_ResultMetric). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Done Learning: epoch=42-step=516.ckpt


  rank_zero_warn(


Testing: 0it [00:00, ?it/s]


KeyboardInterrupt

