In [1]:
import torch
import torch.nn as nn
from PIL import Image
from pathlib import Path
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn.functional import mse_loss
import torchvision.transforms as transforms
from torchvision.models import vgg16, VGG16_Weights
from torchvision.models.feature_extraction import create_feature_extractor

In [2]:
num_style = 10
NUM_STYLE = 10

MEAN = (0.485, 0.456, 0.406)
STD = (0.229, 0.224, 0.225)

normalize = transforms.Normalize(mean=MEAN, std=STD)
denormalize = transforms.Normalize(mean=[-m/s for m, s in zip(MEAN, STD)],
                          std=[1/std for std in STD])

In [3]:
style_path = '../Dashtoon Task/MSGNet/data/artist/'
content_path = '../Dashtoon Task/AdaIn/dataset/content/train_2.5k'


style_index = -1
batch_size = 2
lr = 1e-4
style_weight = 5.0
iterations = 40000
tv_weight = 1e-5

In [4]:
def get_transforms(imsize = None, cropsize = None, cencrop = False):
    transformer = []
    if imsize:
        transformer.append(transforms.Resize(imsize))
    if cropsize:
        if cencrop:
            transformer.append(transforms.CenterCrop(cropsize))
        else:
            transformer.append(transforms.RandomCrop(cropsize))
    
    transformer.append(transforms.ToTensor())
    transformer.append(normalize)
    return transforms.Compose(transformer)    

In [5]:
class ImageDataset:
    def __init__(self, dir_path):
        self.images = sorted(list(dir_path.glob('*.jpg')))
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        img = Image.open(self.images[index]).convert('RGB')
        return img, index
    

class DataProcessor:
    def __init__(self, imsize = 256, cropsize = 240, cencrop = False):
        self.transforms = get_transforms(imsize=imsize, cropsize = cropsize, cencrop = cencrop)
    
    def __call__(self, batch):
        images, indices = list(zip(*batch))
        inputs = torch.stack(tuple(self.transforms(image) for image in images))
        return inputs, indices

In [6]:
content_nodes = ['relu_3_3']
style_nodes = ['relu_1_2', 'relu_2_2', 'relu_3_3', 'relu_4_2']
return_nodes = {3: 'relu_1_2',
                    8: 'relu_2_2',
                    15: 'relu_3_3',
                    22: 'relu_4_2'}

device = torch.device('cuda')
# device = 'cpu'

In [7]:
content_dataset = ImageDataset(dir_path = Path(content_path))
style_dataset = ImageDataset(dir_path = Path(style_path))

data_processor = DataProcessor(imsize=256, cropsize=240, cencrop=False)

In [8]:
content_dataloader = DataLoader(dataset=content_dataset, 
                                batch_size = batch_size, 
                                shuffle = True, 
                                collate_fn = data_processor)


style_dataloader = DataLoader(dataset=style_dataset, 
                                batch_size = batch_size, 
                                shuffle = True, 
                                collate_fn = data_processor)


In [9]:
vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features

In [10]:
for param in vgg.parameters():
    param.requires_grad = False

In [11]:
loss_network = create_feature_extractor(vgg, return_nodes).to(device)

In [12]:
class CIN(nn.Module):
    def __init__(self, num_style, ch):
        super(CIN, self).__init__()
        self.normalize = nn.InstanceNorm2d(ch, affine=False)
        self.offset = nn.Parameter(0.01 * torch.randn(1, num_style, ch))
        self.scale = nn.Parameter(1 + 0.01 * torch.randn(1, num_style, ch))

    def forward(self, x, style_codes):
        b, c, h, w = x.size()

        x = self.normalize(x)

        gamma = torch.sum(self.scale * style_codes, dim=1).view(b, c, 1, 1)
        beta = torch.sum(self.offset * style_codes, dim=1).view(b, c, 1, 1)

        x = x * gamma + beta

        return x.view(b, c, h, w)


