In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms

import os,sys

from modules.UNet import *
from modules.DataSet import *
from modules.Losses import *

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split

import numpy as np

In [2]:
def train(net,device,val_per=0.1,epochs=10,batch_size=10,resize_to=None):
    if resize_to is not None:
        transform_image = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.Resize(resize_to),
        #transforms.ToTensor()
        ])
    else:
        transform_image = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        #transforms.ToTensor()
        ])
    
    dataSet = UltraSoundDataSet2(root_dir,[transform_image,transform_label])
    nTrain = int(len(dataSet)*(1-val_per))
    nValid = int(len(dataSet)-nTrain)
    
    trainSet,validSet = random_split(dataSet,[nTrain,nValid])
    
    train_loader = DataLoader(trainSet,batch_size=batch_size,shuffle=True,num_workers=4)
    valid_loader = DataLoader(validSet,batch_size=batch_size,shuffle=True,num_workers=4)
    
    optimizer = torch.optim.Adam(net.parameters())
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',patience=10)
    
    running_loss_seg = 0
    
    step = 0
    np.set_printoptions(precision=2)
    
    for epoch in range(epochs):
        net.train()
        
        for batch in train_loader:
            imgs,labels = batch

            images = imgs.to(device=device,dtype=torch.float32)
            labels = labels.to(device=device,dtype=torch.float32)
            
            pred = net(images)
                
            seg_loss = DiceLoss(pred,labels)
            
            #print(seg_loss.item())
            optimizer.zero_grad()
            seg_loss.backward()
            optimizer.step()

            running_loss_seg += seg_loss.item()
            
            step += 1    
            if step % 10 == 9:    # print every 10 mini-batches
                print()
                print('[%d, %5d] loss: %.3f' %(epoch + 1, step + 1, running_loss_seg / 10))
                running_loss_seg = 0.0
                
            if step%50 == 49:
                net.eval()
                val_loss = 0
                for batch in valid_loader:
                    imgs,labels = batch

                    labels = labels.to(device)
                    images = imgs.to(device)
                    with torch.no_grad():
                        pred = net(images)

                    val_loss += DiceLoss(pred,labels)
                print('[%d, %5d] validation loss: %.3f' %(epoch + 1, step + 1, val_loss / len(valid_loader)))
                scheduler.step(val_loss)
                net.train()

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
unet = UNet(init_features=64).to(device)
root_dir = os.path.expanduser("~/projects/ma/data/phantomDataset/vessel_subset")

In [4]:
try:
    train(unet,device,val_per=0.1,epochs=3,batch_size=2,resize_to=None)
except KeyboardInterrupt:
    sys.exit()


[1,    10] loss: 0.813

[1,    20] loss: 0.698

[2,    30] loss: 0.616

[2,    40] loss: 0.524

[3,    50] loss: 0.425
[3,    50] validation loss: 0.381

[3,    60] loss: 0.334
