In [1]:
import os
os.chdir('/src')
import mlflow
import torch
import pytorch_lightning as pl
from src.modules.downstream.classification import ClassificationModel
from src.data.downstream.datasets import DownStreamDataModule
from src.modules.utils import MLFlowLoggerCheckpointer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from src.modules.utils import parse_weights
from torch.nn import functional as F
from torchvision.models import resnet18
import boto3
from torchmetrics.functional import confusion_matrix
import seaborn as sns
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader
from src.data.downstream.datasets import ClassificationDataset
import matplotlib.pyplot as plt

In [2]:
classification_problem = 'multi-class'
training_scheme = 'from-scratch'
ssl_model = 'SimSiam'
strategy = 'unrestricted'
version = ''
data_dir = os.path.join('./','data','down-stream')
batch_size = 8
num_workers = 0
pin_memory = False
optimizer = 'adam'
learning_rate = 0.000001 
weight_decay = 0.01
scheduler = 'cosine'
epochs = 100
tracking_uri = 'file:///src/logs'
monitor_quantity = 'val_loss'
monitor_mode = 'min'
es_delta = 0.001
es_patience = 3
ngpus = 0
precision = 32
log_every_n = 1
ckpt_path = './epoch=61-step=17049.ckpt'

In [3]:
if classification_problem == 'binary':
    data_dir = os.path.join(data_dir,'binary')
    output_dim = 2
    print(data_dir, output_dim)

elif classification_problem =='multi-class':
    data_dir = os.path.join(data_dir,'multi-class')
    output_dim = 8
    print(data_dir,output_dim)

elif classification_problem == 'grading':
    disease = 'CNV'
    #input('please enter disease name from (CSR, MRO, GA, CNV, FMH, PMH, VMT): ')
    data_dir = os.path.join(data_dir,'grading',disease) 
    output_dim = 3
    print(data_dir,output_dim) 

./data/down-stream/multi-class 8


In [4]:
data_module = DownStreamDataModule(data_dir=data_dir,
                                   form=classification_problem,
                                   training_transforms=None,
                                   val_test_transforms=None,
                                   batch_size=batch_size,
                                   num_workers=num_workers,
                                   pin_memory=pin_memory
                                  )

[0.0, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15]


In [5]:
if training_scheme in ['linear', 'transfer-learning']:
    freeze = True
else:
    freeze = False

In [6]:
model = ClassificationModel(model=resnet18,
                            criterion=F.cross_entropy,
                            optimizer=optimizer,
                            learning_rate=learning_rate,
                            weight_decay=weight_decay,
                            scheduler=scheduler,
                            #sched_step_size=args.scheduler_step,
                            #sched_gamma=args.scheduler_gamma,
                            output_dim=output_dim,
                            freeze= freeze,
                            max_epochs=epochs
                            )

In [7]:
mlflow_logger = MLFlowLoggerCheckpointer(experiment_name=training_scheme, 
                                         tracking_uri=tracking_uri,
                                         run_name=' '.join([ssl_model, strategy, classification_problem]),
                                         tags={'training-scheme': training_scheme,
                                               'ssl-model':ssl_model,
                                               'strategy':strategy,
                                               'classification-problem': classification_problem,
                                               'Version': version,
                                              }
                                        )

In [8]:
checkpoint_callback = ModelCheckpoint(monitor=monitor_quantity, 
                                      mode= monitor_mode
                                     )

In [9]:
early_stop = EarlyStopping(monitor=monitor_quantity, 
                           min_delta=es_delta,
                           mode=monitor_mode, 
                           patience=es_patience
                          )

In [10]:
lr_logger = LearningRateMonitor(logging_interval='epoch')

In [11]:
trainer = pl.Trainer(gpus=ngpus,
                     logger=mlflow_logger, 
                     max_epochs=epochs,
                     precision=precision, 
                     log_every_n_steps=log_every_n, 
                     progress_bar_refresh_rate=1,
                     callbacks=[checkpoint_callback, lr_logger, early_stop],
                     )

  f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and"
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [197]:
model = model.load_from_checkpoint(ckpt_path)