class ConvWithCIN(nn.Module):
    def __init__(self, num_style, in_ch, out_ch, stride, activation, ksize):
        super(ConvWithCIN, self).__init__()
        self.padding = nn.ReflectionPad2d(ksize // 2)
        self.conv = nn.Conv2d(in_ch, out_ch, ksize, stride)

        self.cin = CIN(num_style, out_ch)

        if activation == "relu":
            self.activation = nn.ReLU()

        elif activation == "linear":
            self.activation = lambda x: x

    def forward(self, x, style_codes):
        x = self.padding(x)
        x = self.conv(x)
        x = self.cin(x, style_codes)
        x = self.activation(x)

        return x


class ResidualBlock(nn.Module):
    def __init__(self, num_style, in_ch, out_ch):
        super(ResidualBlock, self).__init__()

        self.conv1 = ConvWithCIN(num_style, in_ch, out_ch, 1, "relu", 3)
        self.conv2 = ConvWithCIN(num_style, out_ch, out_ch, 1, "linear", 3)

    def forward(self, x, style_codes):
        out = self.conv1(x, style_codes)
        out = self.conv2(out, style_codes)

        return x + out


class UpsamleBlock(nn.Module):
    def __init__(self, num_style, in_ch, out_ch):
        super(UpsamleBlock, self).__init__()
        self.conv = ConvWithCIN(num_style, in_ch, out_ch, 1, "relu", 3)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x, style_codes):
        x = self.upsample(x)
        x = self.conv(x, style_codes)

        return x


class StyleTransferNetwork(nn.Module):

    def __init__(self, num_style=10):
        super(StyleTransferNetwork, self).__init__()
        self.conv1 = ConvWithCIN(num_style,  3, 32, 1, 'relu', 9)
        self.conv2 = ConvWithCIN(num_style, 32, 64, 2, 'relu', 3)
        self.conv3 = ConvWithCIN(num_style, 64, 128, 2, 'relu', 3)

        self.residual1 = ResidualBlock(num_style, 128, 128)
#         self.residual2 = ResidualBlock(num_style, 128, 128)
        self.residual3 = ResidualBlock(num_style, 128, 128)
#         self.residual4 = ResidualBlock(num_style, 128, 128)
        self.residual5 = ResidualBlock(num_style, 128, 128)

        self.upsampling1 = UpsamleBlock(num_style, 128, 64)
        self.upsampling2 = UpsamleBlock(num_style, 64, 32)

        self.conv4 = ConvWithCIN(num_style, 32, 3, 1, 'linear', 9)

    def forward(self, x, style_codes):
        x = self.conv1(x, style_codes)
        x = self.conv2(x, style_codes)
        x = self.conv3(x, style_codes)

        x = self.residual1(x, style_codes)
#         x = self.residual2(x, style_codes)
        x = self.residual3(x, style_codes)
#         x = self.residual4(x, style_codes)
        x = self.residual5(x, style_codes)

        x = self.upsampling1(x, style_codes)
        x = self.upsampling2(x, style_codes)

        x = self.conv4(x, style_codes)

        return x

In [13]:
def calc_content_loss(features, targets, nodes):
    content_loss = 0
    for node in nodes:
        content_loss += mse_loss(features[node], targets[node])
    return content_loss


def gram(x):
    b, c, h, w = x.size()
    f = x.flatten(2)
    g = torch.bmm(f, f.transpose(1, 2))
    return g.div(h*w)


def calc_style_loss(features, targets, nodes):
    gram_loss = 0
    for node in nodes:
        gram_loss += mse_loss(gram(features[node]), gram(targets[node]))
    return gram_loss


def calc_tv_loss(x):
    tv_loss = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))
    tv_loss += torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
    return tv_loss

In [14]:
model = StyleTransferNetwork()

In [15]:
model.train()
model = model.to(device)

optimizer = Adam(model.parameters(), lr=lr)

losses = {'content': [], 'style': [], 'tv': [], 'total': []}

In [28]:
import time

