In [None]:
import sys
sys.path.insert(0,'/content/drive/MyDrive/BDFormer')

In [1]:
import torch
from torch.utils.data import DataLoader
from model.BDFormer import BDFormer
from datasets import Skin_Dataset
from engine import *
import os
import sys

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from utils import *
from config_setting import setting_config_multitask as config

import warnings
warnings.filterwarnings("ignore")

from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


multi-task dataset -> isic17


In [2]:
sys.path.append(config.work_dir + '/')

log_dir = os.path.join(config.work_dir, 'log')
checkpoint_dir = os.path.join(config.work_dir, 'checkpoints')
resume_model = os.path.join(checkpoint_dir, 'latest.pth')
outputs = os.path.join(config.work_dir, 'outputs')

if not os.path.exists(checkpoint_dir): 
    os.makedirs(checkpoint_dir)

if not os.path.exists(outputs): 
    os.makedirs(outputs)

if not os.path.exists(os.path.join(outputs, 'pred_masks')):
    os.makedirs(os.path.join(outputs, 'pred_masks'))

if not os.path.exists(os.path.join(outputs, 'pred_contours')):
    os.makedirs(os.path.join(outputs, 'pred_contours'))

global logger
logger = get_logger('train', log_dir)
log_config_info(config, logger)

print('Logger and output files are created.')

Logger and output files are created.


In [3]:
if config.device.type == 'cuda':
    torch.cuda.empty_cache()
elif config.device.type == 'mps':
    torch.mps.empty_cache()

set_seed(config.seed, config.device)

print(f'GPU is initalized. Device type: {config.device.type}')

GPU is initalized. Device type: mps


In [4]:
train_dataset = Skin_Dataset(config, split="train", subset_frac=0.01)
train_class_weights = calculate_class_weights(train_dataset.labels, config.device)
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    pin_memory=True,
    num_workers=config.num_workers)

val_dataset = Skin_Dataset(config, split="val", subset_frac=0.1)
val_class_weights = calculate_class_weights(val_dataset.labels, config.device)
val_loader = DataLoader(
    val_dataset,
    batch_size=2,
    shuffle=False,
    pin_memory=True,
    num_workers=config.num_workers,
    drop_last=True)

test_dataset = Skin_Dataset(config, split="test", subset_frac=0.01)
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
    num_workers=config.num_workers,
    drop_last=True)

print('Datasets are loaded.')

Subsetting train set to 20 samples.
Subsetting val set to 15 samples.
Subsetting test set to 6 samples.
Datasets are loaded.


In [5]:
model = BDFormer(img_size=256, in_channels=3, num_classes=config.num_classes, window_size=8).to(config.device)

criterion = config.criterion
optimizer = get_optimizer(config, model)
scheduler = get_scheduler(config, optimizer)

min_loss = 999
start_epoch = 1
min_epoch = 1

print('Model and training parameters are configured.')

SwinTransformerSys expand initial----depths:[2, 2, 2, 2];depths_decoder:[1, 2, 2, 2];drop_path_rate:0.2;num_classes:1
Model and training parameters are configured.


In [None]:
if os.path.exists(resume_model):
    print('#----------Resuming model and learning parameters----------#')
    checkpoint = torch.load(resume_model, map_location=torch.device('cpu'), weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    saved_epoch = checkpoint['epoch']
    start_epoch += saved_epoch
    min_loss, min_epoch, loss = checkpoint['min_loss'], checkpoint['min_epoch'], checkpoint['loss']

    log_info = f'resuming model from {resume_model}. resume_epoch: {saved_epoch}, min_loss: {min_loss:.4f}, min_epoch: {min_epoch}, loss: {loss:.4f}'
    logger.info(log_info)


print('#----------Training----------#')
print('iter_num=', len(train_loader))
for epoch in tqdm(range(start_epoch, config.epochs + 1), ncols=70):
    train_one_epoch_multi(
        train_loader,
        train_class_weights,
        model,
        optimizer,
        scheduler,
        epoch,
        logger,
        config)

    print('#----------Validation----------#')
    loss = val_one_epoch_multi(val_loader, val_class_weights, model, epoch, logger, config)

    if loss < min_loss:
        torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'best.pth'))
        min_loss = loss
        min_epoch = epoch

    torch.save(
        {
            'epoch': epoch,
            'min_loss': min_loss,
            'min_epoch': min_epoch,
            'loss': loss,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }, os.path.join(checkpoint_dir, 'latest.pth'))