trainer.test(model, datamodule=data_module)

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.735401451587677,
 'test_f1': 0.735401451587677,
 'test_loss': 0.8223515748977661,
 'test_prec': 0.735401451587677,
 'test_rec': 0.735401451587677,
 'test_spec': 0.9621992111206055}
--------------------------------------------------------------------------------


[{'test_loss': 0.8223515748977661,
  'test_acc': 0.735401451587677,
  'test_prec': 0.735401451587677,
  'test_rec': 0.735401451587677,
  'test_spec': 0.9621992111206055,
  'test_f1': 0.735401451587677}]

In [12]:
data_module.setup(None)


  f"DataModule.{name} has already been called, so it will not be called again. "


In [13]:
def get_predictions(model, iterator):

    model.eval()

    #images = []
    labels = []
    probs = []

    with torch.no_grad():

        for image, label in tqdm(iterator):

            image = image.cpu()

            label_pred = model.model(image)

            label_prob = F.softmax(label_pred, dim = -1)

            #images.append(image.cpu())
            labels.append(label.cpu())
            probs.append(label_prob.cpu())

    #images = torch.cat(images, dim = 0)
    labels = torch.cat(labels, dim = 0)
    probs = torch.cat(probs, dim = 0)
    pred_labels = torch.argmax(probs, 1)

    return labels, pred_labels

In [15]:
y,yhat = get_predictions(model,data_module.test_dataloader())

  0%|          | 0/1096 [00:00<?, ?it/s]

In [14]:
if classification_problem == 'binary':
    labels = ['NORMAL', 'ABNORMAL']
elif classification_problem == 'multi-class':
    labels = ['NORMAL','CNV','CSR','GA','MRO','VMT','FMH','PMH']
elif classification_problem == 'grading':
    labels = ['Mild', 'MODERATE', 'SEVERE']

In [15]:
heatmap = sns.heatmap(confusion_matrix(yhat,y,output_dim),annot=True,fmt="0000.0f",
            #xticklabels=labels,
            #yticklabels=labels,
            linecolor='black',
            linewidths=0.1,
            cmap='Greens_r',
            center=0,
           )

plt.xlabel('Actual', fontsize = 15) # x-axis label with fontsize 15
plt.ylabel('Predicted', fontsize = 15) # y-axis label with fontsize 15

NameError: name 'yhat' is not defined

In [18]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix, classification_report, f1_score, roc_curve

In [19]:
confusion_matrix(yhat,y,labels=list(range(8)))

array([[545,  17,   4,  30,  84,  30,   0,   1],
       [  2,   5,   0,   2,   1,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0],
       [ 41,  20,   6,   6, 255,  36,   2,   7],
       [  1,   0,   0,   0,   0,   1,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0]])

In [20]:
accuracy_score(yhat,y)*100

73.54014598540147

In [21]:
precision_score(yhat,y,average='weighted')*100

85.66209624219925

In [22]:
recall_score(yhat,y,average='macro')*100

  _warn_prf(average, modifier, msg_start, len(result))


30.627151653638908

In [23]:
f1_score(yhat,y,average='micro')*100

73.54014598540147

In [24]:
x = classification_report(yhat,y)
print(x)

              precision    recall  f1-score   support

           0       0.93      0.77      0.84       711
           1       0.12      0.50      0.19        10
           2       0.00      0.00      0.00         0
           3       0.00      0.00      0.00         0
           4       0.75      0.68      0.72       373
           5       0.01      0.50      0.03         2
           6       0.00      0.00      0.00         0
           7       0.00      0.00      0.00         0

    accuracy                           0.74      1096
   macro avg       0.23      0.31      0.22      1096
weighted avg       0.86      0.74      0.79      1096



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
from torchmetrics.functional import accuracy, precision, recall, specificity, f1

In [None]:
average='weighted'

In [None]:
accuracy(yhat,y, num_classes=output_dim,average = average)

In [None]:
precision(yhat,y, num_classes=output_dim,average = average)

In [None]:
recall(yhat,y, num_classes=output_dim,average = average)

In [None]:
specificity(yhat,y, num_classes=output_dim,average = average)

In [None]:
f1(yhat,y, num_classes=output_dim,average = average)