In [1]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl
from pytorch_lightning import seed_everything

from data_loaders.data_module import ChestDataModule

from transforms.finetuning import ChestTrainTransforms, ChestValTransforms
from transforms.pretraining import Moco2TrainTransforms, Moco2ValTransforms
from models.baseline import BaseLineClassifier
from models import get_model

import torch
seed_everything(1234)

Global seed set to 1234


1234

In [2]:
classifier = BaseLineClassifier(get_model("resnet18", pretrained=True), 
                                num_classes=2, 
                                linear=False,
                                learning_rate=3e-6,
                                b1=0.9,
                                b2=0.999)

wandb_logger = WandbLogger(name='baseline_NL_pneumo_0.1_Adam_3e-6',project='thesis')
checkpoint_callback = ModelCheckpoint(monitor='val_loss', 
                                      dirpath='logs/baseline/pneumo/', 
                                      filename='resnet-NL-0.1-Adam-3e_6-{epoch:02d}-{val_loss:.4f}')

trainer = pl.Trainer(gpus=1, deterministic=True,
                     logger=wandb_logger, callbacks=[checkpoint_callback])

if torch.cuda.is_available():
    classifier = classifier.cuda()

GPU available: True, used: True
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [3]:
data_module = ChestDataModule(ds_list=["chest_xray_pneumonia"], batch_size=16, num_workers=2, balanced=True, train_fraction=0.1)
data_module.train_transforms = ChestTrainTransforms(height=256)
data_module.val_transforms = ChestValTransforms(height=256)

Loaded datasets: chest_xray_pneumonia


In [4]:
classifier = BaseLineClassifier.load_from_checkpoint("logs/baseline/vinbigdata/resnet-full-NL-adam-3e-5-epoch=09-val_loss=0.1518.ckpt")

In [5]:
trainer.test(classifier, test_dataloaders=data_module.val_dataloader())

Before sampling length:  3000
After sampling length:  3000
Creating balanced dataloader


