# import

In [1]:
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
from transformers import AutoModel, CLIPProcessor
from models import WhereIsFeatures
from dataset import FolderData, COCOCaptionData
import tensorflow as tf
from torch.utils.data import DataLoader
from timm.scheduler.cosine_lr import CosineLRScheduler
from torch import nn
import torch
from matplotlib import pyplot as plt
from torchvision import transforms
from PIL import Image
import numpy as np
from torchvision import datasets

2023-05-13 11:25:04.167608: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-05-13 11:25:04.468762: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


# fns

In [2]:
def accuracy(text_embeds, image_embeds, labels):
    logits_per_image = torch.matmul(text_embeds, image_embeds.t()).t()
    probs = logits_per_image.softmax(dim=1)
    return (probs.argmax(1) == labels).float().mean()

def clip_tanh_accuracy(text_embeds, image_embeds, labels):
    logits_per_image = torch.matmul(text_embeds*2-1, (image_embeds*2-1).t()).t()
    probs = logits_per_image.softmax(dim=1)
    return (probs.argmax(1) == labels).float().mean()

# params

In [3]:
device = 'cuda'
n_epochs = 5
warmup = 4
num_workers = 4
batch_size = 16

# data

In [4]:
train_src = '/home/palm/data/coco/annotations/annotations/captions_train2017.json'
test_src = '/home/palm/data/dogs-vs-cats/train'

In [5]:
train_dataset = COCOCaptionData(train_src)
val_dataset = FolderData(test_src, size=224, mul=1)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [6]:
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [7]:
test_texts = []
for folder in sorted(os.listdir(test_src)):
    test_texts.append(f'a photo of a {folder}')
test_inputs = processor(text=test_texts, return_tensors="pt", padding=True)
test_texts

['a photo of a cat', 'a photo of a dog']

# modules

In [21]:
mse = nn.MSELoss()
sigmoid = nn.Sigmoid()

clip = AutoModel.from_pretrained('openai/clip-vit-base-patch32').to(device)
for param in clip.parameters():
    param.requires_grad = False
vision_model = clip.vision_model
visual_projection = clip.visual_projection
text_projection = clip.text_projection
test_prompts = clip.text_model(**test_inputs.to(device))
test_prompts = sigmoid(text_projection(test_prompts[1]))
model = WhereIsFeatures(1500)
model = model.to(device)

# autoencoder: encoder/decoder

In [22]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
schedule = CosineLRScheduler(optimizer,
                                t_initial=5,
                                t_mul=1,
                                lr_min=5e-5,
                                decay_rate=0.1,
                                cycle_limit=1,
                                t_in_epochs=False,
                                noise_range_t=None,
                                )
for epoch in range(2):
    print('Epoch:', epoch + 1)
    model.train()
    progbar = tf.keras.utils.Progbar(3000)
    for idx, (captions, cls) in enumerate(train_loader):
        with torch.no_grad():
            captions = processor(text=captions, return_tensors="pt", padding=True)
            features = clip.text_model(**captions.to(device))
            features = text_projection(features[1])
            features = sigmoid(features)

        x = model.encode(features)
        recon = model.decode(x)
        loss = mse(recon, features)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        printlog = [('loss', loss.cpu().detach().numpy()),
                    ]
        progbar.update(idx + 1, printlog)
        if idx > 2998:
            break
    model.eval()
    progbar = tf.keras.utils.Progbar(len(test_loader))
    with torch.no_grad():
        for idx, (image, _, cls) in enumerate(test_loader):
            image = image.to(device)
            cls = cls.to(device)
            features = vision_model(image)['pooler_output']
            features = visual_projection(features)
            features = sigmoid(features)
            std_acc = accuracy(test_prompts, features, cls)
            x = model.encode(features)
            recon = model.decode(x)
            recon_acc = accuracy(test_prompts, recon, cls)
            loss = mse(recon, features)
            printlog = [('loss', loss.cpu().detach().numpy()),
                        ('std_acc', std_acc.cpu().detach().numpy()),
                        ('recon_acc', recon_acc.cpu().detach().numpy()),
                        ]
            progbar.update(idx + 1, printlog)


Epoch: 1
Epoch: 2


# autoencoder: buffer nowhere

In [23]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
schedule = CosineLRScheduler(optimizer,
                             warmup_t=1,
                             warmup_lr_init=1e-5,
                             t_initial=n_epochs,
                             t_mul=1,
                             lr_min=5e-5,
                             decay_rate=0.1,
                             cycle_limit=1,
                             t_in_epochs=False,
                             noise_range_t=None,
                                )
