<a href="https://colab.research.google.com/github/Jaesu26/vime/blob/main/examples/mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VIME Example

`-` An example to train VIME-Self and VIME-Semi using google colab gpu

## Install VIME

In [1]:
from google.colab import drive

drive.mount("/content/drive")

Mounted at /content/drive


In [2]:
%cd /content/drive/MyDrive/Colab Notebooks/vime

/content/drive/MyDrive/Colab Notebooks/vime


In [3]:
!pip install git+https://github.com/Jaesu26/vime.git

Collecting git+https://github.com/Jaesu26/vime.git
  Cloning https://github.com/Jaesu26/vime.git to /tmp/pip-req-build-s62c5uoq
  Running command git clone --filter=blob:none --quiet https://github.com/Jaesu26/vime.git /tmp/pip-req-build-s62c5uoq
  Resolved https://github.com/Jaesu26/vime.git to commit 23de7979363720f8811ea7cdc1f82dc802bbb828
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting torchmetrics>=0.11.3 (from vime==0.0.1)
  Downloading torchmetrics-1.4.2-py3-none-any.whl.metadata (19 kB)
Collecting lightning>=2.0.0 (from vime==0.0.1)
  Downloading lightning-2.4.0-py3-none-any.whl.metadata (38 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning>=2.0.0->vime==0.0.1)
  Downloading lightning_utilities-0.11.7-py3-none-any.whl.metadata (5.2 kB)
Collecting pytorch-lightning (from lightning>=2.0.0->vime==0.0.1)
  Downloading pytorch_lightning-

## Prepare MNIST

In [4]:
import os
import random
import warnings

import easydict
import numpy as np
import torch
import torch.nn as nn

warnings.filterwarnings("ignore")

- Hyperparameters

In [5]:
args_mlp = easydict.EasyDict({
    "weights_dirpath": "./mlp_weights",
    "num_classes": 10,
    "max_epochs": 100,
    "batch_size": 64,
    "train_size": 0.9,
    "lr": 1e-3,
    "log_interval": 5,
    "seed": 26,
})
args_self = easydict.EasyDict({
    "weights_dirpath": "./vimeself_weights",
    "max_epochs": 10,
    "batch_size": 512,
    "train_size": 0.9,
    "lr": 1e-2,
    "p_masking": 0.3,
    "alpha": 2.0,
    "log_interval": 5,
    "seed": 26,
})
args_semi = easydict.EasyDict({
    "weights_dirpath": "./vimesemi_weights",
    "num_classes": 10,
    "supervised_criterion": nn.CrossEntropyLoss(),
    "max_epochs": 100,
    "labeled_batch_size": 64,
    "unlabeled_batch_size": 512,
    "train_size": 0.9,
    "lr": 1e-3,
    "p_masking": 0.3,
    "K": 3,
    "beta": 1.0,
    "log_interval": 5,
    "seed": 26,
})

In [6]:
def create_folder(path: str) -> None:
    try:
        if not os.path.exists(path):
            os.makedirs(path)
    except OSError as error:
        print(error)

In [7]:
create_folder(args_mlp.weights_dirpath)
create_folder(args_self.weights_dirpath)
create_folder(args_semi.weights_dirpath)

- Load data

In [8]:
import sklearn
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

In [9]:
mnist = fetch_openml("mnist_784")

In [10]:
data = mnist.data.values
target = mnist.target.astype(int).values

In [11]:
data = data / 255.0

In [12]:
data.shape

(70000, 784)

- Split data

In [13]:
num_labeled_data_used = 1000
unlabeled_data_rate = 0.9
seed = 26

In [14]:
X, X_test, y, y_test = train_test_split(data, target, test_size=1/7, random_state=seed, stratify=target)

In [15]:
X_labeled, X_unlabeled, y, _ = train_test_split(X, y, test_size=unlabeled_data_rate, random_state=seed, stratify=y)

In [16]:
X_labeled = X_labeled[:num_labeled_data_used]
y = y[:num_labeled_data_used]

## Supervised Model

In [17]:
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.metrics import accuracy_score
from vime.datamodules import LabeledDataModule
from vime.lightningmodules import MLPClassifier

- Create datamodule and model

In [18]:
dim = X_labeled.shape[1]

In [19]:
labeled_datamodule = LabeledDataModule(
    X_labeled, y, X_test,
    train_size=args_mlp.train_size,
    batch_size=args_mlp.batch_size,
    seed=args_mlp.seed,
)

