In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms as T
from PIL import Image
import torchvision.models as models
from torchvision.utils import save_image

In [7]:
vgg_model = models.vgg19(weights= models.VGG19_Weights.DEFAULT)

In [8]:
print(vgg_model.features[22:])

Sequential(
  (22): ReLU(inplace=True)
  (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (24): ReLU(inplace=True)
  (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (26): ReLU(inplace=True)
  (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (29): ReLU(inplace=True)
  (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (31): ReLU(inplace=True)
  (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (33): ReLU(inplace=True)
  (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (35): ReLU(inplace=True)
  (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)


In [9]:
class VGG(nn.Module):
    def __init__(self, model):
        super(VGG, self).__init__()

        self.model = model.features[:29]
        self.chosen_features = [0, 5, 10, 19, 28]

    def forward(self, x):
        output_features = []
        for i, layer in enumerate(self.model):
            x = layer(x)

            if i in self.chosen_features:
                output_features.append(x)

        return output_features

In [10]:
transform = T.Compose(
    [
        T.Resize((300, 300)),
        T.ToTensor(),
        T.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

In [11]:
style_image = transform(Image.open('./images/picaso_style.jpeg')).unsqueeze(0)
content_image = transform(Image.open('./images/catO.jpg')).unsqueeze(0)

In [12]:
generated_image = content_image.clone().requires_grad_(True)
#generated_image = torch.randn(content_image.shape, requires_grad= True)

In [13]:
model = VGG(vgg_model).eval()
steps = 1001
alpha = 1
beta = .5
optimizer = optim.Adam([generated_image], lr= .005)

In [14]:
def denormalize(tensor):
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    mean = torch.tensor(mean).view(3, 1, 1)
    std = torch.tensor(std).view(3, 1, 1)
    tensor = tensor * std + mean
    return tensor

### Scope of Usage:

*     torch.no_grad(): Affects all operations within the block—computation graphs are not created for any * * tensors or operations within the context of with torch.no_grad().
    .detach(): Affects only the specific tensor on which it is called, detaching it from the computation graph. Subsequent operations on the detached tensor will not track gradients, but prior operations may still have their graphs.

When to Use:

    torch.no_grad(): Use this when you don't need any gradients at all (e.g., during inference or when computing values that don't require backpropagation, like in your style/content feature extraction).
    .detach(): Use this when you want to selectively exclude specific tensors from having their gradients tracked, while other parts of the model still need gradients. It’s typically useful when you want to freeze parts of a model during training or when you’re reusing outputs in a new context without backpropagating through the same part of the model.

Memory Efficiency:

    Both methods prevent the creation of computation graphs, thus saving memory. However, torch.no_grad() is more efficient if you're doing inference for a batch of operations where no gradients are needed, as it disables all gradient tracking within its scope.
    .detach() is useful when you need gradients for some parts of the graph but want to stop tracking for certain tensors.

In [21]:
#content_features = model(content_image).detach()
#style_features = model(style_image).detach()

with torch.no_grad():
    content_features = model(content_image)
    style_features = model(style_image)

for step in range(steps):
    #content_features = model(content_image)
    #style_features = model(style_image)
    generated_features = model(generated_image)

    style_loss = content_loss = total_loss = 0

    for generated_feature, content_feature, style_feature in zip(
        generated_features, content_features, style_features
    ):

        batch_size, channel, height, width = generated_feature.shape
        content_loss += torch.mean(torch.square(generated_feature -
                                                content_feature))

        # content_loss += torch.square(generated_feature -
        #                                        content_feature)/2

        generated_gram = generated_feature.view(channel, height * width).mm(
            generated_feature.view(channel, height * width).t()
        )

        style_gram = style_feature.view(channel, height * width).mm(
            style_feature.view(channel, height * width).t()
        )
        style_loss += torch.mean(torch.square(generated_gram - style_gram))
        # style_loss += (torch.square(generated_gram - style_gram))/(2 * channel * width * height)

    total_loss = alpha * content_loss + beta * style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if step % 200 == 0:
        print(total_loss)
        generated_cat = denormalize(generated_image)
        save_image(generated_cat, "./pic_cat.jpg")

tensor(4.2645e+08, grad_fn=<AddBackward0>)
tensor(77107296., grad_fn=<AddBackward0>)
tensor(35573192., grad_fn=<AddBackward0>)
tensor(19159724., grad_fn=<AddBackward0>)
tensor(12029507., grad_fn=<AddBackward0>)
tensor(8390893., grad_fn=<AddBackward0>)
