In [1]:
%cd ..

/fs01/home/abbasgln/codes/medAI/projects/tta


In [2]:
import os
from dotenv import load_dotenv
# Loading environment variables
load_dotenv()

import torch
import torch.nn as nn
import typing as tp
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from dataclasses import dataclass, field
import logging
import wandb

import medAI
from medAI.utils.setup import BasicExperiment, BasicExperimentConfig

from utils.metrics import MetricCalculator

from timm.optim.optim_factory import create_optimizer

from einops import rearrange, repeat
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import timm

from copy import copy, deepcopy
import pandas as pd

from datasets.datasets import ExactNCT2013RFImagePatches
from medAI.datasets.nct2013 import (
    KFoldCohortSelectionOptions,
    LeaveOneCenterOutCohortSelectionOptions, 
    PatchOptions
)


In [3]:
LEAVE_OUT='JH'

## Data Finetuning

In [4]:
###### No support dataset ######

from vicreg_pretrain_experiment import PretrainConfig
config = PretrainConfig(cohort_selection_config=LeaveOneCenterOutCohortSelectionOptions(leave_out=f"{LEAVE_OUT}"))

from baseline_experiment import BaselineConfig
from torchvision.transforms import v2 as T
from torchvision.tv_tensors import Image as TVImage

class Transform:
    def __init__(selfT, augment=False):
        selfT.augment = augment
        selfT.size = (256, 256)
        # Augmentation
        selfT.transform = T.Compose([
            T.RandomAffine(degrees=0, translate=(0.2, 0.2)),
            T.RandomErasing(p=0.5, scale=(0.02, 0.1), ratio=(0.3, 3.3), value=0.5),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),
        ])  
    
    def __call__(selfT, item):
        patch = item.pop("patch")
        patch = copy(patch)
        patch = (patch - patch.min()) / (patch.max() - patch.min()) \
            if config.instance_norm else patch
        patch = TVImage(patch)
        patch = T.Resize(selfT.size, antialias=True)(patch).float()
        
        label = torch.tensor(item["grade"] != "Benign").long()
        
        if selfT.augment:
            patch_augs = torch.stack([selfT.transform(patch) for _ in range(2)], dim=0)
            return patch_augs, patch, label, item
        
        return -1, patch, label, item


cohort_selection_options_train = copy(config.cohort_selection_config)
cohort_selection_options_train.min_involvement = config.min_involvement_train
cohort_selection_options_train.benign_to_cancer_ratio = config.benign_to_cancer_ratio_train
cohort_selection_options_train.remove_benign_from_positive_patients = config.remove_benign_from_positive_patients_train

train_ds = ExactNCT2013RFImagePatches(
    split="train",
    transform=Transform(augment=False),
    cohort_selection_options=cohort_selection_options_train,
    patch_options=config.patch_config,
    debug=config.debug,
)

# val_ds = ExactNCT2013RFImagePatches(
#     split="val",
#     transform=Transform(augment=True),
#     cohort_selection_options=config.cohort_selection_config,
#     patch_options=config.patch_config,
#     debug=config.debug,
# )

test_ds = ExactNCT2013RFImagePatches(
    split="test",
    transform=Transform(augment=True),
    cohort_selection_options=config.cohort_selection_config,
    patch_options=config.patch_config,
    debug=config.debug,
)


train_loader = DataLoader(
    train_ds, batch_size=config.batch_size, shuffle=True, num_workers=4
)

# val_loader = DataLoader(
#     val_ds, batch_size=config.batch_size, shuffle=False, num_workers=4
# )

test_loader = DataLoader(
    test_ds, batch_size=config.batch_size, shuffle=False, num_workers=4
)



Computing positions: 100%|██████████| 756/756 [00:04<00:00, 160.41it/s]
Computing positions: 100%|██████████| 616/616 [00:07<00:00, 79.16it/s]


## Model

In [5]:
from vicreg_pretrain_experiment import TimmFeatureExtractorWrapper
from timm.layers.adaptive_avgmax_pool import SelectAdaptivePool2d


fe_config = config.model_config

# Create the model
model: nn.Module = timm.create_model(
    fe_config.model_name,
    num_classes=fe_config.num_classes,
    in_chans=1,
    features_only=fe_config.features_only,
    norm_layer=lambda channels: nn.GroupNorm(
                    num_groups=fe_config.num_groups,
                    num_channels=channels
                    ))

# Separate creation of classifier and global pool from feature extractor
global_pool = SelectAdaptivePool2d(
    pool_type='avg',
    flatten=True,
    input_fmt='NCHW',
    )

model = nn.Sequential(TimmFeatureExtractorWrapper(model), global_pool)
linear = nn.Linear(512, 2)   

# CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/vicreg_pretrain_gn_loco/vicreg_pretrain_gn_loco_{LEAVE_OUT}/', 'best_model.ckpt')
# CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/vicreg_pretrn_5e-3-20linprob_gn_loco/vicreg_pretrn_5e-3-20linprob_gn_loco_{LEAVE_OUT}/', 'best_model.ckpt')
# CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/vicreg_pretrn_2048zdim_gn_loco/vicreg_pretrn_2048zdim_gn_loco_{LEAVE_OUT}/', 'best_model.ckpt')
# CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/vicreg_pretrn_2048zdim_gn_loco2/vicreg_pretrn_2048zdim_gn_loco2_{LEAVE_OUT}/', 'best_model.ckpt')
CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/vicreg_1024-300finetune_1e-3lr_gn_loco/vicreg_1024-300finetune_1e-3lr_gn_loco_{LEAVE_OUT}/', 'best_model.ckpt')


# model.load_state_dict(torch.load(CHECkPOINT_PATH)['model'])
model.load_state_dict(torch.load(CHECkPOINT_PATH)['fe_model'])
linear.load_state_dict(torch.load(CHECkPOINT_PATH)['linear'])
linear.eval()
linear.cuda()
model.eval()
model.cuda()

a = True

## Train uisng finetuner

In [6]:
from models.finetuner import Fineturner

metric_calculator = MetricCalculator()
finetuner_model: Fineturner = Fineturner(model, 512, 2, metric_calculator=metric_calculator, log_wandb=False)
finetuner_model.train(train_loader,
                  epochs=10,
                  train_backbone=True,
                  lr=1e-3
                  )

train:   0%|          | 0/932 [00:00<?, ?it/s]

train: 100%|██████████| 932/932 [00:43<00:00, 21.26it/s]
train: 100%|██████████| 932/932 [00:38<00:00, 24.43it/s]
train: 100%|██████████| 932/932 [00:38<00:00, 24.34it/s]
train: 100%|██████████| 932/932 [00:38<00:00, 24.42it/s]
train: 100%|██████████| 932/932 [00:38<00:00, 24.04it/s]
train: 100%|██████████| 932/932 [00:38<00:00, 24.47it/s]
train: 100%|██████████| 932/932 [00:38<00:00, 23.92it/s]
train: 100%|██████████| 932/932 [00:37<00:00, 24.56it/s]
train: 100%|██████████| 932/932 [00:38<00:00, 24.24it/s]
train: 100%|██████████| 932/932 [00:38<00:00, 24.22it/s]


In [7]:
finetuner_model.metric_calculator.reset()
desc='test'
finetuner_model.validate(test_loader, desc=desc)
metric_calculator = finetuner_model.metric_calculator

test: 100%|██████████| 726/726 [01:16<00:00,  9.44it/s]


## Train linear model on reprs

### Get train reprs

In [None]:
from models.linear_prob import LinearProb

loader = train_loader

desc = "train"
metric_calculator = MetricCalculator()
# linear_prob = nn.Linear(512, 2).cuda()
# optimizer = optim.Adam(linear_prob.parameters(), lr=1e-4)
all_reprs_labels_metadata_train = []
all_reprs = []
all_labels = []
for i, batch in enumerate(tqdm(loader, desc=desc)):
    batch = deepcopy(batch)
    images_augs, images, labels, meta_data = batch
    images_augs = images_augs.cuda()
    images = images.cuda()
    labels = labels.cuda()
    
    reprs = model(images).detach()
    all_reprs.append(reprs.cpu().numpy())
    all_labels.append(labels.cpu().numpy())
    all_reprs_labels_metadata_train.append((reprs, labels, meta_data))

    # logits = linear_prob(reprs)
    # loss = nn.CrossEntropyLoss()(logits, labels)
    
    # optimizer.zero_grad()
    # loss.backward()
    # optimizer.step()
all_reprs = np.concatenate(all_reprs, axis=0)
all_labels = np.concatenate(all_labels, axis=0)


train:   0%|          | 0/932 [00:00<?, ?it/s]

### SKlearn logistic regression

In [7]:
# from sklearn.linear_model import LogisticRegression

# LR = LogisticRegression(solver='lbfgs', max_iter=1000, multi_class='multinomial')
# LR.fit(all_reprs, all_labels)

# # Assuming your input features have the same dimension as the scikit-learn model
# input_features = LR.coef_.shape[1]  # Replace with the actual number of features
# linear_prob = nn.Linear(input_features, 1) # Binary classification (1 output unit)

# # Step 4: Assign the weights and bias from scikit-learn model to PyTorch model
# with torch.no_grad():  # Disable gradient computation for this operation
#     linear_prob.weight.data = torch.from_numpy(LR.coef_).float()
#     linear_prob.bias.data = torch.from_numpy(LR.intercept_).float()

# linear_prob.cuda()

### Linear prob 

In [8]:
# os.environ["WANDB_MODE"] = "disabled"
linear_prob: LinearProb = LinearProb(512, 2, metric_calculator=metric_calculator, log_wandb=False)
linear_prob.train(all_reprs_labels_metadata_train,
                  epochs=15,
                  lr=5e-3
                  )

