In [None]:
!pip install torch torchvision Pillow -q




In [None]:
from torchvision import models
from torchvision import transforms
from PIL import Image
import torch
import torch.nn as nn
import numpy as np

In [None]:
print(torch.rand(2,3).cuda())

tensor([[0.9027, 0.1280, 0.0288],
        [0.7828, 0.9218, 0.0314]], device='cuda:0')


In [None]:


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_image(image_path, transform=None, max_size=None):
    image = Image.open(image_path)

    if max_size:
        width, height = image.size
        if width > height:
            new_width = max_size
            new_height = int(height * (max_size / width))
        else:
            new_width = int(width * (max_size / height))
            new_height = max_size

        image = image.resize((new_width, new_height), Image.LANCZOS)

    if transform:
        image = transform(image).unsqueeze(0)

    return image.to(device)

class VGGNet(nn.Module):
    def __init__(self):
        super(VGGNet, self).__init__()
        self.select = ['0', '5', '10', '19', '28']
        self.vgg = models.vgg19(pretrained=True).features

    def forward(self, x):
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features


def main(content_path, style_path, max_size=400, total_step=4000, log_step=80, sample_step=800, style_weight=100, lr=0.001):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.472, 0.436, 0.424), std=(0.227, 0.226, 0.225))
    ])

    content = load_image(content_path, transform, max_size=256)
    style = load_image(style_path, transform, max_size=256)  # Resize style image to match content dimensions

    target = content.clone().requires_grad_(True)

    optimizer = torch.optim.Adam([target], lr=lr, betas=[0.5, 0.999])
    vgg = VGGNet().to(device).eval()

    for step in range(total_step):
        target_features = vgg(target)
        content_features = vgg(content)
        style_features = vgg(style)

        style_loss = 0
        content_loss = 0
        for f1, f2, f3 in zip(target_features, content_features, style_features):
            content_loss += torch.mean((f1 - f2) ** 2)

            # Check dimensions of f1 and f3
            _, c, h, w = f1.size()
            f1 = f1.view(c, -1)  # Reshape f1 to (c, h*w)
            f3 = f3.view(c, -1)  # Reshape f3 to (c, h*w)

            # Calculate Gram matrix for f1 and f3
            f1 = torch.mm(f1, f1.t())
            f3 = torch.mm(f3, f3.t())

            style_loss += torch.mean((f1 - f3) ** 2) / (c * h * w)



        loss = content_loss + style_weight * style_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (step+1) % log_step == 0:
            print('Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}'.format(step+1, total_step, content_loss.item(), style_loss.item()))

        if (step+1) % sample_step == 0:
            denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
            img = target.clone().squeeze()
            img = denorm(img).clamp_(0, 1)
            img = img.cpu().detach().numpy()
            img = np.transpose(img, (1, 2, 0))
            img = (img * 255).astype(np.uint8)
            img = Image.fromarray(img)
            img.save('output-{}.png'.format(step+1))

# You can specify the paths to your content and style images here
content_path = '/content/12472_Comp_E1_ImageA.jpg'
style_path = '/content/download-_1_.png'

main(content_path, style_path)



Step [80/4000], Content Loss: 11.0362, Style Loss: 206.3892
Step [160/4000], Content Loss: 16.5164, Style Loss: 133.2332
Step [240/4000], Content Loss: 19.9567, Style Loss: 97.7935
Step [320/4000], Content Loss: 22.5299, Style Loss: 76.2136
Step [400/4000], Content Loss: 24.5207, Style Loss: 61.5953
Step [480/4000], Content Loss: 26.1519, Style Loss: 51.0329
Step [560/4000], Content Loss: 27.5391, Style Loss: 43.0825
Step [640/4000], Content Loss: 28.7737, Style Loss: 36.9377
Step [720/4000], Content Loss: 29.8577, Style Loss: 32.0678
Step [800/4000], Content Loss: 30.8445, Style Loss: 28.1310
Step [880/4000], Content Loss: 31.7401, Style Loss: 24.8980
Step [960/4000], Content Loss: 32.5444, Style Loss: 22.2054
Step [1040/4000], Content Loss: 33.2919, Style Loss: 19.9351
Step [1120/4000], Content Loss: 33.9825, Style Loss: 18.0039
Step [1200/4000], Content Loss: 34.6085, Style Loss: 16.3488
Step [1280/4000], Content Loss: 35.1890, Style Loss: 14.9166
Step [1360/4000], Content Loss: 35.