In [1]:
!pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-lvd51mc6
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-lvd51mc6
  Resolved https://github.com/openai/CLIP.git to commit a1d071733d7111c9c014f024669f959182114e33
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting ftfy (from clip==1.0)
  Downloading ftfy-6.2.0-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.2.0-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.4/54.4 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25ldone
[?25h  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369497 sha256=148de58213793cf13a6c07182382281beafd99f694276151c321baba64a107fa
  Stored in directory: /tmp/pip-ephem-wheel-cache-o507ygtg/wheels/da/2b/4c/d6691fa9597aac8bb

In [2]:
import logging
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from clip import clip
import kornia.losses as losses


from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
from PIL import Image
import numpy as np
import pandas as pd

from glob import glob
import os

In [3]:
class MLP(nn.Module):
    def __init__(self, input_size=2, hidden_size=64, num_layers = 5, output_size=3):
        super(MLP, self).__init__()
        
        self.input = nn.Linear(input_size, hidden_size)
        self.main_activation = nn.ReLU()
        
        model = []
        for i in range(num_layers):
            model += [nn.Linear(hidden_size, hidden_size),
                      self.main_activation
                     ]
        model += [nn.Linear(hidden_size, output_size),nn.Tanh()]
        self.model = nn.Sequential(*model)

    def set_parameters(self, estimated_parameters):
        i = 0
        for name, param in self.named_parameters():
            if 'weight' in name:
                self.state_dict()[name].copy_(estimated_parameters[i])
                i += 1
            elif 'bias' in name:
                self.state_dict()[name].copy_(estimated_parameters[i])
                i += 1
                
    def forward(self, x):
        x = self.main_activation(self.input(x))
        x = self.model(x)
        return x
    
def get_parameter_shapes(model):
    param_shapes = {}
    for name, param in model.named_parameters():
        if 'weight' in name:
            param_shapes[name[:-7] + ' weights'] = param.shape
        elif 'bias' in name:
            param_shapes[name[:-5] + ' biases'] = param.shape
    return param_shapes

In [4]:
import torch.nn.utils.spectral_norm as spectral_norm

def get_nonspade_norm_layer(opt, norm_type='instance'):
    # helper function to get # output channels of the previous layer
    def get_out_channel(layer):
        if hasattr(layer, 'out_channels'):
            return getattr(layer, 'out_channels')
        return layer.weight.size(0)

    # this function will be returned
    def add_norm_layer(layer):
        nonlocal norm_type
        if norm_type.startswith('spectral'):
            layer = spectral_norm(layer)
            subnorm_type = norm_type[len('spectral'):]
        else:
            subnorm_type = norm_type

        if subnorm_type == 'none' or len(subnorm_type) == 0:
            return layer

        # remove bias in the previous layer, which is meaningless
        # since it has no effect after normalization
        if getattr(layer, 'bias', None) is not None:
            delattr(layer, 'bias')
            layer.register_parameter('bias', None)

        if subnorm_type == 'batch':
            norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
        elif subnorm_type == 'sync_batch':
            norm_layer = nn.SyncBatchNorm(get_out_channel(layer), affine=True)
        elif subnorm_type == 'instance':
            norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
        elif subnorm_type == 'instanceaffine':
            norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=True)
        else:
            raise ValueError('normalization layer %s is not recognized' % subnorm_type)

        return nn.Sequential(layer, norm_layer)

    return add_norm_layer

# 5 means nothing here just for initiliazation lol
norm_layer = get_nonspade_norm_layer(5)

In [5]:
class FeatureExtractionModel(nn.Module):
    def __init__(self, param_shapes,norm_layer = norm_layer,ds_repititions = 4,width = 64,num_out = 1000):
        super(FeatureExtractionModel, self).__init__()
        
        # Load the CLIP model
        self.clip, _ = clip.load("ViT-B/32", device="cpu")
        model = []
        
        #the input channels is 2 because clip embedding goes from (bs,512) --> (bs,2,16,16)
        model += [norm_layer(nn.Conv2d(2, width, 3, stride=1, padding=1)),
                  nn.ReLU(inplace=True)]
        
        # this is a learned pointwise convolution repeated 4 times, helps in learning features         
        for i in range(ds_repititions):
            model += [norm_layer(nn.Conv2d(width, width, 3, stride=1, padding=1)),
          nn.ReLU(inplace=True)]
        
        # num outputs is supposed to be the number of MLP parameters that we want to estimate          
        model += [nn.Conv2d(width, num_out, 1)]
        
        self.model = nn.Sequential(*model)

    def forward(self, x):
        # Extract features using the CLIP model
        x = self.clip.encode_image(x)
        x = x.view(-1, 2, 16, 16)
        
        x = self.model(x)
        return x

