In [1]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch.nn.functional as F
from torch import nn, optim

import sys
sys.path.append("./../../")

from modules.dvae.model import DVAE
from modules.clip.model import CLIP, DVAECLIP
from config_reader import ConfigReader
from datasets.mnist_loader import MNISTData
from utilities.md_mnist_utils import LabelsInfo
from notebooks.utils import show
from modules.common_utils import latent_to_img, img_to_latent

In [2]:
data_source = MNISTData(
    img_type='md',
    root_path='/u/82/sukhoba1/unix/Desktop/TA-VQVAE/data/multi_descriptive_MNIST/',
    batch_size=8)

train_loader = data_source.get_train_loader(16)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
config_dir_path = '/u/82/sukhoba1/unix/Desktop/TA-VQVAE/configs/finished/'
config_path = config_dir_path + 'dvae_mnistmd_v256_ds2_remote.yaml'
CONFIG = ConfigReader(config_path=config_path)

In [4]:
dvae = DVAE(
    in_channels=CONFIG.in_channels,
    vocab_size=CONFIG.vocab_size,
    num_x2downsamples=CONFIG.num_x2downsamples,
    num_resids_downsample=CONFIG.num_resids_downsample,
    num_resids_bottleneck=CONFIG.num_resids_bottleneck,
    hidden_dim=CONFIG.hidden_dim,
    device=CONFIG.DEVICE)

dvae.eval()

dvae.load_model(
    root_path=CONFIG.save_model_path,
    model_name=CONFIG.save_model_name)

In [5]:
model = DVAECLIP(
    img_latent_height=32,
    img_latent_width=32,
    img_latent_channels=256,
    txt_max_length=12,
    txt_vocab_size=20,
    embed_dim=128,
    num_blocks=8,
    hidden_dim=256,
    n_attn_heads=8,
    dropout_prob=0.1,
    device=DEVICE
)

optimizer = optim.Adam(model.parameters(), lr=0.001)

In [6]:
iteration = 0
for epoch in range(100):
    for batch_index, (img, txt) in enumerate(train_loader):
        current_batch_size = img.size(0)

        img = img.to(DEVICE)
        txt = txt.permute(1, 0).to(DEVICE)
        labels = torch.arange(current_batch_size).to(DEVICE)
        
        imglat = img_to_latent(img, dvae)
        
        logits_per_image, logits_per_text = model(imglat, txt)

        loss_img = F.cross_entropy(logits_per_image, labels)
        loss_txt = F.cross_entropy(logits_per_text, labels)
        loss = (loss_img + loss_txt) / 2
        
        loss.backward()
        
        if iteration % 20 == 0:
            print("Epoch: {} Iter: {} Loss: {}".format(epoch, iteration, round(loss.item(), 5)))

        if (batch_index + 1) % 4 == 0:
            optimizer.step()
            optimizer.zero_grad()

            iteration += 1    
        
        optimizer.step()
        optimizer.zero_grad()
        
        iteration += 1

Epoch: 0 Iter: 0 Loss: 2.94878
Epoch: 0 Iter: 20 Loss: 2.77642
Epoch: 0 Iter: 40 Loss: 2.77164
Epoch: 0 Iter: 60 Loss: 2.77501
Epoch: 0 Iter: 80 Loss: 2.77351
Epoch: 0 Iter: 100 Loss: 2.7735
Epoch: 0 Iter: 120 Loss: 2.76932
Epoch: 0 Iter: 140 Loss: 2.77419
Epoch: 0 Iter: 160 Loss: 2.77344
Epoch: 0 Iter: 180 Loss: 2.77136
Epoch: 0 Iter: 200 Loss: 2.772
Epoch: 0 Iter: 220 Loss: 2.7686
Epoch: 0 Iter: 240 Loss: 2.77069
Epoch: 0 Iter: 260 Loss: 2.77297
Epoch: 0 Iter: 280 Loss: 2.77108
Epoch: 0 Iter: 300 Loss: 2.77227
Epoch: 0 Iter: 320 Loss: 2.7742
Epoch: 0 Iter: 340 Loss: 2.77213
Epoch: 0 Iter: 360 Loss: 2.77284
Epoch: 0 Iter: 380 Loss: 2.76459
Epoch: 0 Iter: 400 Loss: 2.77479
Epoch: 0 Iter: 420 Loss: 2.77291
Epoch: 0 Iter: 440 Loss: 2.7727
Epoch: 0 Iter: 460 Loss: 2.77273
Epoch: 0 Iter: 480 Loss: 2.77226
Epoch: 0 Iter: 500 Loss: 2.77224
Epoch: 0 Iter: 520 Loss: 2.77198
Epoch: 0 Iter: 540 Loss: 2.77324
Epoch: 0 Iter: 560 Loss: 2.77291
Epoch: 0 Iter: 580 Loss: 2.7729
Epoch: 0 Iter: 600 Loss