for epoch in range(n_epochs):
    print('Epoch:', epoch + 1)
    model.train()
    progbar = tf.keras.utils.Progbar(3000)
    for idx, (captions, cls) in enumerate(train_loader):
        with torch.no_grad():
            captions = processor(text=captions, return_tensors="pt", padding=True)
            features = clip.text_model(**captions.to('cuda'))
            features = text_projection(features[1])
            features = sigmoid(features)
            x = model.encode(features)

        x, ecd, gt = model.where(x, False)
        loss = mse(x, gt)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        printlog = [('loss', loss.cpu().detach().numpy()),
                    ]
        progbar.update(idx + 1, printlog)
        if idx > 2998:
            break
    model.eval()
    progbar = tf.keras.utils.Progbar(len(test_loader))
    with torch.no_grad():
        for idx, (image, _, cls) in enumerate(test_loader):
            image = image.to(device)
            cls = cls.to(device)
            features = vision_model(image)['pooler_output']
            features = visual_projection(features)
            features = sigmoid(features)
            std_acc = accuracy(test_prompts, features, cls)
            prompts_ecd = model.encode(test_prompts)
            _, prompts_ecd, _ = model.where(prompts_ecd, False)
            x = model.encode(features)
            x, ecd, gt = model.where(x, False)
            buffer_acc = accuracy(prompts_ecd[:, 0], ecd[:, 0], cls)
            recon = model.decode(x)
            recon_acc = accuracy(test_prompts, recon, cls)
            loss = mse(x, gt)
            printlog = [('loss', loss.cpu().detach().numpy()),
                        ('std_acc', std_acc.cpu().detach().numpy()),
                        ('recon_acc', recon_acc.cpu().detach().numpy()),
                        ('buffer_acc', buffer_acc.cpu().detach().numpy()),
                        ]
            progbar.update(idx + 1, printlog)


Epoch: 1
Epoch: 2
Epoch: 3
Epoch: 4
Epoch: 5


# autoencoder: buffer where

In [27]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
schedule = CosineLRScheduler(optimizer,
                             warmup_t=1,
                             warmup_lr_init=1e-5,
                             t_initial=n_epochs,
                             t_mul=1,
                             lr_min=5e-5,
                             decay_rate=0.1,
                             cycle_limit=1,
                             t_in_epochs=False,
                             noise_range_t=None,
                                )
for epoch in range(n_epochs):
    print('Epoch:', epoch + 1)
    model.train()
    progbar = tf.keras.utils.Progbar(3000)
    for idx, (captions, cls) in enumerate(train_loader):
        with torch.no_grad():
            captions = processor(text=captions, return_tensors="pt", padding=True)
            features = clip.text_model(**captions.to('cuda'))
            features = text_projection(features[1])
            features = sigmoid(features)
            x = model.encode(features)

        x, ecd, gt = model.where(x, True)
        loss = mse(x, gt)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        printlog = [('loss', loss.cpu().detach().numpy()),
                    ]
        progbar.update(idx + 1, printlog)
        if idx > 2998:
            break
    model.eval()
    progbar = tf.keras.utils.Progbar(len(test_loader))
    with torch.no_grad():
        for idx, (image, _, cls) in enumerate(test_loader):
            image = image.to(device)
            cls = cls.to(device)
            features = vision_model(image)['pooler_output']
            features = visual_projection(features)
            features = sigmoid(features)
            std_acc = accuracy(test_prompts, features, cls)
            prompts_ecd = model.encode(test_prompts)
            _, prompts_ecd, _ = model.where(prompts_ecd, True)
            x = model.encode(features)
            x, ecd, gt = model.where(x, True)
            buffer_acc = accuracy(prompts_ecd[:, 0], ecd[:, 0], cls)
            buffer_acc_tanh = accuracy(prompts_ecd[:, 0]*2-1, ecd[:, 0]*2-1, cls)
            loss = mse(x, gt)
            printlog = [('loss', loss.cpu().detach().numpy()),
                        ('std_acc', std_acc.cpu().detach().numpy()),
                        ('sigmoid_acc', buffer_acc.cpu().detach().numpy()),
                        ('tanh_acc', buffer_acc_tanh.cpu().detach().numpy()),
                        ]
            progbar.update(idx + 1, printlog)


Epoch: 1
Epoch: 2
Epoch: 3
Epoch: 4
Epoch: 5


# tanh eval

In [25]:
model.eval()
progbar = tf.keras.utils.Progbar(len(test_loader))
with torch.no_grad():
    for idx, (image, _, cls) in enumerate(test_loader):
        image = image.to(device)
        cls = cls.to(device)
        features = vision_model(image)['pooler_output']
        features = visual_projection(features)
        features = sigmoid(features)
        std_acc = accuracy(test_prompts, features, cls)
        prompts_ecd = model.encode(test_prompts)
        _, prompts_ecd, _ = model.where(prompts_ecd, True)
        x = model.encode(features)
        x, ecd, gt = model.where(x, True)
        buffer_acc = accuracy(prompts_ecd[:, 0]*2-1, ecd[:, 0]*2-1, cls)
        recon = model.decode(x)
        recon_acc = accuracy(test_prompts, recon, cls)
        loss = mse(x, gt)
        printlog = [('loss', loss.cpu().detach().numpy()),
                    ('std_acc', std_acc.cpu().detach().numpy()),
                    ('recon_acc', recon_acc.cpu().detach().numpy()),
                    ('buffer_acc', buffer_acc.cpu().detach().numpy()),
                    ]
        progbar.update(idx + 1, printlog)


