# import

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
from lightai.learner import *
from lightai.dataloader import *
import torchvision.models as models
from functional import *
from model import *
from transform import *
from dataset import *
from metric import *

# data

In [None]:
depth_csv = pd.read_csv('inputs/depths.csv')

In [14]:
tsfm = compose([Hflip(0.5),CropRandom(0.5,0.5),to_unit_float])

In [5]:
bs = 256

In [None]:
trn_ds = ImageDataset('inputs/train',tsfm=tsfm)
val_ds = ImageDataset('inputs/val',tsfm=to_unit_float)
trn_sampler = BatchSampler(trn_ds, bs)
val_sampler = BatchSampler(val_ds, bs)
trn_dl = DataLoader(trn_sampler,batch_tsfm=batch_tsfm)
val_dl = DataLoader(val_sampler,batch_tsfm=batch_tsfm)

In [15]:
sp_trn_ds = ImageDataset('inputs/sample/train',tsfm=tsfm)
sp_val_ds = ImageDataset('inputs/sample/val',tsfm=to_unit_float)
sp_trn_sampler = BatchSampler(sp_trn_ds, bs)
sp_val_sampler = BatchSampler(sp_val_ds, bs)
sp_trn_dl = DataLoader(sp_trn_sampler,batch_tsfm=batch_tsfm)
sp_val_dl = DataLoader(sp_val_sampler,batch_tsfm=batch_tsfm)

In [None]:
no_tsfm_ds = ImageDataset('inputs/sample/train',tsfm=None)

In [None]:
plt.imshow(no_tsfm_ds[0][0])

In [None]:
plt.imshow(sp_trn_ds[0][0])

In [None]:
sp_trn_ds[0][0].shape

# train

In [32]:
resnet18 = models.resnet18(True)
model = Dynamic(model_cut(resnet18,-2), sp_trn_ds).cuda()
learner = Learner(sp_trn_dl, sp_val_dl, model, nn.BCEWithLogitsLoss(), metric=Score, small_better=False)

In [33]:
learner.freeze_to(-1)

In [34]:
wd = 1e-6
lr = 0.3

In [None]:
learner.lr_find(n_epochs=5,wds=wd)

In [35]:
learner.fit(n_epochs=32, lrs=lr,wds=wd,clr_params=[20,2,0.01,0.1])

HBox(children=(IntProgress(value=0, description='Epoch', max=16), HTML(value='')))

   epoch    trn_loss   val_loss     score                                                                                    
     1      0.651240   0.628934   0.375000  
     2      0.621132   0.574791   0.375000                                                                                   
     3      0.598241   0.628720   0.375000                                                                                   
     4      0.596789   0.607573   0.375000                                                                                   
     5      0.590795   0.574312   0.375000                                                                                   
     6      0.586429   0.596422   0.375000                                                                                   
     7      0.582896   0.724770   0.000000                                                                                   
     8      0.574214   0.583840   0.293750                               

In [36]:
learner.unfreeze()

In [None]:
learner.lr_find(n_epochs=5,start_lrs=[1e-7,1e-6],wds=wd)

In [37]:
lrs = np.array([lr/10,lr])

In [38]:
learner.fit(n_epochs=32, lrs=lrs/10, wds=wd, clr_params=[20,2,0.01,0.1])

HBox(children=(IntProgress(value=0, description='Epoch', max=32), HTML(value='')))

   epoch    trn_loss   val_loss     score                                                                                    
     1      0.432122   0.490752   0.375000  
     2      0.428570   0.586099   0.375000                                                                                   
     3      0.417920   0.504456   0.468750                                                                                   
     4      0.408395   0.471157   0.481250                                                                                   
     5      0.403731   0.446566   0.484375                                                                                   
     6      0.404022   0.473287   0.479375                                                                                   
     7      0.400638   0.515204   0.463750                                                                                   
     8      0.404920   0.481827   0.467500                               

In [None]:
learner.recorder.plot_loss()

# submit

In [None]:
resnet18 = models.resnet18(True)
model = Dynamic(model_cut(resnet18,-2), trn_ds).cuda()
model.load_state_dict(torch.load('model/0.7885647177696228'))

In [None]:
%%time
thresholds,scores = thres_score(model, val_dl)

In [None]:
max(scores)

In [None]:
best_thres = np.argmax(scores)
best_thres = best_thres/100

In [None]:
best_thres

In [None]:
class TestDataset():
    def __init__(self,tsfm):
        img_path = Path('inputs/test/images')
        self.img = list(img_path.iterdir())
        self.tsfm = tsfm
        
    def __getitem__(self, idx):
        img = Image.open(self.img[idx])
        if self.tsfm:
            img = self.tsfm(img)
        name = self.img[idx].parts[-1].split('.')[0]
        return img, name
    
    def __len__(self):
        return len(self.img)

In [None]:
def get_test_data():
    def batch_tsfm(x):
        img, name = x
        img = np.transpose(img,axes=[0,3,1,2])
        img -= img.mean(axis=(0,2,3),keepdims=True)
        return img, name
    def tsfm(img):
        return np.asarray(img).astype(np.float32)/255
    test_ds = TestDataset(tsfm=tsfm)
    test_sampler = BatchSampler(test_ds,512)
    test_dl = DataLoader(test_sampler,batch_tsfm)
    return test_ds,test_sampler,test_dl

In [None]:
%%time
submit = pd.read_csv('inputs/sample_submission.csv')
test_ds,test_sampler,test_dl = get_test_data()
thres = best_thres
with torch.no_grad():
    model.eval()
    for img, name in test_dl:
        mask = torch.sigmoid(model(T(img)))
        mask = mask > thres
        mask = mask.cpu().numpy()
        for n,m in zip(name,mask):
            m = rl_enc(m)
            submit.loc[submit['id']==n,'rle_mask'] = m

In [None]:
submit.to_csv('submit.csv',index=False)