In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchsummary import summary
from model import FCN32s, UNet
from dataset import P2_DATA

BATCH_SIZE = 4 # Don't use 8
NUM_OF_WORKER = 0
NORM_MEAN = [0.485, 0.456, 0.406]
NORM_STD  = [0.229, 0.224, 0.225]
INPUT_SIZE = 512
CKPT_DIR = '../ckpt_p2/'

In [3]:
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(NORM_MEAN, NORM_STD)
])
valid_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(NORM_MEAN, NORM_STD)
])

# Load dataset
print("loading training dataset.....")
trainset = P2_DATA(root='../p2_data/train/',      transform=train_transform)
# validset = P2_DATA(root='../p2_data/validation/', transform=valid_transform)
print("Complete image loading")

loading training dataset.....
('../p2_data/train\\0000_sat.jpg', '../p2_data/train\\0000_mask.png')
('../p2_data/train\\0001_sat.jpg', '../p2_data/train\\0001_mask.png')
('../p2_data/train\\0002_sat.jpg', '../p2_data/train\\0002_mask.png')
('../p2_data/train\\0003_sat.jpg', '../p2_data/train\\0003_mask.png')
('../p2_data/train\\0004_sat.jpg', '../p2_data/train\\0004_mask.png')
('../p2_data/train\\0005_sat.jpg', '../p2_data/train\\0005_mask.png')
('../p2_data/train\\0006_sat.jpg', '../p2_data/train\\0006_mask.png')
('../p2_data/train\\0007_sat.jpg', '../p2_data/train\\0007_mask.png')
('../p2_data/train\\0008_sat.jpg', '../p2_data/train\\0008_mask.png')
('../p2_data/train\\0009_sat.jpg', '../p2_data/train\\0009_mask.png')
('../p2_data/train\\0010_sat.jpg', '../p2_data/train\\0010_mask.png')
('../p2_data/train\\0011_sat.jpg', '../p2_data/train\\0011_mask.png')
('../p2_data/train\\0012_sat.jpg', '../p2_data/train\\0012_mask.png')
('../p2_data/train\\0013_sat.jpg', '../p2_data/train\\0013_m

In [4]:
trainset_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_OF_WORKER)
# validset_loader = DataLoader(validset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_OF_WORKER)

# Use GPU if available, otherwise stick with cpu
use_cuda = torch.cuda.is_available()
torch.manual_seed(123)
device = torch.device("cuda" if use_cuda else "cpu")
print('Device used:', device)

Device used: cuda


In [5]:
def save_checkpoint(checkpoint_path, model, optimizer):
    state = {'state_dict': model.state_dict(),
             'optimizer' : optimizer.state_dict()}
    torch.save(state, checkpoint_path)
    print('model saved to %s' % checkpoint_path)

def train_fcn32s(model, epoch, log_interval=100):
    #optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    criterion = nn.CrossEntropyLoss()
    model.train()  # set training mode
    
    iteration = 0
    for ep in range(epoch):
        for batch_idx, (data, target) in enumerate(trainset_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if iteration % log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    ep, batch_idx * len(data), len(trainset_loader.dataset),
                    100. * batch_idx / len(trainset_loader), loss.item()))
            iteration += 1
        
        if ep % 5 == 0:
            save_checkpoint(CKPT_DIR + 'p2-%i.pth' % ep, model, optimizer)
    
    # save the final model
    save_checkpoint('p2-%i.pth' % ep, model, optimizer)

def train_unet(model, epoch, log_interval=100):
    optimizer = optim.RMSprop(model.parameters(), lr=0.001, weight_decay=1e-8, momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score
    grad_scaler = torch.cuda.amp.GradScaler(enabled=False)
    criterion = nn.CrossEntropyLoss()
    model.train()  # set training mode
    
    iteration = 0
    for ep in range(epoch):
        for batch_idx, (data, target) in enumerate(trainset_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            optimizer.zero_grad()
            grad_scaler.scale(loss).backward()
            grad_scaler.step(optimizer)
            grad_scaler.update()

            if iteration % log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    ep, batch_idx * len(data), len(trainset_loader.dataset),
                    100. * batch_idx / len(trainset_loader), loss.item()))
            iteration += 1
        
        if ep % 5 == 0:
            save_checkpoint(CKPT_DIR + 'p2-%i.pth' % ep, model, optimizer)
    
    # save the final model
    save_checkpoint('p2-%i.pth' % ep, model, optimizer)


In [6]:
# fcn32s = FCN32s().to(device)
# summary(fcn32s, (3, INPUT_SIZE, INPUT_SIZE))
# train_fcn32s(fcn32s, 100, log_interval=10)

In [7]:
unet = UNet().to(device)
train_unet(unet, 100, log_interval=10)

RuntimeError: CUDA out of memory. Tried to allocate 512.00 MiB (GPU 0; 16.00 GiB total capacity; 13.05 GiB already allocated; 125.00 MiB free; 13.83 GiB reserved in total by PyTorch)