In [6]:
import os
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import WandbLogger
import argparse
import warnings
from datasets.data_interface import DataInterface
from dsmil import DSMIL
from train_dsmil import get_args
from easydict import EasyDict

args = EasyDict({
    "comment": "",
    "seed": 2021,
    "fp16": True,
    "amp_level": "O2",
    "precision": 16,
    "multi_gpu_mode": "dp",
    "gpus": [0],
    "epochs": 200,
    "grad_acc": 2,
    "frozen_bn": False,
    "patience": 10,
    "server": "test",
    "log_path": "logs_v2/",
    "dataset_name": "camel_data",
    "data_shuffle": False,
    "data_dir": "/mnt/nvme0n1/ICCV/ds-mil/datasets/cam-16/",
    "label_dir": "/mnt/nvme0n1/ICCV/camelyon_5_fold/",
    "fold": 0,
    "nfold": 4,
    "train_batch_size": 1,
    "train_num_workers": 8,
    "test_batch_size": 1,
    "test_num_workers": 8,
    "model_name": "TransMIL",
    "n_classes": 2,
    "opt": "lookahead_radam",
    "lr": 0.0002,
    "opt_eps": None,
    "opt_betas": None,
    "momentum": None,
    "weight_decay": 0.00001,
    "base_loss": "CrossEntropyLoss",
    "save_dir": "logs",
    "num_gpus": 1,
    "log_dir": "./logs/",
    "pretrained_model_path": "logs/DMSO/epoch=0-step=0.ckpt",
    "project_name": "DSMIL_camel",
    "logger": "wandb",
    "max_epochs": 200
})



In [7]:
DataInterface_dict = {
        "train_batch_size": 1,
        "train_num_workers": 8,
        "test_batch_size": 1,
        "test_num_workers": 8,
        "dataset_name": 'camel_data',
        "dataset_cfg": args,
    }
dm = DataInterface(**DataInterface_dict)
dm.setup()

In [11]:
model = DSMIL(
        num_classes=2,
        criterion=nn.BCEWithLogitsLoss(),
        i_class="i_class",
        log_dir=args.log_dir,
    )

In [12]:
model

