In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir('..')

In [3]:
from argparse import ArgumentParser
from PIL import Image
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_tensor
from torch.utils.data import DataLoader, random_split


import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

In [4]:
from data import NetherlandsF3DS, NetherlandsF3Transform, collate_fn
from train import LitClassification

In [5]:
root_ds = 'dataset/tiles_inlines/tiles_inlines'

df = pd.read_csv('tiles_inline.csv')
total_sample = len(df.index)
print('Numper of rows:', total_sample)

train_size = int(0.8*total_sample)
train_index, val_index = random_split(range(total_sample), [train_size, total_sample - train_size])

train_df = df.loc[list(train_index)].reset_index()
val_df = df.loc[list(val_index)].reset_index()

train_ds = NetherlandsF3DS(root_ds, train_df, transforms=NetherlandsF3Transform(augment=True, use_depth=False))
val_ds = NetherlandsF3DS(root_ds, val_df, transforms=NetherlandsF3Transform(augment=False, use_depth=False))

print(len(train_ds), len(val_ds))

Numper of rows: 94720
75776 18944


In [6]:
img, mask = train_ds[0]
print(img.size(), mask.size())

torch.Size([1, 25, 64]) torch.Size([1])


In [7]:
train_dl = DataLoader(train_ds, batch_size=32, num_workers=4, collate_fn=collate_fn, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=32, num_workers=4, collate_fn=collate_fn)

In [8]:
tt_logger = TensorBoardLogger(save_dir='logs',
#                              version='19',
                             name='classification')
checkpoint_dir = os.path.join(tt_logger.log_dir, 'ckpt')
checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_dir,
                                      save_top_k=1,
                                      verbose=False,
                                      monitor='metrics_mAP',
                                      mode='max',
                                      save_last=False,)

In [9]:
def parse_args(args=None):
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = LitClassification.add_model_specific_args(parser)
    parser.add_argument('--seed', type=int, default=42)
    return parser.parse_args(args)

def main(args):
    pl.seed_everything(args.seed)
    model = LitClassification(**vars(args))    
    trainer = pl.Trainer.from_argparse_args(args, logger= tt_logger, checkpoint_callback=False)
    return model, trainer

In [10]:
# add PROGRAM level args
program_args = """
      --seed 42
      """.split()
model_args = """
    --name_model resunet
    --num_down_stage 4
    --num_filter1 16
    --bilinear n
    --lr 5e-4
    --momentum 0.9
    --weight_decay 5e-4
    """.split()
 
# add all the available trainer options to argparse
# ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
#     --resume_from_checkpoint original_sgd_logs/bs_32/last.ckpt
trainer_args = """
    --max_epoch 10
    --gpus 1
    --progress_bar_refresh_rate 20
    --num_sanity_val_steps 0
""".split()
args = parse_args(program_args + model_args + trainer_args)

In [11]:
model, trainer = main(args)

Global seed set to 42
GPU available: True, used: True
TPU available: None, using: 0 TPU cores


In [12]:
trainer.fit(model, train_dl, val_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type             | Params
------------------------------------------------
0 | encoder    | ResEncoder       | 22.0 M
1 | classifier | Sequential       | 5.1 K 
2 | criterion  | CrossEntropyLoss | 0     
------------------------------------------------
22.0 M    Trainable params
0         Non-trainable params
22.0 M    Total params
87.919    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]



1

In [13]:
import torch
torch.save(model.encoder.state_dict(), 'pretrain34.pth')