In [20]:
mlp_classifier = MLPClassifier(
    input_dim=dim,
    hidden_dims=[256, 128, 64],
    num_classes=args_mlp.num_classes,
    lr=args_mlp.lr,
    log_interval=args_mlp.log_interval,
    seed=args_mlp.seed,
)

INFO: Seed set to 26
INFO:lightning.fabric.utilities.seed:Seed set to 26


- Train supervised model

In [21]:
checkpoint = ModelCheckpoint(
    dirpath=args_mlp.weights_dirpath,
    filename="mlp",
    monitor="val_loss",
    mode="min",
    save_weights_only=True,
)
early_stop = EarlyStopping(
    monitor="val_loss",
    patience=20,
    mode="min",
)
trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    logger=False,
    callbacks=[checkpoint, early_stop],
    max_epochs=args_mlp.max_epochs,
    num_sanity_val_steps=0,
    enable_progress_bar=False,
    enable_model_summary=False,
    deterministic=True,
)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [22]:
trainer.fit(mlp_classifier, labeled_datamodule)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Val Loss: 1.9018 | Val Macro Acc: 0.6945
Epoch 1 | Train Loss: 1.5351  Val Loss: 0.7285 | Val Macro Acc: 0.8152
Epoch 5 | Train Loss: 0.3352  Val Loss: 0.6142 | Val Macro Acc: 0.8197
Epoch 10 | Train Loss: 0.1090  Val Loss: 0.5879 | Val Macro Acc: 0.8208
Epoch 15 | Train Loss: 0.1373  Val Loss: 0.5575 | Val Macro Acc: 0.8271
Epoch 20 | Train Loss: 0.0902  Val Loss: 0.6153 | Val Macro Acc: 0.8274
Epoch 25 | Train Loss: 0.0700  Val Loss: 0.5352 | Val Macro Acc: 0.8063
Epoch 30 | Train Loss: 0.1725  Val Loss: 0.5334 | Val Macro Acc: 0.8146
Epoch 35 | Train Loss: 0.1954  Val Loss: 0.6503 | Val Macro Acc: 0.7774
Epoch 40 | Train Loss: 0.0954  Val Loss: 0.6027 | Val Macro Acc: 0.7896
Epoch 45 | Train Loss: 0.0734  Val Loss: 0.5819 | Val Macro Acc: 0.8015
Epoch 50 | Train Loss: 0.0413  Val Loss: 0.5844 | Val Macro Acc: 0.8086
Epoch 55 | Train Loss: 0.0333  

- Test supervised model

In [23]:
pred = trainer.predict(mlp_classifier, labeled_datamodule, ckpt_path="best")
pred = np.concatenate(pred).argmax(1)

INFO: Restoring states from the checkpoint path at /content/drive/MyDrive/Colab Notebooks/vime/mlp_weights/mlp-v16.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /content/drive/MyDrive/Colab Notebooks/vime/mlp_weights/mlp-v16.ckpt
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: Loaded model weights from the checkpoint at /content/drive/MyDrive/Colab Notebooks/vime/mlp_weights/mlp-v16.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Loaded model weights from the checkpoint at /content/drive/MyDrive/Colab Notebooks/vime/mlp_weights/mlp-v16.ckpt


In [24]:
accuracy_score(y_test, pred)

0.8817

## VIME

In [25]:
from vime import VIMESelf, VIMESelfDataModule, VIMESemi, VIMESemiDataModule

### VIME Self

- Create datamodule and model

In [26]:
dim = X_unlabeled.shape[1]

In [27]:
self_datamodule = VIMESelfDataModule(
    X_unlabeled,
    train_size=args_self.train_size,
    batch_size=args_self.batch_size,
    seed=args_self.seed,
)

In [28]:
vime_self = VIMESelf(
    input_dim=dim,
    hidden_dims=[256, 128],
    lr=args_self.lr,
    p_masking=args_self.p_masking,
    alpha=args_self.alpha,
    log_interval=args_self.log_interval,
    seed=args_self.seed,
)

INFO: Seed set to 26
INFO:lightning.fabric.utilities.seed:Seed set to 26


- Train vime self

