# import

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
from lightai.train import *
from salt.resnet import *
from salt.unet import *
from salt.transform import *
from salt.dataset import *
from salt.metric import *
from salt.file_op import *
from salt.crit import *
from salt.visualize import *
from salt.predict import *
from salt.callback import *
from salt.evaluate import *

# prepare

In [None]:
distort = Distort(5,5,5)
intensity = MyColorJitter(brightness=0.1)
crop = CropRandom(0.7)
param = {'degrees':0,'resample':Image.BICUBIC}
zoom = MyRandomAffine(scale=[0.5,2],**param)
zoom_in = MyRandomAffine(scale=[1,2],**param)
zoom_out = MyRandomAffine(scale=[0.5,1],**param)
shift = MyRandomAffine(translate=[0.5,0.5],**param)
shear = MyRandomAffine(shear=45,**param)
trn_tsfm = MyRandomApply([sample_hflip,shift,intensity,MyRandomChoice(
        [distort,zoom,crop,shear],ps=[0.45,0.225,0.225,0.1])],ps=[0.5,0.5,0.5,0.5])

In [None]:
bs = 16
lr = 0.01
wd = 5e-6
drop = 0
linear_drop = 0

# train

In [None]:
file_loc = 'sample'
k = 0
print(f'fold {k}')
trn_ds = CsvDataset(f'inputs/{file_loc}/{k}/trn.csv',tsfm=MyCompose(trn_tsfm,to_np))
val_ds = CsvDataset(f'inputs/{file_loc}/{k}/val.csv',tsfm=MyCompose(to_np), tta_tsfms=[None, hflip])
trn_sampler = BatchSampler(RandomSampler(trn_ds), bs, drop_last=True)
val_sampler = BatchSampler(SequentialSampler(val_ds), bs, drop_last=False)
trn_dl = DataLoader(trn_ds, trn_sampler)
val_dl = DataLoader(val_ds, val_sampler)
# log_dir = 'runs/step_lr'
#     writer = SummaryWriter(f'{log_dir}')
writer = None
model = Dynamic(resnet34, trn_ds, drop, linear_drop, writer=writer).cuda()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
loss_fn = Crit([1,0.5,0.05])
reverse_ttas = [None,hflip]
metrics = [NoSaltScore(reverse_ttas), HasSaltScore(reverse_ttas), EmptyTP(reverse_ttas), EmptyFP(reverse_ttas),
           Score(reverse_ttas)]
evaluator = Evaluator(val_dl=val_dl,metrics=metrics,model=model,loss_fn=loss_fn)
sv_best = SaveBestModel(model=model, optimizer=optimizer, small_better=False, name='best')
# sv_period = SavePeriodically(period=5)
learner = Learner(model=model, trn_dl=trn_dl, optimizer=optimizer, evaluator=evaluator, loss_fn=loss_fn, 
                  callbacks=[sv_best], metrics=metrics)

In [None]:
epochs = 50
# sched = ReduceOnPlateau(optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.5, patience=5, verbose=True))
# sched = LRSchedWrapper(optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.5))
sched = LRSchedWrapper(optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5))
# phase1 = np.linspace(1/20, 1, num=10, endpoint=False)
# phase2 = np.linspace(1, 1/20, num=40)
# phase = np.concatenate([phase1, phase2])
# sched = LRSchedWrapper(optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: phase[epoch]))
learner.fit(epochs, sched=sched)

In [None]:
if writer:
    writer.close()

In [None]:
models = []
for i in range(1):
    model = Dynamic(resnet34, trn_ds, drop, linear_drop, writer=writer).cuda()
    model.load_state_dict(torch.load(f'model/best')['model'])
    models.append(model)

In [None]:
ss = []
val_ds = CsvDataset(f'inputs/sample/2/val.csv',tsfm=MyCompose(to_np), tta_tsfms=[None,hflip])
val_sampler = BatchSampler(val_ds, bs)
val_dl = DataLoader(val_sampler, n_worker=4)
for model in models:
    s = val_score([model], val_dl, [None,hflip])
    ss.append(s)
print(np.array(ss).mean())

In [None]:
ss

# submit

In [None]:
test_val_dls = []
for i in range(5):
    test_val_ds = CsvDataset(f'inputs/all/{i}/val.csv',tsfm=to_np, tta_tsfms=[None, hflip])
    test_val_sampler = BatchSampler(test_val_ds, 128)
    test_val_dl = DataLoader(test_val_sampler, n_worker=4)
    test_val_dls.append(test_val_dl)

In [None]:
models = []
for i in range(5):
    resnet = resnet18(drop=0)
    model = Dynamic(resnet, trn_ds, 0, 0).cuda()
    model.load_state_dict(torch.load(f'model/256ep lovasz loss/fold{i}')['model'])
    models.append(model)

In [None]:
%%time
ss = []
for model, val_dl in zip(models, test_val_dls):
    s = val_score(model, val_dl, [None, hflip])
    ss.append(s)
print(np.array(ss).mean())

In [None]:
ss

In [None]:
%%time
predict_test(models)