for i in range(1, 1+iterations):
    start_time = time.time()

    content_images, _ = next(iter(content_dataloader))
    style_images, style_indices = next(iter(style_dataloader))

    style_codes = torch.zeros(batch_size, NUM_STYLE, 1)

    for b, s in enumerate(style_indices):
        style_codes[b, s] = 1

    content_images = content_images.to(device).contiguous()
    style_images = style_images.to(device).contiguous()
    style_codes = style_codes.to(device).contiguous()

    output_images = model(content_images, style_codes)

    content_features = loss_network(content_images)
    style_features = loss_network(style_images)
    output_features = loss_network(output_images)

    style_loss = calc_style_loss(output_features,
                                 style_features,
                                 style_nodes)
    content_loss = calc_content_loss(output_features,
                                     content_features,
                                     content_nodes)
    tv_loss = calc_tv_loss(output_images)

    total_loss = content_loss \
        + style_loss * style_weight \
        + tv_loss * tv_weight

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    losses['content'].append(content_loss.item())
    losses['style'].append(style_loss.item())
    losses['tv'].append(tv_loss.item())
    losses['total'].append(total_loss.item())

    if i % 100 == 0:
        end_time = time.time()
        time_taken = end_time - start_time
        log = f"iter.: {i}, time taken: {time_taken:.2f} seconds"
        for k, v in losses.items():
            # calculate a recent average value
            avg = sum(v[-50:]) / 50
            log += f", {k}: {avg:1.4f}"
        print(log)

    if i % 5000 == 0:
        torch.save({"state_dict": model.state_dict()}, f"model_{i}.ckpt")

    torch.cuda.empty_cache()

torch.save({"state_dict": model.state_dict()}, "model.ckpt")

iter.: 100, time taken: 0.28 seconds, content: 13.8616, style: 2.6055, tv: 0.3848, total: 26.8890
iter.: 200, time taken: 0.28 seconds, content: 14.5003, style: 2.8403, tv: 0.4104, total: 28.7019
iter.: 300, time taken: 0.30 seconds, content: 13.5037, style: 2.5887, tv: 0.3987, total: 26.4471
iter.: 400, time taken: 0.29 seconds, content: 15.0221, style: 2.7741, tv: 0.4148, total: 28.8924
iter.: 500, time taken: 0.31 seconds, content: 14.2237, style: 2.5694, tv: 0.4073, total: 27.0707
iter.: 600, time taken: 0.30 seconds, content: 14.8373, style: 2.5156, tv: 0.4166, total: 27.4153
iter.: 700, time taken: 0.29 seconds, content: 13.2620, style: 2.4252, tv: 0.3937, total: 25.3880
iter.: 800, time taken: 0.32 seconds, content: 14.0699, style: 2.2461, tv: 0.4120, total: 25.3002
iter.: 900, time taken: 0.31 seconds, content: 13.6203, style: 2.2394, tv: 0.4051, total: 24.8170
iter.: 1000, time taken: 0.28 seconds, content: 13.4120, style: 2.0773, tv: 0.4007, total: 23.7987
iter.: 1100, time t

iter.: 8400, time taken: 0.31 seconds, content: 10.7585, style: 1.1577, tv: 0.4299, total: 16.5470
iter.: 8500, time taken: 0.33 seconds, content: 10.2867, style: 1.1268, tv: 0.4106, total: 15.9205
iter.: 8600, time taken: 0.38 seconds, content: 10.5280, style: 1.0737, tv: 0.4457, total: 15.8967
iter.: 8700, time taken: 0.39 seconds, content: 10.5497, style: 1.1165, tv: 0.4415, total: 16.1323
iter.: 8800, time taken: 0.33 seconds, content: 10.4605, style: 1.1347, tv: 0.4417, total: 16.1340
iter.: 8900, time taken: 0.29 seconds, content: 9.9834, style: 1.1078, tv: 0.4330, total: 15.5224
iter.: 9000, time taken: 0.28 seconds, content: 10.6456, style: 1.1477, tv: 0.4524, total: 16.3842
iter.: 9100, time taken: 0.32 seconds, content: 10.5545, style: 1.0983, tv: 0.4399, total: 16.0458
iter.: 9200, time taken: 0.30 seconds, content: 10.4754, style: 1.0803, tv: 0.4522, total: 15.8768
iter.: 9300, time taken: 0.31 seconds, content: 10.3013, style: 1.0841, tv: 0.4221, total: 15.7219
iter.: 9400

