In [1]:
from torch.utils.tensorboard import SummaryWriter
import torch 
import impaintingLib as imp

### Visualization :
# - plot_original_img
# - plot_altered_img
# - plot_res_img
# - plot_all_img
# - board_plot_last_img
# - board_loss_train
# - board_loss_test
# - full_board

### Alteration :
# - squareMask

### Models : 
# - AutoEncoder
# - UNet
# - UNetPartialConv
# - SubPixelNetwork

### Loss : 
# - torch.nn.L1Loss()
# - torch.nn.L2Loss()
# - perceptual_loss
# - totalVariationLoss

print("Is GPU available ?", torch.cuda.is_available())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.manual_seed(0)

torch.cuda.empty_cache()
print("Mémoire allouée : ", torch.cuda.memory_allocated() / 1024**2)
print("Mémoire réservé : ", torch.cuda.memory_reserved() / 1024**2)

# -------------- Parameters

trainloader, testloader = imp.data.getFaces()

runName   = "test2"
model     = imp.model.UNet().to(device)

optimizer  = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.001)
criterions =  [imp.loss.totalVariationLoss,
               imp.loss.perceptual_loss,
               torch.nn.L1Loss()]

alterFunc = imp.mask.Alter(min_cut=4, max_cut=60).squareMask

# -------------- Process

# --- Train
visu      = imp.utils.Visu(runName = runName)
visuFuncs = [visu.board_loss_train]

imp.process.train(model, optimizer, trainloader, criterions, epochs=1, alter=alterFunc, visuFuncs=visuFuncs)
# imp.process.model_save(model,runName)

# --- Test
visu      = imp.utils.Visu(runName = runName)
visuFuncs = [visu.board_loss_test,
             visu.board_plot_last_img,
             visu.plot_last_img]


# model = imp.process.model_load(model,runName)
imp.process.test(model, testloader, alter=alterFunc, visuFuncs=visuFuncs)

# -------------- Display Model

# Tensor Board Model
# writer = SummaryWriter("runs/" + runName)
# example_input, _ = next(iter(trainloader))
# writer.add_graph(model,example_input.cuda())
# writer.close()

Is GPU available ? True
Mémoire allouée :  57.0078125
Mémoire réservé :  62.0


  0%|          | 0/375 [00:00<?, ?it/s]

KeyboardInterrupt: 