In [1]:
import torch
import torch.nn as nn
import wandb
import matplotlib.pyplot as plt
import torch.nn.functional as F
import random
import torch.optim as optim
import os

from Model.Vanila_UNet import VanilaUNet
from Model.BatchNormalized_ver import BN_VanilaUNet
from Dataset import ISBI
from utils import Random_processing
from torchinfo import summary
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from tqdm.auto import tqdm

class set_seed:
    def __init__(self, seed):
        self.seed = seed
    
    def forward(self):
        random.seed(self.seed)
        torch.manual_seed(self.seed)


args = {
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'max_epochs' : 50,
    'train_batch_size': 4,
    'valid_batch_size': 4,
    'init_lr': 1e-2,
    'model_name': 'bn_vanilaUnet',
}


set_seed(1234)

  from .autonotebook import tqdm as notebook_tqdm


<__main__.set_seed at 0x7fe03fbad670>

### Running WandB

In [2]:
wandb.init(project='ISBI Semantic Segmentation-2nd')
wandb.run.name = args['model_name']
wandb.run.save()

wandb.config.update(args)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mduriankim[0m ([33mdurian[0m). Use [1m`wandb login --relogin`[0m to force relogin




### Define Dataset, DataLoader and Model

In [3]:
ds_train = ISBI('train', Random_processing())
ds_valid = ISBI('valid', Random_processing())

dl_train = DataLoader(ds_train, batch_size=args['train_batch_size'], shuffle=True)
dl_valid = DataLoader(ds_valid, batch_size=args['valid_batch_size'], shuffle=False)

model = VanilaUNet(in_channels=1, out_channels=2).to(args['device'])

### Define Loss function and Optimizer

In [4]:
class DiceLoss(nn.Module):
    def __init__(self, num_classes, weights=None, smooth=1e-7):
        super(DiceLoss, self).__init__()
        self.num_classes = num_classes
        self.smooth = smooth
        if weights is None:
            self.weights = torch.ones(num_classes)
        else:
            self.weights = weights
    
    def forward(self, inputs, targets):
        inputs = torch.softmax(inputs, dim=1)
        targets = F.one_hot(targets.squeeze().long(), num_classes=self.num_classes).permute(0, 3, 1, 2).float()
        
        intersection = (inputs * targets).sum((2, 3))
        dice = (2. * intersection + self.smooth) / (inputs.sum((2, 3)) + targets.sum((2, 3)) + self.smooth)
        
        class_weights = self.weights.to(inputs.device)
        class_weights = class_weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
        weighted_dice = dice * class_weights
        
        return 1 - torch.mean(weighted_dice)


class ce_loss(nn.Module):
    def __init__(self):
        super(ce_loss, self).__init__()

    def forward(self, preds, targets):
        loss_fn = nn.CrossEntropyLoss()
        targets = targets.squeeze().long()
        loss = loss_fn(preds, targets)
        return loss
    

# loss_fn = ce_loss()
loss_fn = DiceLoss(num_classes=2, weights=torch.tensor([0.7, 0.3]))

optimizer = optim.SGD(model.parameters(), lr = args['init_lr'], momentum=.99)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

### Create log path

In [5]:
log_path = f"Model/{args['model_name']}/log"
best_model_path = log_path + '/best_model'

if os.path.exists(best_model_path):
    pass
else:
    os.makedirs(best_model_path)

# Set test image
test_sample = next(iter(dl_valid))
test_imgs, test_lbls, test_oris = test_sample['image'], test_sample['label'], test_sample['origin']
view_img = test_oris[0].permute(1,2,0)
view_img = ((view_img * .1662)+ .491)*255.0
view_img = view_img.to(torch.int).detach().cpu().numpy()
view_lbl = test_lbls[0].permute(1,2,0)
test_imgs = test_imgs.to(args['device'])


### Define train & validation code


In [6]:
class engine():
    def model_train_one_epoch(model, samples, optimizer, loss_fn):
        imgs, lbls = samples['image'], samples['label']
        model.train()

        imgs = imgs.to(args['device'])
        lbls = lbls.to(args['device'])
        preds = model(imgs)

        loss = loss_fn(preds, lbls)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        return loss.item()

    @torch.no_grad()
    def model_valid_one_epoch(model, samples, loss_fn):
        imgs, lbls = samples['image'], samples['label']
        model.eval()

        imgs = imgs.to(args['device'])
        lbls = lbls.to(args['device'])

        preds = model(imgs)

        loss = loss_fn(preds, lbls)

        return loss.item()

def run():
    best_val_loss = 1e9

    for epoch in range(args['max_epochs']):
        print('--------------------------------')
        print(f'      Epoch : {epoch}')
        print('--------------------------------')

        for _, data in tqdm(enumerate(dl_train), total=len(dl_train)):
            train_loss= engine.model_train_one_epoch(model, data, optimizer, loss_fn)
        
        for _, data in tqdm(enumerate(dl_valid), total=len(dl_valid)):
            valid_loss = engine.model_valid_one_epoch(model, data, loss_fn)

        
        # torch.save(model.state_dict(), os.path.join(log_path, f'epoch_{epoch}.pth'))

        if valid_loss < best_val_loss:
            torch.save(model.state_dict(), os.path.join(best_model_path, 'best_model.pth'))
            best_val_loss = valid_loss
        
        # Inference
        with torch.no_grad():
            model.eval()

            view_preds = model(test_imgs)
            view_preds = torch.argmax(view_preds, dim=1)
            view_pred = view_preds[0].to(torch.int)
            view_pred = view_pred.detach().cpu()

        
        wandb.log(
            {
                'epoch' : epoch,
                'train loss' : train_loss,
                'valid_loss' : valid_loss,
            }
        )
        # Save Inference
        plt.figure()
        plt.subplot(1,3,1)
        plt.imshow(view_img)
        plt.title('origin tile')
        plt.axis('off')

        plt.subplot(1,3,2)
        plt.imshow(view_lbl)
        plt.title('Ground Truth')
        plt.axis('off')

        plt.subplot(1,3,3)
        plt.imshow(view_pred)
        plt.title('Prediction')
        plt.axis('off')

        plt.savefig(os.path.join(log_path, f'Inference_{epoch}'))
        plt.close()
        print()

run()

--------------------------------
      Epoch : 0
--------------------------------


  0%|          | 0/27 [00:00<?, ?it/s]

: 