## Path GAN Discriminator Set Up

In [18]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from torchvision import models
from torchsummary import summary

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

In [69]:
class PathGAN_D(nn.Module):
    def __init__(self,
                 reduced, in_vecs = 32, lstm_actv='tanh'):
        super(PathGAN_D, self).__init__()

        self.reduced = reduced
        self.input_dim = 3 if reduced else 4

        if lstm_actv == 'tanh':
            self.lstm_actv = nn.Tanh()
        else:
            self.lstm_actv = nn.Identity()
        
        self.lstm_1 = nn.LSTM(input_size=self.input_dim, hidden_size=500, batch_first=True)
        self.bn1 = nn.BatchNorm1d(in_vecs)

        self.lstm2 = nn.LSTM(input_size=3000, hidden_size=100, batch_first=True)
        self.bn2 = nn.BatchNorm1d(in_vecs)

        self.lstm3 = nn.LSTM(input_size=100, hidden_size=100, batch_first=True)
        self.bn3 = nn.BatchNorm1d(in_vecs)

        self.lstm4 = nn.LSTM(input_size=100, hidden_size=100, batch_first=True)
        self.bn4 = nn.BatchNorm1d(in_vecs)

        self.lstm5 = nn.LSTM(input_size=100, hidden_size=1, batch_first=True)
        self.sigmoid = nn.Sigmoid()

        # Pre-trained VGG16 model
        self.vgg = models.vgg16(weights='DEFAULT')
        self.vgg_features = nn.Sequential(*list(self.vgg.features.children())[:])
        for param in self.vgg_features.parameters():
            param.requires_grad = False

        self.conv = nn.Conv2d(512, 100, kernel_size=3, stride=1, padding=0)
        self.leaky_relu = nn.LeakyReLU(0.3)
        # self.flatten = nn.Flatten()

        # if weights is not None:
        #     self.load_state_dict(torch.load(weights))


    def forward(self, x, img_input):
        # Scanpath input
        print("in:", x.shape)
        x, (h1, c1) = self.lstm_1(x)
        x = self.lstm_actv(x)
        # print(x)
        x = self.bn1(x)
        print("bn1:", x.shape)

        # Image input
        z = self.vgg_features(img_input)
        print("z1:", z.shape)
        z = self.conv(z)
        print("z2:", z.shape)
        z = self.leaky_relu(z)
        print("z3:",z.shape)
        z = z.view(z.shape[0], -1)
        print("z4 (after flatten):",z.shape)
        z = z.unsqueeze(1).repeat(1, 32, 1)
        print("z5:",z.shape)

        # Merge
        print("Before merge (x,z)", x.shape, z.shape)
        x = torch.cat([x, z], dim=-1)
        print("cat shape:", x.shape)
        x, (h2, c2) = self.lstm2(x) 
        # Not passing (h1, c1) in above line to let the initial LSTM be independent from the image input
        x = self.lstm_actv(x)
        print(x.shape)
        x = self.bn2(x)
        print(x.shape)
        x, (h3, c3) = self.lstm3(x, (h2, c2))
        x = self.lstm_actv(x)
        print(x.shape)
        x = self.bn3(x)
        print(x.shape)
        x, (h4, c4) = self.lstm4(x, (h3, c3))
        x = self.lstm_actv(x)
        print(x.shape)
        x = self.bn4(x)

        print("flow encoder out shape: ", x.shape)

        # NOTE:
        # Remove these two layers. Use 1024 x 50 features for flow encoder.

        # x, _ = self.lstm5(x)
        # x = self.lstm_actv(x)
        # print(x.shape)
        # x = self.sigmoid(x)
        # print(x.shape)

        return x, (h4, c4)

In [70]:
model = PathGAN_D(reduced=True, in_vecs=32).to(device)

In [71]:
x = torch.randn(16, 32, 3).to(device) # batch_size, seq_len, feature_dim. NOTE: Seq_len can be variable.
img_input = torch.randn(16, 3, 224, 224).to(device)

out, hidden_embeddings = model(x, img_input)

in: torch.Size([16, 32, 3])
bn1: torch.Size([16, 32, 500])
z1: torch.Size([16, 512, 7, 7])
z2: torch.Size([16, 100, 5, 5])
z3: torch.Size([16, 100, 5, 5])
z4 (after flatten): torch.Size([16, 2500])
z5: torch.Size([16, 32, 2500])
Before merge (x,z) torch.Size([16, 32, 500]) torch.Size([16, 32, 2500])
cat shape: torch.Size([16, 32, 3000])
torch.Size([16, 32, 100])
torch.Size([16, 32, 100])
torch.Size([16, 32, 100])
torch.Size([16, 32, 100])
torch.Size([16, 32, 100])
flow encoder out shape:  torch.Size([16, 32, 100])


In [72]:
out.shape

torch.Size([16, 32, 100])

## Shape Reducer

### NOTE: Only this gets trained in the flow-prompt-Adapter (FP-Adapter)

In [73]:
class shapeReducerMLP(nn.Module):
    def __init__(self):
        super(shapeReducerMLP, self).__init__()
        self.flatten = nn.Flatten(start_dim=1)
        self.linear1 = nn.Linear(3200, 1600)
        self.linear2 = nn.Linear(1600, 1024)
        self.model = nn.Sequential(
            self.linear1,
            nn.LeakyReLU(0.3),
            self.linear2
        )

    def forward(self, x):
        x = self.flatten(x) # BS x 3200
        return self.model(x)