In [24]:
class Pix2PixDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.image_paths = sorted([os.path.join(data_dir, file_name) for file_name in os.listdir(data_dir) if file_name.endswith('.jpg') or file_name.endswith('.png')])
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        width, height = image.size

        # Split the image vertically into two halves
        left_half = image.crop((0, 0, width // 2, height)) 
        right_half = image.crop((width // 2, 0, width, height))

        # Apply transformations if needed
        if self.transform:
            left_half = self.transform(left_half)
            right_half = self.transform(right_half)

        # Normalize the full-resolution source and target images
        source_fullres = left_half / 255.0
        target_highres = right_half / 255.0
        
                # Create the x, y coordinates
        x = torch.linspace(-1, 1, left_half.size(1))
        y = torch.linspace(-1, 1, left_half.size(2))
        xx, yy = torch.meshgrid(x, y)

        # Concatenate the image and coordinates
        source_fullres = torch.cat([left_half, xx.unsqueeze(0), yy.unsqueeze(0)], dim=0)
#         .view(-1, 5)
#         target_highres = target_highres.view(-1,3)
        
        # Resize the source image to 224x224 and normalize for ResNet input
        source_lowres = transforms.Resize((224, 224))(left_half)
        source_lowres = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(source_lowres)

        return source_lowres, source_fullres, target_highres

In [25]:
# Define transformations if needed
transform = transforms.Compose([
    transforms.ToTensor(),
    # Add other transformations if needed
])
batch_size = 64

# Create the dataset
dataset = Pix2PixDataset(data_dir="/kaggle/input/pix2pix-dataset/cityscapes/cityscapes/train/", transform=transform)

# Create a data loader
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [26]:
for source_lowres, source_fullres, target_highres in data_loader:
    break

In [27]:
source_lowres.shape,source_fullres.shape,target_highres.shape

(torch.Size([64, 3, 224, 224]),
 torch.Size([64, 5, 256, 256]),
 torch.Size([64, 3, 256, 256]))

In [33]:
k = 16
bs = 64
nci = source_fullres.shape[1]
print(source_fullres.shape)
# bs, 5 rgbxy, h//k=h_lr, w//k=w_lr, k, k
tiles = source_fullres.unfold(2, k, k).unfold(3, k, k)
print(tiles.shape)

h_lr = source_fullres.shape[2] // k
w_lr = source_fullres.shape[3] // k

tiles = tiles.permute(0, 2, 3, 4, 5, 1).contiguous().view(
    bs, h_lr, w_lr, int(k * k), nci)
out = tiles

torch.Size([64, 5, 256, 256])
torch.Size([64, 5, 16, 16, 16, 16])


In [34]:
out.shape

torch.Size([64, 16, 16, 256, 5])

In [11]:
# Usage example
mlp = MLP()
parameter_shapes = get_parameter_shapes(mlp)
outs = sum(p.numel() for p in mlp.parameters())
feature_extraction_model = FeatureExtractionModel(parameter_shapes,num_out = outs)

input_vector = torch.randn(1,3,224,224)
estimated_parameters = feature_extraction_model(input_vector)

100%|███████████████████████████████████████| 338M/338M [00:05<00:00, 59.5MiB/s]


In [18]:
parameter_shapes.items()

dict_items([('input weights', torch.Size([64, 2])), ('input biases', torch.Size([64])), ('model.0 weights', torch.Size([64, 64])), ('model.0 biases', torch.Size([64])), ('model.2 weights', torch.Size([64, 64])), ('model.2 biases', torch.Size([64])), ('model.4 weights', torch.Size([64, 64])), ('model.4 biases', torch.Size([64])), ('model.6 weights', torch.Size([64, 64])), ('model.6 biases', torch.Size([64])), ('model.8 weights', torch.Size([64, 64])), ('model.8 biases', torch.Size([64])), ('model.10 weights', torch.Size([3, 64])), ('model.10 biases', torch.Size([3]))])

In [39]:
source_fullres.shape

torch.Size([64, 5, 256, 256])

In [56]:
def apply_mlp(source_fullres, lr_params, parameter_shapes, k=16):
    """
    Apply the MLP to the lowres tiles of the input image.
    
    Args:
        source_fullres (torch.Tensor): The full-resolution input image.
        lr_params (torch.Tensor): The flattened MLP parameters.
        parameter_shapes (dict): A dictionary containing the shapes of the MLP parameters.
        k (int): The tile size.
    
    Returns:
        torch.Tensor: The output of the MLP, with the same shape as the original image.
    """
    bs = source_fullres.size(0)
    h, w = source_fullres.size(2), source_fullres.size(3)
    
    # Compute the lowres height and width
    h_lr = h // k
    w_lr = w // k
    
    # Unfold the input image into tiles
    tiles = source_fullres.unfold(2, k, k).unfold(3, k, k)
    tiles = tiles.permute(0, 2, 3, 4, 5, 1).contiguous().view(bs, h_lr, w_lr, int(k * k), tiles.size(1))
    
    # Apply the MLP to the tiles
    out = tiles
    layer_idx = 0
    while f'model.{layer_idx} weights' in parameter_shapes:
        print(layer_idx)
        w_shape = parameter_shapes[f'model.{layer_idx} weights']
        b_shape = parameter_shapes[f'model.{layer_idx} biases']
#         print(w_shape)
        nci, nco = w_shape[1], w_shape[0]
#         print(lr_params)
        w_ = lr_params[:, :np.prod(w_shape)]
        b_ = lr_params[:, np.prod(w_shape):np.prod(w_shape) + np.prod(b_shape)]
        
        w_ = w_.permute(0, 2, 3, 1).view(bs, h_lr, w_lr, nci, nco)
        b_ = b_.permute(0, 2, 3, 1).view(bs, h_lr, w_lr, 1, nco)
        out = torch.matmul(out, w_) + b_
        
        if layer_idx < len(parameter_shapes) // 2 - 1:
            out = torch.nn.functional.leaky_relu(out, 0.01, inplace=True)
        else:
            out = torch.nn.functional.tanh(out)
        
        layer_idx += 1
    
    # Reorder the tiles and reshape the output
    out = out.view(bs, h_lr, w_lr, k, k, nco).permute(0, 5, 1, 3, 2, 4)
    out = out.contiguous().view(bs, nco, h, w)
    
    return out

In [69]:
21187/64**2

5.172607421875

In [66]:
def apply_mlp(source_fullres, lr_params, parameter_shapes, k=16):
    """
    Apply the MLP to the lowres tiles of the input image.
    
    Args:
        source_fullres (torch.Tensor): The full-resolution input image.
        lr_params (torch.Tensor): The flattened MLP parameters.
        parameter_shapes (dict): A dictionary containing the shapes of the MLP parameters.
        k (int): The tile size.
    
    Returns:
        torch.Tensor: The output of the MLP, with the same shape as the original image.
    """
    bs = source_fullres.size(0)
    h, w = source_fullres.size(2), source_fullres.size(3)
    
    # Compute the lowres height and width
    h_lr = h // k
    w_lr = w // k
    
    # Unfold the input image into tiles
    tiles = source_fullres.unfold(2, k, k).unfold(3, k, k)
    tiles = tiles.permute(0, 2, 3, 4, 5, 1).contiguous().view(bs, h_lr, w_lr, int(k * k), tiles.size(1))
    
    # Apply the MLP to the tiles
    out = tiles
    layer_idx = 0
    start_idx = 0
    while f'model.{layer_idx} weights' in parameter_shapes:
        w_shape = parameter_shapes[f'model.{layer_idx} weights']
        b_shape = parameter_shapes[f'model.{layer_idx} biases']
        
        nci, nco = w_shape[1], w_shape[0]
        
        w_size = np.prod(w_shape)
        b_size = np.prod(b_shape)
        
        w_ = lr_params[start_idx:start_idx+w_size]
        print(w_.shape)
        w_ = lr_params[start_idx:start_idx+w_size].view(w_shape)
        b_ = lr_params[start_idx+w_size:start_idx+w_size+b_size].view(b_shape)
        
        w_ = w_.view(bs, h_lr, w_lr, nci, nco)
        b_ = b_.view(bs, h_lr, w_lr, 1, nco)
        out = torch.matmul(out, w_) + b_
        
        if layer_idx < len(parameter_shapes) // 2 - 1:
            out = torch.nn.functional.leaky_relu(out, 0.01, inplace=True)
        else:
            out = torch.nn.functional.tanh(out)
        
        start_idx += w_size + b_size
        layer_idx += 1
    
    # Reorder the tiles and reshape the output
    out = out.view(bs, h_lr, w_lr, k, k, nco).permute(0, 5, 1, 3, 2, 4)
    out = out.contiguous().view(bs, nco, h, w)
    
    return out

In [70]:
(16**2)*(64**2)

1048576

In [72]:
np.prod(torch.Size([1, 21187, 16, 16]))/1048576

5.172607421875

In [67]:
pred = apply_mlp(source_fullres, estimated_parameters, parameter_shapes, k=16)

torch.Size([1, 21187, 16, 16])


RuntimeError: shape '[64, 64]' is invalid for input of size 5423872

In [19]:
for name,shape in parameter_shapes.items():
    print(name)
    print(shape)
    print(np.prod(shape))

input weights
torch.Size([64, 2])
128
input biases
torch.Size([64])
64
model.0 weights
torch.Size([64, 64])
4096
model.0 biases
torch.Size([64])
64
model.2 weights
torch.Size([64, 64])
4096
model.2 biases
torch.Size([64])
64
model.4 weights
torch.Size([64, 64])
4096
model.4 biases
torch.Size([64])
64
model.6 weights
torch.Size([64, 64])
4096
model.6 biases
torch.Size([64])
64
model.8 weights
torch.Size([64, 64])
4096
model.8 biases
torch.Size([64])
64
model.10 weights
torch.Size([3, 64])
192
model.10 biases
torch.Size([3])
3


In [None]:
estimated_parameters.shape

In [None]:
# Assuming you have already defined the feature_extraction_model
for param in feature_extraction_model.clip.parameters():
    param.requires_grad = False

In [None]:
def apply_mlp(source_fullres, lr_params, k=16, num_layers=None, channels=None):
    """
    Apply the MLP to the lowres tiles of the input image.
    
    Args:
        source_fullres (torch.Tensor): The full-resolution input image.
        lr_params (torch.Tensor): The flattened MLP parameters.
        k (int): The tile size.
        num_layers (int): The number of layers in the MLP.
        channels (list): The number of input and output channels for each layer.
    
    Returns:
        torch.Tensor: The output of the MLP, with the same shape as the original image.
    """
    bs = source_fullres.size(0)
    nci = source_fullres.size(1)
    
    # Compute the lowres height and width
    h_lr = source_fullres.size(2) // k
    w_lr = source_fullres.size(3) // k
    
    # Unfold the input image into tiles
    tiles = source_fullres.unfold(2, k, k).unfold(3, k, k)
    tiles = tiles.permute(0, 2, 3, 4, 5, 1).contiguous().view(bs, h_lr, w_lr, int(k * k), nci)
    
    # Apply the MLP to the tiles
    out = tiles
    for idx, nco in enumerate(channels[:-1]):
        nci = channels[idx]
        bstart, bstop = _get_bias_indices(idx, num_layers)
        wstart, wstop = _get_weight_indices(idx, num_layers)
        
        w_ = lr_params[:, wstart:wstop]
        b_ = lr_params[:, bstart:bstop]
        
        w_ = w_.permute(0, 2, 3, 1).view(bs, h_lr, w_lr, nci, nco)
        b_ = b_.permute(0, 2, 3, 1).view(bs, h_lr, w_lr, 1, nco)
        out = torch.matmul(out, w_) + b_
        
        if idx < num_layers - 1:
            out = torch.nn.functional.leaky_relu(out, 0.01, inplace=True)
        else:
            out = torch.nn.functional.tanh(out)
    
    # Reorder the tiles and reshape the output
    out = out.view(bs, h_lr, w_lr, k, k, channels[-1]).permute(0, 5, 1, 3, 2, 4)
    out = out.contiguous().view(bs, channels[-1], source_fullres.size(2), source_fullres.size(3))
    
    return out

In [None]:
criterion = nn.MSELoss()
# criterion = losses.ssim(window_size=11, reduction='mean', max_val=1.0)
optimizer = torch.optim.Adam(feature_extraction_model.parameters(), lr=0.001)

for batch in data_loader:
    source_lowres, source_fullres, target_highres = batch

    # Extract the unique parameters for each input in the batch
    outs = feature_extraction_model(source_lowres)
    for name,shape in outs.items():
        if 
    break

In [None]:
outs