In [4]:
%load_ext autoreload
%autoreload 2

from imports import *
from models import *
from utils import *
from data import *
from configs import CONFIGS, EXP_CODES

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


RuntimeError: Unexpected error from cudaGetDeviceCount(). Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 804: forward compatibility was attempted on non supported HW

In [4]:
'''
load saved experiment
'''


config_name = 'bird2d_3dim'
config = deepcopy(CONFIGS[config_name])

train_dataset, test_dataset, train_loader, test_loader = get_dataset_dataloader(config)

model_code = "fzrjrp4m"


fix_randomness(seed=42)

model = config['model_class'](seed=100, **config['model_args']).cuda()
model.load_state_dict(torch.load('./models/model_{}.pt'.format(model_code)))
_ = model.eval()

  0%|          | 0/15000 [00:00<?, ?it/s]

100%|██████████| 15000/15000 [00:03<00:00, 4662.53it/s]
100%|██████████| 3000/3000 [00:00<00:00, 4673.57it/s]


In [5]:
it = iter(train_loader)

In [42]:
sample = next(it)[0]

i1, i2, i3 = get_i1_i2_i3(sample)

i1 = i1.cuda()
i2 = i2.cuda()

with torch.no_grad():
    pred_vs = model_encoder(model, i1, i2).squeeze()