In [29]:
checkpoint = ModelCheckpoint(
    dirpath=args_self.weights_dirpath,
    filename="vime_self",
    monitor="val_loss",
    mode="min",
    save_weights_only=True,
)
early_stop = EarlyStopping(
    monitor="val_loss",
    patience=5,
    mode="min",
)
trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    logger=False,
    callbacks=[checkpoint, early_stop],
    max_epochs=args_self.max_epochs,
    num_sanity_val_steps=0,
    enable_progress_bar=False,
    enable_model_summary=False,
    deterministic=True,
)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [30]:
trainer.fit(vime_self, self_datamodule)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Val Loss: 0.1394 | Val Loss_m: 0.0923 | Val Loss_r: 0.0235
Epoch 1 | Train Loss: 0.3249 | Train Loss_m: 0.2473 | Train Loss_r: 0.0388  Val Loss: 0.1125 | Val Loss_m: 0.0813 | Val Loss_r: 0.0156
Epoch 5 | Train Loss: 0.2509 | Train Loss_m: 0.2172 | Train Loss_r: 0.0169  

INFO: `Trainer.fit` stopped: `max_epochs=10` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.


Val Loss: 0.1093 | Val Loss_m: 0.0797 | Val Loss_r: 0.0148
Epoch 10 | Train Loss: 0.2461 | Train Loss_m: 0.2149 | Train Loss_r: 0.0156  

In [31]:
best_model_path = checkpoint.best_model_path

In [32]:
vime_self_best = VIMESelf.load_from_checkpoint(best_model_path)

INFO: Seed set to 26
INFO:lightning.fabric.utilities.seed:Seed set to 26


In [33]:
pretrained_encoder = vime_self_best.encoder

- Train supervised model from pretrained encoder

In [34]:
with torch.no_grad():
    Z = pretrained_encoder(torch.tensor(X_labeled, dtype=torch.float32).cuda())
    Z_test = pretrained_encoder(torch.tensor(X_test, dtype=torch.float32).cuda())

In [35]:
Z = Z.cpu().numpy()
Z_test = Z_test.cpu().numpy()

In [36]:
labeled_datamodule_from_unsupervised = LabeledDataModule(Z, y, Z_test, train_size=args_mlp.train_size, batch_size=args_mlp.batch_size, seed=args_mlp.seed)

In [37]:
mlp_classifier = MLPClassifier(
    input_dim=128,
    hidden_dims=[128, 64],
    num_classes=args_mlp.num_classes,
    lr=args_mlp.lr,
    seed=args_mlp.seed,
)

INFO: Seed set to 26
INFO:lightning.fabric.utilities.seed:Seed set to 26


In [38]:
checkpoint = ModelCheckpoint(
    dirpath=args_mlp.weights_dirpath,
    filename="mlp_from_unsupervised",
    monitor="val_loss",
    mode="min",
    save_weights_only=True,
)
early_stop = EarlyStopping(
    monitor="val_loss",
    patience=10,
    mode="min",
)
trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    logger=False,
    callbacks=[checkpoint, early_stop],
    max_epochs=args_mlp.max_epochs,
    num_sanity_val_steps=0,
    enable_progress_bar=False,
    enable_model_summary=False,
    deterministic=True,
)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [39]:
trainer.fit(mlp_classifier, labeled_datamodule_from_unsupervised)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Val Loss: 1.9042 | Val Macro Acc: 0.6532
Epoch 1 | Train Loss: 1.7589  Val Loss: 0.5341 | Val Macro Acc: 0.8254
Epoch 10 | Train Loss: 0.3040  Val Loss: 0.4334 | Val Macro Acc: 0.8854
Epoch 20 | Train Loss: 0.1056  Val Loss: 0.5201 | Val Macro Acc: 0.8090
Epoch 30 | Train Loss: 0.0625  

- Test vime self

In [40]:
pred = trainer.predict(mlp_classifier, labeled_datamodule_from_unsupervised, ckpt_path="best")
pred = np.concatenate(pred).argmax(1)

INFO: Restoring states from the checkpoint path at /content/drive/MyDrive/Colab Notebooks/vime/mlp_weights/mlp_from_unsupervised-v9.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /content/drive/MyDrive/Colab Notebooks/vime/mlp_weights/mlp_from_unsupervised-v9.ckpt
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: Loaded model weights from the checkpoint at /content/drive/MyDrive/Colab Notebooks/vime/mlp_weights/mlp_from_unsupervised-v9.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Loaded model weights from the checkpoint at /content/drive/MyDrive/Colab Notebooks/vime/mlp_weights/mlp_from_unsupervised-v9.ckpt