train_linear_prob: 100%|██████████| 932/932 [00:21<00:00, 42.44it/s] 
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 572.63it/s]
train_linear_prob: 100%|██████████| 932/932 [00:02<00:00, 445.22it/s]
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 597.71it/s]
train_linear_prob: 100%|██████████| 932/932 [00:02<00:00, 450.82it/s]
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 595.35it/s]
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 537.73it/s]
train_linear_prob: 100%|██████████| 932/932 [00:02<00:00, 431.97it/s]
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 647.99it/s]
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 601.65it/s]
train_linear_prob: 100%|██████████| 932/932 [00:02<00:00, 464.58it/s]
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 548.07it/s]
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 668.93it/s]
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 666.74it/s]
train_linear_prob: 1

## MEMO on finetuned model

In [6]:
loader = test_loader
adapt_to_test = True

from memo_experiment import batched_marginal_entropy
metric_calculator = MetricCalculator()
desc = "test"

for i, batch in enumerate(tqdm(loader, desc=desc)):
    batch = deepcopy(batch)
    images_augs, images, labels, meta_data = batch
    images_augs = images_augs.cuda()
    images = images.cuda()
    labels = labels.cuda()
    
    batch_size, aug_size= images_augs.shape[0], images_augs.shape[1]

    # if finetuned here
    # adaptation_fe_model = deepcopy(finetuner_model.feature_extractor)
    # adaptation_head_model = deepcopy(finetuner_model.linear)
    
    # if finetuned already
    adaptation_fe_model = deepcopy(model)
    adaptation_head_model = deepcopy(linear)
    
    # adaptation_head_model = deepcopy(linear_prob.linear)
    if adapt_to_test:
        # Adapt to test
        _images_augs = images_augs.reshape(-1, *images_augs.shape[2:]).cuda()
        # adaptation_head_model = deepcopy(linear_prob)
        # adaptation_fe_model.eval()
        params = [{"params": adaptation_head_model.parameters()}, {"params": adaptation_fe_model.parameters()}]
        optimizer = optim.SGD(params, lr=5e-4)
        
        # optimizer = optim.SGD(adaptation_head_model.parameters(), lr=1e-10)
        # reprs = adaptation_fe_model(_images_augs).detach() # for only adapting head
        for j in range(4):
            optimizer.zero_grad()
            reprs = adaptation_fe_model(_images_augs) # for only adapting head
            outputs = adaptation_head_model(reprs).reshape(batch_size, aug_size, -1)  
            loss, logits = batched_marginal_entropy(outputs)
            loss.mean().backward()
            optimizer.step()
    
    # Evaluate
    reprs = adaptation_fe_model(images)
    logits = adaptation_head_model(reprs)
    loss = nn.CrossEntropyLoss()(logits, labels)
                    
    # Update metrics   
    metric_calculator.update(
        batch_meta_data = meta_data,
        probs = nn.functional.softmax(logits, dim=-1).detach().cpu(),
        # probs = nn.functional.tanh(logits).detach().cpu(),
        labels = labels.detach().cpu(),
    )

test:   0%|          | 0/726 [00:00<?, ?it/s]

## Find metrics

In [7]:
# Log metrics every epoch
metrics = metric_calculator.get_metrics()

# Update best score
(
    best_score_updated,
    best_score
    ) = metric_calculator.update_best_score(metrics, desc)

best_score_updated = copy(best_score_updated)
best_score = copy(best_score)
        
# Log metrics
metrics_dict = {
    f"{desc}/{key}": value for key, value in metrics.items()
    }
metrics_dict.update(best_score) if desc == "val" else None 


# wandb.log(
#     metrics_dict,
#     )
metrics_dict

{'test/patch_auroc': tensor(0.6703),
 'test/patch_accuracy': tensor(0.6315),
 'test/all_inv_patch_auroc': tensor(0.6181),
 'test/all_inv_patch_accuracy': tensor(0.6244),
 'test/core_auroc': tensor(0.7580),
 'test/core_accuracy': tensor(0.7155),
 'test/all_inv_core_auroc': tensor(0.6740),
 'test/all_inv_core_accuracy': tensor(0.7023)}

## Log with wandb

In [8]:
import wandb
group=f"offline_4it-memo_vicreg_1024fintn_gn_loco"
name= group + f"_{LEAVE_OUT}"
wandb.init(project="tta", entity="mahdigilany", name=name, group=group)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mmahdigilany[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
# os.environ["WANDB_MODE"] = "enabled"
metrics_dict.update({"epoch": 0})
wandb.log(
    metrics_dict,
    )
wandb.finish()



VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁
test/all_inv_core_accuracy,▁
test/all_inv_core_auroc,▁
test/all_inv_patch_accuracy,▁
test/all_inv_patch_auroc,▁
test/core_accuracy,▁
test/core_auroc,▁
test/patch_accuracy,▁
test/patch_auroc,▁

0,1
epoch,0.0
test/all_inv_core_accuracy,0.7023
test/all_inv_core_auroc,0.67404
test/all_inv_patch_accuracy,0.62441
test/all_inv_patch_auroc,0.6181
test/core_accuracy,0.71549
test/core_auroc,0.75803
test/patch_accuracy,0.63154
test/patch_auroc,0.67031
