# Advanced training
with augmentations, lr scheduling, testing on benchmarks and wandb logging

In [1]:
import torch # 1.8
import segmentation_models_pytorch as smp
from tqdm.auto import tqdm
import albumentations as A
import os
from matplotlib import pyplot as plt
import wandb
import numpy as np
import importlib

#from dlutils.utils import visualization
from dlutils.utils.utils import listdir
from dlutils.learn import traintest, metrics, predict, losses
from dlutils.learn.transforms import transforms, pipeline, augmentations, maskutils, postprocessors
from dlutils.learn.datasets import from_npfiles
from dlutils.utils.samplers.randomgridsampler import make_random_grid_sampler
from dlutils.dataadapters.npfileadapter import NpFileReader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
russia_manual_root = 'samples/'
np_file_name = '/img_mask.np'
# russia_pseudo_root = 'data/russia_pseudo/'

# get lists of files in each directory
russia_manual = os.listdir(russia_manual_root)
russia_manual = [russia_manual_root+file+np_file_name for file in russia_manual]
#russia_pseudo = os.listdir(russia_pseudo_root)
#russia_manual

In [3]:
num_workers = 4
batch_size = 10
normparams = transforms.get_normparams({'mean': (106.94118, 108.923965, 98.418015),
                                        'std': (54.646156, 51.051823, 50.47959)}, 3)
sample_size=512

aug_dict = {
            'Rotate':{'p':0.6},
            'CenterCrop':{'height': sample_size, 'width':sample_size, 'p': 1., 'always_apply': True},
            'RandomRotate90':{'p':0.5},
            'Flip':{'p':0.5},
            'RandomResizedCrop':{'height': sample_size, 'width':sample_size, 'scale':(0.8, 1.2), 'ratio':(1., 1.), 'p':0.2},
            'ColorJitter':{'brightness':0.2, 'contrast':0.2, 'saturation':0.2, 'hue':0.1, 'p':0.2},
            #'ChannelShuffle':{'p':0.1},
            'Downscale':{'p':0.1},
            'GridDistortion':{'p':0.2},
            'ISONoise':{'p':0.2},
            'ImageCompression':{'p':0.2},
            #'GaussNoise':{'p':0.2},
            'Emboss':{'p':0.2},
            'Sharpen':{'p':0.2},
            'Blur':{'p':0.1},
            'CLAHE':{'p':0.2},
            #'HistogramMatching':{'reference_images':reference, 'read_fn':lambda x: x, 'p':0.2, 'blend_ratio': (0.1, 0.9)},
            #'PixelDistributionAdaptation':{'reference_images':reference, 'read_fn':lambda x: x, 'p':0.2, 'blend_ratio': (0.1, 0.7)}
           }

loss_str = "smp.losses.JaccardLoss(mode='multiclass', from_logits=True)"
loss_fn = eval(loss_str)

val_preprocess = transforms.get_transform_to_tensor(normparams)
val_postprocess = pipeline.get_pipeline([transforms.get_transform_from_tensor(),
                                                 maskutils.get_onehot_to_labels(),
                                                 maskutils.clip_to_8bit])

mixed_precision = True
device = 'cuda'


In [4]:
model_str = 'smp.Unet(encoder_name="resnet34", encoder_weights="imagenet")'
model = eval(model_str)

In [5]:
train_files_str = 'russia_manual'
train_weights_str = "'auto'"

train_weights = eval(train_weights_str)
train_files = eval(train_files_str)
print(train_files)

# ?
# true_mask = true_mask.read()
# true_mask = torch.tensor(true_mask)

train_data = from_npfiles(train_files,
                          sample_size=int(sample_size*1.5),
                          pipeline=pipeline.get_pipeline([pipeline.get_img_mask_split(mask_channels=(3,)),
                                                          augmentations.compose_from_dict(aug_dict),
                                                          pipeline.get_dict_wrapper({'image': transforms.get_transform_to_tensor(normparams),
                                                                                     'mask': transforms.get_transform_to_tensor(dtype=torch.int64)})]))

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_sampler=make_random_grid_sampler(train_data,
                                                                                  batch_size=batch_size, 
                                                                                  length=2000,
                                                                                  weights=train_weights
                                                                                 ),
                                           num_workers=num_workers)