In [41]:
accuracy_score(y_test, pred)

0.9114

### VIME Semi

- Create datamodule and model

In [42]:
dim = X_labeled.shape[1]

In [43]:
semi_datamodule = VIMESemiDataModule(
    X_unlabeled,
    X_labeled,
    y,
    X_test,
    train_size=args_semi.train_size,
    labeled_batch_size=args_semi.labeled_batch_size,
    unlabeled_batch_size=args_semi.unlabeled_batch_size,
    seed=args_semi.seed,
)

In [44]:
vime_semi = VIMESemi(
    pretrained_encoder=pretrained_encoder,
    hidden_dims=[256, 128, 64],
    num_classes=args_semi.num_classes,
    supervised_criterion=args_semi.supervised_criterion,
    lr=args_semi.lr,
    p_masking=args_semi.p_masking,
    K=args_semi.K,
    beta=args_semi.beta,
    log_interval=args_semi.log_interval,
    seed=args_semi.seed,
)

INFO: Seed set to 26
INFO:lightning.fabric.utilities.seed:Seed set to 26


- Train vime semi

In [45]:
checkpoint = ModelCheckpoint(
    dirpath=args_semi.weights_dirpath,
    filename="vime_semi",
    monitor="val_loss",
    mode="min",
    save_weights_only=True,
)
early_stop = EarlyStopping(
    monitor="val_loss",
    patience=20,
    mode="min",
)

In [46]:
trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    logger=False,
    callbacks=[checkpoint, early_stop],
    max_epochs=args_semi.max_epochs,
    num_sanity_val_steps=0,
    enable_progress_bar=False,
    enable_model_summary=False,
    deterministic=True,
)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [47]:
trainer.fit(vime_semi, semi_datamodule)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Val Loss_s: 1.2112
Epoch 1 | Train Loss: 1.4965 | Train Loss_s: 1.4658 | Train Loss_u: 0.0307  Val Loss_s: 0.5357
Epoch 5 | Train Loss: 0.4626 | Train Loss_s: 0.3815 | Train Loss_u: 0.0811  Val Loss_s: 0.4713
Epoch 10 | Train Loss: 0.3003 | Train Loss_s: 0.1893 | Train Loss_u: 0.1110  Val Loss_s: 0.4894
Epoch 15 | Train Loss: 0.3164 | Train Loss_s: 0.2001 | Train Loss_u: 0.1163  Val Loss_s: 0.4844
Epoch 20 | Train Loss: 0.3252 | Train Loss_s: 0.2040 | Train Loss_u: 0.1213  Val Loss_s: 0.4324
Epoch 25 | Train Loss: 0.2737 | Train Loss_s: 0.1510 | Train Loss_u: 0.1228  Val Loss_s: 0.4307
Epoch 30 | Train Loss: 0.2563 | Train Loss_s: 0.1239 | Train Loss_u: 0.1324  Val Loss_s: 0.4537
Epoch 35 | Train Loss: 0.2373 | Train Loss_s: 0.1132 | Train Loss_u: 0.1241  Val Loss_s: 0.3816
Epoch 40 | Train Loss: 0.3556 | Train Loss_s: 0.2321 | Train Loss_u: 0.1234  Val Loss_s: 0.4391
Epoch 45 | Train Loss: 0.2350 | Train Loss_s: 0.1156 | Train Loss_u: 0.1194  Val Loss_s: 0.4732
Epoch 50 | Train Loss: 

- Test vime semi

In [48]:
pred = trainer.predict(vime_semi, semi_datamodule, ckpt_path="best")
pred = np.concatenate(pred).argmax(1)

INFO: Restoring states from the checkpoint path at /content/drive/MyDrive/Colab Notebooks/vime/vimesemi_weights/vime_semi-v6.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /content/drive/MyDrive/Colab Notebooks/vime/vimesemi_weights/vime_semi-v6.ckpt
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: Loaded model weights from the checkpoint at /content/drive/MyDrive/Colab Notebooks/vime/vimesemi_weights/vime_semi-v6.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Loaded model weights from the checkpoint at /content/drive/MyDrive/Colab Notebooks/vime/vimesemi_weights/vime_semi-v6.ckpt


In [49]:
accuracy_score(y_test, pred)

0.9185