In [1]:
import os 
import numpy as np

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

from style_transfer_module.stylizationModel import StylizationModel
from color_transfer.color_transfer import histogram_matching
from pipeline import Pipeline
import utils 

import matplotlib.pyplot as plt
from PIL import Image
import sys
from os.path import join, dirname

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
content_path = "./coco/"

content_batch_size = 2
style_batch_size = 2
content_size = 360
style_size = 360

style_weigth = np.array([1e5, 1e5, 1e5, 1e5])

## Training part

In [3]:
transform_content = transforms.Compose([
                    transforms.Resize(content_size),
                    transforms.CenterCrop(content_size),
                    transforms.ToTensor(),
                    ])
transform_style = transforms.Compose([
                    transforms.Resize(style_size),
                    transforms.CenterCrop(style_size),
                    transforms.ToTensor(),])

### Create data for training

In [None]:
content_dataset = datasets.ImageFolder(content_path, transform_content)
content_loader = DataLoader(content_dataset, batch_size=content_batch_size, num_workers=2, 
                            shuffle=True, pin_memory=True)

In [4]:
base_styles_path = "./base_styles/"
base_styles = datasets.ImageFolder(base_styles_path, transform_style)
base_styles = DataLoader(base_styles, batch_size=style_batch_size, num_workers=2, 
                            shuffle=True, pin_memory=True)

#add_styles_path = "./pbn/"
#add_styles = datasets.ImageFolder(add_styles_path, transform_style)
#add_styles = DataLoader(add_styles, batch_size=style_batch_size, num_workers=2, 
#                            shuffle=True, pin_memory=True)

test_styles_path = "./test_styles/"
test_styles = datasets.ImageFolder(test_styles_path, transform_style)
test_styles = DataLoader(test_styles, batch_size=style_batch_size, num_workers=2, 
                            shuffle=True, pin_memory=True)

### Train model

In [None]:
#utils.train_model(base_styles, base_styles, dropout=0.2,
#                         checkpoint="./checkpoints/checkpoint.pth", save_dir="./cp_no_rest/", lr=1e-4, sample_alpha=True, 
#                         connections=True, max_iter=80000)

## Evaluation part

In [5]:
style_eval_size = 400
content_eval_size = 500

style_eval_transform = transforms.Compose([
    transforms.Resize(style_eval_size),
    transforms.CenterCrop(style_eval_size),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.to(DEVICE)),
])

content_eval_transform = transforms.Compose([
    transforms.Resize(content_eval_size),
    transforms.CenterCrop(content_eval_size),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.to(DEVICE)),
    transforms.Lambda(lambda x: x.unsqueeze(0)),  # batch of 1 image
])

### Create data for evaluation

In [6]:
base_styles_path = "./base_styles/"
base_styles = datasets.ImageFolder(base_styles_path, style_eval_transform)
base_styles = DataLoader(base_styles, batch_size=1)

#add_styles_path = "./pbn/"
#add_styles = datasets.ImageFolder(add_styles_path, style_eval_transform)
#add_styles = DataLoader(add_styles, batch_size=1, 
#                            shuffle=True)

test_styles_path = "./test_styles/"
test_styles = datasets.ImageFolder(test_styles_path, style_eval_transform)
test_styles = DataLoader(test_styles, batch_size=1)

content_1 = content_eval_transform(Image.open("./content_variants/1.jpg"))
content_2 = content_eval_transform(Image.open("./content_variants/4.jpg"))

content_3 = content_eval_transform(Image.open("./content_variants/5.jpg")) 
content_list = [content_1, content_2, content_3]

### Create pipeline

In [7]:
model = utils.build_model(dropout=0.2, connections=True, 
                    checkpoint="./checkpoints/checkpoint.pth")
pipeline = Pipeline(model, DEVICE, theta=0.2)

### Evaluate pipeline on given data

In [None]:
utils.eval_pipeline(pipeline, content_list, base_styles, "./output/", max_images=2, 
          alpha_list=[0.4, 1.0], beta_list=[1.0])