In [74]:
shape_reducer = shapeReducerMLP().to(device)

In [78]:
ans = shape_reducer(out)
ans.shape

torch.Size([16, 1024])

### Shallow Decoder for flowEncoder (PathGAN-D+ShapeReducerMLP)

In [59]:
class pathGAN_D_Decoder(nn.Module):
    def __init__(self, out_dim,  in_vecs = 32, latent_dim = 1024): # latent dim always fixed at 1024
        super(pathGAN_D_Decoder, self).__init__()

        # TODO: 
        # step 0: Get the input/output shapes right
        # step 1: Reshape BS x 1024 to BS x in_vecs x -1
        # step 2: LSTM decoding --> predict sequence of in_vecs x 3 (scanpaths) for entire BS

        self.out_dim = out_dim # Should be same as input_dim of PathGAN_D
        self.latent_dim = latent_dim

        # self.linear_init = nn.Linear(1, in_vecs)
        # self.linear_init_actv = nn.LeakyReLU(0.3)

        self.lstm1 = nn.LSTM(input_size=latent_dim, hidden_size=latent_dim // 4, batch_first=True)
        self.bn1 = nn.BatchNorm1d(in_vecs)

        self.lstm2 = nn.LSTM(input_size=latent_dim // 4, hidden_size=latent_dim // 16, batch_first=True)
        self.bn2 = nn.BatchNorm1d(in_vecs)

        self.lstm3 = nn.LSTM(input_size=latent_dim // 16, hidden_size=latent_dim // 64, batch_first=True)
        self.bn3 = nn.BatchNorm1d(in_vecs)

        self.lstm4 = nn.LSTM(input_size=latent_dim // 64, hidden_size=out_dim, batch_first=True)
        self.linear = nn.Linear(out_dim, out_dim)

        self.lstm_actv = nn.Tanh()

        # self.model = nn.Sequential(
        #     self.lstm1,
        #     self.lstm_actv,
        #     self.bn1,
        #     self.lstm2,
        #     self.lstm_actv,
        #     self.bn2,
        #     self.lstm3,
        #     self.lstm_actv,
        #     self.bn3,
        #     self.lstm4,
        #     self.lstm_actv,
        #     self.linear,
        # )


    def forward(self, x):
        
        print("Before all layers", x.shape)
        x = self.linear_init(x)
        print("Init. linear", x.shape)
        x = self.linear_init_actv(x)
        x = x.transpose(1, 0)
        print(x.shape)

        print("Starting the Sequence decoder ...")
        x, (h0, c0) = self.lstm1(x)
        print(x.shape)
        x = self.lstm_actv(x)
        print(x.shape)
        x = self.bn1(x)
        print(x.shape)

        x, (h1, c1) = self.lstm2(x, (h0, c0))
        print(x.shape)
        x = self.lstm_actv(x)
        print(x.shape)
        x = self.bn2(x)
        print(x.shape)

        x, (h2, c2) = self.lstm3(x, (h1, c1))
        print(x.shape)
        x = self.lstm_actv(x)
        print(x.shape)
        x = self.bn3(x)
        print(x.shape)

        x, (h3, c3) = self.lstm4(x, (h2, c2))
        print(x.shape)
        x = self.lstm_actv(x)
        print(x.shape)
        x = self.linear(x)
        print(x.shape)

        return x

In [60]:
disc_decoder = pathGAN_D_Decoder(out_dim=3, in_vecs=32).to(device)

In [61]:
reconstructed = disc_decoder(ans.transpose(1,0)).transpose(1, 0)
reconstructed

Before all layers torch.Size([1024, 1])
Init. linear torch.Size([1024, 32])
torch.Size([32, 1024])
Start the Sequence decoder
torch.Size([32, 256])
torch.Size([32, 256])


RuntimeError: running_mean should contain 256 elements not 32

## CLIPVisionWithProjection - Shape

In [5]:
# !pip install -Uqq transformers accelerate diffusers

^C
[31mERROR: Operation cancelled by user[0m[31m
[0m

In [8]:
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm
2024-07-24 17:01:31.295692: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-07-24 17:01:31.316828: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [9]:
class ImageProjModel(torch.nn.Module):
    """Project for Cross-Attenuation with Text Embeddings"""

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)

    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens

In [10]:
# !mkdir models/image_encoder

In [11]:
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
image_encoder_path = "./models/image_encoder"
# ip_ckpt = "models/ip-adapter_sd15.bin"

In [12]:
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
clip_image_processor = CLIPImageProcessor()
image_encoder = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path)

In [13]:
image_proj_model = ImageProjModel(
        cross_attention_dim=unet.config.cross_attention_dim,
        clip_embeddings_dim=image_encoder.config.projection_dim,
        clip_extra_context_tokens=4,
    )

In [14]:
image_encoder.config.projection_dim

1024

In [17]:
image_path = "test_img.png"
raw_image = Image.open(image_path)
clip_image = clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
image_embeds = image_encoder(clip_image).image_embeds
print("1", image_embeds.shape)
image_embeds_ = []
for image_embed, drop_image_embed in zip(image_embeds, [0]*len(image_embeds)):
  if drop_image_embed == 1:
    image_embeds_.append(torch.zeros_like(image_embed))
  else:
    image_embeds_.append(image_embed)

image_embeds = torch.stack(image_embeds_)
print("2", image_embeds.shape)

1 torch.Size([1, 1024])
2 torch.Size([1, 1024])
