# import

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
from lightai.imps import *
from lightai.learner import *
from lightai.dataloader import *
import torchvision.models as models
from functional import *
from model import *
from transform import *
from dataset import *
from metric import *

# data

In [5]:
trn_tsfm = to_np
val_tsfm = to_np

In [15]:
bs = 128

In [55]:
trn_ds = ImageDataset('inputs/gray/train',tsfm=trn_tsfm,train=True)
val_ds = ImageDataset('inputs/gray/val',tsfm=val_tsfm,train=False,tta_tsfms=[img_hflip])
trn_sampler = BatchSampler(trn_ds, bs)
val_sampler = BatchSampler(val_ds, bs)
trn_dl = DataLoader(trn_sampler)
val_dl = DataLoader(val_sampler)

In [6]:
sp_trn_ds = ImageDataset('inputs/gray/sample/train',tsfm=trn_tsfm)
sp_val_ds = ImageDataset('inputs/gray/sample/val',tsfm=val_tsfm)
sp_trn_sampler = BatchSampler(sp_trn_ds, bs)
sp_val_sampler = BatchSampler(sp_val_ds, bs)
sp_trn_dl = DataLoader(sp_trn_sampler)
sp_val_dl = DataLoader(sp_val_sampler)

# train

In [7]:
resnet = resnet18(False)
model = Dynamic(model_cut(resnet,-2), sp_trn_ds).cuda()
layer_opt = LayerOptimizer(model)
learner = Learner(sp_trn_dl, sp_val_dl, model, nn.BCEWithLogitsLoss(),layer_opt,
                  metric=Score, small_better=False)

In [8]:
wd = 1e-6
lr = 0.3
lrs = np.array([lr/10,lr])

In [11]:
learner.fit(n_epochs=32, lrs=lr,wds=wd,clr_params=[50,5,0.01,0])

HBox(children=(IntProgress(value=0, description='Epoch', max=32), HTML(value='')))

   epoch    trn_loss   val_loss     score                                                                                    
     1      0.291405   0.314589   0.491875  
 35%|##########################9                                                  | 7/20 [00:02<00:04,  3.03it/s, loss=0.275]


KeyboardInterrupt: 

In [9]:
learner.fit(n_epochs=16, lrs=lrs/10,wds=wd,clr_params=[10,2,0.01,0.1])

HBox(children=(IntProgress(value=0, description='Epoch', max=16), HTML(value='')))

   epoch    trn_loss   val_loss     score                                                                                    
     1      0.117367   0.146025   0.722250  
     2      0.115460   0.147501   0.726375                                                                                   
     3      0.113761   0.150519   0.723625                                                                                   
     4      0.111485   0.153450   0.717375                                                                                   
     5      0.109699   0.152763   0.720500                                                                                   
     6      0.105671   0.153596   0.728750                                                                                   
     7      0.102321   0.183058   0.721250                                                                                   
     8      0.100521   0.181586   0.696125                               

In [None]:
learner.unfreeze()

In [None]:
learner.fit(n_epochs=8, lrs=lrs/10, wds=wd, clr_params=[20,2,0.01,0.1])

In [None]:
learner.recorder.plot_loss()

# submit

In [58]:
resnet = resnet18(False)
model = Dynamic(model_cut(resnet,-2), trn_ds).cuda()
model.load_state_dict(torch.load('model/no_tta')['state_dict'])

In [59]:
%%time
thresholds,scores = thres_score(model, val_dl, [None,tta_hflip])

Wall time: 7.67 s


In [60]:
max(scores)

0.81112504

In [61]:
best_thres = np.argmax(scores)
best_thres = best_thres/100

In [62]:
best_thres

0.49

In [66]:
class TestDataset():
    def __init__(self,tsfm=None,tta_tsfms=None):
        img_path = Path('inputs/gray/test/images')
        self.img = list(img_path.iterdir())
        self.tsfm = tsfm
        self.tta_tsfms = tta_tsfms
        
    def __getitem__(self, idx):
        img = Image.open(self.img[idx])
        imgs = [img]
        if self.tta_tsfms:
            for t in self.tta_tsfms:
                imgs.append(t(img))
        if self.tsfm:
            for i, img in enumerate(imgs):
                imgs[i] = self.tsfm(img)
        name = self.img[idx].parts[-1].split('.')[0]
        imgs = [[img,name] for img in imgs]
        return imgs
    
    def __len__(self):
        return len(self.img)

In [69]:
def get_test_data():
    def tsfm(img):
        img = np.asarray(img).astype(np.float32)/255
        img = np.expand_dims(img, 0)
        return img
    def img_hflip(img):
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    test_ds = TestDataset(tsfm=tsfm,tta_tsfms=[img_hflip])
    test_sampler = BatchSampler(test_ds,bs)
    test_dl = DataLoader(test_sampler)
    return test_ds,test_sampler,test_dl

In [70]:
%%time
submit = pd.read_csv('inputs/sample_submission.csv')
test_ds,test_sampler,test_dl = get_test_data()
thres = best_thres
reverse_tta = [None,tta_hflip]
with torch.no_grad():
    model.eval()
    for batch in test_dl:
        predicts = []
        assert len(batch) == len(reverse_tta)
        for [img, name],f in zip(batch,reverse_tta):
            predict = torch.sigmoid(model(T(img)))
            if f:
                predict = f(predict)
            predicts.append(predict)
        predict = torch.stack(predicts).mean(dim=0)
        predict = predict > thres
        predict = predict.cpu().numpy()
        for n,m in zip(name,predict):
            m = rl_enc(m)
            submit.loc[submit['id']==n,'rle_mask'] = m

Wall time: 4min 24s


In [71]:
submit.to_csv('submit.csv',index=False)