# V-JEPA 2: CIFAR-10 Fine-tuning with Pretrained Weights

このノートブックでは、V-JEPA 2の事前学習済みモデルを使用してCIFAR-10でファインチューニングを行う。

## 論文情報
- **Title**: V-JEPA 2: Self-Supervised Video Models Enable Understanding, Prediction and Planning
- **arXiv**: 2506.09985
- **Authors**: Mido Assran et al. (Meta AI)
- **Published**: June 2025

## 実装内容
1. V-JEPA 2モデルの完全実装
2. 事前学習済み重みのロード
3. CIFAR-10データセットの準備（Google Colab対応）
4. メモリ効率的なLinear Probing / Full Fine-tuning
5. 学習と評価（バリデーションハング修正済み）

## 1. セットアップと依存関係のインストール

In [None]:
# Google Colab環境の確認
try:
    import google.colab
    IN_COLAB = True
    print("Running in Google Colab")
except:
    IN_COLAB = False
    print("Running locally")

# GPU確認
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# 必要なライブラリのインストール
!pip install -q timm einops
!pip install -q torchvision
!pip install -q tqdm
!pip install -q wandb  # オプション: ログ記録用

In [None]:
import os
import math
import copy
import inspect
import numpy as np
from typing import Optional, Tuple, List, Dict
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

## 2. V-JEPA 2モデルの実装

論文に忠実な完全実装

In [None]:
# ============================================================================
# Position Encoding Utilities
# ============================================================================

def get_3d_sincos_pos_embed(
    embed_dim: int,
    grid_size: int,
    grid_depth: int,
    cls_token: bool = False
) -> torch.Tensor:
    """
    3D sinusoidal position embeddings for video (T x H x W).
    
    Note: Flexibly handles embed_dim not divisible by 3.
    Dimensions are partitioned ensuring all are even (required for sinusoidal encoding).
    """
    dim_t = (embed_dim // 3) // 2 * 2
    dim_h = ((embed_dim - dim_t) // 2) // 2 * 2
    dim_w = embed_dim - dim_t - dim_h

    grid_t = np.arange(grid_depth, dtype=np.float32)
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)

    grid = np.meshgrid(grid_t, grid_h, grid_w, indexing='ij')
    grid = np.stack(grid, axis=0)
    grid = grid.reshape([3, -1]).T

    pos_embed_t = get_1d_sincos_pos_embed_from_grid(dim_t, grid[:, 0])
    pos_embed_h = get_1d_sincos_pos_embed_from_grid(dim_h, grid[:, 1])
    pos_embed_w = get_1d_sincos_pos_embed_from_grid(dim_w, grid[:, 2])

    pos_embed = np.concatenate([pos_embed_t, pos_embed_h, pos_embed_w], axis=1)

    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)

    return torch.from_numpy(pos_embed).float()