iter.: 16700, time taken: 0.36 seconds, content: 9.6144, style: 0.9969, tv: 0.4412, total: 14.5989
iter.: 16800, time taken: 0.32 seconds, content: 9.2526, style: 1.0186, tv: 0.4429, total: 14.3458
iter.: 16900, time taken: 0.36 seconds, content: 9.4632, style: 0.9598, tv: 0.4610, total: 14.2620
iter.: 17000, time taken: 0.34 seconds, content: 9.5786, style: 1.0207, tv: 0.4706, total: 14.6819
iter.: 17100, time taken: 0.32 seconds, content: 9.1406, style: 0.9854, tv: 0.4339, total: 14.0676
iter.: 17200, time taken: 0.33 seconds, content: 9.3071, style: 1.0019, tv: 0.4508, total: 14.3168
iter.: 17300, time taken: 0.33 seconds, content: 9.0900, style: 1.0024, tv: 0.4590, total: 14.1019
iter.: 17400, time taken: 0.39 seconds, content: 8.9859, style: 0.9571, tv: 0.4416, total: 13.7713
iter.: 17500, time taken: 0.34 seconds, content: 8.9581, style: 0.9717, tv: 0.4404, total: 13.8164
iter.: 17600, time taken: 0.34 seconds, content: 9.2065, style: 1.0250, tv: 0.4651, total: 14.3316
iter.: 177

iter.: 25000, time taken: 0.35 seconds, content: 8.6821, style: 0.9591, tv: 0.4742, total: 13.4778
iter.: 25100, time taken: 0.34 seconds, content: 8.2014, style: 0.9461, tv: 0.4321, total: 12.9319
iter.: 25200, time taken: 0.35 seconds, content: 8.5525, style: 0.9650, tv: 0.4652, total: 13.3777
iter.: 25300, time taken: 0.36 seconds, content: 8.5831, style: 0.9301, tv: 0.4595, total: 13.2336


KeyboardInterrupt: 

In [30]:
check_point = torch.load('trained_models/model_25000.ckpt')

In [31]:
model = StyleTransferNetwork()
model.load_state_dict(check_point['state_dict'])
model.eval()

StyleTransferNetwork(
  (conv1): ConvWithCIN(
    (padding): ReflectionPad2d((4, 4, 4, 4))
    (conv): Conv2d(3, 32, kernel_size=(9, 9), stride=(1, 1))
    (cin): CIN(
      (normalize): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    )
    (activation): ReLU()
  )
  (conv2): ConvWithCIN(
    (padding): ReflectionPad2d((1, 1, 1, 1))
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
    (cin): CIN(
      (normalize): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    )
    (activation): ReLU()
  )
  (conv3): ConvWithCIN(
    (padding): ReflectionPad2d((1, 1, 1, 1))
    (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
    (cin): CIN(
      (normalize): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    )
    (activation): ReLU()
  )
  (residual1): ResidualBlock(
    (conv1): ConvWithCIN(
      (padding): ReflectionPad2d((1, 1, 1, 1))
      (conv): Co

In [32]:
test_image_path = '../Dashtoon Task/AdaIn/dataset/content/train_2.5k/COCO_train2014_000000255419.jpg'

In [41]:
import torchvision

In [42]:
def imload(path, imsize=None, cropsize=None, cencrop=False):
    transformer = get_transforms(imsize=imsize,
                                 cropsize=cropsize,
                                 cencrop=cencrop)
    image = Image.open(path).convert("RGB")
    return transformer(image).unsqueeze(0)


def imsave(image, save_path):
    """Save a image."""
    image = denormalize(torchvision.utils.make_grid(image)).clamp_(0.0, 1.0)
    torchvision.utils.save_image(image, save_path)
    return None

In [43]:
content_image = imload(test_image_path, imsize=256)
style_index = 0

In [44]:
test_image_path = '../Dashtoon Task/AdaIn/dataset/content/train_2.5k/COCO_train2014_000000255419.jpg'


content_image = imload(test_image_path, imsize=256)
style_index = 0

# for all styles
if style_index == -1:
    style_code = torch.eye(NUM_STYLE).unsqueeze(-1)
    content_image = content_image.repeat(NUM_STYLE, 1, 1, 1)

# for specific style
elif style_index in range(NUM_STYLE):
    style_code = torch.zeros(1, NUM_STYLE, 1)
    style_code[:, style_index, :] = 1
    
stylized_image = model(content_image, style_code)
imsave(stylized_image, 'stylized_images.jpg')

In [45]:
stylized_image = model(content_image, style_code)
imsave(stylized_image, 'stylized_images.jpg')