In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# !pip install imgaug

In [3]:
from argparse import ArgumentParser
from data import TGSTransform, TGSSaltDataset, collate_mask_fn
import matplotlib.pyplot as plt
from train import LitUNet
import torch
from torch.utils.data import DataLoader, random_split
from torch import nn
import torch.nn.functional as F
from utils import calculate_mAP
import pandas as pd
import os

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

# Prepair Data

In [4]:
root_ds = 'dataset'
img_size_ori = 101

In [5]:
# df = pd.read_csv(os.path.join(root_ds, 'train.csv'), usecols=[0])
# total_sample = len(df.index)
# print('Numper of rows:', total_sample)

# train_size = int(0.8*total_sample)
# train_index, val_index = random_split(range(total_sample), [train_size, total_sample - train_size])

train_df = pd.read_csv('train.csv')
val_df = pd.read_csv('val.csv')

train_ds = TGSSaltDataset(root_ds, train_df, transforms=TGSTransform(augment=True, use_depth=False))
val_ds = TGSSaltDataset(root_ds, val_df, transforms=TGSTransform(augment=False, use_depth=False))

print(len(train_ds), len(val_ds))

3200 800


In [6]:
img, mask = train_ds[0]
print(img.size(), mask.size())

torch.Size([1, 128, 128]) torch.Size([1, 128, 128])


In [7]:
train_dl = DataLoader(train_ds, batch_size=16, num_workers=4, collate_fn=collate_mask_fn, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=16, num_workers=4, collate_fn=collate_mask_fn)

# Training

In [8]:
tt_logger = TensorBoardLogger(save_dir='logs',
#                              version='19',
                             name='unet')
checkpoint_dir = os.path.join(tt_logger.log_dir, 'ckpt')
checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_dir,
                                      save_top_k=1,
                                      verbose=False,
                                      monitor='metrics_mAP',
                                      mode='max',
                                      save_last=False,)

In [9]:
def parse_args(args=None):
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = LitUNet.add_model_specific_args(parser)
    parser.add_argument('--seed', type=int, default=42)
    return parser.parse_args(args)

def main(args):
    pl.seed_everything(args.seed)
    model = LitUNet(**vars(args))    
    trainer = pl.Trainer.from_argparse_args(args, logger= tt_logger, checkpoint_callback=False)
    return model, trainer

In [10]:
# add PROGRAM level args
program_args = """
      --seed 42
      """.split()
model_args = """
    --name_model resunet
    --num_down_stage 4
    --num_filter1 16
    --bilinear n
    --lr 3e-4
    --momentum 0.9
    --weight_decay 5e-4
    """.split()
 
# add all the available trainer options to argparse
# ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
#     --resume_from_checkpoint original_sgd_logs/bs_32/last.ckpt
trainer_args = """
    --max_epoch 50
    --gpus 1
    --progress_bar_refresh_rate 20
    --num_sanity_val_steps 0
""".split()
args = parse_args(program_args + model_args + trainer_args)

In [11]:
model, trainer = main(args)

Global seed set to 42
GPU available: True, used: True
TPU available: None, using: 0 TPU cores


In [12]:
model.model