def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray) -> np.ndarray:
    """
    1D sinusoidal position embeddings.
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float32)
    omega /= embed_dim / 2.0
    omega = 1.0 / 10000**omega

    pos = pos.reshape(-1)
    out = np.einsum('m,d->md', pos, omega)

    emb_sin = np.sin(out)
    emb_cos = np.cos(out)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)
    return emb


# ============================================================================
# Patch Embedding
# ============================================================================

class PatchEmbed3D(nn.Module):
    """3D patch embedding for video."""

    def __init__(
        self,
        img_size: int = 224,
        patch_size: int = 16,
        tubelet_size: int = 2,
        in_channels: int = 3,
        embed_dim: int = 768
    ):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.tubelet_size = tubelet_size
        self.grid_size = img_size // patch_size

        self.proj = nn.Conv3d(
            in_channels,
            embed_dim,
            kernel_size=(tubelet_size, patch_size, patch_size),
            stride=(tubelet_size, patch_size, patch_size)
        )

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
        B, C, T, H, W = x.shape
        x = self.proj(x)

        T_patches = x.shape[2]
        H_patches = x.shape[3]
        W_patches = x.shape[4]

        x = x.flatten(2).transpose(1, 2)
        return x, (T_patches, H_patches, W_patches)


# ============================================================================
# Transformer Components
# ============================================================================

class Attention(nn.Module):
    """Multi-head self-attention."""

    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = True,
        qk_scale: Optional[float] = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        if hasattr(F, 'scaled_dot_product_attention'):
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p if self.training else 0.0
            )
        else:
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class MLP(nn.Module):
    """MLP block."""

    def __init__(
        self,
        in_features: int,
        hidden_features: Optional[int] = None,
        out_features: Optional[int] = None,
        drop: float = 0.0
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Block(nn.Module):
    """Transformer block."""

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        qk_scale: Optional[float] = None,
        drop: float = 0.0,
        attn_drop: float = 0.0
    ):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop
        )
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            drop=drop
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


# ============================================================================
# Vision Transformer
# ============================================================================

class VisionTransformer(nn.Module):
    """Vision Transformer for V-JEPA 2."""

    def __init__(
        self,
        img_size: int = 224,
        patch_size: int = 16,
        tubelet_size: int = 2,
        in_channels: int = 3,
        num_frames: int = 16,
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        qk_scale: Optional[float] = None,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        use_cls_token: bool = False
    ):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.tubelet_size = tubelet_size
        self.num_frames = num_frames
        self.embed_dim = embed_dim
        self.use_cls_token = use_cls_token

        self.patch_embed = PatchEmbed3D(
            img_size=img_size,
            patch_size=patch_size,
            tubelet_size=tubelet_size,
            in_channels=in_channels,
            embed_dim=embed_dim
        )

        grid_size = img_size // patch_size
        grid_depth = num_frames // tubelet_size
        num_patches = grid_depth * grid_size * grid_size

        if use_cls_token:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
            num_patches += 1
        else:
            self.cls_token = None

        pos_embed = get_3d_sincos_pos_embed(
            embed_dim,
            grid_size,
            grid_depth,
            cls_token=use_cls_token
        )
        self.register_buffer('pos_embed', pos_embed.unsqueeze(0))

        self.pos_drop = nn.Dropout(p=drop_rate)

        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate
            )
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)

        self._init_weights()

    def _init_weights(self):
        if self.cls_token is not None:
            nn.init.trunc_normal_(self.cls_token, std=0.02)

        w = self.patch_embed.proj.weight.data
        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        self.apply(self._init_layer_weights)

    def _init_layer_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x: torch.Tensor, return_all_tokens: bool = False) -> torch.Tensor:
        B = x.shape[0]

        x, _ = self.patch_embed(x)

        if self.cls_token is not None:
            cls_tokens = self.cls_token.expand(B, -1, -1)
            x = torch.cat([cls_tokens, x], dim=1)

        x = x + self.pos_embed
        x = self.pos_drop(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)

        if return_all_tokens or self.cls_token is None:
            return x
        else:
            return x[:, 0]


print("V-JEPA 2 model components defined")

## 3. V-JEPA 2モデルのビルド

ViT-L/g構成でモデルを構築

In [None]:
def build_vjepa2_encoder(model_size='vitl', num_frames=16):
    """
    Build V-JEPA 2 encoder.

    Args:
        model_size: 'vitl' (300M), 'vith' (600M), or 'vitg' (1B)
        num_frames: Number of input frames
    """
    configs = {
        'vitl': {
            'embed_dim': 1024,
            'depth': 24,
            'num_heads': 16,
        },
        'vith': {
            'embed_dim': 1280,
            'depth': 32,
            'num_heads': 16,
        },
        'vitg': {
            'embed_dim': 1408,
            'depth': 40,
            'num_heads': 16,
        }
    }

    if model_size not in configs:
        raise ValueError(f"Unknown model size: {model_size}")

    config = configs[model_size]

    model = VisionTransformer(
        img_size=224,
        patch_size=16,
        tubelet_size=2,
        in_channels=3,
        num_frames=num_frames,
        embed_dim=config['embed_dim'],
        depth=config['depth'],
        num_heads=config['num_heads'],
        mlp_ratio=4.0,
        use_cls_token=True
    )

    return model


MODEL_SIZE = 'vitl'
NUM_FRAMES = 4

encoder = build_vjepa2_encoder(MODEL_SIZE, NUM_FRAMES)

total_params = sum(p.numel() for p in encoder.parameters())
trainable_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad)

print(f"Model: V-JEPA 2 {MODEL_SIZE.upper()}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Number of frames: {NUM_FRAMES}")

## 4. 事前学習済み重みのダウンロードとロード

Meta AIの公式リポジトリから事前学習済みモデルをダウンロード

In [None]:
import urllib.request
import os

def load_pretrained_vjepa2(model_size='vitl'):
    """
    Load pretrained V-JEPA 2 weights from official sources.
    
    Available models:
    - vitl: ViT-L/16 (300M params, 256 resolution)
    - vith: ViT-H/16 (600M params, 256 resolution)
    - vitg: ViT-g/16 (1B params, 256 resolution)
    - vitg_384: ViT-g/16 (1B params, 384 resolution)
    """
    
    hub_model_mapping = {
        'vitl': 'vjepa2_vit_large',
        'vith': 'vjepa2_vit_huge',
        'vitg': 'vjepa2_vit_giant',
        'vitg_384': 'vjepa2_vit_giant_384',
    }
    
    hf_model_mapping = {
        'vitl': 'facebook/vjepa2-vitl-fpc64-256',
        'vith': 'facebook/vjepa2-vith-fpc64-256',
        'vitg': 'facebook/vjepa2-vitg-fpc64-256',
        'vitg_384': 'facebook/vjepa2-vitg-fpc64-384',
    }
    
    direct_urls = {
        'vitl': 'https://dl.fbaipublicfiles.com/vjepa2/vitl.pt',
        'vith': 'https://dl.fbaipublicfiles.com/vjepa2/vith.pt',
        'vitg': 'https://dl.fbaipublicfiles.com/vjepa2/vitg.pt',
        'vitg_384': 'https://dl.fbaipublicfiles.com/vjepa2/vitg-384.pt',
    }
    
    if model_size not in hub_model_mapping:
        print(f"Unknown model size: {model_size}")
        print(f"Available sizes: {list(hub_model_mapping.keys())}")
        return None
    
    try:
        print(f"Loading V-JEPA 2 {model_size.upper()} from PyTorch Hub...")
        hub_model_name = hub_model_mapping[model_size]
        
        encoder = torch.hub.load(
            'facebookresearch/vjepa2',
            hub_model_name,
            trust_repo=True
        )
        
        print(f"Successfully loaded {hub_model_name}")
        return encoder
        
    except Exception as e:
        print(f"PyTorch Hub loading failed: {e}")
        print("Trying alternative methods...")
    
    try:
        print(f"Loading V-JEPA 2 {model_size.upper()} from Hugging Face...")
        from transformers import AutoModel
        
        hf_model_name = hf_model_mapping[model_size]
        encoder = AutoModel.from_pretrained(hf_model_name, trust_remote_code=True)
        
        print(f"Successfully loaded {hf_model_name}")
        return encoder
        
    except ImportError:
        print("transformers not installed. Install with: pip install transformers")
    except Exception as e:
        print(f"Hugging Face loading failed: {e}")
    
    try:
        print(f"Downloading V-JEPA 2 {model_size.upper()} weights...")
        os.makedirs('./pretrained_weights', exist_ok=True)
        
        url = direct_urls[model_size]
        save_path = f'./pretrained_weights/{model_size}_vjepa2.pt'
        
        if not os.path.exists(save_path):
            print(f"Downloading from {url}...")
            urllib.request.urlretrieve(url, save_path)
            print(f"Downloaded to {save_path}")
        else:
            print(f"Found cached weights at {save_path}")
        
        checkpoint = torch.load(save_path, map_location='cpu')
        
        print("Direct download successful, but requires manual model construction")
        print("Please use PyTorch Hub or Hugging Face for automatic loading")
        return checkpoint
        
    except Exception as e:
        print(f"Direct download failed: {e}")
    
    return None


print("="*70)
print("Loading V-JEPA 2 Pretrained Weights")
print("="*70)

print(f"\nModel: V-JEPA 2 {MODEL_SIZE.upper()}")
print(f"Parameters: ~{['300M', '600M', '1B', '1B'][['vitl', 'vith', 'vitg', 'vitg_384'].index(MODEL_SIZE)]}")
print(f"Frames: {NUM_FRAMES}\n")

pretrained_result = load_pretrained_vjepa2(MODEL_SIZE)

if pretrained_result is None:
    print("\n" + "="*70)
    print("WARNING: Could not load pretrained weights")
    print("="*70)
    print("Proceeding with random initialization")
    print("Note: Results will be significantly worse without pretrained weights")
    print("\nTroubleshooting:")
    print("1. Check internet connection")
    print("2. Install transformers: pip install transformers")
    print("3. Try running: git clone https://github.com/facebookresearch/vjepa2")
    print("="*70)
    pretrained_encoder = encoder
elif isinstance(pretrained_result, tuple):
    print("\nReceived tuple from pretrained loader, extracting model...")
    pretrained_encoder = pretrained_result[0] if len(pretrained_result) > 0 else encoder
elif isinstance(pretrained_result, dict):
    print("\nReceived checkpoint dict, using encoder with random weights")
    print("(Manual state_dict loading required - not implemented)")
    pretrained_encoder = encoder
elif hasattr(pretrained_result, 'parameters'):
    pretrained_encoder = pretrained_result
    print("\n" + "="*70)
    print("SUCCESS: Using official V-JEPA 2 pretrained weights")
    print("="*70)
    print(f"Model: {MODEL_SIZE.upper()}")
    print(f"Source: Meta AI / FAIR")
    print(f"Pretrained on: Video datasets")
    print("="*70)
else:
    print(f"\nUnexpected type from pretrained loader: {type(pretrained_result)}")
    print("Using encoder with random weights")
    pretrained_encoder = encoder

encoder = pretrained_encoder

print("\nEncoder ready for CIFAR-10 fine-tuning")

## 5. ImageNet分類ヘッドの追加

事前学習済みエンコーダに分類ヘッドを追加

In [None]:
class ImageNetClassifier(nn.Module):
    """
    V-JEPA 2 encoder + linear classification head for image classification.
    """

    def __init__(
        self,
        encoder: nn.Module,
        num_classes: int = 10,
        freeze_encoder: bool = True,
        use_video_frames: bool = False,
        num_frames: int = 1
    ):
        super().__init__()
        self.encoder = encoder
        self.num_classes = num_classes
        self.use_video_frames = use_video_frames
        self.num_frames = num_frames

        forward_sig = inspect.signature(encoder.forward)
        self.supports_return_all_tokens = 'return_all_tokens' in forward_sig.parameters
        
        if freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False
            print("Encoder frozen (Linear Probing mode)")
        else:
            print("Encoder unfrozen (Full Fine-tuning mode)")

        embed_dim = encoder.embed_dim
        self.head = nn.Linear(embed_dim, num_classes)

        nn.init.trunc_normal_(self.head.weight, std=0.02)
        nn.init.constant_(self.head.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() == 4:
            x = x.unsqueeze(2).repeat(1, 1, self.num_frames, 1, 1)

        if self.supports_return_all_tokens:
            features = self.encoder(x, return_all_tokens=False)
        else:
            features = self.encoder(x)
            if features.dim() == 3:
                features = features[:, 0]

        logits = self.head(features)

        return logits


NUM_CLASSES = 10
FREEZE_ENCODER = True

model = ImageNetClassifier(
    encoder=pretrained_encoder,
    num_classes=NUM_CLASSES,
    freeze_encoder=FREEZE_ENCODER,
    num_frames=NUM_FRAMES
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters: {total_params - trainable_params:,}")
print(f"Target classes: {NUM_CLASSES}")

## 6. CIFAR-10データセットの準備

CIFAR-10データセットをロードして前処理（Google Colab最適化済み）

In [None]:
from torchvision.datasets import CIFAR10

def get_transforms(is_training=True):
    """
    Get CIFAR-10 transforms.
    """
    if is_training:
        return transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.Resize(224),
            transforms.ColorJitter(0.4, 0.4, 0.4),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
    else:
        return transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])


BATCH_SIZE = 64
NUM_WORKERS = 2

print("Loading CIFAR-10 dataset...")
train_dataset = CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=get_transforms(is_training=True)
)
val_dataset = CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=get_transforms(is_training=False)
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    persistent_workers=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    persistent_workers=True
)

CIFAR10_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                   'dog', 'frog', 'horse', 'ship', 'truck']

print(f"CIFAR-10 dataset loaded")
print(f"Train samples: {len(train_dataset):,}")
print(f"Val samples: {len(val_dataset):,}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Number of workers: {NUM_WORKERS}")

## 7. 学習設定

In [None]:
EPOCHS = 1
LEARNING_RATE = 0.001 if FREEZE_ENCODER else 0.0001
WEIGHT_DECAY = 0.0001
WARMUP_EPOCHS = 5

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=EPOCHS - WARMUP_EPOCHS
)

def warmup_lr_scheduler(optimizer, warmup_epochs, base_lr):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return (epoch + 1) / warmup_epochs
        return 1.0
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

warmup_scheduler = warmup_lr_scheduler(optimizer, WARMUP_EPOCHS, LEARNING_RATE)

print("Training Configuration:")
print(f"Epochs: {EPOCHS}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Weight decay: {WEIGHT_DECAY}")
print(f"Warmup epochs: {WARMUP_EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Dataset: CIFAR-10")
print(f"Classes: 10")
print(f"Training mode: {'Linear Probing' if FREEZE_ENCODER else 'Full Fine-tuning'}")

## 8. 学習・評価関数

In [None]:
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch):
    """
    Train for one epoch with memory optimization.
    """
    model.train()

    total_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]", leave=False)
    for batch_idx, (images, labels) in enumerate(pbar):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        if batch_idx % 10 == 0:
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'acc': f"{100.*correct/total:.2f}%"
            })
        
        if batch_idx % 50 == 0 and torch.cuda.is_available():
            torch.cuda.empty_cache()

    avg_loss = total_loss / len(train_loader)
    accuracy = 100. * correct / total

    return avg_loss, accuracy


@torch.no_grad()
def evaluate(model, val_loader, criterion, device, epoch):
    """
    Evaluate on validation set with memory optimization.
    """
    model.eval()

    total_loss = 0.0
    correct = 0
    total = 0

    correct_top5 = 0
    compute_top5 = len(val_loader.dataset.classes) >= 10 if hasattr(val_loader.dataset, 'classes') else False

    pbar = tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]", leave=False)
    
    for batch_idx, (images, labels) in enumerate(pbar):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        outputs = model(images)
        loss = criterion(outputs, labels)

        total_loss += loss.detach().cpu().item()

        _, predicted = outputs.detach().max(1)
        batch_correct = predicted.eq(labels).sum().item()
        correct += batch_correct
        total += labels.size(0)

        if compute_top5:
            _, pred_top5 = outputs.detach().topk(5, 1, True, True)
            pred_top5 = pred_top5.t()
            correct_top5 += pred_top5.eq(labels.view(1, -1).expand_as(pred_top5)).sum().item()

        if batch_idx % 5 == 0:
            postfix = {
                'loss': f"{loss.item():.4f}",
                'acc': f"{100.*correct/total:.2f}%"
            }
            if compute_top5:
                postfix['top5'] = f"{100.*correct_top5/total:.2f}%"
            pbar.set_postfix(postfix)
        
        if batch_idx % 20 == 0 and torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        if batch_idx % 50 == 0:
            import gc
            gc.collect()

    avg_loss = total_loss / len(val_loader)
    accuracy_top1 = 100. * correct / total
    accuracy_top5 = 100. * correct_top5 / total if compute_top5 else 0.0

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return avg_loss, accuracy_top1, accuracy_top5


print("Training functions defined")

## 9. 学習実行

In [None]:
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc_top1': [],
    'val_acc_top5': [],
    'lr': []
}

best_val_acc = 0.0
best_epoch = 0

print("\n" + "="*70)
print("Starting Training")
print("="*70)

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    import gc
    gc.collect()

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print("-" * 70)

    train_loss, train_acc = train_one_epoch(
        model, train_loader, criterion, optimizer, device, epoch
    )

    val_loss, val_acc_top1, val_acc_top5 = evaluate(
        model, val_loader, criterion, device, epoch
    )

    if epoch < WARMUP_EPOCHS:
        warmup_scheduler.step()
    else:
        scheduler.step()

    current_lr = optimizer.param_groups[0]['lr']

    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc_top1'].append(val_acc_top1)
    history['val_acc_top5'].append(val_acc_top5)
    history['lr'].append(current_lr)

    print(f"\nResults:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"  Val Loss: {val_loss:.4f} | Val Top-1: {val_acc_top1:.2f}%")
    if val_acc_top5 > 0:
        print(f"  Val Top-5: {val_acc_top5:.2f}%")
    print(f"  Learning Rate: {current_lr:.6f}")

    if torch.cuda.is_available():
        print(f"  GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB / {torch.cuda.max_memory_allocated()/1e9:.2f}GB")

    if val_acc_top1 > best_val_acc:
        best_val_acc = val_acc_top1
        best_epoch = epoch
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc_top1': val_acc_top1,
            'val_acc_top5': val_acc_top5,
        }, 'vjepa2_cifar10_best.pth')
        print(f"  Best model saved (Val Acc: {best_val_acc:.2f}%)")

    if (epoch + 1) % 5 == 0:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        import gc
        gc.collect()
        print("  Memory cleanup performed")

    print("-" * 70)

print("\n" + "="*70)
print("Training Complete")
print("="*70)
print(f"Best Validation Accuracy: {best_val_acc:.2f}% (Epoch {best_epoch+1})")

## 10. 結果の可視化

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

axes[0, 0].plot(history['train_loss'], label='Train Loss')
axes[0, 0].plot(history['val_loss'], label='Val Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

axes[0, 1].plot(history['train_acc'], label='Train Acc')
axes[0, 1].plot(history['val_acc_top1'], label='Val Acc')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].set_title('Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True)

if any(acc > 0 for acc in history['val_acc_top5']):
    axes[1, 0].plot(history['val_acc_top5'], label='Val Top-5 Acc', color='green')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Accuracy (%)')
    axes[1, 0].set_title('Top-5 Accuracy')
    axes[1, 0].legend()
    axes[1, 0].grid(True)
else:
    axes[1, 0].text(0.5, 0.5, 'Top-5 not computed\n(10 classes only)', 
                    ha='center', va='center', transform=axes[1, 0].transAxes)
    axes[1, 0].set_title('Top-5 Accuracy')

axes[1, 1].plot(history['lr'], color='orange')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Learning Rate')
axes[1, 1].set_title('Learning Rate Schedule')
axes[1, 1].set_yscale('log')
axes[1, 1].grid(True)

plt.tight_layout()
plt.savefig('vjepa2_cifar10_training.png', dpi=300, bbox_inches='tight')
plt.show()

print("Training curves saved to 'vjepa2_cifar10_training.png'")

## 11. モデルの保存とロード

In [None]:
torch.save({
    'epoch': EPOCHS,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'history': history,
    'config': {
        'model_size': MODEL_SIZE,
        'num_classes': NUM_CLASSES,
        'freeze_encoder': FREEZE_ENCODER,
        'learning_rate': LEARNING_RATE,
        'epochs': EPOCHS,
        'dataset': 'CIFAR-10'
    }
}, 'vjepa2_cifar10_final.pth')

print("Final model saved to 'vjepa2_cifar10_final.pth'")

checkpoint = torch.load('vjepa2_cifar10_best.pth')
model.load_state_dict(checkpoint['model_state_dict'])

print(f"\nBest model loaded")
print(f"Epoch: {checkpoint['epoch']+1}")
print(f"Val Acc: {checkpoint['val_acc_top1']:.2f}%")
if checkpoint['val_acc_top5'] > 0:
    print(f"Val Top-5 Acc: {checkpoint['val_acc_top5']:.2f}%")

## 12. 推論テスト

In [None]:
@torch.no_grad()
def predict(model, image_path, transform, device):
    """
    Predict class for a single image.
    """
    from PIL import Image

    model.eval()

    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)

    outputs = model(image_tensor)
    probabilities = F.softmax(outputs, dim=1)

    top5_prob, top5_idx = torch.topk(probabilities, min(5, NUM_CLASSES), dim=1)

    return top5_idx[0].cpu().numpy(), top5_prob[0].cpu().numpy()


test_transform = get_transforms(is_training=False)

sample_idx = 0
sample_image, sample_label = val_dataset[sample_idx]

model.eval()
with torch.no_grad():
    sample_image_batch = sample_image.unsqueeze(0).to(device)
    outputs = model(sample_image_batch)
    probabilities = F.softmax(outputs, dim=1)
    top5_prob, top5_idx = torch.topk(probabilities, min(5, NUM_CLASSES), dim=1)

print("\nPrediction Results:")
print(f"True Label: {sample_label} ({CIFAR10_CLASSES[sample_label]})")
print(f"\nTop-{min(5, NUM_CLASSES)} Predictions:")
for i, (idx, prob) in enumerate(zip(top5_idx[0].cpu().numpy(), top5_prob[0].cpu().numpy())):
    print(f"  {i+1}. Class {idx} ({CIFAR10_CLASSES[idx]}): {prob*100:.2f}%")

print("\n" + "="*70)
print("V-JEPA 2 CIFAR-10 Fine-tuning Complete")
print("="*70)