In [1]:
import os
import sys
# nb_dir = os.path.split(os.getcwd())
if os.getcwd() not in sys.path:
    sys.path.append(os.getcwd())

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.barlow_twins import SwinBarlowTwins
from model.build_swin_vit import build_model
from config import get_config
import os
import sys

In [3]:
import random
import numpy as np
import torch.backends.cudnn as cudnn

config = get_config()

os.makedirs(config.OUTPUT, exist_ok=True)

seed = config.SEED
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True

path = os.path.join(config.OUTPUT, f"config_{config.NAME}.json")
with open(path, "w") as f:
    f.write(config.dump())

print(f"Full config saved to {path}")

Full config saved to output\config_swin_transformer_barlow_twins_cxr.json


In [None]:
from sched import scheduler
import wandb
import torch
from data.build_dataset import get_nih, get_chexpert
import os
from data.dataset import NIH_Dataset, CheX_Dataset
from model.build_swin_vit import build_swin_vit
from model.BoostSwinTransformer import BoostSwin
from .utils import *
from torch.cuda.amp import GradScaler, autocast
import time
from optimizer import get_optim
from scheduler import get_scheduler
from torch.utils.data import DataLoader
from .train import train
from .eval import validate

In [None]:
def train(params):
    with wandb.init(project="swin-twins", entity="asekhri", job_type="train", config=params) as run:
        config = wandb.config
        os.makedirs(config.predictions , exist_ok=True)

        if config.DATASET == "NIH":
            train_loader, valid_loader, test_loader = get_nih(config, NIH_Dataset)
        elif config.DATASET == "chexpert":
            train_loader, valid_loader, test_loader = get_chexpert(config, CheX_Dataset)
        else:
            raise ValueError("Dataset not supported")

        # model
        model = BoostSwin(config).to(config.DEVICE)
        
        # load pretrained
        if config.PRETRAINED:
            load_pretrained(config, model)

        # optimizer        
        optimizer = get_optim(config, model)
        
        # criterion
        criterion = nn.CrossEntropyLoss().to(config.DEVICE)

        # scheduler
        if config.SCHEDULER_NAME:
            lr_scheduler = get_scheduler(config, optimizer)
        else:
            lr_scheduler = None

        # scaler
        scaler = GradScaler(enabled=config.AMP_ENABLE)
        
        # Check if we have a checkpoint
        best_auc = 0
        if config.RESUME:
            config.START_EPOCH, best_auc = load_checkpoint(config, model, optimizer, lr_scheduler, scaler)
            print(f"Resuming from epoch {config.START_EPOCH} with best auc {best_auc}")

        # Train artifacts
        artifact = wandb.Artifact("proposed-method", type="model", description="boost swin T with SSL", metadata=dict(config))
        
        for epoch in range(config.START_EPOCH, config.EPOCHS):
            # Train the model for one epoch
            train_loss, train_acc, train_auc = train(config, train_loader, model, criterion, optimizer, scheduler, scaler, epoch)

            # validate the model for one epoch
            valid_loss, valid_acc, valid_auc = validate(config, model, valid_loader, criterion, epoch)

            if config.SCHEDULER_NAME:
                lr_scheduler.step(valid_loss)
            
            # test the model for one epoch
            test_loss, test_acc, test_auc = validate(config, model, test_loader, criterion, epoch)

            # save checkpoint
            save_checkpoint(config, model, optimizer, test_auc, lr_scheduler, scaler, epoch)

            # save checkpoint if best
            if test_auc > best_auc:
                best_auc = test_auc
                save_checkpoint(config, model, optimizer, test_auc, lr_scheduler, scaler, epoch, is_best=True)

        # add  to wandb    
        artifact.add_file(os.path.join(config.OUTPUT, 'checkpoint.pth'), name="ckp.pth")
        artifact.add_file(os.path.join(config.OUTPUT, 'model_best.pth'), name="best.pth")
        run.log_artifact(artifact)
    wandb.finish()