In [None]:
import os
import json
import random
import argparse
import itertools
import math
import torch
import numpy as np
from torch import nn, optim
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms
from tqdm import tqdm

from torchvision.datasets import MNIST

from models import (
    ViT
)

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

def get_transform():
    data_transforms = []
    data_transforms.append(transforms.ToTensor())
    data_transforms.append(transforms.Normalize((0.1307,), (0.3081,)))
    return transforms.Compose(data_transforms)

def run(device):
    train_loader = DataLoader(MNIST(root='../data', download=True, transform=get_transform()), batch_size=64)
    
    rqvae = RQVAE(
        in_channels=1,
        embedding_dim=28,
        num_embeddings=128,
        hidden_dims=[16, 64, 128]
    ).to(device)
    
    # optimizer
    optimizer = optim.Adam(rqvae.parameters(), lr=1e-3)
    
    for epoch in range(5):
        train(device, epoch, rqvae, optimizer, train_loader)

def train(device, epoch, model, optimizer, train_loader):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, _ in tqdm(train_loader):
        data = data.to(device).float()

        recon, _, commit_loss = model(data)
        loss = model.loss_function(recon, data, commit_loss)['loss']

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss / len(train_loader)
    
    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f}\n"
    )

In [None]:
seed_everything(42)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

run(device, hps)