In [2]:
import torch
import torch.nn as nn
from simple_UNet import UNet
import os
from hyperparameters import *
# from loss_function import DiceLoss
import time
import torch.optim as optim
torch.manual_seed(0)

<torch._C.Generator at 0x28e346e8890>

In [3]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [4]:
model = UNet(in_channels=3,
            out_channels=1,
            n_class=1,
            kernel_size=3,
            padding=1,
            stride=1).to(device)

In [None]:
x = torch.randn(size=(3,3,1024,1024), dtype=torch.float32)
with torch.no_grad():
    out = model(x)

In [None]:
print(out.shape)

In [5]:
n_iters = int(train_set_size / batch_size)
iterations = epochs * n_iters
step_size = 2*n_iters
if not os.path.exists('./results'):
    os.mkdir('./results')
save_PATH = f'./results/{epochs}epochs_{lr}lr_{batch_size}batch'
if not os.path.exists(save_PATH):
    os.mkdir(save_PATH)

In [6]:
model = model.float()
loss_fn = nn.BCEWithLogitsLoss()
opt = torch.optim.Adam(model.parameters(), lr)
lr_scheduler = optim.lr_scheduler.StepLR(opt, step_size=step_size, gamma=0.95)

In [7]:
def train(model, train_dl, valid_dl, loss_fn, optimizer, acc_fn, epochs):
    start = time.time()

    train_loss, valid_loss, accuracy = [], [], []

    best_acc = 0.0

    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('-' * 10)

        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train(True)  # Set trainind mode = true
                dataloader = train_dl
            else:
                model.train(False)  # Set model to evaluate mode
                dataloader = valid_dl

            running_loss = 0.0
            running_acc = 0.0

            step = 0

            # iterate over data
            for x, y in dataloader:
                x = torch.permute(x, (0, 3, 2, 1))
                step += 1

                # forward pass
                if phase == 'train':
                    # zero the gradients
                    optimizer.zero_grad()
                    outputs = model(x)
                    outputs = torch.squeeze(outputs)
                    y = y.to(torch.float64)
                    loss = loss_fn(outputs, y)

                    # the backward pass frees the graph memory, so there is no 
                    # need for torch.no_grad in this training pass
                    loss.backward()
                    optimizer.step()
                    lr_scheduler.step()

                else:
                    with torch.no_grad():
                        outputs = model(x)
                        outputs = torch.squeeze(outputs)
                        loss = loss_fn(outputs, y)

                # stats - whatever is the phase
                acc = acc_fn(outputs, y, batch_size)

                running_acc  += acc*dataloader.batch_size
                running_loss += loss*dataloader.batch_size 

                if step % 10 == 0:
                    # clear_output(wait=True)
                    print('Current step: {}  Loss: {}  Acc: {}  AllocMem (Mb): {}'.format(step, loss, acc, torch.cuda.memory_allocated()/1024/1024))
                    # print(torch.cuda.memory_summary())

            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_acc / len(dataloader.dataset)

            print('{} Loss: {:.4f} Acc: {}'.format(phase, epoch_loss, epoch_acc))

            train_loss.append(epoch_loss) if phase=='train' else valid_loss.append(epoch_loss)

    time_elapsed = time.time() - start
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))    
    
    return train_loss, valid_loss, accuracy

In [8]:
train_loss, valid_loss, accuracy = train(model, train_loader, valid_loader, loss_fn, opt, acc_fn, epochs)