DSMIL(
  (criterion): BCEWithLogitsLoss()
  (model): MILNet(
    (i_classifier): IClassifier(
      (fc): Linear(in_features=512, out_features=1, bias=True)
    )
    (b_classifier): BClassifier(
      (q): Sequential(
        (0): Linear(in_features=512, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=128, bias=True)
        (3): Tanh()
      )
      (v): Identity()
      (fcc): Conv1d(1, 1, kernel_size=(512,), stride=(1,))
    )
  )
  (acc): BinaryAccuracy()
  (auc): BinaryAUROC()
  (F1): BinaryF1Score()
  (precision_metric): BinaryPrecision()
  (recall): BinaryRecall()
  (train_metrics): MetricCollection(
    (BinaryAccuracy): BinaryAccuracy()
    (BinaryF1Score): BinaryF1Score()
    (BinaryPrecision): BinaryPrecision()
    (BinaryRecall): BinaryRecall(),
    prefix=train_
  )
  (valid_metrics): MetricCollection(
    (BinaryAccuracy): BinaryAccuracy()
    (BinaryF1Score): BinaryF1Score()
    (BinaryPrecision): BinaryPrecision()
    

In [15]:
model = model.load_from_checkpoint("/home/mvries/Documents/GitHub/dsmil-devries/logs/DSMIL_camel/fold_0/DSMIL_camel/3vrrselm/checkpoints/epoch=114-step=24840.ckpt").cuda()

In [16]:
model.eval()

DSMIL(
  (criterion): CrossEntropyLoss()
  (model): MILNet(
    (i_classifier): IClassifier(
      (fc): Linear(in_features=512, out_features=1, bias=True)
    )
    (b_classifier): BClassifier(
      (q): Sequential(
        (0): Linear(in_features=512, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=128, bias=True)
        (3): Tanh()
      )
      (v): Identity()
      (fcc): Conv1d(1, 1, kernel_size=(512,), stride=(1,))
    )
  )
  (acc): BinaryAccuracy()
  (auc): BinaryAUROC()
  (F1): BinaryF1Score()
  (precision_metric): BinaryPrecision()
  (recall): BinaryRecall()
  (train_metrics): MetricCollection(
    (BinaryAccuracy): BinaryAccuracy()
    (BinaryF1Score): BinaryF1Score()
    (BinaryPrecision): BinaryPrecision()
    (BinaryRecall): BinaryRecall(),
    prefix=train_
  )
  (valid_metrics): MetricCollection(
    (BinaryAccuracy): BinaryAccuracy()
    (BinaryF1Score): BinaryF1Score()
    (BinaryPrecision): BinaryPrecision()
    (

In [38]:
import torch
i = 0
count_true = 0
for data in dm.val_dataloader():
    i += 1
    output = model(torch.squeeze(data[0]).float().cuda())
    if (torch.sigmoid(output[1]) > 0.5) == data[1].cuda():
        count_true+=1

In [36]:
count_true/i

0.9027777777777778

In [37]:
(torch.sigmoid(output[1]) > 0.5) == data[1].cuda()

tensor([[True]], device='cuda:0')

In [10]:
for data in dm.train_dataloader():
    print(data)

[tensor([[[1.2608, 0.5533, 0.5364,  ..., 0.5037, 0.4116, 0.9201],
         [1.2487, 0.5578, 0.5369,  ..., 0.4875, 0.4907, 0.9150],
         [1.0561, 0.7574, 0.4928,  ..., 0.6080, 0.6566, 0.7421],
         ...,
         [1.2819, 0.5539, 0.5520,  ..., 0.4908, 0.4418, 0.9748],
         [1.2670, 0.5139, 0.5128,  ..., 0.4098, 0.4546, 1.0091],
         [1.1942, 0.6413, 0.5155,  ..., 0.5930, 0.5632, 0.9028]]],
       dtype=torch.float64), tensor([1])]
[tensor([[[1.2760, 0.5455, 0.5218,  ..., 0.4877, 0.4571, 0.9107],
         [1.2244, 0.6283, 0.5444,  ..., 0.6075, 0.4660, 0.7879],
         [1.2691, 0.4996, 0.5831,  ..., 0.4467, 0.4329, 1.0279],
         ...,
         [1.2287, 0.5358, 0.5477,  ..., 0.4940, 0.5026, 0.9783],
         [1.1484, 0.6883, 0.5325,  ..., 0.7620, 0.5725, 0.7528],
         [1.2581, 0.5309, 0.6097,  ..., 0.4774, 0.4451, 0.9975]]],
       dtype=torch.float64), tensor([1])]
[tensor([[[1.2689, 0.4892, 0.5486,  ..., 0.4072, 0.4136, 1.0272],
         [1.2832, 0.4547, 0.5630,  .

[tensor([[[1.2792, 0.5046, 0.5523,  ..., 0.4363, 0.4358, 1.0525],
         [1.2597, 0.5396, 0.5618,  ..., 0.5707, 0.4498, 0.9319],
         [0.9732, 0.9971, 0.5797,  ..., 0.9164, 0.8920, 0.6026],
         ...,
         [1.2060, 0.6178, 0.5221,  ..., 0.6533, 0.5442, 0.8331],
         [1.0819, 0.7465, 0.5305,  ..., 0.8526, 0.7215, 0.6567],
         [0.4301, 1.2841, 0.5817,  ..., 1.1744, 1.0618, 0.3762]]],
       dtype=torch.float64), tensor([1])]
[tensor([[[1.2460, 0.6010, 0.5098,  ..., 0.4947, 0.4672, 0.9676],
         [1.2863, 0.4604, 0.5206,  ..., 0.4063, 0.4214, 1.0977],
         [1.2893, 0.5180, 0.5273,  ..., 0.4416, 0.4605, 1.0402],
         ...,
         [1.0343, 1.0309, 0.6680,  ..., 0.9484, 0.9066, 0.6083],
         [1.1810, 0.6826, 0.5553,  ..., 0.6787, 0.5875, 0.7789],
         [1.2360, 0.5555, 0.5344,  ..., 0.4955, 0.4341, 0.9581]]],
       dtype=torch.float64), tensor([0])]
[tensor([[[1.2182, 0.5660, 0.5436,  ..., 0.5402, 0.5111, 0.8498],
         [1.2488, 0.5780, 0.5230,  .

[tensor([[[1.2241, 0.5527, 0.5189,  ..., 0.4685, 0.5671, 0.9778],
         [1.2752, 0.5172, 0.5039,  ..., 0.4554, 0.4220, 0.9536],
         [1.2593, 0.5376, 0.4923,  ..., 0.4530, 0.4441, 1.0178],
         ...,
         [1.2701, 0.5613, 0.5680,  ..., 0.5151, 0.4279, 0.9281],
         [1.2460, 0.5322, 0.5248,  ..., 0.4810, 0.4944, 1.0378],
         [1.2877, 0.5701, 0.4934,  ..., 0.4720, 0.4262, 0.9730]]],
       dtype=torch.float64), tensor([1])]
[tensor([[[1.2003, 0.6775, 0.5375,  ..., 0.7249, 0.6654, 0.8334],
         [1.2755, 0.5317, 0.5404,  ..., 0.4510, 0.4231, 0.9228],
         [1.2254, 0.6126, 0.5260,  ..., 0.5518, 0.5089, 0.9465],
         ...,
         [1.2943, 0.4508, 0.5229,  ..., 0.3869, 0.4181, 1.0490],
         [1.2684, 0.5517, 0.5072,  ..., 0.4798, 0.4295, 0.9217],
         [1.2653, 0.5448, 0.5603,  ..., 0.4971, 0.4787, 0.9515]]],
       dtype=torch.float64), tensor([1])]
[tensor([[[1.2299, 0.5120, 0.4811,  ..., 0.4319, 0.4966, 1.0387],
         [1.1937, 0.6441, 0.5691,  .

[tensor([[[1.1084, 1.0934, 0.7239,  ..., 1.0043, 0.9127, 0.6638],
         [1.1334, 1.0354, 0.7784,  ..., 0.9951, 0.8987, 0.6855],
         [1.1136, 1.0731, 0.7177,  ..., 0.9665, 0.9508, 0.6031],
         ...,
         [1.1901, 0.9838, 0.7269,  ..., 0.9985, 0.8442, 0.7114],
         [1.1641, 1.0333, 0.7623,  ..., 0.9845, 0.9507, 0.6695],
         [1.0834, 0.9092, 0.6218,  ..., 0.9456, 0.7904, 0.6752]]],
       dtype=torch.float64), tensor([1])]
[tensor([[[1.2696, 0.5526, 0.5673,  ..., 0.4823, 0.4685, 1.0265],
         [1.1798, 0.6649, 0.5355,  ..., 0.8640, 0.6506, 0.7505],
         [1.2655, 0.5194, 0.5639,  ..., 0.4679, 0.4450, 0.9775],
         ...,
         [1.2245, 0.5673, 0.5782,  ..., 0.5763, 0.5291, 0.9090],
         [1.2709, 0.5414, 0.5297,  ..., 0.4596, 0.4396, 0.9516],
         [1.2637, 0.5325, 0.5261,  ..., 0.4678, 0.4441, 0.9189]]],
       dtype=torch.float64), tensor([0])]
[tensor([[[1.2347, 0.5890, 0.5304,  ..., 0.5006, 0.4688, 0.9678],
         [1.2924, 0.4867, 0.4685,  .

[tensor([[[1.2513, 0.5512, 0.5353,  ..., 0.4878, 0.4741, 0.9358],
         [1.1729, 0.6073, 0.5164,  ..., 0.7593, 0.5939, 0.7662],
         [1.2646, 0.5448, 0.5465,  ..., 0.4500, 0.4351, 1.0618],
         ...,
         [1.1905, 0.6385, 0.5110,  ..., 0.5421, 0.5321, 0.9369],
         [1.2826, 0.5338, 0.5541,  ..., 0.4991, 0.4567, 0.9410],
         [1.2193, 0.6819, 0.5487,  ..., 0.6966, 0.5157, 0.8856]]],
       dtype=torch.float64), tensor([0])]
[tensor([[[1.1705, 0.6140, 0.5135,  ..., 0.6598, 0.7133, 0.7668],
         [1.2728, 0.5556, 0.4945,  ..., 0.5302, 0.4754, 0.8927],
         [1.2822, 0.5449, 0.5722,  ..., 0.5421, 0.4477, 0.8832],
         ...,
         [1.1838, 0.6231, 0.4606,  ..., 0.6347, 0.5962, 0.8419],
         [1.2354, 0.6172, 0.5479,  ..., 0.6410, 0.5019, 0.9471],
         [0.5016, 1.2447, 0.5624,  ..., 1.1154, 1.0311, 0.4031]]],
       dtype=torch.float64), tensor([1])]
[tensor([[[1.2216, 0.6372, 0.5740,  ..., 0.5402, 0.5224, 0.9343],
         [1.2767, 0.5576, 0.5818,  .

[tensor([[[1.1084, 1.0934, 0.7239,  ..., 1.0043, 0.9127, 0.6638],
         [1.1334, 1.0354, 0.7784,  ..., 0.9951, 0.8987, 0.6855],
         [1.1136, 1.0731, 0.7177,  ..., 0.9665, 0.9508, 0.6031],
         ...,
         [1.1901, 0.9838, 0.7269,  ..., 0.9985, 0.8442, 0.7114],
         [1.1641, 1.0333, 0.7623,  ..., 0.9845, 0.9507, 0.6695],
         [1.0834, 0.9092, 0.6218,  ..., 0.9456, 0.7904, 0.6752]]],
       dtype=torch.float64), tensor([1])]
[tensor([[[1.2643, 0.5507, 0.5616,  ..., 0.4680, 0.4369, 0.9972],
         [1.2777, 0.5575, 0.5385,  ..., 0.4943, 0.4201, 1.0343],
         [1.2825, 0.4782, 0.5374,  ..., 0.3884, 0.4217, 1.0829],
         ...,
         [1.2221, 0.5442, 0.5103,  ..., 0.4806, 0.5135, 0.9348],
         [1.2616, 0.4802, 0.5245,  ..., 0.3917, 0.4094, 1.0595],
         [1.2732, 0.4741, 0.5372,  ..., 0.3884, 0.4118, 1.0204]]],
       dtype=torch.float64), tensor([0])]
[tensor([[[1.2608, 0.5204, 0.5443,  ..., 0.5011, 0.4753, 0.9317],
         [1.1193, 0.5150, 0.4968,  .

[tensor([[[1.1084, 1.0934, 0.7239,  ..., 1.0043, 0.9127, 0.6638],
         [1.1334, 1.0354, 0.7784,  ..., 0.9951, 0.8987, 0.6855],
         [1.1136, 1.0731, 0.7177,  ..., 0.9665, 0.9508, 0.6031],
         ...,
         [1.1901, 0.9838, 0.7269,  ..., 0.9985, 0.8442, 0.7114],
         [1.1641, 1.0333, 0.7623,  ..., 0.9845, 0.9507, 0.6695],
         [1.0834, 0.9092, 0.6218,  ..., 0.9456, 0.7904, 0.6752]]],
       dtype=torch.float64), tensor([1])]
[tensor([[[1.2682, 0.5357, 0.5882,  ..., 0.4888, 0.4358, 1.0229],
         [1.1944, 0.6757, 0.5285,  ..., 0.6856, 0.6126, 0.7760],
         [1.2651, 0.5722, 0.5028,  ..., 0.5399, 0.4823, 0.8689],
         ...,
         [1.2480, 0.5679, 0.5523,  ..., 0.5718, 0.4779, 0.8055],
         [1.2187, 0.6065, 0.5011,  ..., 0.5957, 0.5565, 0.7939],
         [1.2650, 0.5266, 0.5671,  ..., 0.4968, 0.4616, 0.9929]]],
       dtype=torch.float64), tensor([0])]
[tensor([[[1.1998, 0.5711, 0.4886,  ..., 0.5425, 0.5257, 0.8731],
         [1.3005, 0.4831, 0.5111,  .

[tensor([[[1.2500, 0.4682, 0.5103,  ..., 0.3971, 0.4139, 1.0053],
         [1.2748, 0.4842, 0.5303,  ..., 0.4689, 0.4803, 0.9248],
         [1.2769, 0.5103, 0.5360,  ..., 0.4151, 0.4067, 1.0736],
         ...,
         [1.2752, 0.5580, 0.5279,  ..., 0.4478, 0.4122, 0.9267],
         [1.1965, 0.5910, 0.4911,  ..., 0.5075, 0.5145, 0.9712],
         [1.2752, 0.4647, 0.5223,  ..., 0.4010, 0.4047, 1.0458]]],
       dtype=torch.float64), tensor([1])]
[tensor([[[1.2517, 0.5348, 0.4937,  ..., 0.5263, 0.4835, 0.8951],
         [1.2380, 0.5695, 0.5055,  ..., 0.5614, 0.5283, 0.9409],
         [0.8762, 1.1566, 0.6006,  ..., 0.9832, 0.9379, 0.5298],
         ...,
         [1.2472, 0.5182, 0.5716,  ..., 0.4349, 0.4350, 1.0772],
         [1.2601, 0.5042, 0.5068,  ..., 0.4613, 0.4447, 0.9358],
         [1.2328, 0.5059, 0.5256,  ..., 0.4456, 0.4282, 1.0588]]],
       dtype=torch.float64), tensor([1])]
[tensor([[[1.1084, 1.0934, 0.7239,  ..., 1.0043, 0.9127, 0.6638],
         [1.1334, 1.0354, 0.7784,  .

[tensor([[[1.2896, 0.4643, 0.5655,  ..., 0.3854, 0.4088, 1.1028],
         [1.2820, 0.5290, 0.5351,  ..., 0.4260, 0.4264, 1.0865],
         [1.2486, 0.5530, 0.5184,  ..., 0.5153, 0.4762, 1.0012],
         ...,
         [1.2898, 0.5231, 0.5499,  ..., 0.4681, 0.4047, 0.9898],
         [1.2853, 0.5076, 0.5821,  ..., 0.4401, 0.4080, 1.1033],
         [1.2508, 0.5214, 0.5411,  ..., 0.5105, 0.4328, 0.9365]]],
       dtype=torch.float64), tensor([0])]
[tensor([[[1.2784, 0.4822, 0.5542,  ..., 0.4089, 0.4279, 1.0672],
         [1.2802, 0.5689, 0.5765,  ..., 0.5666, 0.4723, 1.0145],
         [1.2304, 0.5834, 0.4900,  ..., 0.5159, 0.5094, 0.8637],
         ...,
         [1.2662, 0.5441, 0.4926,  ..., 0.4894, 0.4503, 0.9305],
         [1.2738, 0.5725, 0.5108,  ..., 0.4693, 0.4203, 0.9165],
         [1.2657, 0.5492, 0.5102,  ..., 0.4833, 0.4314, 0.9853]]],
       dtype=torch.float64), tensor([1])]
[tensor([[[1.2686, 0.5702, 0.4819,  ..., 0.4752, 0.4485, 0.8901],
         [1.2064, 0.5632, 0.5046,  .

In [20]:
data[0]

tensor([[[1.2589, 0.5502, 0.5706,  ..., 0.5225, 0.4761, 0.9234],
         [1.2650, 0.5683, 0.5202,  ..., 0.4841, 0.4636, 0.9398],
         [1.2757, 0.5228, 0.5014,  ..., 0.4918, 0.4288, 0.9374],
         ...,
         [1.1141, 0.8348, 0.4990,  ..., 0.8527, 0.7392, 0.6979],
         [1.2817, 0.5285, 0.5456,  ..., 0.5063, 0.4361, 1.0552],
         [1.2682, 0.5480, 0.5330,  ..., 0.4816, 0.4192, 1.0176]]],
       dtype=torch.float64)