In [41]:
plotly_scatter(pred_vs[1:2, :].flatten(0, 1).cpu().numpy(), colors=None, title="")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize ConvLSTM cell.
        """
        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        self.bias = bias

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis

        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next


class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dims):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(input_dim, hidden_dims[0], kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(hidden_dims[0], hidden_dims[1], kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(hidden_dims[1], hidden_dims[2], kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        return x


class Generator(nn.Module):
    def __init__(self, input_dim, hidden_dims, num_frames, future_frames, convlstm_kernel_size):
        super(Generator, self).__init__()
        self.motion_encoder = Encoder(input_dim, hidden_dims)
        self.content_encoder = Encoder(input_dim, hidden_dims)
        self.convlstm = ConvLSTMCell(input_dim=hidden_dims[-1], hidden_dim=hidden_dims[-1], kernel_size=convlstm_kernel_size, bias=True)
        self.num_frames = num_frames
        self.future_frames = future_frames

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[1], kernel_size=4, stride=2, padding=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(hidden_dims[1], hidden_dims[0], kernel_size=4, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(hidden_dims[0], input_dim, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        seq_len, _, _, _ = x.size()
        motion_encoded = self.motion_encoder(x[:self.num_frames-1])
        content_encoded = self.content_encoder(x[self.num_frames-1])

        h, c = self.convlstm(content_encoded, (motion_encoded, torch.zeros_like(motion_encoded)))

        outputs = []
        for _ in range(self.future_frames):
            h, c = self.convlstm(h, (h, c))
            out = self.decoder(h)
            outputs.append(out)

        outputs = torch.stack(outputs, dim=0)
        return outputs


# Example usage
T, H, W = 20, 64, 64  # Example dimensions for sequence length and frame size
input_dim = 1  # Grayscale frames
hidden_dims = [64, 128, 256] 


In [None]:
config_name = 'freq'
config = deepcopy(CONFIGS[config_name])

train_dataset, test_dataset, train_loader, test_loader = get_dataset_dataloader(config)

In [None]:
# assume `frames_tensor` is tensor of Nx12x12 grayscale images

def display_video(frames_tensor, fps=10):
    num_frames, height, width = frames_tensor.shape

    fig = plt.figure(figsize=(width / 20, height / 20))
    ax = plt.axes()

    # Define the function to update the frame
    def update(frame):
        ax.imshow(frames_tensor[frame], cmap='gray')#, vmin=0, vmax=1.0)
        ax.axis('off')

    # Create the animation
    anim = animation.FuncAnimation(fig, update, frames=num_frames, interval=1000/fps)

    # Display the animation
    plt.close()
    return HTML(anim.to_jshtml())

In [None]:
display_video(train_dataset[100][0].squeeze(dim=1))

In [None]:
# norm_diff = []
# norm_diff_normalized = []
# cosine_sim = []
# ssim = []
# norm_vs = []

one_two, three_two, one_three, three_one, two_four, four_two = [], [], [], [], [], []
one_two_vs, three_two_vs, one_three_vs, three_one_vs, two_four_vs, four_two_vs = [], [], [], [], [], []

for data, vs in train_dataset:
    i1, i2, i3 = get_i1_i2_i3(data)

    i1 = i1.squeeze().flatten(start_dim=1, end_dim=-1)
    i2 = i2.squeeze().flatten(start_dim=1, end_dim=-1)
    i3 = i3.squeeze().flatten(start_dim=1, end_dim=-1)

    i2p = i2[:-1, :]
    i4p = i2[1:, :]
    
    
    #norm_diff.append(torch.norm(i1 - i2, dim=-1))
    #norm_diff_normalized.append( torch.norm(i1 - i2, dim=-1) / (torch.norm(i1, dim=-1) + torch.norm(i2, dim=-1)) )
    one_two.append(cosine_dissim(i1, i2))
    one_two_vs.append(vs[::2, :])
    # three_two.append(1 - nn.CosineSimilarity(dim=1)(i3, i2))
    one_three.append(cosine_dissim(i1, i3))
    one_three_vs.append(2*vs[::2, :])
    # three_one.append(1 - nn.CosineSimilarity(dim=1)(i3, i1))
    # two_four.append(1 - nn.CosineSimilarity(dim=1)(i2p, i4p))
    # four_two.append(1 - nn.CosineSimilarity(dim=1)(i4p, i2p))
    # ssim.append(1 - SSIM(data_range=1.0, size_average=False, channel=1, nonnegative_ssim=True)(
        # i1.view(-1, 1, 32, 12), i2.view(-1, 1, 32, 12)
    # ))

    # norm_vs.append(vs[::2, :].norm(p=2, dim=-1))
    #     

# norm_diff = torch.cat(norm_diff).numpy()
# norm_diff_normalized = torch.cat(norm_diff_normalized).numpy()
one_two = torch.cat(one_two).numpy()
one_two_vs = torch.cat(one_two_vs).numpy()
# three_two = torch.cat(three_two).numpy()
one_three = torch.cat(one_three).numpy()
one_three_vs = torch.cat(one_three_vs).numpy()
# three_one = torch.cat(three_one).numpy()
# two_four = torch.cat(two_four).numpy()
# four_two = torch.cat(four_two).numpy()

# ssim = torch.cat(ssim).numpy()

# norm_vs = torch.cat(norm_vs).numpy()

In [None]:
scatter = one_two_vs[one_two < 1e-4]
# scatter = np.hstack((scatter, np.ones((scatter.shape[0], 1))))

plt.scatter(*scatter.T)

print(scatter.shape[0] / len(one_two))
print(scatter[:, 0].max())

### Walking along the manifold of Gaussian blob images

In [None]:
model = FC_FC(seed=2).cuda()

model.load_state_dict(torch.load('./models/model_lzhu4svj.pt'))

_ = model.eval()

In [None]:
train_dataset, test_dataset, train_loader, test_loader = get_dataset_dataloader(CONFIGS['fc_fc_3dim_gauss_loops'])

In [None]:
steps = 100

with torch.no_grad():
    rand_ind = torch.randint(0, len(train_dataset), size=(1,)).item()
    i1, i2, i3 = get_i1_i2_i3(train_dataset[rand_ind][0])

    i1 = i1.cuda()
    i2 = i2.cuda()

    i1 = convert_range(i1, (i1.min(), i1.max()), (0, 1))
    i2 = convert_range(i2, (i2.min(), i2.max()), (0, 1))

    rand_ind = torch.randint(0, i1.shape[1], size=(1,)).item()
    v, pred_i3 = model(i1[0][rand_ind], i2[0][rand_ind])

    pred_i3s = [convert_range(pred_i3, (pred_i3.min(), pred_i3.max()), (0, 1))]

    for _ in range(steps):
        pred_i3 = model.get_i3(v, pred_i3s[-1])
        pred_i3 = convert_range(pred_i3, (pred_i3.min(), pred_i3.max()), (0, 1))
        pred_i3s.append(pred_i3)

    # for _ in range(steps):
    #     pred_i3 = model.get_i3(-v, pred_i3s[-1])
    #     pred_i3 = convert_range(pred_i3, (pred_i3.min(), pred_i3.max()), (0, 1))
    #     pred_i3s.append(pred_i3)

    pred_i3s = torch.stack(pred_i3s).squeeze().cpu()


def display_video(frames_tensor, fps=10):
    num_frames, height, width = frames_tensor.shape

    fig = plt.figure(figsize=(width / 20, height / 20))
    ax = plt.axes()

    # Define the function to update the frame
    def update(frame):
        ax.imshow(frames_tensor[frame], cmap='gray', vmin=0, vmax=1.0)
        ax.axis('off')

    # Create the animation
    anim = animation.FuncAnimation(fig, update, frames=num_frames, interval=1000/fps)

    # Display the animation
    plt.close()
    return HTML(anim.to_jshtml())


# assume `frames_tensor` is tensor of Nx12x12 grayscale images
display_video(pred_i3s, fps=10)

In [None]:
with torch.no_grad():
    rand_ind = torch.randint(0, len(train_dataset), size=(1,)).item()
    i1, i2, i3 = get_i1_i2_i3(train_dataset[rand_ind][0])

    i1 = i1.cuda()
    i2 = i2.cuda()

    rand_ind = torch.randint(0, i1.shape[1], size=(1,)).item()
    v, pred_i3 = model(i1[0][rand_ind], i2[0][rand_ind])

    pred_vs = [v]
    i1s = [i2[0][rand_ind]]
    i2s = [pred_i3]

    for _ in range(100):
        v, pred_i3 = model(i1s[-1], i2s[-1])

        i1s.append(i2s[-1])
        i2s.append(pred_i3)
        pred_vs.append(v)


    i2s = torch.stack(i2s).squeeze().cpu()


def display_video(frames_tensor, fps=10):
    num_frames, height, width = frames_tensor.shape

    fig = plt.figure(figsize=(width / 20, height / 20))
    ax = plt.axes()

    # Define the function to update the frame
    def update(frame):
        ax.imshow(frames_tensor[frame], cmap='gray', vmin=0, vmax=1.0)
        ax.axis('off')

    # Create the animation
    anim = animation.FuncAnimation(fig, update, frames=num_frames, interval=1000/fps)

    # Display the animation
    plt.close()
    return HTML(anim.to_jshtml())


# assume `frames_tensor` is tensor of Nx12x12 grayscale images
display_video(pred_i3s, fps=10)

In [None]:
plt.plot(*torch.stack(pred_vs).squeeze().cpu().numpy().T)