In [1]:
#!gpustat

In [2]:
import os
import pandas as pd
from PIL import Image
import numpy as np

import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "7"

In [3]:
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        if p.grad is not None:
            p.grad.data = p.grad.data.float() 

In [4]:
from torch.utils.data import TensorDataset, DataLoader

from few_shot_clip_utils import CLIPClassifier
from clip_utils import ImageDataset, FinetuneDataModule



dataset_name = "mimic-cxr"
model_name = "densenet" #"ViT-B/16, densenet, ViT-L/14
mode = "train_mlp_norm" # ["freeze", "train_norm", "train_full", "train_mlp_norm"]

# do not change
use_augs = True
max_epochs = 20
precision = 16
use_pos_weight = True
# some pre-set settings
if mode == "torchxrayvision":
    model_name = "densenet"
    lr = 0.001
    weight_decay = 0.0
    batch_size = 64
    
    rot_aug = 45
    shift_aug = 0.15
    scale_aug = 0.1
    
    img_value_scale = (1024, 1024)
    input_res = (224, 224)
    
    lr_scale_when_plateau = 0.1
    lr_patience = 5
    # check code of torchxrayvisoin for default learning rate and for num epochs until decay
    

if model_name == "ViT-L/14":
    batch_size = 8 # max 32 for single GPU CL on VitB16, 16 for ViT-L/14
elif model_name == "ViT-B/16":
    batch_size = 32
else:
    batch_size = 32

if model_name == "densenet":
    lr = 1e-6
    weight_decay = 0.2
    mode = "full"
else:
    lr = 1e-6
    weight_decay = 0.2 # decay 0.05 in convnext, 0.2 in CLIP training

In [5]:
#convert_models_to_fp32(model)
num_labels = 14 #data_module.num_labels

if model_name == "densenet_224":
    import torchxrayvision as xrv
    model = xrv.models.DenseNet(weights="densenet121-res224-mimic_ch")
    labels_to_remove = ["No finding", "Support devices", "Pleural other"]
    print(FinetuneDataModule.feature_names)
    
elif model_name == "densenet":
    import torchvision
    import torchvision.transforms as TF
    # get model and re-init out layer
    model = torchvision.models.densenet121(pretrained=True) 
    model.classifier = torch.nn.Linear(1024, num_labels)
    model.num_labels = num_labels
    # create transform
    normalize = TF.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    densenet_size = 256
    transform = TF.Compose([TF.Resize(size=densenet_size, 
                                               interpolation=TF.InterpolationMode.BILINEAR),
                                     TF.CenterCrop(size=(densenet_size, densenet_size)),
                                     TF.ToTensor(),
                                     normalize])
else:
    from clip_utils import load_clip
    clip_base_model, transform, clip_name = load_clip(model_name, device="cpu")
    model = CLIPClassifier(clip_base_model, mode, num_labels)

data_module = FinetuneDataModule(model, transform, 
                                 dataset_name=dataset_name, mode=mode, use_augs=use_augs,
                                batch_size=batch_size)
model.label_names = data_module.label_names

In [6]:
#path = data_module.train_ds.paths[0]
#print(path)
#from PIL import Image
#pil_img = Image.open(path)
#pil_img

In [7]:
import pytorch_lightning
from pytorch_lightning.loggers import WandbLogger

wandb_logger = pytorch_lightning.loggers.WandbLogger(name=None, save_dir=None, offline=False, id=None, 
                                      anonymous=None, version=None, project="early_tests", 
                                      log_model=False, experiment=None, prefix='')
wandb_logger.log_hyperparams({"mode": mode,
                             "dataset_name": dataset_name,
                             "use_augs": use_augs,
                             "use_pos_weight": use_pos_weight,
                             "batch_size": batch_size,
                             "model_name": model_name,
                             })


trainer = pytorch_lightning.Trainer(val_check_interval=300,
                                    precision=precision,
                                    logger=wandb_logger,
                                    max_epochs=max_epochs,
                                    gpus=int(torch.cuda.is_available()),
                                    #overfit_batches=1, 
                                    benchmark=True,
                                    num_sanity_val_steps=0,
                                    )

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: antonius (use `wandb login --relogin` to force relogin)
wandb: wandb version 0.12.10 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [8]:
import importlib
import few_shot_clip_utils
importlib.reload(few_shot_clip_utils)
from few_shot_clip_utils import LitCLIP

pos_fraction = data_module.pos_fraction
#print(pos_fraction)
lit_model = LitCLIP(model, max_epochs, learning_rate=lr, 
                    steps_per_epoch=data_module.steps_per_epoch, 
                    weight_decay=weight_decay,
                    pos_fraction=pos_fraction,
                    use_pos_weight=use_pos_weight)#0.2)

Pos weight:  tensor([  4.8115,   4.8980,  24.9163,   9.4579,  36.8328,  48.8305,  34.1711,
          3.9655,   1.6122,   3.9360, 110.1615,  13.4747,  25.3513,   3.5054],
       dtype=torch.float64)


In [None]:
#data_module.prepare_data()
#data_module.setup(stage="fit")
#trainer.validate(model)
trainer.fit(lit_model, data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [7]

  | Name      | Type              | Params
------------------------------------------------
0 | model     | DenseNet          | 7.0 M 
1 | loss_func | BCEWithLogitsLoss | 0     
------------------------------------------------
7.0 M     Trainable params
0         Non-trainable params
7.0 M     Total params
13.936    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [None]:
# TODO: uniformly calc all metrics