In [None]:
import torch
import torchvision

from models.nn import ShiftedConv
from models.convnext import ConvNeXt
import train

In [None]:
GPU = 0
max_epochs = 100
patience = 10
seed = 1
crop_size = 280
batch_size = 32
checkpoint_dir = "checkpoints/"
latent_dims = 2

torch.manual_seed(seed)
device = torch.device(f"cuda:{str(GPU)}")

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, images, transform=None):
        self.images = images
        self.n_images = len(images)
        self.transform = transform

    def __len__(self):
        return self.n_images

    def __getitem__(self, idx):
        image = self.images[idx]
        if self.transform:
            image = self.transform(image)
        return image

In [None]:
mnist = torchvision.datasets.MNIST("./data", download=True)

In [None]:
classes = []
for i in range(10):
    current_class = (mnist.targets == i).nonzero()
    classes.append(current_class[:, 0])

n_images = 1000

sorted_mnist = []
classes_ok = True
while classes_ok and len(sorted_mnist) < n_images:
    time_series = []
    for i in range(10):
        img = mnist[classes[i][-1].item()][0]
        img = torchvision.transforms.functional.pil_to_tensor(img)[None]
        img = torch.nn.functional.interpolate(img, scale_factor=10)
        time_series.append(img)
        classes[i] = classes[i][:-1]
        if len(classes[i]) == 0:
            classes_ok = False
    time_series = torch.stack(time_series, dim=2)
    sorted_mnist.append(time_series)
sorted_mnist = torch.cat(sorted_mnist, dim=0).to(torch.float)

sorted_mnist = sorted_mnist.unfold(2, 5, 5)
sorted_mnist = sorted_mnist.movedim(2, 1)
sorted_mnist = sorted_mnist.flatten(0, 1)
sorted_mnist = sorted_mnist.movedim(4, 2)
print(sorted_mnist.shape)

sorted_mnist -= sorted_mnist.mean()
sorted_mnist /= sorted_mnist.std()

n_time = sorted_mnist.shape[2]

In [None]:
transform = torchvision.transforms.RandomCrop(crop_size)

train_split = 0.9
train_set = sorted_mnist[: int(len(sorted_mnist) * 0.9)]
val_set = sorted_mnist[int(len(sorted_mnist) * 0.9) :]

train_set = Dataset(train_set, transform=transform)
val_set = Dataset(val_set, transform=transform)

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True,
    drop_last=True,
)
val_loader = torch.utils.data.DataLoader(
    val_set,
    batch_size=batch_size,
    shuffle=False,
    pin_memory=True,
    drop_last=True,
)

In [None]:
encoder = ConvNeXt(in_chans=1, num_classes=latent_dims, dims=[8, 16, 32, 64])
ar_model = ShiftedConv(in_channels=latent_dims, out_channels=latent_dims, kernel_size=n_time)
query_weights = torch.nn.ModuleList()
for _ in range(n_time - 1):
    query_weights.append(torch.nn.Linear(latent_dims, latent_dims))

In [None]:
parameters = (
    list(encoder.parameters())
    + list(ar_model.parameters())
    + list(query_weights.parameters())
)
optimiser = torch.optim.AdamW(parameters, lr=4e-3)

In [None]:
train.train(
    train_loader,
    val_loader,
    encoder=encoder,
    ar_model=ar_model,
    query_weights=query_weights,
    n_time=n_time,
    optimiser=optimiser,
    max_epochs=max_epochs,
    device=device,
    checkpoint_dir=checkpoint_dir,
    patience=patience,
)