In [1]:
%cd ..

/home/paulw/projects/TRUSnet-1


In [2]:
DEVICE='cuda'
ADV_LOSS_FACTOR=1
CYCLE=[100, 100]

from src.utils.metrics import ClassificationOutputCollector
from tqdm.notebook import tqdm

In [3]:
from hydra import initialize, compose

initialize(config_path='../configs')

from src.configuration import register_configs

register_configs()

#cfg = compose(overrides=['+datamodule@dm=sl_datamodule', '+split_seed=0'])
cfg = compose(overrides=['+datamodule@dm=sl_datamodule', '+split_seed=2'])

from omegaconf import OmegaConf
OmegaConf.resolve(cfg)

from hydra.utils import instantiate

cfg.dm.splits.undersample_benign_train = True 
cfg.dm.splits.undersample_benign_eval = True  
cfg.dm.loader_config.num_workers = 0
cfg.dm.loader_config.balance_classes_train = True

cfg.dm.loader_config.batch_size=64

cfg.dm.minimum_involvement = 0.4

from rich import print as pprint
pprint(OmegaConf.to_object(cfg))

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  initialize(config_path='../configs')
  stdout_func(


In [4]:
dm = instantiate(cfg.dm)
dm.setup()

Preparing cores: 100%|██████████| 978/978 [00:00<00:00, 54906.76it/s]
Indexing Patches: 100%|██████████| 978/978 [00:27<00:00, 35.59it/s]
Preparing cores: 100%|██████████| 105/105 [00:00<00:00, 37822.22it/s]
Indexing Patches: 100%|██████████| 105/105 [00:02<00:00, 35.32it/s]
Preparing cores: 100%|██████████| 315/315 [00:00<00:00, 51285.06it/s]
Indexing Patches: 100%|██████████| 315/315 [00:08<00:00, 36.45it/s]


In [5]:
from src.modeling.optimizer_factory import OptimizerConfig, configure_optimizers
from src.modeling.shao_3_player_minimax_models import ShaoEtAlFeatureExtractor, ShaoEtAlMLP
from itertools import chain

#feature_extractor = ShaoEtAlFeatureExtractor().to(DEVICE)
from src.modeling.registry import resnet10_feature_extractor
from src.modeling.mlp import MLPClassifier
feature_extractor = resnet10_feature_extractor().to(DEVICE)
center_disc = MLPClassifier(512, 256, num_classes=5).to(DEVICE)
clf = MLPClassifier(512, 256, num_classes=2).to(DEVICE)


#center_disc = ShaoEtAlMLP(5).to(DEVICE)
#clf = ShaoEtAlMLP(5).to(DEVICE)

from torch.optim import Adam
feat_opt = Adam(feature_extractor.parameters())
clf_opt = Adam(clf.parameters())
center_disc_opt = Adam(center_disc.parameters())

In [6]:
import numpy as np 
import torch

center2label = {
    'UVA': 0, 
    'CRCEO': 1, 
    'PCC': 2, 
    'PMCC': 3, 
    'JH': 4,
}

def to_center_label(center): 
    center = np.array([center2label[center_] for center_ in center])
    return torch.tensor(center)

In [1]:
from torch.nn import functional as F


def shared_step(batch): 
    patch, pos, label, metadata = batch
    label = label.to(DEVICE)
    center = to_center_label(metadata['center']).to(DEVICE)
    feats = feature_extractor(patch.to(DEVICE))
    logits = clf(feats)
    center_logits = center_disc(feats)

    loss = F.cross_entropy(logits, label)
    center_disc_loss = F.cross_entropy(center_logits, center)

    clf_out = {
        'logits': logits, 
        'preds': logits.softmax(-1), 
        'label': label, 
        **metadata
    }
    center_clf_out = {
        'logits': center_logits,
        'preds': logits.softmax(-1),
        'label': center, 
    }

    return loss, center_disc_loss, clf_out, center_clf_out

def main_train_step(batch): 
    loss, center_disc_loss, clf_out, center_clf_out = shared_step(batch)
    
    total_loss = loss - center_disc_loss * ADV_LOSS_FACTOR
    total_loss.backward()
    
    feat_opt.step()
    feat_opt.zero_grad()
    clf_opt.step()
    clf_opt.zero_grad()

    return clf_out

def disc_train_step(batch): 
    loss, center_disc_loss, clf_out, center_clf_out = shared_step(batch)

    center_disc_loss.backward()
    center_disc_opt.step()
    center_disc_opt.zero_grad()

    return center_clf_out

def eval_step(batch):
    with torch.no_grad():
        loss, center_disc_loss, clf_out, center_clf_out = shared_step(batch)

    return clf_out, center_clf_out

TOTAL_TRAIN_STEPS = 0 
def train_disc(): 
    return TOTAL_TRAIN_STEPS % sum(CYCLE) >= CYCLE[0]

def train_epoch(): 

    collector = ClassificationOutputCollector()
    collector_disc = ClassificationOutputCollector()

    for batch in tqdm(dm.train_dataloader()): 
        if train_disc(): 
            out = main_train_step(batch)
            collector.collect_batch(out)
        else: 
            out = disc_train_step(batch)
            collector_disc.collect_batch(out)

        global TOTAL_TRAIN_STEPS
        TOTAL_TRAIN_STEPS += 1

    return collector.compute(), collector_disc.compute()

def eval_epoch(loader): 

    collector = ClassificationOutputCollector()
    collector_disc = ClassificationOutputCollector()

    for batch in tqdm(loader): 
        out, center_out = eval_step(batch)
        collector.collect_batch(out)
        collector_disc.collect_batch(out)

    return collector.compute(), collector_disc.compute()

from src.utils.metrics import patch_and_core_metrics
from torchmetrics.functional import auroc
def base_metrics(out): 
    return {
        'auroc': auroc(out['logits'], out['label'], num_classes=2)
    } 


for epoch in range(100): 
    out, disc_out = train_epoch()
    metrics = patch_and_core_metrics(out, base_metrics)
    print('train auroc: ', metrics['patch_micro_avg_auroc'])
    out, disc_out = eval_epoch(dm.test_dataloader())
    metrics = patch_and_core_metrics(out, base_metrics)
    print('test auroc: ', metrics['patch_micro_avg_auroc'])

out, _ = eval_epoch(dm.test_dataloader())

ModuleNotFoundError: No module named 'src'

In [31]:

patch_and_core_metrics(out, base_metrics)


{'patch_micro_avg_auroc': tensor(0.5035),
 'patch_CRCEO_auroc': tensor(0.5101),
 'patch_JH_auroc': tensor(0.4920),
 'patch_PCC_auroc': tensor(0.5250),
 'patch_PMCC_auroc': tensor(0.4861),
 'patch_UVA_auroc': tensor(0.5004),
 'patch_macro_avg_auroc': tensor(0.5027),
 'core_micro_avg_auroc': tensor(0.4662),
 'core_CRCEO_auroc': tensor(0.3484),
 'core_JH_auroc': tensor(0.6518),
 'core_PCC_auroc': tensor(0.5022),
 'core_PMCC_auroc': tensor(0.5614),
 'core_UVA_auroc': tensor(0.3586),
 'core_macro_avg_auroc': tensor(0.4845)}

In [19]:
out = _

In [21]:
out[0]['logits']

tensor([[ 0.8465, -0.1237],
        [ 0.8336, -0.0893],
        [ 0.8359, -0.0823],
        ...,
        [ 0.0106,  0.0077],
        [ 0.0142,  0.0091],
        [-0.0040,  0.0303]])

In [23]:
out[0]['label']

tensor([0, 1, 0,  ..., 0, 1, 0])

In [13]:
batch = next(iter(dm.train_dataloader()))
out = shared_step(batch)[3]

In [14]:
out = _

In [15]:
out

({},
 {'center_logits': tensor([[ 0.0909, -0.0101, -0.0022,  0.0700, -0.1246],
          [ 0.0321, -0.0577,  0.1129, -0.0096, -0.2043],
          [ 0.0218, -0.1328,  0.0191,  0.0185, -0.0320],
          ...,
          [ 0.7422,  0.6527,  0.4659, -0.2247, -1.3297],
          [ 0.8673,  0.6461,  0.4315, -0.2565, -1.5127],
          [ 0.7893,  0.6343,  0.4444, -0.2510, -1.3884]]),
  'center': tensor([0, 0, 0,  ..., 3, 0, 3])})

In [26]:
from src.utils.metrics import patch_and_core_metrics
from torchmetrics.functional import auroc