Res34Unet(
  (conv1): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (encoder): ResEncoder(
    (conv1): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (encode2): Sequential(
      (0): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=

In [13]:
model.model.encoder.load_state_dict(torch.load('pretrain34.pth'))

<All keys matched successfully>

In [14]:
from torch.backends import cudnn
cudnn.benchmark = True

In [None]:
trainer.max_epochs = 100
trainer.fit(model, train_dl, val_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type        | Params
------------------------------------------
0 | model     | Res34Unet   | 35.2 M
1 | criterion | DiceBCELoss | 0     
------------------------------------------
35.2 M    Trainable params
0         Non-trainable params
35.2 M    Total params
140.811   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [116]:
input

tensor([[[[2, 9, 7, 1, 0, 8, 6, 1, 6, 6, 3, 3, 4, 5, 7, 2, 3],
          [7, 4, 8, 1, 6, 7, 0, 1, 3, 5, 7, 7, 5, 3, 0, 2, 0],
          [2, 4, 4, 9, 7, 6, 1, 1, 5, 2, 3, 4, 6, 7, 7, 2, 7],
          [5, 2, 0, 1, 8, 1, 1, 7, 5, 2, 9, 4, 1, 2, 5, 7, 1],
          [3, 9, 7, 4, 7, 7, 8, 9, 6, 9, 6, 8, 2, 2, 3, 4, 5],
          [6, 5, 6, 3, 0, 8, 0, 0, 4, 4, 6, 9, 8, 3, 3, 7, 5],
          [1, 4, 3, 6, 2, 5, 5, 3, 5, 7, 6, 0, 6, 6, 6, 3, 9],
          [5, 6, 1, 5, 5, 6, 4, 4, 1, 1, 9, 1, 8, 4, 1, 3, 7],
          [1, 7, 4, 1, 3, 8, 7, 4, 9, 6, 8, 0, 2, 6, 7, 3, 7],
          [6, 3, 6, 4, 4, 2, 0, 0, 4, 5, 1, 4, 1, 3, 7, 9, 8],
          [3, 4, 2, 5, 9, 0, 1, 7, 2, 8, 0, 2, 8, 0, 7, 0, 8],
          [5, 1, 5, 0, 3, 2, 7, 4, 4, 4, 5, 5, 3, 4, 3, 1, 9],
          [3, 4, 1, 3, 4, 9, 2, 9, 1, 7, 5, 2, 7, 5, 2, 5, 4],
          [1, 3, 9, 3, 6, 3, 6, 4, 7, 9, 1, 6, 6, 8, 1, 1, 1],
          [8, 8, 9, 4, 1, 8, 3, 5, 8, 0, 2, 6, 2, 3, 8, 6, 8],
          [0, 3, 8, 7, 5, 3, 5, 5, 1, 3, 6, 1, 4, 2, 3,

In [118]:
input.unfold(2, kh, dh).unfold(3, kw, dw).contiguous().view(1, 2, 8, 8, 9)[:, :, :, :, 4]

tensor([[[[4, 1, 7, 1, 5, 7, 3, 2],
          [2, 1, 1, 7, 2, 4, 2, 7],
          [5, 3, 8, 0, 4, 9, 3, 7],
          [6, 5, 6, 4, 1, 1, 4, 3],
          [3, 4, 2, 0, 5, 4, 3, 9],
          [1, 0, 2, 4, 4, 5, 4, 1],
          [3, 3, 3, 4, 9, 6, 8, 1],
          [3, 7, 3, 5, 3, 1, 2, 0]],

         [[3, 4, 0, 4, 1, 4, 8, 3],
          [6, 6, 1, 9, 1, 5, 6, 5],
          [5, 7, 3, 0, 1, 0, 7, 3],
          [1, 8, 3, 5, 0, 0, 5, 9],
          [7, 7, 8, 6, 5, 8, 1, 4],
          [9, 4, 0, 5, 1, 6, 8, 0],
          [9, 2, 4, 3, 9, 1, 7, 6],
          [8, 7, 3, 0, 0, 0, 3, 4]]]])

In [85]:
input.unfold(2, kh, dh)[0, 0].unfold(1, kw, dw).contiguous().view(2, -1, 8 ,4)[:, :, :, 0]

tensor([[[6, 0, 1, 9, 5, 2, 3, 4],
         [0, 6, 6, 4, 3, 7, 0, 2],
         [5, 2, 8, 1, 1, 1, 2, 0],
         [3, 5, 4, 6, 3, 0, 3, 8]],

        [[1, 1, 1, 0, 3, 8, 4, 7],
         [2, 8, 8, 5, 3, 2, 8, 6],
         [0, 7, 5, 4, 9, 4, 7, 9],
         [0, 9, 7, 4, 7, 5, 0, 5]]])

In [102]:
torch.as_strided(input, (3, 3), (1, 1,))

tensor([[0, 9, 7],
        [9, 7, 4],
        [7, 4, 5]])

In [134]:
from modules.pool import AttentionPooling

In [135]:
att = AttentionPooling(2)

In [137]:
w = att(input)

In [142]:
w[:, :, :, :, 4].size()

torch.Size([1, 2, 9, 9])

In [144]:
w[:, :, 0, 0].size()

torch.Size([1, 2, 9])