Epoch 0/19
----------
Current step: 10  Loss: 0.7036934812373147  Acc: 0.39170533418655396  AllocMem (Mb): 0.0
Current step: 20  Loss: 0.6957467269258985  Acc: 0.4773002564907074  AllocMem (Mb): 0.0
Current step: 30  Loss: 0.7070352376940321  Acc: 0.31131210923194885  AllocMem (Mb): 0.0
Current step: 40  Loss: 0.7005226345828902  Acc: 0.38519516587257385  AllocMem (Mb): 0.0
Current step: 50  Loss: 0.7051764840386567  Acc: 0.26998597383499146  AllocMem (Mb): 0.0
Current step: 60  Loss: 0.7020407322783768  Acc: 0.29310837388038635  AllocMem (Mb): 0.0
Current step: 70  Loss: 0.6979523009094919  Acc: 0.359161376953125  AllocMem (Mb): 0.0
Current step: 80  Loss: 0.6952284341750499  Acc: 0.4180564880371094  AllocMem (Mb): 0.0
Current step: 90  Loss: 0.6955208891478847  Acc: 0.35359877347946167  AllocMem (Mb): 0.0
Current step: 100  Loss: 0.693933338435686  Acc: 0.3967247009277344  AllocMem (Mb): 0.0
Current step: 110  Loss: 0.6930694219036001  Acc: 0.5777999758720398  AllocMem (Mb): 0.0
trai

Current step: 30  Loss: 0.63811804647612  Acc: 0.6886879205703735  AllocMem (Mb): 0.0
Current step: 40  Loss: 0.6671070437529124  Acc: 0.6148048639297485  AllocMem (Mb): 0.0
Current step: 50  Loss: 0.6205195939628538  Acc: 0.7300140261650085  AllocMem (Mb): 0.0
Current step: 60  Loss: 0.6294214964326329  Acc: 0.706891655921936  AllocMem (Mb): 0.0
Current step: 70  Loss: 0.6562089606923109  Acc: 0.640838623046875  AllocMem (Mb): 0.0
Current step: 80  Loss: 0.6805176802780579  Acc: 0.5819435119628906  AllocMem (Mb): 0.0
Current step: 90  Loss: 0.6536180587218496  Acc: 0.6464012265205383  AllocMem (Mb): 0.0
Current step: 100  Loss: 0.6716620330607839  Acc: 0.6032752990722656  AllocMem (Mb): 0.0
Current step: 110  Loss: 0.6824959144572403  Acc: 0.5777999758720398  AllocMem (Mb): 0.0
train Loss: 0.6501 Acc: 0.6566556917519129
Current step: 10  Loss: 0.6819446682929993  Acc: 0.5792831182479858  AllocMem (Mb): 0.0
Current step: 20  Loss: 0.6486272215843201  Acc: 0.6570907831192017  AllocMem (

Current step: 50  Loss: 0.6014870281376716  Acc: 0.7300140261650085  AllocMem (Mb): 0.0
Current step: 60  Loss: 0.6147271456977705  Acc: 0.706891655921936  AllocMem (Mb): 0.0
Current step: 70  Loss: 0.6529342320391152  Acc: 0.640838623046875  AllocMem (Mb): 0.0
Current step: 80  Loss: 0.6871248438205839  Acc: 0.5819435119628906  AllocMem (Mb): 0.0
Current step: 90  Loss: 0.6497015093095342  Acc: 0.6464012265205383  AllocMem (Mb): 0.0
Current step: 100  Loss: 0.674795587146491  Acc: 0.6032752990722656  AllocMem (Mb): 0.0
Current step: 110  Loss: 0.6896830193748429  Acc: 0.5777999758720398  AllocMem (Mb): 0.0
train Loss: 0.6438 Acc: 0.6566556917519129
Current step: 10  Loss: 0.6888530254364014  Acc: 0.5792831182479858  AllocMem (Mb): 0.0
Current step: 20  Loss: 0.6434537172317505  Acc: 0.6570907831192017  AllocMem (Mb): 0.0
Current step: 30  Loss: 0.6091980338096619  Acc: 0.7157996892929077  AllocMem (Mb): 0.0
valid Loss: 0.6312 Acc: 0.6514777348499106
Epoch 13/19
----------
Current step

KeyboardInterrupt: 

In [None]:
print(len(train_loader.dataset))

In [None]:
for x,y in train_loader:
    print(x)