In [9]:
import torch
import src
import wandb

from src.utils.dataset import load_data

In [10]:
ENCODER = 'resnet34'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multicalss segmentation
DEVICE = ('cuda' if torch.cuda.is_available() else 'cpu') 

In [11]:
model = src.FPN(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS,
    in_channels=1,
    classes=1,
    activation=ACTIVATION,
)

In [12]:
loss = src.utils.base.SumOfLosses(
    src.utils.losses.DiceLoss(),
    src.utils.losses.BCELoss()
)

metrics = [
    src.utils.metrics.IoU(threshold=0.5)
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

In [13]:
train_epoch = src.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = src.utils.train.ValidEpoch(
    model,
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [14]:
train_loader, valid_loader = load_data(test_size=0.3, batch_size=1, img_size=256, dir='./data/', artificial_increase=20)

In [15]:
wandb.login()

wandb.init(project="Trus_images")




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

In [16]:
max_score = 5
trash = 0
for i in range(0, 20):
    if trash > 6:
        break
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    wandb.log({'train/train_IoU': train_logs['iou_score'], 'train/train_loss': train_logs['dice_loss + bce_loss']})
    wandb.log({'valid/valid_IoU': valid_logs['iou_score'], 'valid/valid_loss': valid_logs['dice_loss + bce_loss']})
    
    # do something (save model, change lr, etc.)
    if max_score > valid_logs['dice_loss + bce_loss']:
        max_score = valid_logs['dice_loss + bce_loss']
        torch.save(model, './checkpoint/best_model.pth')
        trash = 0
        print('Model saved!')
    else:
        trash +=1
        
    if i == 25:
        optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']*0.56
        print('Decrease decoder learning rate to 1e-5!')
        
wandb.finish()


Epoch: 0
train: 100%|██████████| 266/266 [00:22<00:00, 11.81it/s, dice_loss + bce_loss - 0.9064, iou_score - 0.1013] 
valid: 100%|██████████| 114/114 [00:09<00:00, 11.52it/s, dice_loss + bce_loss - 0.8663, iou_score - 0.04792]
Model saved!

Epoch: 1
train: 100%|██████████| 266/266 [00:19<00:00, 13.60it/s, dice_loss + bce_loss - 0.756, iou_score - 0.209]  
valid: 100%|██████████| 114/114 [00:10<00:00, 10.76it/s, dice_loss + bce_loss - 0.6866, iou_score - 0.2549]
Model saved!

Epoch: 2
train: 100%|██████████| 266/266 [00:17<00:00, 14.79it/s, dice_loss + bce_loss - 0.6291, iou_score - 0.3063]
valid: 100%|██████████| 114/114 [00:10<00:00, 10.96it/s, dice_loss + bce_loss - 0.4345, iou_score - 0.4802]
Model saved!

Epoch: 3
train: 100%|██████████| 266/266 [00:18<00:00, 14.53it/s, dice_loss + bce_loss - 0.5251, iou_score - 0.391] 
valid:   0%|          | 0/114 [00:00<?, ?it/s]