In [1]:
import torch
import src

from src.utils.CustomDataset import load_data

In [3]:
ENCODER = 'resnet34'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multicalss segmentation
DEVICE = ('cuda' if torch.cuda.is_available() else 'cpu') 

In [4]:
model = src.FPN(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS,
    in_channels=1,
    classes=1,
    activation=ACTIVATION,
)

In [5]:
loss = src.utils.base.SumOfLosses(
    src.utils.losses.DiceLoss(),
    src.utils.losses.BCELoss()
)

metrics = [
    src.utils.metrics.IoU(threshold=0.5)
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

In [6]:
train_epoch = src.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = src.utils.train.ValidEpoch(
    model,
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [7]:
train_loader, valid_loader = load_data(test_size=0.3, batch_size=1, img_size=256, dir='./data/', artificial_increase=20)

In [10]:
max_score = 5
trash = 0
for i in range(0, 10):
    if trash > 6:
        break
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    
    # do something (save model, change lr, etc.)
    if max_score > valid_logs['dice_loss + bce_loss']:
        max_score = valid_logs['dice_loss + bce_loss']
        torch.save(model, './checkpoint/best_model.pth')
        trash = 0
        print('Model saved!')
    else:
        trash +=1
        
    if i == 25:
        optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']*0.56
        print('Decrease decoder learning rate to 1e-5!')


Epoch: 0
train: 100%|██████████| 266/266 [00:19<00:00, 13.58it/s, dice_loss + bce_loss - 0.9187, iou_score - 0.1119] 
valid: 100%|██████████| 114/114 [00:09<00:00, 11.53it/s, dice_loss + bce_loss - 0.7722, iou_score - 0.1976]
Model saved!

Epoch: 1
train: 100%|██████████| 266/266 [00:16<00:00, 15.69it/s, dice_loss + bce_loss - 0.7365, iou_score - 0.221]  
valid: 100%|██████████| 114/114 [00:09<00:00, 11.51it/s, dice_loss + bce_loss - 0.5375, iou_score - 0.4106]
Model saved!

Epoch: 2
train: 100%|██████████| 266/266 [00:16<00:00, 15.86it/s, dice_loss + bce_loss - 0.6293, iou_score - 0.3066]
valid: 100%|██████████| 114/114 [00:09<00:00, 11.64it/s, dice_loss + bce_loss - 0.4963, iou_score - 0.403] 
Model saved!

Epoch: 3
train: 100%|██████████| 266/266 [00:16<00:00, 15.84it/s, dice_loss + bce_loss - 0.5207, iou_score - 0.3956]
valid: 100%|██████████| 114/114 [00:09<00:00, 11.56it/s, dice_loss + bce_loss - 0.3146, iou_score - 0.5905]
Model saved!

Epoch: 4
train: 100%|██████████| 266/266 