In [None]:
!git clone https://github.com/WiraDKP/neural_style_transfer.git

In [None]:
cd neural_style_transfer

In [None]:
import torch
from torch import optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

# Import Data

In [None]:
from src.utils import load_image
from torchvision import transforms

In [None]:
transform = transforms.Compose([
    transforms.Resize(300),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])  
    
content = load_image("data/content/najwa.jpg", transform).to(device)
style = load_image("data/style/barli.jpg", transform).to(device)

output = load_image("data/content/najwa.jpg", transform).to(device)
output.requires_grad = True

# Training Preparation -> MCO

In [None]:
from src.model import NeuralStyleTransfer
from src.criterion import criterion

In [None]:
model = NeuralStyleTransfer().to(device)
optimizer = optim.AdamW([output], lr=0.05)

# Training

## Extract Fitur

Arsitektur Feature Extractor VGG16
```
(features)
      0-3   ConvReLU() x2
        4   MaxPool2d()
      5-8   ConvReLU() x2
        9   MaxPool2d()
    10-15   ConvReLU() x4
       16   MaxPool2d()
    17-22   ConvReLU() x4
       23   MaxPool2d()
    24-29   ConvReLU() x4
       30   MaxPool2d()
       
(avgpool)       
            AdaptiveAvgPool2d()

(classifier)
      0-2   LinearBlock()
      3-5   LinearBlock()
        6   Linear()
```

Layer yang direkomendasikan
```
["4", "8", "13", "20", "27"]
```

In [None]:
content_features = model(content, layers=["4", "8"])
style_features = model(style, layers=["4", "8"])

## Training Loop

In [None]:
from src.utils import draw_styled_image

In [None]:
max_epochs = 2500
for epoch in range(1, max_epochs+1):
    output_features = model(output, layers=["4", "8"])
    loss = criterion(content_features, style_features, output_features, output_features, style_weight=1e6)
    loss.backward()
    
    optimizer.step()
    optimizer.zero_grad()
    
    if epoch % 100 == 0:
        print(f"Epoch: {epoch:5} | Loss: {loss.item():.5f}")
        _ = draw_styled_image(output)