In [None]:
!pip install timm fastai wandb -Uqq

[K     |████████████████████████████████| 431 kB 4.2 MB/s 
[K     |████████████████████████████████| 189 kB 45.8 MB/s 
[K     |████████████████████████████████| 1.7 MB 40.0 MB/s 
[K     |████████████████████████████████| 55 kB 3.8 MB/s 
[K     |████████████████████████████████| 181 kB 45.9 MB/s 
[K     |████████████████████████████████| 144 kB 48.6 MB/s 
[K     |████████████████████████████████| 63 kB 1.8 MB/s 
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
%cd /content/drive/MyDrive/colab_notebooks/algovera/nCight/

/content/drive/MyDrive/colab_notebooks/algovera/nCight


In [None]:
!git clone https://github.com/AlgoveraAI/freelance-medical-image-classification.git

Cloning into 'freelance-medical-image-classification'...
remote: Enumerating objects: 18, done.[K
remote: Counting objects: 100% (18/18), done.[K
remote: Compressing objects: 100% (15/15), done.[K
remote: Total 18 (delta 4), reused 11 (delta 2), pack-reused 0[K
Unpacking objects: 100% (18/18), done.


In [None]:
#%cd freelance-medical-image-classification/
#!git init
#!git checkout -b dev-arshy

/content/drive/MyDrive/colab_notebooks/algovera/nCight/freelance-medical-image-classification


In [None]:
with open('wandbkey.text', 'r') as f:
    wandbk = f.read()
f.close()

In [None]:
#export
import os
from fastai.vision.all import *
from fastai.vision.learner import cnn_learner, create_head, create_body, num_features_model, default_split, has_pool_type, apply_init, _update_first_layer
from fastai.callback.wandb import WandbCallback
from sklearn.model_selection import StratifiedKFold
from random import sample
from timm import create_model
import wandb

SEED=101
random.seed(SEED)
set_seed(SEED, True)

wandb.login(key=wandbk)

[34m[1mwandb[0m: W&B API key is configured (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
#export
config = {
    'local':True,
    'bs': 8,
    'epochs':25,
    'freeze_epochs':2,
    'lr':1e-3,
    'model':'efficientnet_b3', #if fastai pass fastai callable function; if timm pass arch name
    'timm':True,
    'pretrained':True
}


TimmConfig= {
    #non-transformer
    'efficientnet_b0':{'arch': 'efficientnet_b0', 'is_transformer':False},
    'efficientnet_b3':{'arch': 'efficientnet_b3', 'is_transformer':False},
    'tf_efficientnet_b0_ap':{'arch':'tf_efficientnet_b0_ap', 'is_transformer':False},
    'tf_efficientnet_b3_ap':{'arch':'tf_efficientnet_b0_ap', 'is_transformer':False},
    'tf_efficientnetv2_xl_in21ft1k':{'arch':'tf_efficientnetv2_xl_in21ft1k', 'is_transformer':False},
    'ssl_resnext50_32x4d':{'arch':'ssl_resnext50_32x4d', 'is_transformer':False},
    'ssl_resnext101_32x4d':{'arch':'ssl_resnext101_32x4d', 'is_transformer':False},

    #transformer
    'swin_base_patch4_window7_224':{'arch':'swin_base_patch4_window7_224', 'is_transformer':True, 'size':224},
    'swin_base_patch4_window7_224_in22k':{'arch':'swin_base_patch4_window7_224_in22k', 'is_transformer':True, 'size':224},
    'beit_base_patch16_224':{'arch':'beit_base_patch16_224', 'is_transformer':True, 'size':224},
    'beit_base_patch16_224_in22k':{'arch':'beit_base_patch16_224_in22k', 'is_transformer':True, 'size':224},
    'deit_base_distilled_patch16_224':{'arch':'deit_base_distilled_patch16_224', 'is_transformer':True, 'size':224},
    'deit_base_distilled_patch16_384':{'arch':'deit_base_distilled_patch16_384', 'is_transformer':True, 'size':224}
}


def get_input(local=False):
    if local:
        print("Reading local medicaldata directory.")

        # Root directory for dataset
        filename = Path('./data')

        return filename

    dids = os.getenv('DIDS', None)

    if not dids:
        print("No DIDs found in environment. Aborting.")
        return

    dids = json.loads(dids)

    cwd = os.getcwd()
    print('cwd', cwd)

    did = dids[0]
    filename = Path(f'/data/inputs/{did}/0')  # 0 for metadata service
    return filename


def get_label(fn):
    if fn.suffix == '.jpeg':
        l_fn = f"{str(fn).split('.jpeg')[0]}.json"

    elif fn.suffix == '.png':
        l_fn = f"{str(fn).split('.png')[0]}.json"

    elif fn.suffix == '.jpg':
        l_fn = f"{str(fn).split('.jpg')[0]}.json"
    
    with open(l_fn, 'r') as tmp:
        l = json.load(tmp)

    return l['Scope_type']


def get_patient(fn):
    return ' '.join(str(fn).split('/')[-5:-2])


def get_train_test(df):
    ids = list(df['patient_id'].unique())
    train_ids = random.sample(ids, int(len(ids)*0.8))
    test_ids = [id_ for id_ in ids if id_ not in train_ids]
    df.loc[df[df['patient_id'].isin(test_ids)].index, 'is_valid'] = True

    return df


def get_df(local=True):
    print("Preparing df.")
    filename = get_input(True)
    image_fns = get_image_files(filename)

    df = pd.DataFrame(list(image_fns), columns=['fns'])

    df['label'] = df['fns'].apply(lambda x: get_label(x))
    df['patient_id'] = df['fns'].apply(lambda x: get_patient(x))
    df['is_valid'] = False

    df = get_train_test(df)

    return df


def setup_dataloaders(df, bs, size=512, augs=None):
    print("Setting up dls")
    if not augs:
        augs = [Brightness(), 
                Contrast(), 
                Hue(), 
                Saturation(), 
                DeterministicDihedral(),
                Hue(), 
                Saturation(), 
                RandomErasing(max_count=3)]

    db = DataBlock(blocks=(ImageBlock, CategoryBlock),
                get_x=ColReader('fns'),
                get_y=ColReader('label'),
                splitter=ColSplitter(),
                item_tfms=Resize(size),
                batch_tfms=setup_aug_tfms(augs) 
                )
    dls = db.dataloaders(df, bs=bs)

    return dls

def get_timm_model(
    arch:str, 
    transformer:bool=None,
    pretrained=True, 
    cut=None, 
    n_in=3
):
    "Creates a body from any model in the `timm` library."
    if not transformer:
        body = create_model(arch, pretrained=pretrained, num_classes=0, global_pool='')
        _update_first_layer(body, n_in, pretrained)
        if cut is None:
            ll = list(enumerate(body.children()))
            cut = next(i for i,o in reversed(ll) if has_pool_type(o))
        body =  nn.Sequential(*list(body.children())[:cut])
        nf = num_features_model(nn.Sequential(*body.children()))
        head = create_head(nf, 2)
        model = nn.Sequential(body, head)
        apply_init(model[1], nn.init.kaiming_normal_)

        return model
      
    else:
        return create_model(arch, 
                            pretrained=pretrained, 
                            num_classes=2)
        

def get_learner_lr(dls,
    model, # Model arch
    timm:bool=False, # True if using timm model
    pretrained:bool=True, # Use pretrained backbone
    ):

    print("Setting up learner.")

    if not timm:
        learner = cnn_learner(
                              dls,
                              model,
                              pretrained=pretrained,
                              metrics=accuracy
                          )
        
    else:
        model_config = TimmConfig[model]
        model = get_timm_model(model_config['arch'], 
                               model_config['is_transformer'],
                               pretrained,
                               )
        learner = Learner(
            dls,
            model,
            metrics=accuracy
        )
        
    #v = learner.lr_find()
    #lr = v[0]

    return learner

def setup_train(
    local,
    bs,
    epochs,
    freeze_epochs,
    lr,
    model,
    timm,
    pretrained
):
    df = get_df(local)

    try:
        size = TimmConfig[model]['size']
    except:
        size=256  

    dls = setup_dataloaders(df, bs, size)

    learner = get_learner_lr(
                      dls=dls, 
                      model=model, 
                      timm=timm,
                      pretrained=pretrained)
    
    model_name = model if isinstance(model, str) else model.__name__
    run_name = f'{model_name}_{freeze_epochs}_{epochs}'
    sbm = SaveModelCallback(fname=run_name)

    wandb.init(project="algovera_ncight_kneeshoulder", 
               name=run_name)
    
    learner.freeze()
    learner.fit_one_cycle(freeze_epochs, lr_max=lr, cbs=[GradientAccumulation(16), 
                                                         GradientClip(), 
                                                         WandbCallback(log_preds=False),
                                                         sbm])

    learner.unfreeze()
    learner.fit_one_cycle(epochs, lr_max=lr, cbs=[GradientAccumulation(16), 
                                                  GradientClip(), 
                                                  WandbCallback(log_preds=False), 
                                                  sbm])
    
    preds, targs = learner.get_preds(dl=learner.dls.valid)
    preds = torch.argmax(preds, 1).numpy()
    targs = targs.numpy()
    cm = wandb.plot.confusion_matrix(
        y_true=targs,
        preds=preds,
        class_names=list(learner.dls.vocab))
        
    wandb.log({"conf_mat": cm})

    return learner    

In [None]:
#fastai resnet34
config = {
    'local':True,
    'bs': 8,
    'epochs':25,
    'freeze_epochs':2,
    'lr':1e-3,
    'model':resnet34, #if fastai pass fastai callable function; if timm pass arch name
    'timm':False,
    'pretrained':True
}
learner = setup_train(
              local=config['local'],
              bs=config['bs'],
              epochs=config['epochs'],
              freeze_epochs=config['freeze_epochs'],
              lr=config['lr'],
              model=config['model'],
              timm=config['timm'],
              pretrained=config['pretrained']
              )

Preparing df.
Reading local medicaldata directory.
Setting up dls
Setting up learner.


VBox(children=(Label(value=' 332.54MB of 332.54MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=…

0,1
accuracy,▁██▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▂▂▂▂▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇██
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr_0,▁▆▁▂▃▄▆▇████▇▇▇▆▅▅▄▄▃▃▂▂▁▁▁
mom_0,█▃██▆▅▃▂▁▁▁▁▂▂▂▃▄▄▅▅▆▆▇▇███
raw_loss,▂▂▁▁▄▆██▅▃▅▄▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂
sqr_mom_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,▃▃▁▁▃▅▇████▇▇▇▇▆▆▆▆▅▅▅▅▅▅▅▅
valid_loss,▂▁▁▆▆██▂▂▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
wd_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,0.5
epoch,27.0
eps_0,1e-05
lr_0,1e-05
mom_0,0.9493
raw_loss,0.66623
sqr_mom_0,0.99
train_loss,0.90613
valid_loss,0.75652
wd_0,0.01


epoch,train_loss,valid_loss,accuracy,time
0,1.643651,2.964331,0.5,00:01
1,1.493677,2.739938,0.5,00:01


Better model found at epoch 0 with valid_loss value: 2.9643309116363525.
Better model found at epoch 1 with valid_loss value: 2.739938259124756.


epoch,train_loss,valid_loss,accuracy,time
0,1.242838,2.321101,0.5,00:02
1,1.017436,2.223137,0.5,00:01
2,0.997334,2.171421,0.5,00:01
3,0.976277,2.964754,0.5,00:01
4,0.949704,2.877976,0.5,00:01
5,0.818263,2.606045,0.5,00:01
6,0.789156,2.493492,0.5,00:01
7,0.72981,1.243196,0.7,00:02
8,0.706348,1.353082,0.7,00:01
9,0.630736,0.924228,0.7,00:01


Better model found at epoch 0 with valid_loss value: 2.3211007118225098.
Better model found at epoch 1 with valid_loss value: 2.2231366634368896.
Better model found at epoch 2 with valid_loss value: 2.1714212894439697.
Better model found at epoch 7 with valid_loss value: 1.2431957721710205.
Better model found at epoch 9 with valid_loss value: 0.9242278337478638.
Better model found at epoch 11 with valid_loss value: 0.7239178419113159.
Better model found at epoch 12 with valid_loss value: 0.7233232259750366.
Better model found at epoch 13 with valid_loss value: 0.7227859497070312.
Better model found at epoch 15 with valid_loss value: 0.5150192975997925.
Better model found at epoch 23 with valid_loss value: 0.5022369623184204.
Better model found at epoch 24 with valid_loss value: 0.47461119294166565.


In [None]:
#timm non-transformer
config = {
    'local':True,
    'bs': 8,
    'epochs':25,
    'freeze_epochs':2,
    'lr':1e-3,
    'model':'efficientnet_b3', #if fastai pass fastai callable function; if timm pass arch name
    'timm':True,
    'pretrained':True
}

learner = setup_train(
              local=config['local'],
              bs=config['bs'],
              epochs=config['epochs'],
              freeze_epochs=config['freeze_epochs'],
              lr=config['lr'],
              model=config['model'],
              timm=config['timm'],
              pretrained=config['pretrained']
              )

Preparing df.
Reading local medicaldata directory.
Setting up dls
Setting up learner.


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b3_ra2-cf984f9c.pth
[34m[1mwandb[0m: Currently logged in as: [33marshy[0m (use `wandb login --relogin` to force relogin)


epoch,train_loss,valid_loss,accuracy,time
0,1.030818,2.2e-05,1.0,00:09
1,0.827489,0.013753,1.0,00:03


Better model found at epoch 0 with valid_loss value: 2.214802589151077e-05.


epoch,train_loss,valid_loss,accuracy,time
0,0.345403,0.001176,1.0,00:03
1,0.726613,0.015658,1.0,00:03
2,0.738459,0.059143,1.0,00:03
3,0.931538,0.056423,1.0,00:03
4,0.963151,0.077981,1.0,00:02
5,1.006117,0.11896,1.0,00:03
6,1.113811,0.158997,1.0,00:03
7,1.015884,0.132504,1.0,00:03
8,0.939737,0.262234,0.8,00:03
9,0.959773,0.191914,1.0,00:03


Better model found at epoch 0 with valid_loss value: 0.0011758357286453247.


In [None]:
#timm transformer
config = {
    'local':True,
    'bs': 8,
    'epochs':25,
    'freeze_epochs':2,
    'lr':1e-3,
    'model':'swin_base_patch4_window7_224', #if fastai pass fastai callable function; if timm pass arch name
    'timm':True,
    'pretrained':True
}

learner = setup_train(
              local=config['local'],
              bs=config['bs'],
              epochs=config['epochs'],
              freeze_epochs=config['freeze_epochs'],
              lr=config['lr'],
              model=config['model'],
              timm=config['timm'],
              pretrained=config['pretrained']
              )

Preparing df.
Reading local medicaldata directory.
Setting up dls
Setting up learner.


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth" to /root/.cache/torch/hub/checkpoints/swin_base_patch4_window7_224_22kto1k.pth


VBox(children=(Label(value=' 94.76MB of 94.76MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
accuracy,▁▁▁▁▁▁▃▃▁▃▃▃▆▅▆▅▆██▅▅▅▅▆▆▆▆
epoch,▁▁▂▂▂▂▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇██
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr_0,▁▆▁▂▃▄▆▇████▇▇▇▆▅▅▄▄▃▃▂▂▁▁▁
mom_0,█▃██▆▅▃▂▁▁▁▁▂▂▂▃▄▄▅▅▆▆▇▇███
raw_loss,█▆▄▅▂▄▃▄▁▂▃▂▁▃▁▂▁▁▂▁▁▁▁▁▄▁▁
sqr_mom_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▇▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
valid_loss,█▇▅▅▄▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wd_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,0.8
epoch,27.0
eps_0,1e-05
lr_0,1e-05
mom_0,0.9493
raw_loss,0.0064
sqr_mom_0,0.99
train_loss,0.3258
valid_loss,0.49876
wd_0,0.01


epoch,train_loss,valid_loss,accuracy,time
0,0.699094,0.73283,0.5,00:02
1,0.686557,0.58717,0.8,00:02


Better model found at epoch 0 with valid_loss value: 0.7328304052352905.
Better model found at epoch 1 with valid_loss value: 0.5871695280075073.


epoch,train_loss,valid_loss,accuracy,time
0,0.491708,0.58717,0.8,00:03
1,0.537913,1.47859,0.5,00:02
2,0.766828,1.47859,0.5,00:02
3,0.951958,1.868269,0.5,00:02
4,1.146876,1.868269,0.5,00:03
5,1.286144,0.814035,0.5,00:02
6,1.286975,0.814035,0.5,00:02
7,1.231394,1.024518,0.5,00:02
8,1.240213,1.024518,0.5,00:02
9,1.216538,0.752027,0.5,00:02


Better model found at epoch 0 with valid_loss value: 0.5871695280075073.
