In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch.nn as nn
from torch import (Tensor, IntTensor, ones, zeros, optim, save,
                    load, randint)
from torch import float as tfloat
import numpy as np
import matplotlib.pyplot as plt
from src.diff_utils import (
    ConvType, DataType, channel_map, conv_map, heat_1d_loss, heat_2d_loss, device,
    num_T, num_X, num_Y, T_max, Y_max, X_max
)
from src.visual_util import (
    sampling_traj, show_img_from_tensor, pick_random_label, denoise, show_img, make_gif
)
from src.vit import VisionTransformer
from src.dataset_loader import load_dataset
from tqdm import tqdm

In [None]:
vit = VisionTransformer(
    channels=1, in_H=32, in_W=32,
    patch_size=4, embedding_dim=64,
    num_layers=6, num_heads=8, proj_dim=24,
    hidden_dim=32, out_dim=16
).to(device)

In [None]:
ot = vit(zeros(5, 1, 32, 32).to(device), zeros(5, 3).to(device))
print(ot.shape)
show_img(ot[0])

In [None]:
define_models = True
training = True

conv_type = ConvType.Conv2d
data_type = DataType.heat_2d
# T = 512
# if define_models:
#     e = VDM_Encoder(T, conv_type).to(device)
#     d = VDM_Decoder(T, num_channels=channel_map[conv_type][data_type], label_dim=3, conv_map=conv_map[conv_type][data_type], conv_type=conv_type).to(device)
# else:
#     e = load(f'models/{data_type.name}/model-enc.pkl', map_location=device)
#     d = load(f'models/{data_type.name}/model-dec.pkl', map_location=device)
# times = IntTensor(np.linspace(0, T, T+1, dtype=int)).to(device)

In [None]:
base_loss_fn = nn.L1Loss()

lr = 2e-3
optimizer = optim.Adam(vit.parameters(), lr=lr)
batch_size = 32
epochs = 500 if training else 0

args = {
    "lr": lr,
    "optimizer": optimizer,
    "batch_size": batch_size,
    "epochs": epochs,
    "conv_type": conv_type,
    "data_type": data_type,
    "training": training,
    "num_T": num_T,
    "num_Y": num_Y,
    "num_X": num_X
}

dataset_dict = load_dataset(args)
dataset = dataset_dict['train_dataset']
train_dataloader = dataset_dict['train_dataloader']

In [None]:
losses = []
for epoch in range(epochs):
    total_loss = 0

    for inp, label in tqdm(train_dataloader):
        optimizer.zero_grad()
        
        inp = inp.to(device, dtype=tfloat)

        loss = 0
        sample_depth = 5
        rand_idx = np.random.randint(inp.shape[1] - sample_depth)
        
        x_t_i = inp[:,rand_idx:rand_idx+1]
        for _ in range(sample_depth):
            x_t_i = vit(x_t_i, label)
        
            loss += base_loss_fn(x_t_i, inp[:,rand_idx+_+1:rand_idx+_+2])

        # if data_type == DataType.heat_1d:
        #     loss += heat_1d_loss(x_t_i, label, 0, conv_type, 1)
        # elif data_type == DataType.heat_2d:
        #     print(x_t_i.shape)
        #     loss += heat_2d_loss(x_t_i, label, 0, conv_type, 1)

        # print(loss.item())
        total_loss += loss.item() * inp.shape[0]
        # print(inp.shape[0])

        loss.backward()
        optimizer.step()
    total_loss /= len(train_dataloader)
    losses += [total_loss]
    
    if True:
        num_images = 3
        plt.subplot(1, num_images+1, 1)
        plt.plot(losses)
        # plt.show()
        x, cond = dataset[0]
        which = np.random.randint(x.shape[0] - 1)
        next_ = vit(x[None,which:which+1,:,:], cond[None,:])
        plt.subplot(1, num_images+1, 2)
        show_img_from_tensor(x[which:which+1])
        plt.subplot(1, num_images+1, 3)
        show_img_from_tensor(x[which+1:which+2])
        plt.subplot(1, num_images+1, 4)
        show_img_from_tensor(next_[0])
        plt.show()
    cond = Tensor([0]).to(device)
    print(f"Condition: {cond.cpu().detach().numpy()[0]}\n{total_loss=}\nEpoch {epoch+1} finished.")

    save(vit, f'vit_models/{data_type.name}/vit.pkl')

In [None]:
x_test_i, y_test_i = dataset_dict['test_dataset'][0]
if data_type == DataType.mnist:
    base = Tensor(x_test_i[None,:,:]).to(device, dtype=tfloat)
    condition = Tensor(y_test_i).to(device, dtype=tfloat)
    # base = Tensor(x_test).to(device, dtype=tfloat)[655]
    # x_T, eps = e(base, times[T:T+1])
    x_T = zeros((1,1,32,32))
elif data_type == DataType.heat_1d:
    base = Tensor(x_test_i).to(device, dtype=tfloat)
    condition = Tensor(y_test_i).to(device, dtype=tfloat)
    x_T = zeros((1,32,1,32))
elif data_type == DataType.heat_2d:
    base = Tensor(x_test_i).to(device, dtype=tfloat)
    condition = Tensor(y_test_i).to(device, dtype=tfloat)
    # condition[0] = 1.0
    x_T = zeros((1,num_T,32,32))

if conv_type == ConvType.Conv3d:
    x_T = x_T[None,:,:,:,:]

print(x_test_i.shape)
ou = zeros(96 * 3, 32, 32).to(device)
ou[0] = x_test_i[0]
for i in range(1, x_test_i.shape[0] * 3):
    # ou[i] = vit(x_test_i[None,i-1:i])[0] # Uses reference for next frame prediction
    ou[i] = vit(ou[None,i-1:i], condition[None,:])[0] # Completely autoregressive
# x_T = (denoise(d, x_T, T, times, condition.reshape(1,3)) + 1) / 2 # remap from [-1, 1] to [0, 1]
# show_img(x_T[0,0,-1].reshape(1,32,32))
make_gif('visuals/xVIT-1200func-96x32x32-BASE-TRANSFORMER.gif', base.cpu().detach().numpy(), num_T, 8)
make_gif('visuals/xVIT-1200func-96x32x32-GENERATED-TRANSFORMER.gif', ou.cpu().detach().numpy(), num_T, 8)