In [None]:
!git clone https://github.com/WiraDKP/neural_style_transfer.git
cd neural_style_transfer

In [None]:
import torch
from torch import optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

# Import Data

In [None]:
from src.utils import load_image
from torchvision import transforms

In [None]:
transform = transforms.Compose([
    transforms.Resize(300),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])  
    
content = load_image("data/content/petani.jpg", transform).to(device)
style = load_image("data/style/style1.jpg", transform).to(device)

output = load_image("data/content/petani.jpg", transform).to(device)
output.requires_grad = True

# Training Preparation -> MCO

In [None]:
from src.model import NeuralStyleTransfer
from src.criterion import criterion

In [None]:
model = NeuralStyleTransfer()
optimizer = optim.AdamW([output], lr=0.001)

# Training

### Extract Fitur

Arsitektur Feature Extractor VGG19 with Batchnorm 
```
(features)
      0-5   ConvBnReLU() x2
        6   MaxPool2d()
     7-12   ConvBnReLU() x2
       13   MaxPool2d()
    14-25   ConvBnReLU() x4
       26   MaxPool2d()
    27-38   ConvBnReLU() x4
       39   MaxPool2d()
    40-51   ConvBnReLU() x4
       52   MaxPool2d()
       
(avgpool)       
            AdaptiveAvgPool2d()

(classifier)
      0-2   LinearBlock()
      3-5   LinearBlock()
        6   Linear()
```

In [None]:
content_features = model(content, layers=["40"])
style_features = model(style, layers=["0", "7", "14", "27", "40"])

## Training Loop

In [None]:
from src.utils import draw_styled_image

In [None]:
max_epochs = 10
for epoch in range(1, max_epochs+1):
    output_features = model(output, layers=["0", "7", "14", "27", "40"])
    loss = criterion(content_features, style_features, output_features)
    loss.backward()
    
    optimizer.step()
    optimizer.zero_grad()
    
    if epoch % 5 == 0:
        print(f"Epoch: {epoch:5} | Loss: {loss.item():.5f}")
        draw_styled_image(output)