In [None]:
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import matplotlib
import torch
import yaml

from model.model import get_model
from data_loader import TrainDataModule, get_all_test_dataloaders

# autoreload imported modules
%reload_ext autoreload
%matplotlib inline

In [None]:
with open('./configs/ae_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Reproducibility
pl.seed_everything(config['seed'])

In [None]:
# matplotlib.use('TkAgg')
train_data_module = TrainDataModule(
    split_dir=config['split_dir'],
    target_size=config['target_size'],
    batch_size=config['batch_size'])

# Plot some images
batch = next(iter(train_data_module.train_dataloader()))

# Print statistics
print(f"Batch shape: {batch.shape}")
print(f"Batch min: {batch.min()}")
print(f"Batch max: {batch.max()}")

fig, ax = plt.subplots(1, 5, figsize=(15, 5))
for i in range(5):
    ax[i].imshow(batch[i].squeeze(), cmap='gray')
    ax[i].axis('off')
plt.show()

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

In [None]:
# Init model
model = get_model(config)

# Use tensorboard logger and CSV logger
trainer = pl.Trainer(
    max_epochs=config['num_epochs'],
    logger=[
        pl.loggers.TensorBoardLogger(save_dir='./'),
        pl.loggers.CSVLogger(save_dir='./')
    ],
)

In [None]:
# Train the model
trainer.fit(model, datamodule=train_data_module)

In [None]:
# Reconstructions from the validation set
batch = next(iter(train_data_module.val_dataloader()))

with torch.inference_mode():
    results = model.detect_anomaly(batch)
    reconstructions = results['reconstruction']
    anomaly_map = results['anomaly_map']

# Plot images and reconstructions
fig, ax = plt.subplots(3, 5, figsize=(15, 7))
for i in range(5):
    # Plot original image
    ax[0][i].imshow(batch[i].squeeze(), cmap='gray')
    ax[0][0].set_title('Original')
    ax[0][i].axis('off')
    # Plot reconstruction image
    ax[1][i].imshow(reconstructions[i].squeeze(), cmap='gray')
    ax[1][i].set_title('Reconstruction')
    ax[1][i].axis('off')
    # Plot anomaly map
    ax[2][i].imshow(anomaly_map[i].squeeze(), cmap='plasma')
    ax[2][i].set_title('Anomaly map')
    ax[2][i].axis('off')

In [None]:
# Get test dataloaders
test_dataloaders = get_all_test_dataloaders(config['split_dir'], config['target_size'], config['batch_size'])

In [None]:
diseases = ['absent_septum', 'edema', 'enlarged_ventricles', 'mass', 'dural']
fig, ax = plt.subplots(3, len(diseases), figsize=(15, 5))
for i in range(len(diseases)):
    batch = next(iter(test_dataloaders[diseases[i]]))
    inputs, pos_labels, neg_masks = batch
    ax[0][i].imshow(inputs[i].squeeze(), cmap='gray')
    ax[0][i].axis('off')
    ax[1][i].imshow(pos_labels[i].squeeze(), cmap='gray')
    ax[1][i].axis('off')
    ax[2][i].imshow(neg_masks[i].squeeze(), cmap='gray')
    ax[2][i].axis('off')
    ax[0][i].set_title(diseases[i])

In [None]:
from evaluate import Evaluator
evaluator = Evaluator(model, model.device, test_dataloaders)
metrics, fig_metrics, fig_example = evaluator.evaluate()

In [None]:
fig_metrics['F1']

In [None]:
fig_example