In [1]:
from avalanche.benchmarks import SplitMNIST, SplitCIFAR10
import lovely_tensors as lt
from torchvision import transforms, datasets
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import shutup;
import pytorch_lightning as pl
import os 


os.environ['PYTORCH_ENABLE_MPS_FALLBACK']="1"
shutup.please()
pl.seed_everything(42)
lt.monkey_patch()

Global seed set to 42


In [2]:
seed = 42

benchmark = SplitCIFAR10(
    n_experiences=5,
    return_task_id=True,
    shuffle=False,
    dataset_root='./datasets',
    train_transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (1.0, 1.0, 1.0))
    ]),
    eval_transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (1.0, 1.0, 1.0))
    ]),
)

train_dataset = datasets.CIFAR10(
    root='./datasets',
    train=True,
    transform=transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (1.0, 1.0, 1.0)),
        ]
    ),
)
test_dataset = datasets.CIFAR10(
    root='./datasets',
    train=False,
    transform=transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (1.0, 1.0, 1.0)),
        ]
    ),
)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
import wandb
import argparse
from train_utils import get_device, add_arguments, get_wandb_params
from src.vq_vae.init_scrips import get_model
from src.utils.train_script import overwrite_config_with_args, parse_arguments
from configparser import ConfigParser
from src.vq_vae.configuration.config import TrainConfig

ini_config = ConfigParser()
ini_config.read("../src/vq_vae/configuration/train.ini")

parser = argparse.ArgumentParser(description="Model trainer")
parser = add_arguments(parser)

## Read args
args = parse_arguments(parser)
args.accelerator = "mps"
args.train_logger = "int"
args.evaluation_logger = "int"
args.max_epochs = 60
args.min_epochs = 60
args.num_workers = 0
args.regularization_dropout = 0.2
args.regularization_lambda = 0.01
args.learning_rate = 0.0001
args.batch_size = 256
args.best_model_prefix = "artifacts"
args.num_random_noise = 1000
args.model = "vq-vae"

config = TrainConfig.construct_typed_config(ini_config)
overwrite_config_with_args(args, config)

is_using_wandb = (
    config.train_logger == "wandb"
    or config.evaluation_logger == "wandb"
    or args.run_id
)
if is_using_wandb:
    wandb_params = get_wandb_params(args, config)

    wandb.run.name = args.experiment_name or (
        f"RI-0."
        f"RN-{config.num_random_noise}."
        f"Dr-{config.regularization_dropout}."
        f"Wd-{config.regularization_lambda}."
    )
    wandb_params["name"] = wandb.run.name
else:
    wandb_params = None

In [4]:
from avalanche.training.plugins import EvaluationPlugin
from train_utils import get_loggers
from src.avalanche.strategies import NaivePytorchLightning

device = get_device(config)
vq_vae_model = get_model(config, device)

In [5]:
from src.vq_vae.init_scrips import get_evaluation_plugin, get_callbacks

train_experience, test_experience = next(iter(zip(benchmark.train_stream, benchmark.test_stream)))
# Test trained model
cl_strategy_logger, eval_plugin_loggers = get_loggers(config, vq_vae_model, wandb_params)
evaluation_plugin = EvaluationPlugin(
    suppress_warnings=True,
)

cl_strategy = NaivePytorchLightning(
    accelerator=config.accelerator,
    devices=config.devices,
    validate_every_n=config.validate_every_n,
    accumulate_grad_batches=config.accumulate_grad_batches,
    train_logger=cl_strategy_logger,
    initial_resume_from=args.resume_from,
    model=vq_vae_model,
    device=device,
    optimizer=vq_vae_model.configure_optimizers(),
    criterion=vq_vae_model.criterion,
    train_mb_size=config.batch_size,
    train_mb_num_workers=config.num_workers,
    train_epochs=config.max_epochs,
    eval_mb_size=config.batch_size,
    evaluator=evaluation_plugin,
    callbacks=get_callbacks(config),
    max_epochs=config.max_epochs,
    min_epochs=config.min_epochs,
    best_model_path_prefix=config.best_model_prefix,
    plugins=[],
)

cl_strategy.train(train_experience, [test_experience])

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

{}

In [None]:

# Test trained model
test_dataset = test_experience.dataset
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)

vq_vae_model.eval()
losses = []

with torch.no_grad():
    for batch in tqdm(test_dataloader):
        x, y, *_ = batch
        x, y = x.to(vq_vae_model.device), y.to(vq_vae_model.device)

        vq_loss, x_recon, quantized, _, perplexity, logits = vq_vae_model.forward(x)
        _, reconstruction_loss, clf_loss, clf_acc, _ = vq_vae_model.criterion(
            (vq_loss, x_recon, quantized, x, perplexity, logits), y
        )
        loss = vq_loss + reconstruction_loss
        losses.append(loss)

avg_test_loss = torch.tensor(losses).mean()

In [None]:
from matplotlib import pyplot as plt

plt.imshow(x_recon[0].permute(1, 2, 0) + 0.5)

In [None]:
for param in vq_vae_model.parameters():
    print(param.data)