In [None]:
# Imports

import os
import re
from argparse import ArgumentParser

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib

# PyTorch Lightning
import pytorch_lightning as pl
import seaborn as sns

# PyTorch
import torch
from torch import Tensor, nn, optim
import torch.nn.functional as F
import torch.utils.data as data
import torchtext as tt
from torchtext.vocab import build_vocab_from_iterator

import torchmetrics.functional as metrics

from tqdm.notebook import tqdm
from IPython.display import set_matplotlib_formats
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

from nltk.tokenize import RegexpTokenizer
import wandb

DEVICE = torch.device("cpu")

# Import GPU-related things
if torch.cuda.is_available():
    # import cupy as np
    # import cudf as pd

    # Ensure that all operations are deterministic on GPU (if used) for reproducibility
    torch.backends.cudnn.determinstic = True
    torch.backends.cudnn.benchmark = False

    DEVICE = torch.device("cuda:0")
# else:

# Plotting
plt.set_cmap("cividis")
#%matplotlib inline
set_matplotlib_formats("svg", "pdf")  # For export
matplotlib.rcParams["lines.linewidth"] = 2.0
sns.reset_orig()

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = os.environ.get("PATH_DATASETS", "data/")
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/")

# Setting the seed
pl.seed_everything(42)

print('CUDA:', torch.cuda.is_available())
print("Device:", DEVICE)

In [None]:
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")

# disable randomness, dropout, etc...
model.eval()

# predict with the model
y_hat = model(x)