In [1]:
import torch
import torch.nn as nn
from simple_UNet import UNet
import os
from hyperparameters import *
from DiceBCEloss import DiceBCELoss
import time
import torch.optim as optim
torch.manual_seed(0)

<torch._C.Generator at 0x111501df0>

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [3]:
model = UNet(in_channels=3,
            out_channels=1,
            n_class=1,
            kernel_size=3,
            padding=1,
            stride=1).to(device)

In [4]:
x = torch.randn(size=(3,3,1024,1024), dtype=torch.float32)
with torch.no_grad():
    out = model(x)

In [None]:
print(out.shape)

In [11]:
n_iters = int(train_set_size / batch_size)
iterations = epochs * n_iters
step_size = 2*n_iters
if not os.path.exists('./results_dice'):
    os.mkdir('./results_dice')
save_PATH = f'./results_dice/{epochs}epochs_{lr}lr_{batch_size}batch'
if not os.path.exists(save_PATH):
    os.mkdir(save_PATH)

In [6]:
model = model.float()
#loss_fn = nn.BCEWithLogitsLoss()
loss_fn = DiceBCELoss()
opt = torch.optim.Adam(model.parameters(), lr)
lr_scheduler = optim.lr_scheduler.StepLR(opt, step_size=step_size, gamma=0.95)

In [9]:
def train(model, train_dl, valid_dl, loss_fn, optimizer, acc_fn, epochs):
    start = time.time()

    train_loss, valid_loss, accuracy = [], [], []

    best_acc = 0.0

    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('-' * 10)

        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train(True)  # Set trainind mode = true
                dataloader = train_dl
            else:
                model.train(False)  # Set model to evaluate mode
                dataloader = valid_dl

            running_loss = 0.0
            running_acc = 0.0

            step = 0

            # iterate over data
            for x, y in dataloader:
                x = torch.permute(x, (0, 3, 2, 1))
                step += 1

                # forward pass
                if phase == 'train':
                    # zero the gradients
                    optimizer.zero_grad()
                    outputs = model(x)
                    outputs = torch.squeeze(outputs)
                    y = y.to(torch.float64)
                    loss = loss_fn(outputs.float(), y.float())

                    # the backward pass frees the graph memory, so there is no 
                    # need for torch.no_grad in this training pass
                    loss.backward()
                    optimizer.step()
                    lr_scheduler.step()

                else:
                    with torch.no_grad():
                        outputs = model(x)
                        outputs = torch.squeeze(outputs)
                        loss = loss_fn(outputs.float(), y.float())

                # stats - whatever is the phase
                acc = acc_fn(outputs, y, batch_size)

                running_acc  += acc*dataloader.batch_size
                running_loss += loss*dataloader.batch_size 

                if step % 10 == 0:
                    # clear_output(wait=True)
                    print('Current step: {}  Loss: {}  Acc: {}  AllocMem (Mb): {}'.format(step, loss, acc, torch.cuda.memory_allocated()/1024/1024))
                    # print(torch.cuda.memory_summary())

            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_acc / len(dataloader.dataset)

            print('{} Loss: {:.4f} Acc: {}'.format(phase, epoch_loss, epoch_acc))

            train_loss.append(epoch_loss) if phase=='train' else valid_loss.append(epoch_loss)

    time_elapsed = time.time() - start
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))    
    
    return train_loss, valid_loss, accuracy

In [10]:
train_loss, valid_loss, accuracy = train(model, train_loader, valid_loader, loss_fn, opt, acc_fn, epochs)

Epoch 0/19
----------
Current step: 10  Loss: 1.256129503250122  Acc: 0.39170533418655396  AllocMem (Mb): 0.0
Current step: 20  Loss: 1.1978378295898438  Acc: 0.4773002564907074  AllocMem (Mb): 0.0
Current step: 30  Loss: 1.3189489841461182  Acc: 0.31131210923194885  AllocMem (Mb): 0.0
Current step: 40  Loss: 1.2599986791610718  Acc: 0.38519516587257385  AllocMem (Mb): 0.0
Current step: 50  Loss: 1.3528696298599243  Acc: 0.26998597383499146  AllocMem (Mb): 0.0
Current step: 60  Loss: 1.3310961723327637  Acc: 0.29310837388038635  AllocMem (Mb): 0.0
Current step: 70  Loss: 1.2778018712997437  Acc: 0.359161376953125  AllocMem (Mb): 0.0


KeyboardInterrupt: 

In [None]:
print(len(train_loader.dataset))

In [None]:
for x,y in train_loader:
    print(x)