In [None]:
%matplotlib inline
%load_ext tensorboard

import torch
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger


In [None]:
checkpoint_callback = ModelCheckpoint(monitor="val_loss", dirpath="logs/RoofSense", save_top_k=1, save_last=True)
early_stopping_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=10)
logger = TensorBoardLogger(save_dir="logs/RoofSense")

In [None]:
from lightning import Trainer
from torchgeo.trainers import SemanticSegmentationTask

from classification.datamodules import TrainingDataModule

datamodule = TrainingDataModule(  # Dataset Options
    root="../pretraining",  # Data Module Options
    batch_size=3, patch_size=512,num_workers=10, )
task = SemanticSegmentationTask(model="unet",
    backbone="resnet18",
    weights=False,
    in_channels=6,
    num_classes=9,
    loss="focal",
    ignore_index=0,
    lr=0.00001,
    patience=6, )

In [None]:
%tensorboard --logdir "logs/RoofSense" --host localhost --port 6001

In [None]:
datamodule.setup("fit")
len(datamodule.train_batch_sampler)

In [None]:
trainer = Trainer(callbacks=[checkpoint_callback, early_stopping_callback],
    log_every_n_steps=1,
    logger=logger,
    max_epochs=10)

trainer.fit(model=task, datamodule=datamodule)

In [None]:
trainer.test(model=task, datamodule=datamodule)

In [None]:
model = task.model.to("cpu").eval()
model

In [None]:
from tqdm import tqdm
import numpy as np

y_preds = []
y_trues = []
for batch in tqdm(datamodule.test_dataloader()):
    images = batch["image"].to("cpu")
    y_trues.append(batch["mask"].numpy())
    with torch.inference_mode():
        y_pred = model(images).argmax(dim=1).cpu().numpy()
    y_preds.append(y_pred)



y_preds = np.concatenate(y_preds)
y_trues = np.concatenate(y_trues)

In [None]:
for batch in datamodule.val_dataloader():
    print(batch["bbox"])

In [None]:


minibatch_id = 5
img = images[minibatch_id, :, :, :]
msk = batch["mask"].to("cpu")[minibatch_id, :, :, :].squeeze()
prd = model(images).argmax(dim=1).cpu()[minibatch_id, :, :]



In [None]:
y_preds

In [None]:
import matplotlib.pyplot as plt

datamodule.plot(datamodule.val_dataset
                [0])

In [None]:
plt.imshow(y_preds.squeeze())

In [None]:
from sklearn.metrics import precision_score, recall_score

precision = precision_score(y_trues, y_preds, average="micro")
recall = recall_score(y_trues, y_preds, average="micro")
f1 = 2 * (precision * recall) / (precision + recall)
print(precision, recall, f1)