if os.path.exists(os.path.join(checkpoint_dir, 'best.pth')):
    print('#----------Testing----------#')
    best_weight = torch.load(config.work_dir + '/checkpoints/best.pth', map_location=torch.device('cpu'))
    model.load_state_dict(best_weight)
    loss = test_one_epoch_multi(test_loader, model, criterion, logger, config)
    os.rename(
        os.path.join(checkpoint_dir, 'best.pth'),
        os.path.join(checkpoint_dir, f'best-epoch{min_epoch}-loss{min_loss:.4f}.pth')
    )

#----------Training----------#
iter_num= 10


  0%|                                           | 0/1 [00:00<?, ?it/s]

multi-task dataset ->multi-task dataset -> multi-task dataset -> multi-task dataset -> isic17isic17
isic17
 isic17

train: epoch 1, iter:0, loss: 2.0458, lr: 0.00033


In [None]:
# Freezing layers
# for param in model.parameters():
#     param.requires_grad = False

# for name, layer in model.multi_task_MaxViT.named_children():
#     if name in ['classifier', 'seg_out_conv']:
#         for param in layer.parameters():
#             param.requires_grad = True

# resume_model = os.path.join(checkpoint_dir, 'latest_finetune.pth')

# resuming_finetune = os.path.exists(resume_model)

# if not resuming_finetune:
#     resume_model = os.path.join(checkpoint_dir, 'latest.pth')

# checkpoint = torch.load(resume_model, map_location=torch.device('cpu'), weights_only=False)
# model.module.load_state_dict(checkpoint['model_state_dict'], strict=False)
# saved_epoch = checkpoint['epoch']
# start_epoch += saved_epoch
# loss = checkpoint['loss']

# if resuming_finetune:
#     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#     scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
#     min_loss, min_epoch = checkpoint['min_loss'], checkpoint['min_epoch']
#     log_info = f'Resuming model from {resume_model}. resume_epoch: {saved_epoch}, min_loss: {min_loss:.4f}, min_epoch: {min_epoch}, loss: {loss:.4f}'
#     logger.info(log_info)
# else:
#     min_loss = 0
#     log_info = f'Loading baseline model from {resume_model}. resume_epoch: {saved_epoch}, loss: {loss:.4f}'
#     logger.info(log_info)

# print('#----------Training----------#')
# print('iter_num=', len(train_loader))
# for epoch in tqdm(range(start_epoch, config.epochs + 1), ncols=70):

#     torch.cuda.empty_cache()

#     train_one_epoch_multi(
#         train_loader,
#         model,
#         optimizer,
#         scheduler,
#         epoch,
#         logger,
#         config,
#         train_class_weights,
#         scaler=scaler)

#     print('#----------Validation----------#')
#     loss = val_one_epoch_multi(
#             val_loader,
#             model,
#             epoch,
#             logger,
#             config,
#             val_class_weights)

#     # if loss < min_loss and epoch > 35:
#     if loss < min_loss:
#         torch.save(model.module.state_dict(), os.path.join(checkpoint_dir, 'best_finetune.pth'))
#         min_loss = loss
#         min_epoch = epoch

#     torch.save(
#         {
#             'epoch': epoch,
#             'min_loss': min_loss,
#             'min_epoch': min_epoch,
#             'loss': loss,
#             'model_state_dict': model.module.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'scheduler_state_dict': scheduler.state_dict(),
#         }, os.path.join(checkpoint_dir, 'latest_finetune.pth'))

# if os.path.exists(os.path.join(checkpoint_dir, 'best_finetune.pth')):
#     print('#----------Testing----------#')
#     best_weight = torch.load(config.work_dir + '/checkpoints/best_finetune.pth', map_location=torch.device('cpu'))
#     model.module.load_state_dict(best_weight)
#     loss = test_one_epoch_multi(
#             test_loader,
#             model,
#             criterion,
#             logger,
#             config)
#     os.rename(
#         os.path.join(checkpoint_dir, 'best_finetune.pth'),
#         os.path.join(checkpoint_dir, f'best_finetune-epoch{min_epoch}-loss{min_loss:.4f}.pth')
#     )