# Installing necessary packages

In [1]:
!pip install deeplake

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting deeplake
  Downloading deeplake-3.5.3.tar.gz (494 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m494.1/494.1 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting boto3 (from deeplake)
  Downloading boto3-1.26.142-py3-none-any.whl (135 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m135.6/135.6 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
Collecting pathos (from deeplake)
  Downloading pathos-0.3.0-py3-none-any.whl (79 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.8/79.8 kB[0m [31m785.3 kB/s[0m eta [36m0:00:00[0m
[?25hCollecting humbug>=0.3.1 (from deeplake)
  Downloading humbug-0.3.1-py3-none-any.whl (15 kB)
Collecting numcodecs (from deeplake)
  Downloading numcodecs-0.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.7 MB)
[2K 

In [2]:
# from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import deeplake
import torch.utils.data as data
from PIL import Image, ImageFile
# from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms, models
from tqdm import tqdm
device = torch.device('cuda')
# from sampler import InfiniteSamplerWrapper

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

# Hyper parameters

In [None]:
LR = 1e-4
B = 4
LR_DECAY = 2e-7
SCALE_FACTOR = 10
MAX_ITER = 10000
patch_size = 8
patch_stride = 8
C = 256 #channel dimension of relu3_1 for input image of 256, 256
OUTPUT_SIZE = (64, 64)
D = C*patch_size**2 #spatial dimensions of relu3_1 for input image of 256, 256

# Define data loader

In [None]:
ds_c = deeplake.load('hub://activeloop/coco-train') # Deep Lake Dataset

ds_s = deeplake.load('hub://activeloop/wiki-art')


tform = transforms.Compose([
    transforms.ToPILImage(), 
    # Must convert to PIL image for subsequent operations to run
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.expand(3, x.shape[1], x.shape[2])), 
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
     # Must convert to pytorch tensor for subsequent operations to run
])

# content Dataloader


In [None]:
def train_transform():
    transform_list = [
        transforms.ToPILImage(), 
        transforms.Resize(size=(512, 512)),
        transforms.RandomCrop(256),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.expand(3, x.shape[1], x.shape[2]))
    ]
    return transforms.Compose(transform_list)

def adjust_learning_rate(optimizer, iteration_count):
    """Imitating the original implementation"""
    lr = LR / (1.0 + LR_DECAY * iteration_count)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


# Essential operations

In [None]:
decoder = nn.Sequential(
    # nn.ReflectionPad2d((1, 1, 1, 1)),
    # nn.Conv2d(512, 256, (3, 3)),
    # nn.ReLU(),
    # nn.Upsample(scale_factor=2, mode='nearest'),
    # nn.ReflectionPad2d((1, 1, 1, 1)),
    # nn.Conv2d(256, 256, (3, 3)),
    # nn.ReLU(),
    # nn.ReflectionPad2d((1, 1, 1, 1)),
    # nn.Conv2d(256, 256, (3, 3)),
    # nn.ReLU(),
    # nn.ReflectionPad2d((1, 1, 1, 1)),
    # nn.Conv2d(D, 256, (3, 3)),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 128, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 64, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 3, (3, 3)),
)

In [None]:
vgg = nn.Sequential(
    nn.Conv2d(3, 3, (1, 1)),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(3, 64, (3, 3)),
    nn.ReLU(),  # r elu1-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),  # relu1-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 128, (3, 3)),
    nn.ReLU(),  # relu2-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),  # relu2-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 256, (3, 3)),
    nn.ReLU(),  # relu3-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 512, (3, 3)),
    nn.ReLU(),  # relu4-1, this is the last layer used
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU()  # relu5-4
)

In [None]:
import gc


class GATv2Layer(nn.Module):
    def __init__(self, in_features:int, out_features:int, 
    n_heads: int, is_concat: bool = True,
    dropout: float  = 0.6,
    leaky_relu_slop: float = 0.2,
    share_weights:bool = True
    ):
        super(GATv2Layer, self).__init__()
        self.is_concat = is_concat
        self.n_heads = n_heads
        self.share_weights = share_weights,
    
        if is_concat:
            # hidden_dim = 
            assert out_features % n_heads == 0
            self.hidden_dim = out_features//n_heads
        else:
            self.hidden_dim = out_features
        
        self.Key = nn.Linear(in_features, self.hidden_dim * n_heads, bias = False)
        if share_weights:
            self.Query = self.Key
        else:
            self.Query = nn.Linear(in_features, self.hidden_dim*n_heads, bias = False)
        
        self.attn = nn.Linear(self.hidden_dim, 1, bias = False)
        self.activation = nn.LeakyReLU(leaky_relu_slop)
        self.softmax = nn.Softmax(dim = 2)
        self.dropout = nn.Dropout(dropout)
    def forward(self, h, adj_mat):
        adj_mat = adj_mat.unsqueeze(3)
        if type(h) is tuple:
            content, style = h
            assert content.shape[0] == style.shape[0]
            h = torch.cat((content, style), dim = 1)
        h_shape0 = h.shape[0]
        num_nodes = h.shape[1] #since h = NxLxF, where N is batch, L is node, and F is feature vector dimensions respectively
        assert (True in torch.isnan(h)) is False
        key = self.Key(h).view(-1, num_nodes, self.n_heads, self.hidden_dim).to(device)
        query = self.Query(h).view(-1, num_nodes, self.n_heads, self.hidden_dim).to(device)
        # del h
        assert (True in torch.isnan(key)) is False
        # gc.collect()
        assert (True in torch.isnan(query)) is False
        torch.cuda.empty_cache()
        key_repeat = key.repeat(1, num_nodes, 1, 1) #change: repeat(1, 1, num_nodes, 1, 1) --> repeat(1, num_nodes, 1, 1)
        # del key
        # gc.collect()
        query_repeat = query.repeat_interleave(num_nodes, dim = 1).to(device)
        g_sum = (key_repeat + query_repeat).to(device)
        g_sum = g_sum.view(-1, num_nodes, num_nodes, self.n_heads, self.hidden_dim).to(device)


        score = self.attn(self.activation(g_sum)).squeeze(-1).to(device)
        assert (True in torch.isnan(score)) is False
        # del g_sum
        # gc.collect()
        torch.cuda.empty_cache()
        assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == h_shape0
        assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == num_nodes
        assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == num_nodes
        assert adj_mat.shape[3] == 1 or adj_mat.shape[3] == self.n_heads
        score = score.masked_fill(adj_mat == 0, float('-inf')).to(device)

        attention = self.softmax(score).to(device)
        assert (True in torch.isnan(score)) is False

        attention = self.dropout(attention).to(device)
        l1, l2, l3, l4 = attention.shape
        assert (True in torch.isnan(attention)) is False
        attention2 = torch.eye(num_nodes).unsqueeze(0).unsqueeze(3).expand(l1, -1, -1, l4)
        assert attention.shape == attention2.shape
        # del score
        # print(attention, 'ATTENTION HERE!')
        attn_res = torch.einsum('nijh,njhf->nihf', attention, query).to(device) #check this part again
        assert (True in torch.isnan(attn_res)) is False
        checking = attention.clone().detach()
        # print(checking.reshape(-1).mean(), 'CHECKING MEAN VALUE HERE')

        if self.is_concat:
            return attn_res.reshape(h_shape0, num_nodes, self.n_heads * self.hidden_dim).to(device)
        else:
            return attn_res.mean(dim=2).to(device)


In [None]:
def patch2feat(image, unfold):
    #dimension of image: NxCxHxW
    #patch dimension: K1xK2 = K
    #num_of_patches for image = P
    patches = unfold(image).to(device) #image size: NxC*KxP
    return patches.transpose(1, 2) #dimensions: NxLxC*K
    # change: should do NxPxC*K so I will delete the first transpose

def feat2patch(feature,
              #  weight,
               fold):
    # if patch_size is tuple:
    #   k1, k2 = patch_size[0], patch_size[1]
    # else:
    #   k1, k2 = patch_size, patch_size
    # h, w = output_size[0], output_size[1]
    # N, L, D = feature.size(0), feature.size(1), feature.size(2)
    # C = D/(k1*k2)

    
    #feature dimension: PxNxC*K
    # feature = feature.transpose(0,1)
    #after transpose NxPxC*K
    # weight = torch.randn((int(D), int(C), int(k1), int(k2)), requires_grad = True).to(device)
    # bias = torch.randn((1, int(C), int(h), int(w)), requires_grad = True).to(device)
    # feature_convolved = feature.matmul(weight.view(weight.size(0), -1).t()).transpose(1, 2).to(device)
    # print(feature_convolved.shape)
    # fold = torch.nn.Fold(output_size, patch_size)
    feature_summed = fold(feature.transpose(1, 2)).to(device)
    #  + bias.expand(N, -1 ,-1, -1).to(device)
    return feature_summed 


In [None]:
# x = torch.randn(2, 256, 32, 32)
# unfold = nn.Unfold(patch_size, stride = patch_stride)
# output = patch2feat(x, unfold)
# print(output.shape)
# fold = nn.Fold((32,32), 8, stride =8)
# weight = weight = torch.randn((int(C*patch_size*patch_size), int(C), int(patch_size), int(patch_size))).to(device)
# print(feat2patch(output, weight, fold).shape)

In [None]:
def knn(t, k):
    k += 1
    if type(t) is tuple:
        content, style = t
        assert style.shape[0] == content.shape[0]
        content = content / torch.norm(content, dim=2, keepdim=True) 
        style = style / torch.norm(style, dim=2, keepdim=True)
        nc, l, f = content.shape #nc, l, f is batch, node, and feature vector dimensions respectively
        ns, p, f = style.shape
        
        all = torch.cat((content, style), dim = 1).to(device)

        similarity = torch.matmul(all, all.transpose(1, 2)).to(device) # Ncx(l+p)x(l+p)

        similarity[:, l:, l:] = float('-inf')
        similarity[:, :l, :l] = float('-inf')




        _, indices = torch.topk(similarity, k, 1, True)
        indices = indices.transpose(1,2).to(device)

        total = l+p
        adj_matrix = torch.eye(total).unsqueeze(0).expand(nc, -1, -1).to(device)
        adj_matrix = adj_matrix.scatter_(2, indices, 1).to(device)

        
        return adj_matrix
    else:
        content = t
        content = content / torch.norm(content, dim=2, keepdim=True) 

        nc, l, f = content.shape #nc, l, f is batch, node, and feature vector dimensions respectively
        similarity = torch.matmul(content, content.transpose(1, 2)) # Ncx(l)x(l)
        _, indices = torch.topk(similarity, k, 1, True)
        indices = indices.transpose(1,2).to(device)
        adj_matrix = torch.zeros(nc, l, l).to(device)
        adj_matrix = adj_matrix.scatter_(2, indices, 1).to(device)
        return adj_matrix

In [None]:
# content = torch.randn((1, 2, 1)).to(device)
# style = torch.randn((1, 2, 1)).to(device)
# print(content)
# print(style)
# matrix = torch.tensor([[[0, 0, 0, 1],
#                         [0, 0, 1, 0],
#                         [0, 1, 0, 0],
#                         [1, 0, 0, 1]]])
# print(matrix.shape)
# gat = GATv2Layer(1, 1, 1).to(device)
# result = gat((content, style), matrix)
# print(result)
# print(result.shape)

In [None]:
# test = torch.randn((4, 3, 256,256))
# path2 = "/content/gdrive/My Drive/vgg_normalised.pth"
# vgg.load_state_dict(torch.load(path2), strict=False)
# vgg = nn.Sequential(*list(vgg.children())[:18])
# vgg.eval()
# output = vgg(test)
# print(output.shape)
# fold = nn.Fold(OUTPUT_SIZE, patch_size)
# unfold = nn.Unfold(patch_size, stride = patch_stride)
# output = unfold(output).transpose(1, 2) #patches of dimension: 
# print(output.shape)
# matrix = knn(output, 5)
# gat1 = GATv2Layer(16384, 16384, 4)
# output = gat1(output, matrix)
# print(output.shape)

In [None]:
def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std


def calc_mean_std_(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 3)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1)
    return feat_mean, feat_std


In [None]:
def adaptive_instance_normalization(content_feat, style_feat):
    assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)

def adaptive_patch_normalization(x, y):
    assert (x.shape == y.shape)
    style_patch_mean, style_patch_std = calc_mean_std_(y)
    content_patch_mean, content_patch_std = calc_mean_std_(x)
    normalized_content = (x - content_patch_mean.expand(x.size())) / content_patch_std.expand(x.size())
    return normalized_content * style_patch_std.expand(x.size()) + style_patch_mean.expand(x.size())


# Build the complete network

In [None]:
class Net(nn.Module):
    def __init__(self, encoder, decoder,
                 gat1, gat2,
                 fold, unfold):
        super(Net, self).__init__()
        enc_layers = list(encoder.children())
        self.enc_1 = nn.Sequential(*enc_layers[:4])  # input -> relu1_1
        self.enc_2 = nn.Sequential(*enc_layers[4:11])  # relu1_1 -> relu2_1
        self.enc_3 = nn.Sequential(*enc_layers[11:18])  # relu2_1 -> relu3_1
        self.enc_4 = nn.Sequential(*enc_layers[18:31])  # relu3_1 -> relu4_1
        self.decoder = decoder
        self.gat1 = gat1
        self.gat2 = gat2
        self.fold = fold
        self.unfold = unfold
        self.pool = nn.MaxPool2d(2, 2)
        self.upsample = nn.Upsample(scale_factor=2, mode = 'nearest')

        # self.fold = nn.Fold()
        self.mse_loss = nn.MSELoss()

        # fix the encoder
        for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
            for param in getattr(self, name).parameters():
                param.requires_grad = False

    # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
    def encode_with_intermediate(self, input):
        results = [input]
        for i in range(4):
            func = getattr(self, 'enc_{:d}'.format(i + 1))
            results.append(func(results[-1]))
        return results[1:]

    # extract relu4_1 from input image
    def encode(self, input):
        for i in range(4):
            input = getattr(self, 'enc_{:d}'.format(i + 1))(input)
        return input

    def calc_content_loss(self, input, target):
        assert (input.size() == target.size())
        assert (target.requires_grad is False)
        return self.mse_loss(input, target)

    def calc_style_loss(self, input, target):
        assert (input.size() == target.size())
        assert (target.requires_grad is False)
        input_mean, input_std = calc_mean_std(input)
        target_mean, target_std = calc_mean_std(target)
        return self.mse_loss(input_mean, target_mean) + \
               self.mse_loss(input_std, target_std)

    def forward(self, content, style, patch_size, patch_stride, k, alpha=1.0,):
        assert 0 <= alpha <= 1
        style_feats = self.encode_with_intermediate(style) #relu1_1, relu2_1, relu3_1, relu4_1
        content_feats = self.encode_with_intermediate(content) #relu1_1, relu2_1, relu3_1, relu4_1

        # #assigning spatial sizes of the content for feat2patch later on
        _, _, H, W = content_feats[-2].shape

        content3_1 = content_feats[-2]
        style3_1 = style_feats[-2]
        #patch2feat to construct the node features
        content_patches = patch2feat(content3_1, self.unfold) #we take [-2] of feats since
        style_patches = patch2feat(style3_1, self.unfold)
        #knn to construct the adjacency matrix for style to content message passing
        style_matrix = knn((content_patches, style_patches), k)
        updated_content_patches = self.gat1((content_patches,style_patches), style_matrix)[:, :content_patches.shape[1]]

        #knn to construct the adjacency matrix for content to content message passing
        content_matrix = knn(updated_content_patches, k)
        final_content_patches = self.gat2(updated_content_patches, content_matrix)

        #feat2patch to reconstruct the imagelike content features
        final_content_feat = feat2patch(final_content_patches,
                                        # self.weight, 
                                        self.fold)
  
        #feature refinement
        t = adaptive_instance_normalization(final_content_feat, style_feats[-2])
        t = alpha * t + (1 - alpha) * final_content_feat

        g_t = self.decoder(t)
        g_t_feats = self.encode_with_intermediate(g_t)

        t_grad_false = t.clone().detach()
        t_grad_false.requires_grad = False
        # print(t_grad_false.requires_grad)
        loss_c = self.calc_content_loss(g_t_feats[-2], t_grad_false)
        loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
        for i in range(1, 4):
            loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
        return loss_c, loss_s

In [None]:
fold = nn.Fold((64,64), patch_size, stride = patch_stride)
unfold = nn.Unfold(patch_size, stride = patch_stride)
#loading parameters of vgg
path2 = "/content/gdrive/My Drive/vgg_normalised.pth"
vgg.load_state_dict(torch.load(path2), strict=False)
vgg = nn.Sequential(*list(vgg.children())[:31])

#initing parameter for fold 
# weight = torch.randn((int(C*patch_size*patch_size), int(C), int(patch_size), int(patch_size))).to(device)
# conv_weight = nn.Parameter(weight)

#initing GAT layers
gat1 = GATv2Layer(16384, 16384, 4).to(device)
gat2 = GATv2Layer(16384, 16384, 4).to(device)

# Training

In [None]:
torch.cuda.empty_cache()

In [None]:
network = Net(vgg, decoder,
              gat1, gat2,
              # conv_weight,
              fold,
              unfold)
network.train()
network.to(device)

content_tf = train_transform()
style_tf = train_transform()

c_loader= ds_c.pytorch(batch_size = B, transform = {'images': content_tf}, shuffle = False)
c_iter = iter(c_loader)

# writer = SummaryWriter(log_dir=str(log_dir))

# style dataloader
s_loader= ds_s.pytorch(batch_size = B, transform = {'images': style_tf, 'labels': None}, shuffle = False)
s_iter = iter(s_loader)



optimizer = torch.optim.Adam([{'params':network.decoder.parameters(), 'lr':LR},
                              {'params':network.gat1.parameters(), 'lr':LR},
                              {'params':network.gat2.parameters(), 'lr':LR}
                              # {'params':network.weight, 'lr':LR}
                              ])


for i in tqdm(range(MAX_ITER)):
    adjust_learning_rate(optimizer, iteration_count=i)
    content_images = next(c_iter)['images'].to(device)
    try:
      style_images = next(s_iter)['images'].to(device)
    except StopIteration:
      s_iter = iter(s_loader)
      style_images = next(s_iter)['images'].to(device)
    loss_c, loss_s = network(content_images, style_images, patch_size, patch_stride,3)
    #   def forward(self, content, style, patch_size, patch_stride, k, symm, alpha=1.0,):
    loss_c = loss_c
    loss_s =  SCALE_FACTOR * loss_s
    loss = loss_c + loss_s 
    torch.cuda.empty_cache()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


# # #Loading decoder to drive
# # decoder_name = 'decoder.pt'
# # path = F"/content/gdrive/My Drive/{decoder_name}" 
# # torch.save(decoder.state_dict(), path)
# # #loading GAT layers to drive
# # gat1_name = 'gat1.pt'
# # path = F"/content/gdrive/My Drive/{gat1_name}" 
# # torch.save(gat1.state_dict(), path)

# # gat2_name = 'gat2.pt'
# # path = F"/content/gdrive/My Drive/{gat2_name}" 
# # torch.save(gat2.state_dict(), path)

# # #Loading parameters of weight:
# # weight_name = 'weight.pt'
# # path = F"/content/gdrive/My Drive/{weight_name}" 
# # torch.save(conv_weight.state_dict(), path)

# #     writer.add_scalar('loss_content', loss_c.item(), i + 1)
# #     writer.add_scalar('loss_style', loss_s.item(), i + 1)
# # writer.close()

In [None]:
# x = torch.randn((1, 256, 64, 64)).to(device)
# x_patch = patch2feat(x, unfold)
# reconstructed_x = feat2patch(x_patch, fold)
# (x - reconstructed_x).abs().max()

# Inference

In [None]:
# pool = nn.MaxPool2d(2, 2)
# upsample = nn.Upsample(scale_factor=2, mode = 'nearest')
def graph_part(content, style):
  content3_1 = content
  style3_1 = style
  assert (True in torch.isnan(content3_1)) is False
  assert (True in torch.isnan(style3_1)) is False
  #patch2feat to construct the node features
  content_patches = patch2feat(content3_1, unfold) #we take [-2] of feats since
  style_patches = patch2feat(style3_1, unfold)
  ## knn to construct the adjacency matrix for style to content message passing
  style_matrix = knn((content_patches, style_patches), 5).to(device)
  content_matrix = knn(content_patches, 5).to(device)
  assert (True in torch.isnan(content_patches)) is False
  assert (True in torch.isnan(style_patches)) is False
  updated_content_patches = gat1((content_patches,style_patches), style_matrix)[:, :content_patches.shape[1]]
  assert (True in torch.isnan(updated_content_patches)) is False
  ##knn to construct the adjacency matrix for content to content message passing
  final_content_patches = gat2(updated_content_patches, content_matrix)

  #feat2patch to reconstruct the imagelike content features
  final_content_feat = feat2patch(content_patches,
                                  # weight,
                                  fold)
  return final_content_feat

In [None]:
def style_transfer(vgg, decoder, content, style, alpha=1.0,
                   interpolation_weights=None):
    assert (0.0 <= alpha <= 1.0)
    content_f = vgg(content)
    style_f = vgg(style)
    content_f = graph_part(content_f, style_f)
    if interpolation_weights:
        _, C, H, W = content_f.size()
        feat = torch.FloatTensor(1, C, H, W).zero_().to(device)
        base_feat = adaptive_instance_normalization(content_f, style_f)
        for i, w in enumerate(interpolation_weights):
            feat = feat + w * base_feat[i:i + 1]
            feat.to(device)
        content_f = content_f[0:1]
    else:
        feat = adaptive_instance_normalization(content_f, style_f)
    feat = feat * alpha + content_f * (1 - alpha)
    return decoder(feat)

In [None]:
from PIL import Image
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from google.colab import files
decoder.eval()
decoder.to(device)

vgg.eval()
path2 = "/content/gdrive/My Drive/vgg_normalised.pth"
vgg.load_state_dict(torch.load(path2), strict=False)
vgg = nn.Sequential(*list(vgg.children())[:18])
vgg.to(device)

# Upload the first image file
uploaded = files.upload()

# Get the uploaded image file name
image_filename1 = next(iter(uploaded))

# Upload the second image file
uploaded = files.upload()

# Get the uploaded image file name
image_filename2 = next(iter(uploaded))

# Open the first uploaded image
image1 = Image.open(image_filename1)

# Open the second uploaded image
image2 = Image.open(image_filename2)

# Preprocess the images
preprocess = transforms.Compose([
    transforms.Resize(size=(256, 256)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.expand(3, x.shape[1], x.shape[2]))
])

input_tensor1 = preprocess(image1)
input_tensor2 = preprocess(image2)

In [None]:
 # Concatenate along the batch dimension

# Assuming you have already defined and loaded your model
content = input_tensor1.unsqueeze(0).to(device)
style = input_tensor2.unsqueeze(0).to(device)
assert style is not None


output_image = style_transfer(vgg, decoder, content, style, alpha = 0.8).detach().to('cpu')
output_image = output_image.squeeze(0).permute(1, 2, 0)
# Show the input and output images
fig, axs = plt.subplots(1, 3)
axs[0].imshow(image1)
axs[0].set_title("image1")
axs[0].axis("off")

axs[1].imshow(image2)
axs[1].set_title("image2")
axs[1].axis("off")

axs[2].imshow(output_image)
axs[2].set_title("Output Image")
axs[2].axis("off")

plt.show()


# Pyramid features (not implemented because of GPU mem)

In [None]:
import torchvision.models as models
import torch
def pyramid_feature(image): # we use a 4 by 4 grid
  model = models.vgg19(pretrained=True)
  enc_1 = model.features[:18]  # relu3_1
  enc_2 = model.features[18:27]  #relu4_1
  feat_1 = enc_1(image)
  feat_2 = enc_2(feat_1)
  feat_1 = extract_patches(feat_1, 14, 14)
  feat_2 = extract_patches(feat_2, 7, 7)
  return torch.cat([feat_1,feat_2], dim = 2)
  
print(pyramid_feature(torch.randn(2, 3, 224, 224)).shape) # torch.Size([2, 16, 75264])