[34m[1mwandb[0m: Currently logged in as: [33mgenvekt[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.15 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade




HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.9415, device='cuda:0'),
 'test_loss': tensor(0.1526, device='cuda:0')}
--------------------------------------------------------------------------------



[{'test_loss': 0.15255430340766907, 'test_acc': 0.9415107369422913}]

In [4]:
trainer.fit(classifier, data_module)

[34m[1mwandb[0m: Currently logged in as: [33mgenvekt[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.18 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade



  | Name     | Type     | Params
--------------------------------------
0 | model    | ResNet   | 11.2 M
1 | accuracy | Accuracy | 0     

  | Name     | Type     | Params
--------------------------------------
0 | model    | ResNet   | 11.2 M
1 | accuracy | Accuracy | 0     


Before sampling length:  16
After sampling length:  16




HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Before sampling length:  5216
After sampling length:  521
Creating balanced dataloader




HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…






1

In [2]:
from models.pretraining.moco import ModifiedMocoV2


In [3]:
data_module = ChestDataModule(ds_list=["chest14"], batch_size=16, num_workers=2, balanced=True)
data_module.train_transforms = Moco2TrainTransforms(height=256)
data_module.val_transforms = Moco2TrainTransforms(height=256)

Loaded datasets: chest14


In [4]:
moco = ModifiedMocoV2(pretrained = False,
                      base_encoder="resnet18", 
                      num_negatives=65536,
                      linear=True,
                      batch_size=16,
                      num_workers=2,
                      datamodule=data_module,
                      learning_rate=1e-4)

wandb_logger = WandbLogger(name='moco_linear_nopretrain_45_rotation',project='thesis')
checkpoint_callback = ModelCheckpoint(monitor='val_loss', 
                                      dirpath='logs/pretraining/moco/', 
                                      filename='moco_linear_nopretrain_45_rotation-{epoch:02d}-{val_loss:.4f}')

trainer = pl.Trainer(gpus=1, deterministic=True,
                     logger=wandb_logger, callbacks=[checkpoint_callback])

if torch.cuda.is_available():
    classifier = moco.cuda()

GPU available: True, used: True
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [5]:
moco.linear

True

In [None]:
trainer.fit(moco, data_module)

[34m[1mwandb[0m: Currently logged in as: [33mgenvekt[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.12 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade



  | Name      | Type   | Params
-------------------------------------
0 | encoder_q | ResNet | 11.2 M
1 | encoder_k | ResNet | 11.2 M

  | Name      | Type   | Params
-------------------------------------
0 | encoder_q | ResNet | 11.2 M
1 | encoder_k | ResNet | 11.2 M


Before sampling length:  17305
After sampling length:  17305




HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Before sampling length:  69219
After sampling length:  69219




HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

In [7]:
moco = ModifiedMocoV2.load_from_checkpoint("logs/pretraining/moco/resnet-linear_45_rotation-epoch=07-val_loss=5.1505.ckpt")

In [19]:
classifier = BaseLineClassifier(moco.encoder_q, 
                                num_classes=2, 
                                linear=False,
                                learning_rate=3e-5,
                                b1=0.9,
                                b2=0.999)

wandb_logger = WandbLogger(name='finetune_resnet_nonlinear_chest14_01_Adam_3e-5',project='thesis')
checkpoint_callback = ModelCheckpoint(monitor='val_loss', 
                                      dirpath='logs/finetune/moco/', 
                                      filename='resnet-chest14-01-nonlinear-adam-3e-5-{epoch:02d}-{val_loss:.4f}')

trainer = pl.Trainer(gpus=1, deterministic=True,
                     logger=wandb_logger, callbacks=[checkpoint_callback])

if torch.cuda.is_available():
    classifier = classifier.cuda()

GPU available: True, used: True
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [20]:
data_module = ChestDataModule(ds_list=["chest14"], batch_size=16, num_workers=2, balanced=True, train_fraction=0.1)
data_module.train_transforms = ChestTrainTransforms(height=256)
data_module.val_transforms = ChestValTransforms(height=256)

Loaded datasets: chest14


In [21]:
trainer.fit(classifier, data_module)



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_step_loss,1.10939
train_step_acc,100.0
train_step_acc5,100.0
epoch,3.0
_step,13199.0
_runtime,3921.0
_timestamp,1608902590.0
val_loss,1.0578
val_acc,99.94218
val_acc5,99.99422


0,1
train_step_loss,██▇▅▅▄▄▃▃▂▃▂▃▃▂▃▂▂▁▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
train_step_acc,▂▁▂▅▅▅▆▆██▇█████████████████████████████
train_step_acc5,▁▁▄▆▆▇▇█████████████████████████████████
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▃▃▃▃▃▆▆▆▆▆▆▆▆▆▆▆▆▆█
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇█
_timestamp,▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇█
val_loss,█▃▁
val_acc,▁██
val_acc5,▁██


[34m[1mwandb[0m: wandb version 0.10.12 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade



  | Name     | Type     | Params
--------------------------------------
0 | model    | ResNet   | 11.2 M
1 | accuracy | Accuracy | 0     

  | Name     | Type     | Params
--------------------------------------
0 | model    | ResNet   | 11.2 M
1 | accuracy | Accuracy | 0     


Before sampling length:  17305
After sampling length:  17305


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

Before sampling length:  69219
After sampling length:  6921


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [1]:
import pandas as pd

In [2]:
df = pd.read_csv("datasets/chexpert_5.csv")

In [3]:
df_train = df[df["Phase"]=="train"]
df_val = df[df["Phase"]=="val"]

df_train = df_train.sample(int(len(df_train)*0.1), random_state=123456)

In [4]:
len(df_train)

13450

In [5]:
df_full = pd.concat([df_train, df_val])

In [6]:
df_full

Unnamed: 0.2,Unnamed: 0,Unnamed: 0.1,Path,Sex,Age,Frontal/Lateral,AP/PA,No Finding,Enlarged Cardiomediastinum,Cardiomegaly,...,Consolidation,Pneumonia,Atelectasis,Pneumothorax,Pleural Effusion,Pleural Other,Fracture,Support Devices,patient,Phase
15507,32564,32716,/new_data/CheXpert/CheXpert-v1.0/train/patient...,Male,33,Frontal,AP,,,0,...,0,,0,,1,,,,patient07974,train
5516,79855,80238,/new_data/CheXpert/CheXpert-v1.0/train/patient...,Male,53,Frontal,PA,1.0,,0,...,0,,0,0.0,0,,,,patient19289,train
102775,194042,195048,/new_data/CheXpert/CheXpert-v1.0/train/patient...,Male,58,Frontal,AP,,,0,...,0,,0,0.0,1,,,1.0,patient47486,train
104386,150870,151630,/new_data/CheXpert/CheXpert-v1.0/train/patient...,Female,69,Frontal,AP,,,0,...,1,,0,,0,,,1.0,patient35810,train
114016,66039,66366,/new_data/CheXpert/CheXpert-v1.0/train/patient...,Female,60,Frontal,AP,,0.0,0,...,1,,0,1.0,0,,,1.0,patient15965,train
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
168130,211104,212199,/new_data/CheXpert/CheXpert-v1.0/train/patient...,Male,84,Frontal,AP,,,0,...,0,,0,,1,,,,patient56458,val
168131,26387,26489,/new_data/CheXpert/CheXpert-v1.0/train/patient...,Male,54,Frontal,PA,,,0,...,1,,1,,1,,,1.0,patient06465,val
168132,26389,26491,/new_data/CheXpert/CheXpert-v1.0/train/patient...,Male,55,Frontal,PA,,,0,...,1,,1,,0,,,,patient06465,val
168133,26391,26493,/new_data/CheXpert/CheXpert-v1.0/train/patient...,Male,55,Frontal,PA,,,0,...,0,,0,,0,1.0,,,patient06465,val


In [7]:
df_full.to_csv("datasets/chexpert_5_01.csv")