In [1]:
import impaintingLib as imp
import re

from torch.utils.tensorboard import SummaryWriter
import torch 

# Si seed = 0 alors on est random sinon on est "déterministe"
seed = 1

shuffle = seed == 0
matcher = r"model=(\S+,{0,1})+ loss=(\S+,{0,1})+"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainloader, testloader = imp.data.getFaces(batch_size=32,shuffle=shuffle)
alterFunc = imp.mask.Alter(min_cut=15, max_cut=45, seed=seed).irregularMask #.irregularMask #.downScale #.squareMask

if not shuffle :
    torch.manual_seed(seed)

def modelMatch(model,in_channels):
    
    if "autoenco" in model :
        res = imp.model.AutoEncoder(in_channels).to(device)
        
    elif "unet" in model :
        convType = "conv2d"
        netType  = "default"
        
        if "dilat" in model :
            convType = "dilated"
        if "partial" in model :
            netType  = "partial"
            
        res = imp.model.UNet(in_channels, convType=convType, netType=netType).to(device)
        
    elif "pixel" in model :
        res = imp.model.SubPixelNetwork(in_channels).to(device)
        
    else :
        print("ERREUR : AUCUN MODELE RECONNUE DANS {}".format(model))
        
    return res

def lossMatch(loss):
    
    if "l1" in loss :
        res = torch.nn.L1Loss()
    elif "l2" in loss :
        res = torch.nn.MSELoss()
    elif "vgg" in loss :
        res = imp.loss.perceptualVGG
    elif "ae" in loss :
        res = imp.loss.perceptualAE
    elif "totalvar" in loss :
        res = imp.loss.totalVariation
        
    else :
        print("ERREUR : AUCUNE LOSS RECONNUE DANS {}".format(loss))
        
    return res

def parse(expeName = 'basic'):

    count = 0
    path = "./routines/" + expeName + ".txt"
    doSave = False
    onlyTest = False
    
    with open(path) as file:
        lines = file.readlines()

    for i,line in enumerate(lines):
        line = line.strip()

        # 1ere ligne = epoch et param
        if i < 1 :
            lineSplit = line.split(",")
            epoch = int(lineSplit[0])
            
            if "save" in lineSplit:
                doSave = True
                
            if "test" in lineSplit:
                onlyTest = True

        # autres lignes = scenario
        elif line and line[0] != "#" :
            model_list,loss_list = re.search(matcher, line).groups()
            model_list = model_list.split(",")
            loss_list  = loss_list.split(",")
            
            losses = []
            for loss in loss_list :
                loss = loss.split("*")
                if len(loss) < 2 :
                    loss = ["1"] + loss
                    
                loss[0] = int(loss[0])
                loss[1] = lossMatch(loss[1].lower())
                losses.append(loss)
                
            models = []
            #first_model = model_list.pop(0).lower()
            #models.append(modelMatch(first_model,4)) # Le premier model prend 4 channels en input
            models += [modelMatch(model.lower(),3) for model in model_list] # Les autres en prennent 3
            
            runName = "{} {} Epoch{}".format("-".join(map(str, models)), " + ".join(loss_list),epoch)
            optimizer  = torch.optim.Adam(models[-1].parameters(), lr=1e-3, weight_decay=0.001)
            visu       = imp.utils.Visu(runName = runName, expeName=expeName, save=doSave)
            
            print("------ Scénario {} : {}".format(count,runName))
            count += 1
            
            # Train
            if not onlyTest : 
                print("- Train")
                visuFuncs = [#visu.board_loss_train,
                             #visu.plot_altered_img,
                             #visu.plot_res_img
                ]
                #partial_3
                models[0].load_state_dict(torch.load('modelSave/unet_3channels.pth'))
                models[1].load_state_dict(torch.load('modelSave/pixelshuffle_3channels.pth'))
                
                imp.process.train(models, optimizer, trainloader, losses, epochs=epoch, alter=alterFunc, visuFuncs=visuFuncs)

                if doSave:
                    imp.process.model_save(models,runName)
            
            # Test
            if onlyTest : 
                models = imp.process.model_load(models,runName)
            
            print("- Test")
            
            visuFuncs = [#visu.board_loss_test,
                         #visu.board_plot_img,
                         visu.plot_altered_img,
                         visu.plot_res_img
                        ]
            imp.process.test(models, testloader, alter=alterFunc, visuFuncs=visuFuncs)

In [2]:
parse("best_irregular")

------ Scénario 0 : UNet(partial conv2d)-PixelShuffle1 L1 + 500*perceptualAE + perceptualVGG + totalvar Epoch11
- Test


  0%|          | 0/39 [00:00<?, ?it/s]