In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [2]:
%cd '/content/gdrive/MyDrive/unet_crack'
! ls

/content/gdrive/MyDrive/unet_crack
data	    Dice_coeff_loss.py	logs		  model.py    trained_models
dataset.py  focalloss.py	lovasz_losses.py  myutils.py  unet_code.ipynb


##**Train**

In [None]:
import logging
import copy
import time

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np

from myutils import transform, cal_iou, onehot
from model import Unet
from dataset import myDataset

import lovasz_losses as L
from Dice_coeff_loss import dice_loss
from focalloss import FocalLoss

batch_size = 2
#num_epochs = [200, 200, 200, 200, 200, 200, 200, 200]
num_epochs = [100, 100, 100, 100]
num_workers = 2
lr = 0.0001

losslist = ['focal', 'bce', 'dice', 'lovasz']
#optimlist = ['adam', 'sgd']
optimlist = ['adam']
iflog = True

train_dataset = myDataset('./data/train', transform=transform)
val_dataset = myDataset('./data/val', transform=transform)
train_loader = DataLoader(dataset=train_dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=num_workers)
val_loader = DataLoader(dataset=val_dataset,
                            batch_size=1,
                            shuffle=False)
criterion = nn.BCELoss()
focallos = FocalLoss(gamma=2)

epoidx = -1
for los in losslist:
    for opt in optimlist:
        start =  time.time()
        print(los, opt)
        torch.manual_seed(77)
        torch.cuda.manual_seed(77)
        unet = Unet(3,2).cuda()
        history = []
        if 'adam' in opt :
            optimizer = torch.optim.Adam(unet.parameters(), lr=lr)
        elif 'sgd' in opt:
            optimizer = torch.optim.SGD(unet.parameters(), lr=10*lr, momentum=0.9)

        logging.basicConfig(filename='./logs/logger_unet.log', level=logging.INFO)

        total_step = len(train_loader)
        epoidx += 1
        for epoch in range(num_epochs[epoidx]):
            totalloss = 0
            for i, (images, masks) in enumerate(train_loader):
                images = images.cuda()
                masks = masks.cuda()
                outputs = unet(images)
                if 'bce' in los :
                    masks = onehot(masks)              
                    loss = criterion(outputs, masks)
                elif 'dice' in los :
                    masks = onehot(masks)              
                    loss = dice_loss(outputs, masks)
                elif 'lovasz' in los :
                    masks = onehot(masks)          
                    loss = L.lovasz_hinge(outputs, masks)
                elif 'focal' in los :
                    loss = focallos(outputs, masks.long())

                totalloss += loss*images.size(0)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                if i+1 == total_step:
                    train_pa, train_mpa, train_miou, train_fwiou = \
                                        cal_iou(unet,train_dataset)
                    val_pa, val_mpa, val_miou, val_fwiou = \
                                        cal_iou(unet,val_dataset)
                    history.append([totalloss.item()/len(train_dataset), 
                                    train_pa, train_mpa, train_miou, train_fwiou,
                                    val_pa, val_mpa, val_miou, val_fwiou])
                
                if  i+1 == total_step and epoch%3==0 and val_miou>0.8:
                    torch.save(unet.state_dict(), 
                                './trained_models/unet_'+opt+'_'+los+'_'+str(epoch+1)+'.pkl')

            history_np = np.array(history)
            np.save('./logs/unet_'+opt+'_'+los+'.npy',history_np)
        end = time.time()
        print((end-start)/60)

##**Predict**

In [4]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np

from myutils import transform
from model import Unet

torch.cuda.manual_seed(777)
torch.manual_seed(777)

img_folder = './data/test/images/'
img_dir = os.listdir(img_folder)
img_list = [img_folder+k for k in img_dir]
img_list.sort()

unet = Unet(3,2).cuda()
unet.load_state_dict(torch.load('trained_models/unet_adam_focal_100.pkl'))


for file in img_list:
    img = Image.open(file).resize([512,512])
    img = transform(img).cuda().unsqueeze(0)
    with torch.no_grad():
        pred = unet(img)
    pred = torch.argmin(pred,1)
    pred = pred.squeeze().cpu().numpy()
    pred = np.uint8(pred*255)
    pred_img = Image.fromarray(pred)
    img_name = str.split(file, '/')[-1]
    img_name = 'y_'+ img_name
    img_name = './data/test/unet_pred/' + img_name
    pred_img.save(img_name, 'png')