In [1]:
import os
os.chdir('../../code')

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
from argparse import ArgumentParser
import numpy as np
import pandas as pd
import cv2

from data.datamodule import SaltDM
from utils.metrics import cal_mAP, cal_mIoU
from model import get_model
from model.lovasz_losses import lovasz_hinge, lovasz_hinge2
from model.layer import DiceBCELoss

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms.functional import resize

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

from utils import str2bool
import timm

In [4]:
class Lit(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.salt = timm.create_model('tf_efficientnet_b4_ns', pretrained=True, num_classes=1)
        self.salt.conv_stem = timm.models.layers.create_conv2d(1, self.salt.conv_stem.out_channels, 3, stride=2, padding='')
        self.criterion = F.binary_cross_entropy_with_logits

    def forward(self, x):
        return self.salt(x)
    
    def _step_with_loss(self, batch, batch_idx):
        inputs, target = batch
        logit = self(inputs)
        # bs = masks.size(0)
        # loss = self.criterion(logit.view(bs, -1, 1), masks.view(bs, -1, 1))
        loss = self.criterion(logit.squeeze(), target)
        # for i, mo in enumerate(mid_outs):
        #     loss += (0.5/len(mid_outs)) * self.criterion(mo.squeeze(1), resize(masks, mo.size()[-2:]).squeeze(1))
        return loss, logit

    def training_step(self, batch, batch_idx):
        loss, logit = self._step_with_loss(batch, batch_idx)
        self.log('Loss/train', loss.item(), on_step=False, on_epoch=True, prog_bar=False, logger=True)
        target = batch[1] > 0
        precision = ((logit > 0) == target).sum() / target.size(0) 
        self.log('Metrics_mAP/train', precision.item(), on_step=False, on_epoch=True, prog_bar=False, logger=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, logit = self._step_with_loss(batch, batch_idx)
        target = batch[1] > 0
        precision = ((logit > 0) == target).type(torch.float).mean()
        # precision = precision.mean()
        self.log('Loss/val', loss.item(), on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('Metrics_mAP/val', precision.item(), on_step=False, on_epoch=True, prog_bar=True, logger=True)

    def test_step(self, batch, batch_idx):
        images = batch
        logit_null = self(images)
        preds_null = torch.sigmoid(logit_null).squeeze(1).detach()
        logit_flip = self(images.flip(-1))
        preds_flip = torch.sigmoid(logit_flip).squeeze(1).flip(-1).detach()
        return (preds_flip + preds_null) / 2
    
    def test_epoch_end(self, outputs):
        preds = torch.cat(outputs, dim=0).cpu().numpy()

        preds_101 = np.zeros((preds.shape[0], 101, 101), dtype=np.float32)
        for idx in range(preds.shape[0]):
            preds_101[idx] = cv2.resize(preds[idx], dsize=(101, 101))
        np.save(self.hparams.save_pred, preds_101)

    def configure_optimizers(self):
        # Setup optimizer
        if self.hparams.optimizer == 'sgd':
            optimizer = torch.optim.SGD(self.salt.parameters(), 
                                        lr=self.hparams.max_lr, 
                                        momentum=self.hparams.momentum,
                                        weight_decay=self.hparams.weight_decay)
        elif self.hparams.optimizer == 'adamw':
            optimizer = torch.optim.AdamW(self.salt.parameters(), 
                                        lr=self.hparams.max_lr, 
                                        # momentum=self.hparams.momentum,
                                        weight_decay=self.hparams.weight_decay)
        else:
            raise ValueError('wrong optimizer option')
        
            # torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
            #                                                         T_max=20,
            #                                                         eta_min=self.hparams.min_lr,
            #                                                         verbose=True),
        lr_scheduler = {
            'scheduler': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 20, T_mult=1, eta_min=self.hparams.min_lr),
            'interval': 'epoch',
            'frequency': 1,
        }
        # return optimizer
        return [optimizer], [lr_scheduler]

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False) 
        parser.add_argument('--model', default='res34v5', type=str, help='Model version')
        parser.add_argument('--optimizer', default='sgd', type=str, help='Optimizer')
        parser.add_argument('--snapshot_size', default=50, type=int, help='Number epochs per snapshot')
        parser.add_argument('--max_lr', default=0.01, type=float, help='max learning rate')
        parser.add_argument('--min_lr', default=0.001, type=float, help='min learning rate')
        parser.add_argument('--momentum', default=0.9, type=float, help='momentum for SGD')
        parser.add_argument('--weight_decay', default=1e-4, type=float, help='Weight decay for SGD')
        parser.add_argument('--save_pred', default='../predictions/', type=str, help='prediction save space')
        return parser

In [5]:
def parse_args(args=None):
    parser = ArgumentParser()
    parser.add_argument('--seed', type=int, default=42)
    parser = pl.Trainer.add_argparse_args(parser)
    parser = Lit.add_model_specific_args(parser)
    return parser.parse_args(args)

In [6]:
args = """
--seed 42
--model effunet_b4
--optimizer adamw
--snapshot_size 50
--max_lr 1e-3
--min_lr 1e-7
--momentum 0.9
--weight_decay 1e-4
--max_epoch 100
--gpus 1
--progress_bar_refresh_rate 20
--num_sanity_val_steps 2
""".split()
args = parse_args(args)

In [7]:
from data.dataset import TGSSaltDatasetClassify
from data.transforms import TGSTransform
from torch.utils.data import DataLoader

In [8]:
trainer = pl.Trainer(gpus=1)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [9]:
model = Lit(**vars(args))

In [10]:
df = pd.read_csv(os.path.join('../dataset', 'folds.csv'), index_col='id')
tdf = df[df['fold'] != 0].reset_index()
vdf = df[df['fold'] == 0].reset_index()

In [11]:
tds = TGSSaltDatasetClassify('../dataset', tdf, transforms=TGSTransform(1))
vds = TGSSaltDatasetClassify('../dataset', vdf, transforms=TGSTransform(0))

In [12]:
tdl = DataLoader(tds,
              shuffle=True,
              batch_size=16,
              num_workers=4,
              pin_memory=True)
vdl = DataLoader(vds,
              shuffle=False,
              batch_size=16,
              num_workers=4,
              pin_memory=True)

In [13]:
trainer.fit(model, tdl, vdl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type         | Params
--------------------------------------
0 | salt | EfficientNet | 17.5 M
--------------------------------------
17.5 M    Trainable params
0         Non-trainable params
17.5 M    Total params
70.198    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]



1

In [None]:
model(next(iter(dl))[0])