In [13]:
from torch import nn 
import torch 
import torch.nn.functional as F
import numpy as np
import pandas as pd
import sys 
import os 
from pathlib import Path
from tqdm import tqdm

sys.path.append("../")

In [7]:

class Args(): 
    ...

args = Args()
args.batch_size = 64

In [9]:
# load the dataset
print("Loading the dataset...")
from src.data.registry import (
    exact_patches_sl_tuffc_prostate,
    exact_patches_sl_tuffc_ndl,
)

train_dataset = exact_patches_sl_tuffc_prostate("train")
val_dataset = exact_patches_sl_tuffc_ndl("val")
test_dataset = exact_patches_sl_tuffc_ndl("test")

# load the model
print("Loading the model...")
from src.modeling.registry import (
    create_model,
    vicreg_resnet10_pretrn_allcntrs_noPrst_ndl_crop,
)

backbone1 = vicreg_resnet10_pretrn_allcntrs_noPrst_ndl_crop(2).backbone
fc1 = nn.Linear(512, 2)
clf1 = nn.Sequential(backbone1, fc1)

backbone2 = vicreg_resnet10_pretrn_allcntrs_noPrst_ndl_crop(2).backbone
fc2 = nn.Linear(512, 2)

# create the dataloaders
print("Creating the dataloaders...")
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
import numpy as np

labels = np.array(train_dataset.labels).astype("int")
weight_for_classes = [1 / sum(labels == label) for label in np.unique(labels)]
weights = [weight_for_classes[label] for label in labels]
# logging.info(f"Weights for classes {weights}")

train_sampler = WeightedRandomSampler(
    weights=weights,
    num_samples=len(train_dataset),
)
train_loader = DataLoader(
    train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=4
)
val_loader = DataLoader(
    val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4
)
test_loader = DataLoader(
    test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4
)

Loading the dataset...


Preparing cores: 100%|██████████| 528/528 [00:00<00:00, 87692.74it/s]
Indexing Patches: 100%|██████████| 528/528 [00:00<00:00, 61144.50it/s]
Preparing cores: 100%|██████████| 130/130 [00:00<00:00, 60456.76it/s]
Indexing Patches: 100%|██████████| 130/130 [00:00<00:00, 57693.31it/s]
Preparing cores: 100%|██████████| 138/138 [00:00<00:00, 89669.09it/s]
Indexing Patches: 100%|██████████| 138/138 [00:00<00:00, 60735.99it/s]


Loading the model...
Creating the dataloaders...


In [11]:
clf1.cuda()

opt = torch.optim.Adam(fc1.parameters(), lr=1e-4)

In [14]:
def eval(model, loader):
    from src.utils.accumulators import DictConcatenation
    model.eval()
    acc = DictConcatenation()
    for batch in loader:
        X, y, info = batch
        X = X.cuda()
        y = y.cuda()
        y_hat = model(X).softmax(dim=1)
        acc({"y": y, "y_hat": y_hat, **info})
       
    df = acc.compute(out_fmt="dataframe")
    metrics = {} 
    from sklearn.metrics import roc_auc_score
    metrics["patch_auc"] = roc_auc_score(df.y, df.y_hat_1)
    core_pred = df.groupby("core_specifier").y_hat_1.mean()
    core_label = df.groupby("core_specifier").y.mean()
    metrics["core_auc"] = roc_auc_score(core_label, core_pred)

    return metrics, df

for batch in tqdm(train_loader):
    X, y, _ = batch
    X = X.cuda()
    y = y.cuda()
    opt.zero_grad()
    y_hat = clf1(X)
    loss = F.cross_entropy(y_hat, y)
    loss.backward()
    opt.step()


 30%|██▉       | 955/3206 [01:25<03:20, 11.22it/s]


KeyboardInterrupt: 