KeyboardInterrupt: 

In [4]:
model = CLIP(
    img_height=128,
    img_width=128,
    img_channels=3,
    patch_height=8,
    patch_width=8,
    txt_max_length=12,
    txt_vocab_size=20,
    embed_dim=128,
    num_blocks=8,
    hidden_dim=256,
    n_attn_heads=8,
    dropout_prob=0.1,
    device=DEVICE
)

optimizer = optim.Adam(model.parameters(), lr=0.001)

In [5]:
iteration = 0
for epoch in range(100):
    for batch_index, (img, txt) in enumerate(train_loader):
        current_batch_size = img.size(0)

        img = img.to(DEVICE)
        txt = txt.permute(1, 0).to(DEVICE)
        labels = torch.arange(current_batch_size).to(DEVICE)
        
        logits_per_image, logits_per_text = model(img, txt)

        loss_img = F.cross_entropy(logits_per_image, labels)
        loss_txt = F.cross_entropy(logits_per_text, labels)
        loss = (loss_img + loss_txt) / 2
        
        loss.backward()
        
        if iteration % 100 == 0:
            print("Epoch: {} Iter: {} Loss: {}".format(epoch, iteration, round(loss.item(), 5)))

        optimizer.step()
        optimizer.zero_grad()
        
        iteration += 1

Epoch: 0 Iter: 0 Loss: 4.27204
Epoch: 0 Iter: 100 Loss: 3.95311
Epoch: 0 Iter: 200 Loss: 3.38472
Epoch: 0 Iter: 300 Loss: 2.76777
Epoch: 0 Iter: 400 Loss: 1.80431
Epoch: 0 Iter: 500 Loss: 0.95452


KeyboardInterrupt: 

In [None]:
img, txt = next(iter(train_loader))

img = img.to(DEVICE)
txt = txt.permute(1, 0).to(DEVICE)

In [None]:
logits_per_image, logits_per_text = model(img, txt)

In [None]:
F.softmax(logits_per_image, dim=0)

In [None]:
F.softmax(logits_per_text, dim=0)

In [None]:
data_source = MNISTData(
    img_type='md',
    root_path='/u/82/sukhoba1/unix/Desktop/TA-VQVAE/data/multi_descriptive_MNIST/',
    batch_size=8)

train_loader = data_source.get_train_loader(3)

In [None]:
img, txt = next(iter(train_loader))

img = img.to(DEVICE)
txt = txt.to(DEVICE)

In [None]:
current_batch_size = 3

img = img.repeat_interleave(current_batch_size, dim=0)
txt = txt.repeat(current_batch_size, 1)
match_labels = torch.eye(current_batch_size).flatten()

In [None]:
show(img, plot_grid=True, figsize=(14,14))

In [None]:
txt

In [None]:
match_labels

In [None]:
img_model = ImgEncoder(
    img_height=128,
    img_width=128,
    img_channels=3,
    patch_height=8,
    patch_width=8,
    embed_dim=128,
    num_blocks=8,
    hidden_dim=256,
    n_attn_heads=8,
    dropout_prob=0.1,
    device=DEVICE
)

txt_model = TxtEncoder(
    txt_max_length=12,
    txt_vocab_size=20,
    embed_dim=128,
    num_blocks=8,
    hidden_dim=256,
    n_attn_heads=8,
    dropout_prob=0.1,
    device=DEVICE
)

In [None]:
img_model(img).shape

In [None]:
txt_model(txt).shape

In [None]:
config_dir_path = '/u/82/sukhoba1/unix/Desktop/TA-VQVAE/configs/'
config_path = config_dir_path + 'matcher_mnistmd_v256_ds3.yaml'
CONFIG = ConfigReader(config_path=config_path)

CONFIG.BATCH_SIZE = 128

CONFIG.print_config_info()

In [None]:
data_source = MNISTData(
    img_type=CONFIG.dataset_type,
    root_path=CONFIG.root_path,
    batch_size=CONFIG.BATCH_SIZE)

train_loader = data_source.get_train_loader(8)

In [None]:
# dvae = DVAE(
#     in_channels=CONFIG.in_channels,
#     vocab_size=CONFIG.vocab_size,
#     num_x2downsamples=CONFIG.num_x2downsamples,
#     num_resids_downsample=CONFIG.num_resids_downsample,
#     num_resids_bottleneck=CONFIG.num_resids_bottleneck,
#     hidden_dim=CONFIG.hidden_dim,
#     device=CONFIG.DEVICE)

