<a href="https://colab.research.google.com/github/Jaesu26/vime/blob/main/example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VIME Example

`-` An example to train VIME-Self and VIME-Semi using google colab gpu

## Clone Repository

In [1]:
from google.colab import drive

drive.mount("/content/drive")

Mounted at /content/drive


In [2]:
%cd /content/drive/MyDrive/Colab Notebooks/vime

/content/drive/MyDrive/Colab Notebooks/vime


In [3]:
!git clone https://github.com/Jaesu26/vime.git

Cloning into 'vime'...
remote: Enumerating objects: 150, done.[K
remote: Counting objects: 100% (150/150), done.[K
remote: Compressing objects: 100% (98/98), done.[K
remote: Total 150 (delta 91), reused 106 (delta 50), pack-reused 0[K
Receiving objects: 100% (150/150), 21.07 KiB | 1.62 MiB/s, done.
Resolving deltas: 100% (91/91), done.


In [4]:
!pip install -r vime/requirements.txt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting lightning>=2.0.0 (from -r vime/requirements.txt (line 3))
  Downloading lightning-2.0.2-py3-none-any.whl (1.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m71.8 MB/s[0m eta [36m0:00:00[0m
Collecting arrow<3.0,>=1.2.0 (from lightning>=2.0.0->-r vime/requirements.txt (line 3))
  Downloading arrow-1.2.3-py3-none-any.whl (66 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.4/66.4 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
Collecting croniter<1.4.0,>=1.3.0 (from lightning>=2.0.0->-r vime/requirements.txt (line 3))
  Downloading croniter-1.3.14-py2.py3-none-any.whl (18 kB)
Collecting dateutils<2.0 (from lightning>=2.0.0->-r vime/requirements.txt (line 3))
  Downloading dateutils-0.6.12-py2.py3-none-any.whl (5.7 kB)
Collecting deepdiff<8.0,>=5.7.0 (from lightning>=2.0.0->-r vime/requirements.txt (line 3))
  Downloa

## Prepare MNIST

In [5]:
!pip install easydict

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [6]:
import logging
import os
import random
import warnings

import easydict
import numpy as np
import torch

warnings.filterwarnings("ignore") 

- Hyperparameters

In [7]:
args_self = easydict.EasyDict({
    "weights_dirpath": "./vimeself_weights",
    "max_epochs": 50,
    "batch_size": 512,
    "train_size": 0.9,
    "learning_rate": 1e-2, 
    "p_masking": 0.3,
    "alpha": 2.0,
    "log_interval": 5,
    "seed": 26,
})
args_semi = easydict.EasyDict({
    "weights_dirpath": "./vimesemi_weights",
    "num_classes": 10,
    "task_type": "multiclass",
    "max_epochs": 50,
    "labeled_batch_size": 128,
    "unlabeled_batch_size": 1024,
    "train_size": 0.9,
    "learning_rate": 1e-3, 
    "p_masking": 0.3,
    "K": 3,
    "beta": 1.0,
    "log_interval": 5,
    "seed": 26,
})

In [8]:
def create_folder(path: str) -> None:
    try:
        if not os.path.exists(path):
            os.makedirs(path)
    except OSError as error:
        print(error)

In [9]:
create_folder(args_self.weights_dirpath)
create_folder(args_semi.weights_dirpath)

- Load data

In [10]:
import sklearn
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

In [11]:
mnist = fetch_openml("mnist_784")

In [12]:
data = mnist.data.values
target = mnist.target.astype(int).values

In [13]:
data /= 255.0

In [14]:
data.shape

(70000, 784)

- Split data

In [15]:
num_labeled_data_used = 1000
unlabeled_data_rate = 0.9
seed = 26

In [16]:
X, X_test, y, y_test = train_test_split(data, target, test_size=1/7, random_state=seed, stratify=target)

In [17]:
X_labeled, X_unlabeled, y, _ = train_test_split(X, y, test_size=unlabeled_data_rate, random_state=seed, stratify=y)

In [18]:
X_labeled = X_labeled[:num_labeled_data_used]
y = y[:num_labeled_data_used]

## VIME

In [19]:
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from sklearn.metrics import accuracy_score

from vime.vime.datamodules import VIMESelfDataModule, VIMESemiDataModule
from vime.vime.lightningmodules import VIMESelf, VIMESemi

### VIME Self

- Create datamodule and model

In [20]:
dim = X_unlabeled.shape[1]

In [21]:
self_datamodule = VIMESelfDataModule(X_unlabeled, train_size=0.9, batch_size=512, seed=args_self.seed)

In [22]:
vime_self = VIMESelf(
    in_features_list=[dim],
    out_features_list=[256],
    learning_rate=args_self.learning_rate,
    p_masking=args_self.p_masking,
    alpha=args_self.alpha,
    log_interval=args_self.log_interval,
    seed=args_self.seed,
)

INFO: Global seed set to 26
INFO:lightning.fabric.utilities.seed:Global seed set to 26


- Train vime self

In [23]:
mc = ModelCheckpoint(
    dirpath=args_self.weights_dirpath,
    filename="vime_self",
    monitor="val_loss",
    mode="min",
    save_weights_only=True,
)

es = EarlyStopping(
    monitor="val_loss",
    patience=10,
    mode="min",
)

In [24]:
trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    logger=False,
    callbacks=[mc, es],
    max_epochs=args_self.max_epochs,
    num_sanity_val_steps=0,
    enable_progress_bar=False, 
    enable_model_summary=False,
    deterministic=True,
)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [25]:
trainer.fit(vime_self, self_datamodule)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Val Loss: 0.2607 | Val Loss_m: 0.2199 | Val Loss_r: 0.0204
Epoch 0 | Train Loss: 0.3272 | Train Loss_m: 0.2400 | Train Loss_r: 0.0436  Val Loss: 0.2429 | Val Loss_m: 0.2118 | Val Loss_r: 0.0155
Epoch 5 | Train Loss: 0.2443 | Train Loss_m: 0.2129 | Train Loss_r: 0.0157  Val Loss: 0.2318 | Val Loss_m: 0.2012 | Val Loss_r: 0.0153
Epoch 10 | Train Loss: 0.2327 | Train Loss_m: 0.2019 | Train Loss_r: 0.0154  Val Loss: 0.2255 | Val Loss_m: 0.1955 | Val Loss_r: 0.0150
Epoch 15 | Train Loss: 0.2264 | Train Loss_m: 0.1960 | Train Loss_r: 0.0152  Val Loss: 0.2225 | Val Loss_m: 0.1926 | Val Loss_r: 0.0149
Epoch 20 | Train Loss: 0.2232 | Train Loss_m: 0.1930 | Train Loss_r: 0.0151  Val Loss: 0.2208 | Val Loss_m: 0.1911 | Val Loss_r: 0.0149
Epoch 25 | Train Loss: 0.2213 | Train Loss_m: 0.1912 | Train Loss_r: 0.0150  Val Loss: 0.2194 | Val Loss_m: 0.1900 | Val Loss_r: 0.0147
Epoch 30 | Train Loss: 0.2198 | Train Loss_m: 0.1901 | Train Loss_r: 0.0148  Val Loss: 0.2188 | Val Loss_m: 0.1892 | Val Loss_r

INFO: `Trainer.fit` stopped: `max_epochs=50` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.


## VIME Semi

- Create datamodule and model

In [26]:
dim = X_labeled.shape[1]

In [27]:
semi_datamodule = VIMESemiDataModule(
    X_unlabeled,
    X_labeled,
    y,
    X_test,
    train_size=0.9,
    labeled_batch_size=args_semi.labeled_batch_size,
    unlabeled_batch_size=args_semi.unlabeled_batch_size,
    seed=args_semi.seed,
)

In [28]:
best_model_path = mc.best_model_path

In [29]:
vime_self_best = VIMESelf.load_from_checkpoint(best_model_path)

INFO: Global seed set to 26
INFO:lightning.fabric.utilities.seed:Global seed set to 26


In [30]:
pretrained_encoder = vime_self_best.encoder

In [31]:
vime_semi = VIMESemi(
    pretrained_encoder=pretrained_encoder,
    in_features_list=[256, 128],
    out_features_list=[128, 64],
    num_classes=args_semi.num_classes,
    task_type="multiclass",
    learning_rate=args_semi.learning_rate,
    p_masking=args_semi.p_masking,
    K=args_semi.K,
    beta=args_semi.beta,
    log_interval=args_semi.log_interval,
    seed=args_semi.seed,
)

INFO: Global seed set to 26
INFO:lightning.fabric.utilities.seed:Global seed set to 26


- Train vime semi

In [32]:
mc = ModelCheckpoint(
    dirpath=args_semi.weights_dirpath,
    filename="vime_semi",
    monitor="val_loss",
    mode="min",
    save_weights_only=True,
)

es = EarlyStopping(
    monitor="val_loss",
    patience=10,
    mode="min",
)

In [33]:
trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    logger=False,
    callbacks=[mc, es],
    max_epochs=args_semi.max_epochs,
    num_sanity_val_steps=0,
    enable_progress_bar=False, 
    enable_model_summary=False,
    deterministic=True,
)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [34]:
trainer.fit(vime_semi, semi_datamodule)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Val Loss_s: 2.1408
Epoch 0 | Train Loss: 1.9603 | Train Loss_s: 1.9093 | Train Loss_u: 0.0510  Val Loss_s: 0.8129
Epoch 5 | Train Loss: 0.9396 | Train Loss_s: 0.8622 | Train Loss_u: 0.0774  Val Loss_s: 0.6433
Epoch 10 | Train Loss: 0.6104 | Train Loss_s: 0.5061 | Train Loss_u: 0.1042  Val Loss_s: 0.6121
Epoch 15 | Train Loss: 0.5356 | Train Loss_s: 0.4106 | Train Loss_u: 0.1250  Val Loss_s: 0.5532
Epoch 20 | Train Loss: 0.4913 | Train Loss_s: 0.3541 | Train Loss_u: 0.1373  Val Loss_s: 0.5279
Epoch 25 | Train Loss: 0.3867 | Train Loss_s: 0.2485 | Train Loss_u: 0.1383  Val Loss_s: 0.5183
Epoch 30 | Train Loss: 0.3723 | Train Loss_s: 0.2312 | Train Loss_u: 0.1411  Val Loss_s: 0.5222
Epoch 35 | Train Loss: 0.4436 | Train Loss_s: 0.2998 | Train Loss_u: 0.1437  Val Loss_s: 0.5405
Epoch 40 | Train Loss: 0.3767 | Train Loss_s: 0.2293 | Train Loss_u: 0.1474  

In [35]:
pred = trainer.predict(vime_semi, semi_datamodule, ckpt_path=mc.best_model_path)

INFO: Restoring states from the checkpoint path at /content/drive/MyDrive/Colab Notebooks/vime/vimesemi_weights/vime_semi.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /content/drive/MyDrive/Colab Notebooks/vime/vimesemi_weights/vime_semi.ckpt
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: Loaded model weights from the checkpoint at /content/drive/MyDrive/Colab Notebooks/vime/vimesemi_weights/vime_semi.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Loaded model weights from the checkpoint at /content/drive/MyDrive/Colab Notebooks/vime/vimesemi_weights/vime_semi.ckpt


In [36]:
pred = np.concatenate(pred).argmax(1)

In [37]:
accuracy_score(y_test, pred)

0.915