Document for Model training



Importing packages

In [2]:
import matplotlib.pyplot as plt
from synDataFunctionality.TreeLib import Tree, genTree
import numpy as np
from synDataFunctionality.genInputFromLabel import labelToInput


Setting Constants

In [3]:
startX = 5
startY = 360
startAngle = 0
starWidth = 15
stopWidth = 3
startLength = 20
bifurcProb = 0.3


Make some trees 

In [4]:
from synDataFunctionality.saveSynData import genSynDat
import torch.utils.data as td
from DataLoaders import SynData

Make and save some synthetic data

In [5]:
lst = [startX, startY, starWidth, startLength, startAngle, stopWidth]
num = 9

# make num samples
genSynDat("SynDat/SynInput", "SynDat/SynLabel", lst, (736, 736), num)


Test that we can make a dataSet and it outputs data as expected

In [6]:
SynDataSet = SynData("SynDat/synInput", "SynDat/synLabel")

#Test we can retrieve data from Dataset
test, lab = SynDataSet[0]
print(test.shape)
print(lab.shape)

torch.Size([1, 736, 736])
torch.Size([1, 736, 736])


Visualize synthetic generated data

In [7]:
import torchvision
"""
#test that dataloader works, and show images
test_loader = td.DataLoader(SynDataSet, batch_size=2)

testIter = iter(test_loader)
for i in range(len(testIter)):
    imgs, labs = testIter.next()
    grid = torchvision.utils.make_grid(imgs) #.numpy()[0] hack to show tensor in plt
    plt.imshow(grid.numpy()[0], cmap="gray", vmin=0, vmax=255)
    plt.show()
    lab_grid = torchvision.utils.make_grid(labs)
    plt.imshow(lab_grid.numpy()[0], cmap="gray", vmin=0, vmax=1)
    plt.show()"""

'\n#test that dataloader works, and show images\ntest_loader = td.DataLoader(SynDataSet, batch_size=2)\n\ntestIter = iter(test_loader)\nfor i in range(len(testIter)):\n    imgs, labs = testIter.next()\n    grid = torchvision.utils.make_grid(imgs) #.numpy()[0] hack to show tensor in plt\n    plt.imshow(grid.numpy()[0], cmap="gray", vmin=0, vmax=255)\n    plt.show()\n    lab_grid = torchvision.utils.make_grid(labs)\n    plt.imshow(lab_grid.numpy()[0], cmap="gray", vmin=0, vmax=1)\n    plt.show()'

Check that transformations can be applied to dataSet

In [8]:
#Test transformations work correctly on dataloader:
from torchvision.transforms import RandomVerticalFlip, RandomHorizontalFlip, ColorJitter, CenterCrop, Normalize
from torchvision.transforms.functional import rotate
import torchvision

t_both = [RandomHorizontalFlip(p=0.5),  RandomVerticalFlip(p=0.5)]
#Maybe normalize imgs automatically in dataSet

In [9]:
TransDataSet = SynData("SynDat/synInput", "SynDat/synLabel", transforms_both=t_both)#, transforms_train=t_dat)


In [10]:
import torchvision
import numpy as np
"""
#test that dataloader works, and show images
trans_loader = td.DataLoader(TransDataSet, batch_size=2)

transIter = iter(trans_loader)
for i in range(len(transIter)):
    imgs, labs = transIter.next()
    grid = torchvision.utils.make_grid(imgs).numpy()[0]
    #print(np.amax(np.array(imgs)))
    #print(np.amin(np.array(imgs)))
    plt.imshow(grid+1, cmap="gray", vmin=0, vmax=255)
    plt.show()
    lab_grid = torchvision.utils.make_grid(labs).numpy()[0]
    plt.imshow(lab_grid, cmap="gray", vmin=0, vmax=1)
    plt.show()"""

'\n#test that dataloader works, and show images\ntrans_loader = td.DataLoader(TransDataSet, batch_size=2)\n\ntransIter = iter(trans_loader)\nfor i in range(len(transIter)):\n    imgs, labs = transIter.next()\n    grid = torchvision.utils.make_grid(imgs).numpy()[0]\n    #print(np.amax(np.array(imgs)))\n    #print(np.amin(np.array(imgs)))\n    plt.imshow(grid+1, cmap="gray", vmin=0, vmax=255)\n    plt.show()\n    lab_grid = torchvision.utils.make_grid(labs).numpy()[0]\n    plt.imshow(lab_grid, cmap="gray", vmin=0, vmax=1)\n    plt.show()'

Experiment with Basic Unet (to test whether it works)

Import packages and files for Unet and training

In [11]:
from Unet.UNetBasic import UnetBasic
import torch
import torch.optim as optim
import torch.nn as nn
import torch.cuda


In [12]:
#make Unet

net1 = UnetBasic()

#Try to use cuda machine

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#check which device we are on

print(device)


cpu


Make training and testing images. Apply random transformations to them:

In [13]:
#note, we comment our genSynDat, as otherwise we kill previous generated data.

#Make 300 samples for training data:
#genSynDat("SynDat/SynInput", "SynDat/SynLabel", lst, (736, 736), 300)
trainingData = SynData("SynDat/SynInputTrain", "SynDat/SynLabelTrain", t_both)
#Make 50 samples as test data:
#genSynDat("SynDat/SynInputTest", "SynDat/SynLabelTest", lst, (736, 736), 50)
testData = SynData("SynDat/SynInputTest", "SynDat/SynLabelTest", t_both)

trainLoader = td.DataLoader(trainingData, shuffle=True)
testLoader = td.DataLoader(testData, shuffle=True)

Training our basic Unet

In [14]:
from trainingFunctionality import trainLoop

net1.to(device) #For now is cpu, but will hopefully be CUDA later
net1 = net1.float() #hack that should remove float error

#Adam for now
optimizer = optim.Adam(net1.parameters(), lr=0.001)
criterion = nn.BCELoss()
criterion.to(device)

#Training, pray for me
trainLoss, testLoss , net = trainLoop(net1, optimizer, criterion, device, 2, trainLoader, testLoader, print_interv=5)

[1,     5] loss: 0.131
[1,    10] loss: 0.115


KeyboardInterrupt: 

Kill all files created in this session, such to diminish clutter

In [18]:
#kill generated files such that we diminish clutter
from synDataFunctionality.saveSynData import order_66

#order_66("SynDat/synInputTrain", "SynDat/synLabelTrain")