# dvae.eval()
# dvae.load_model(
#     root_path=CONFIG.vae_model_path,
#     model_name=CONFIG.vae_model_name)

In [None]:
model = TrMatcher(
    img_height=128,
    img_width=128,
    img_channels=3,
    patch_height=8,
    patch_width=8,
    txt_max_length=12,
    txt_vocab_size=20,
    embed_dim=128,
    num_blocks=8,
    hidden_dim=256,
    n_attn_heads=8,
    dropout_prob=0.1,
    tr_norm_first=False,
    out_dim=1,
    sigmoid_output=True)

# model = TrMatcher(
#     img_height=16,
#     img_width=16,
#     img_embed_dim=CONFIG.vocab_size,
#     txt_max_length=12,
#     txt_vocab_size=20,
#     embed_dim=64,
#     num_blocks=8,
#     hidden_dim=256,
#     n_attn_heads=4,
#     dropout_prob=0.1,
#     out_dim=1,
#     sigmoid_output=True)

model.to(CONFIG.DEVICE)

optimizer = optim.Adam(model.parameters(), lr=CONFIG.LR)

In [None]:
iteration = 0
for epoch in range(CONFIG.NUM_EPOCHS):
    for batch_index, (img, txt) in enumerate(train_loader):
        current_batch_size = img.size(0)

        img = img.repeat_interleave(current_batch_size, dim=0)
        txt = txt.repeat(current_batch_size, 1)
        match_labels = torch.eye(current_batch_size).flatten()

        img = img.to(CONFIG.DEVICE)
        txt = txt.permute(1, 0).to(CONFIG.DEVICE)
        match_labels = match_labels.to(CONFIG.DEVICE)

        #with torch.no_grad():
        #    latent = dvae.ng_q_encode(img)
        #b, emb, h, w = latent.size()
        #x = latent.view(b, emb, -1).permute(2, 0, 1)
        
        pred_labels = model(img, txt, average_cls_token=False)

        loss = F.binary_cross_entropy(pred_labels, match_labels)
        loss.backward()
        
        if iteration % 50 == 0:
            print("Epoch: {} Iter: {} Loss: {}".format(epoch, iteration, round(loss.item(), 5)))

        optimizer.step()
        optimizer.zero_grad()
        
        iteration += 1

In [None]:
pred_labels

In [None]:
# iteration = 0
# for epoch in range(CONFIG.NUM_EPOCHS):
#     for batch_index, (img, txt) in enumerate(train_loader):
#         current_batch_size = img.size(0)
#         n_true = current_batch_size // 2
#         true_txt = txt[:n_true, :]
#         false_txt = txt[n_true:, :]
#         false_txt = torch.cat((false_txt[[-1], :], false_txt[:-1, :]), dim=0)
#         txt = torch.cat((true_txt, false_txt), dim=0)

#         match_labels = torch.zeros(current_batch_size)
#         match_labels[:n_true] = 1.0

#         img = img.to(CONFIG.DEVICE)
#         txt = txt.permute(1, 0).to(CONFIG.DEVICE)
#         match_labels = match_labels.to(CONFIG.DEVICE)

#         #with torch.no_grad():
#         #    latent = dvae.ng_q_encode(img)
#         #b, emb, h, w = latent.size()
#         #x = latent.view(b, emb, -1).permute(2, 0, 1)
        
#         pred_labels = model(img, txt)

#         loss = F.binary_cross_entropy(pred_labels, match_labels)
#         loss.backward()
        
#         if iteration % 100 == 0:
#             print("Epoch: {} Iter: {} Loss: {}".format(epoch, iteration, round(loss.item(), 5)))

#         optimizer.step()
#         optimizer.zero_grad()
        
#         iteration += 1

In [None]:
img, txt = next(iter(train_loader))

In [None]:
img_all = img.repeat_interleave(4, dim=0)

img_all.shape

In [None]:
show(img_all, plot_grid=True, figsize=(14,14))

In [None]:
txt.repeat(4, 1)

In [None]:
torch.eye(4).flatten()

In [None]:
current_batch_size = img.size(0)
n_true = current_batch_size // 2
true_txt = txt[:n_true, :]
false_txt = txt[n_true:, :]
false_txt = torch.cat((false_txt[[-1], :], false_txt[:-1, :]), dim=0)
txt_new = torch.cat((true_txt, false_txt), dim=0)

In [None]:
txt

In [None]:
txt_new