In [None]:
import os

import numpy as np
import torch
import torchvision
from tqdm import tqdm
import matplotlib.pyplot as plt

from models.nn import ShiftedConv
from models.convnext import ConvNeXt
import cpc

In [None]:
GPU = 0
n_test = 100
seed = 1
crop_size = 280
batch_size = 32
version = 10
checkpoint_dir = f"checkpoints/version_{str(version)}"
latent_dims = 2

torch.manual_seed(seed)
device = torch.device(f"cuda:{str(GPU)}")

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, images, transform=None):
        self.images = images
        self.n_images = len(images)
        self.transform = transform

    def __len__(self):
        return self.n_images

    def __getitem__(self, idx):
        image = self.images[idx]
        if self.transform:
            image = self.transform(image)
        return image

In [None]:
mnist = torchvision.datasets.MNIST("./data", download=True)

In [None]:
classes = []
for i in range(10):
    current_class = (mnist.targets == i).nonzero()
    classes.append(current_class[:, 0])

n_images = 1000

sorted_mnist = []
classes_ok = True
while classes_ok and len(sorted_mnist) < n_images:
    time_series = []
    for i in range(10):
        img = mnist[classes[i][-1].item()][0]
        img = torchvision.transforms.functional.pil_to_tensor(img)[None]
        img = torch.nn.functional.interpolate(img, scale_factor=10)
        time_series.append(img)
        classes[i] = classes[i][:-1]
        if len(classes[i]) == 0:
            classes_ok = False
    time_series = torch.stack(time_series, dim=2)
    sorted_mnist.append(time_series)
sorted_mnist = torch.cat(sorted_mnist, dim=0).to(torch.float)

sorted_mnist = sorted_mnist.unfold(2, 5, 5)
sorted_mnist = sorted_mnist.movedim(2, 1)
sorted_mnist = sorted_mnist.flatten(0, 1)
sorted_mnist = sorted_mnist.movedim(4, 2)
print(sorted_mnist.shape)

sorted_mnist -= sorted_mnist.mean()
sorted_mnist /= sorted_mnist.std()

n_time = sorted_mnist.shape[2]

In [None]:
transform = torchvision.transforms.RandomCrop(crop_size)

test_set = Dataset(sorted_mnist[:n_test], transform=transform)

test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=batch_size,
    shuffle=False,
    pin_memory=True,
)

In [None]:
encoder = ConvNeXt(in_chans=1, num_classes=latent_dims, dims=[8, 16, 32, 64])
ar_model = ShiftedConv(in_channels=latent_dims, out_channels=latent_dims, kernel_size=n_time)
query_weights = torch.nn.ModuleList()
for _ in range(n_time - 1):
    query_weights.append(torch.nn.Linear(latent_dims, latent_dims))

In [None]:
encoder.load_state_dict(
    torch.load(os.path.join(checkpoint_dir, "encoder.pt"), weights_only=True)
)
ar_model.load_state_dict(
    torch.load(
        os.path.join(checkpoint_dir, "ar_model.pt"), weights_only=True
    )
)
query_weights.load_state_dict(
    torch.load(
        os.path.join(checkpoint_dir, "query_weights.pt"), weights_only=True
    )
)

In [None]:
def test(test_loader,
         encoder,
         ar_model,
         query_weights,
):
        encoder = encoder.to(device)
        ar_model = ar_model.to(device)
        query_weights = query_weights.to(device)
        encoder.eval()
        ar_model.eval()
        query_weights.eval()
        latents_list = []
        context_list = []
        with torch.no_grad():
            for batch in tqdm(
                test_loader,
                bar_format=f"Predicting {{l_bar}}{{bar}}{{r_bar}}",
            ):
                batch_size = batch.shape[0]
                batch = batch.to(device)
                context, latents = cpc.forward(
                    batch, batch_size=batch_size, encoder=encoder, ar_model=ar_model
                )
                latents_list.append(latents)
                context_list.append(context)
        return latents_list, context_list

In [None]:
embeddings = test(
    test_loader=test_loader,
    encoder=encoder,
    ar_model=ar_model,
    query_weights=query_weights,
)
latents, context = embeddings
latents = torch.cat(latents, dim=0).cpu()
context = torch.cat(context, dim=0).cpu()

In [None]:
fig, ax = plt.subplots(2, 5)
for i in range(5):
    ax[0, i].imshow(sorted_mnist[0, 0, i])
    ax[0, i].axis("off")
for i in range(5):
    ax[1, i].imshow(sorted_mnist[1, 0, i])
    ax[1, i].axis("off")