['samples/001/img_mask.np', 'samples/002/img_mask.np', 'samples/003/img_mask.np', 'samples/004/img_mask.np', 'samples/005/img_mask.np', 'samples/006/img_mask.np', 'samples/008/img_mask.np', 'samples/009/img_mask.np', 'samples/010/img_mask.np', 'samples/011/img_mask.np', 'samples/012/img_mask.np', 'samples/013/img_mask.np', 'samples/014/img_mask.np', 'samples/015/img_mask.np', 'samples/016/img_mask.np', 'samples/017/img_mask.np', 'samples/018/img_mask.np', 'samples/019/img_mask.np', 'samples/020/img_mask.np', 'samples/021/img_mask.np', 'samples/022/img_mask.np', 'samples/023/img_mask.np', 'samples/024/img_mask.np', 'samples/027/img_mask.np', 'samples/028/img_mask.np', 'samples/029/img_mask.np', 'samples/030/img_mask.np', 'samples/031/img_mask.np', 'samples/033/img_mask.np', 'samples/034/img_mask.np', 'samples/038/img_mask.np', 'samples/039/img_mask.np', 'samples/040/img_mask.np', 'samples/041/img_mask.np', 'samples/042/img_mask.np', 'samples/043/img_mask.np', 'samples/044/img_mask.np', 



In [6]:
epochs = 20 #40
pbar = tqdm(total=traintest.infer_train_length(train_loader, epochs))
lr = 5e-5
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
schedulers = traintest.make_schedulers(optimizer,
                                       warmup_params ={'start_factor': 0.4,  'total_iters': 3},
                                       exp_params = {'gamma': 0.98})


wandb.init(project='swrc',
           config={"model": model_str,
                   #'params': total_params, 
                   'lr':lr, 
                   'batch_size': batch_size,
                   'sample_size': sample_size,
                   'aug': str(aug_dict),
                   'data': train_files_str,
                   'mixed_precision': mixed_precision,
                   'train_weights': train_weights_str,
                   'loss': loss_str})

  0%|          | 0/40000 [00:00<?, ?it/s]

In [32]:
benchmark_root = 'samples_benchmarks/'
benchmark_files =['026','035','045']
#ru_benchmarks = ['Mytishi', 'Ufa', 'Kolomna']

benchmarks = [NpFileReader(os.path.join(benchmark_root, f)+'/img_mask.np') for f in benchmark_files]

max_f1, max_rt = 0, 0

# cut some references for HistogramMatching
reference_sample_size=1024
reference = (benchmarks[0][:3, 3800:3800+reference_sample_size, 3400:3400+reference_sample_size].transpose(1, 2, 0),
             benchmarks[1][:3, :reference_sample_size, :reference_sample_size].transpose(1, 2, 0),
             benchmarks[2][:3, 2048:2048+reference_sample_size, 2048:2048+reference_sample_size].transpose(1, 2, 0)
            #,  benchmarks[3][:3, :reference_sample_size, :reference_sample_size].transpose(1, 2, 0),
            #  benchmarks[4][:3, :reference_sample_size, :reference_sample_size].transpose(1, 2, 0),
            #  benchmarks[5][:3, 1024:1024+reference_sample_size, 1024:1024+reference_sample_size].transpose(1, 2, 0),
            #  benchmarks[6][:3, 1024:1024+reference_sample_size, 1024:1024+reference_sample_size].transpose(1, 2, 0),
             )

benchmarks_weights = np.array([np.sum(b.shape[1:]) for b in benchmarks])
benchmarks_weights = benchmarks_weights / np.max(benchmarks_weights)

for epoch in range(epochs):
    log = dict()
    train_loss = traintest.train(train_loader, model, loss_fn, 
                                 optimizer, mixed_precision=mixed_precision, #accumulate_loss=2,
                                 #loss_handler_type='dict', 
                                 schedulers=schedulers, pbar=pbar, device=device)
    log['Train'] = train_loss
            
    tp = fp = fn = rtp = rfp = rfn = 0
    rt_ru = []
    sw_ru = []
    sh_ru = []
    wl_ru = []
            
    for b_name, b in zip(benchmark_files, benchmarks):
        rgb = b[:3]
        gt = postprocessors.separate_instances(b[3][0], 3, 4)
        pred = predict.predict(model, rgb, preprocess=val_preprocess, sample_size=1024,
                               postprocess=val_postprocess, device=device)[0]  
    
                
        pred = postprocessors.separate_instances(pred, 3, 4)
                
        scores = metrics.get_iou_multiclass(pred, gt, n_classes=4)
        for sc, cn in zip(scores[1:], ('sh', 'wl', 'rt')):
            log[f'{b_name}_{cn}'] = sc
                    
        _tp, _fp, _fn = metrics.objectwise_stats_raster((pred==3).astype(np.uint8), (gt==3).astype(np.uint8))
        tp += _tp
        fp += _fp
        fn += _fn
        log[f'{b_name}_prec'] = metrics.precision(_tp, _fp)
        log[f'{b_name}_rec'] = metrics.recall(_tp, _fn)
        log[f'{b_name}_f1'] = metrics.f1_score(log[f'{b_name}_prec'], log[f'{b_name}_rec'])
        
        #if b_name in ru_benchmarks:
        rtp += _tp
        rfp += _fp
        rfn += _fn
        rt_ru.append(scores[3])
        sh_ru.append(scores[1])
        wl_ru.append(scores[2])
        sw_ru.append((scores[1] + scores[2]) / 2)
                        
    log['Avg_prec'] = metrics.precision(tp, fp)
    log['Avg_rec'] = metrics.recall(tp, fn)
    log['Avg_f1'] = metrics.f1_score(log['Avg_prec'], log['Avg_rec'])
            
    log['Avg_prec_ru'] = metrics.precision(rtp, rfp)
    log['Avg_rec_ru'] = metrics.recall(rtp, rfn)
    log['Avg_f1_ru'] = metrics.f1_score(log['Avg_prec_ru'], log['Avg_rec_ru'])
            
    log['Avg_rt_ru'] = float(np.average(rt_ru, weights=benchmarks_weights[:3]))
    log['Avg_sw_ru'] = float(np.average(sw_ru, weights=benchmarks_weights[:3]))
    log['Avg_sh_ru'] = float(np.average(sh_ru, weights=benchmarks_weights[:3]))
    log['Avg_wl_ru'] = float(np.average(wl_ru, weights=benchmarks_weights[:3]))
            
    log['Avg_sw'] = metrics.avg_class_score(log, ('sh','wl'), weights=benchmarks_weights)
    log['Avg_rt'] = metrics.avg_class_score(log, ('rt',), weights=benchmarks_weights)

    # if log['Avg_f1_ru'] > max_f1 or log['Avg_rt_ru'] > max_rt:
    #    torch.save(model.module.state_dict(),
    #               f"unet_resnet34/Unet_{np.round(log['Avg_f1_ru'], 3)}_{np.round(log['Avg_rt_ru'], 3)}_{np.round(log['Avg_sw_ru'], 3)}.pth")

    wandb.log(log)
wandb.finish()

TypeError: cannot pickle '_io.FileIO' object

In [32]:
torch.save(model.module.state_dict(), f"unet_resnet34/Unet.pth")

In [36]:
type(train_loader)

torch.utils.data.dataloader.DataLoader

In [38]:
len(train_loader)

2000

In [19]:
from torchvision import transforms as T

model = smp.Unet(encoder_name='mit_b3', encoder_weights = None, classes=5, activation=None)
model.load_state_dict(torch.load('models/Unet_0.7052_0.8498_0.7048.pth'))

class ModelTrace(torch.nn.Module):
    def __init__(self, model, transform, inv_transform, device):
        super().__init__()
        self.model = model
        self.model.eval()
        self.model.to(device)
        self.transform = transform
        self.inv_transform = inv_transform

    @torch.no_grad()
    def forward(self, x):
        x = self.transform(x.to(torch.float))
        x = self.model(x)
        x = self.inv_transform(x)
        return x

normparams = {'mean': (106.9, 108.9, 98.4), 'std': (54.6, 51.0, 50.5)}
tr = T.Normalize(**normparams)
inv_tr = lambda x: (torch.max(x, dim=1, keepdim=False)[1]).to(torch.uint8)
trace = ModelTrace(model, tr, inv_tr, 'cpu')
model_tr = torch.jit.trace(trace, torch.rand(1, 3, 1024, 1024))
model_tr(torch.rand(1, 3, 1024, 1024)).shape

  if h % output_stride != 0 or w % output_stride != 0:


torch.Size([1, 1024, 1024])

In [20]:
model_tr.save('models/model.pt')

In [12]:
next(iter(train_loader))

TypeError: cannot pickle '_io.FileIO' object

In [20]:
train_data[0,[0,0]]

{'mask': tensor([[[[0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           ...,
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0]]]]),
 'image': tensor([[[[-5.8451e-01,  1.0764e-03,  3.6707e-01,  ...,  3.7675e-02,
            -2.0022e-01,  2.7557e-01],
           [-1.6276e+00, -1.5178e+00, -1.2433e+00,  ...,  2.3897e-01,
            -6.2111e-01, -8.7730e-01],
           [-1.7923e+00, -1.5727e+00, -1.5178e+00,  ...,  3.1217e-01,
            -5.8451e-01, -9.1390e-01],
           ...,
           [-5.3822e-02, -3.8321e-01, -1.7191e+00,  ..., -4.0151e-01,
            -1.0054e+00,  9.2574e-02],
           [ 2.2153e+00,  1.7395e+00,  1.8127e+00,  ...,  1.4747e-01,
             5.5006e-01,  1.6846e+00],
           [ 2.1055e+00, -1.9021e+00,  1.6114e+00,  ...,  5.5006e-01,
             1.1722e+00,  1.9774e+00]],
 
          [[-6.0574e-01,  6.0253e-02,  3.1490e-01,  ...,  9.9429e-02