This example requires the following dependencies to be installed:
pip install lightly[timm]

In [None]:
#!pip install lightly[timm]

In [None]:
import glob
import zipfile
from pathlib import Path
from PIL import Image

from torch.utils.data import Dataset
from huggingface_hub import snapshot_download


class RawImageDataset(Dataset):
    """Dataset that loads images directly from raw files."""

    def __init__(self, root_dir, transform=None, image_extensions=None):
        self.root_dir = Path(root_dir)
        self.transform = transform

        if image_extensions is None:
            image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPEG', '*.JPG', '*.PNG']

        # Find all image files
        self.image_paths = []
        print(f"Searching for images in: {self.root_dir}")

        for pattern in image_extensions:
            found = glob.glob(str(self.root_dir / '**' / pattern), recursive=True)
            self.image_paths.extend(found)
            if found:
                print(f"  Found {len(found)} {pattern} files")

        self.image_paths.sort()
        print(f"Total images found: {len(self.image_paths)}")

        if len(self.image_paths) == 0:
            print("\nWarning: No images found. Directory structure (first 20 items):")
            for item in sorted(self.root_dir.rglob('*'))[:20]:
                print(f"  {item}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # import pdb
        # pdb.set_trace()
        img_path = self.image_paths[idx]

        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            img = Image.new('RGB', (96, 96), color='black')

        if self.transform:
            img = self.transform(img)
            # import pdb; pdb.set_trace()
            # print(img.shape,"old image shape")
            # img = img[0]
            # print(img.shape,"new image shape")

        return img


def download_and_extract_dataset(repo_id, cache_dir=None, max_workers=4):
    """Download and extract dataset from HuggingFace."""

    print(f"Downloading dataset from {repo_id}...")

    try:
        local_dir = snapshot_download(
            repo_id=repo_id,
            repo_type="dataset",
            cache_dir=cache_dir,
            max_workers=max_workers,
            resume_download=True,
        )
        print(f"Dataset downloaded to: {local_dir}")
    except Exception as e:
        print(f"Error during download: {e}")
        print("Retrying with single worker...")
        local_dir = snapshot_download(
            repo_id=repo_id,
            repo_type="dataset",
            cache_dir=cache_dir,
            max_workers=1,
            resume_download=True,
        )
        print(f"Dataset downloaded to: {local_dir}")

    # Extract zip files if present
    local_path = Path(local_dir)
    zip_files = list(local_path.glob('*.zip'))

    if zip_files:
        print(f"\nFound {len(zip_files)} zip files. Extracting...")
        extract_dir = local_path / 'extracted'
        extract_dir.mkdir(exist_ok=True)

        # for zip_file in zip_files:
        #     print(f"  Extracting {zip_file.name}...")
        #     try:
        #         with zipfile.ZipFile(zip_file, 'r') as zf:
        #             zf.extractall(extract_dir)
        #         print("    ✓ Extracted successfully")
        #     except Exception as e:
        #         print(f"    ✗ Error: {e}")

        return extract_dir
    else:
        print("No zip files found, using directory as-is")
        return local_path


def get_mae_transform():
    """Get MAE-compatible transform."""
    from lightly.transforms import MAETransform
    return MAETransform()

In [None]:
# Download and extract dataset
# data_dir = download_and_extract_dataset(
#     repo_id="tsbpp/fall2025_deeplearning",
#     cache_dir=None,
#     max_workers=4
# )
data_dir = Path('./data/devel')

# Create transform
transform = get_mae_transform()

# Create dataset
dataset = RawImageDataset(data_dir, transform=transform)
print(f"\nDataset ready with {len(dataset)} images")

In [None]:
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import torch
import torchvision
from timm.models.vision_transformer import vit_base_patch32_224
from torch import nn

In [None]:
from lightly.models import utils
from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
from lightly.transforms import MAETransform

In [None]:
class MAE(nn.Module):
    def __init__(self, vit):
        super().__init__()

        decoder_dim = 512
        self.mask_ratio = 0.75
        self.patch_size = vit.patch_embed.patch_size[0]

        self.backbone = MaskedVisionTransformerTIMM(vit=vit)
        self.sequence_length = self.backbone.sequence_length
        self.decoder = MAEDecoderTIMM(
            num_patches=vit.patch_embed.num_patches,
            patch_size=self.patch_size,
            embed_dim=vit.embed_dim,
            decoder_embed_dim=decoder_dim,
            decoder_depth=1,
            decoder_num_heads=16,
            mlp_ratio=4.0,
            proj_drop_rate=0.0,
            attn_drop_rate=0.0,
        )

    def forward_encoder(self, images, idx_keep=None):
        return self.backbone.encode(images=images, idx_keep=idx_keep)

    def forward_decoder(self, x_encoded, idx_keep, idx_mask):
        # build decoder input
        batch_size = x_encoded.shape[0]
        x_decode = self.decoder.embed(x_encoded)
        x_masked = utils.repeat_token(
            self.decoder.mask_token, (batch_size, self.sequence_length)
        )
        x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))

        # decoder forward pass
        x_decoded = self.decoder.decode(x_masked)

        # predict pixel values for masked tokens
        x_pred = utils.get_at_index(x_decoded, idx_mask)
        x_pred = self.decoder.predict(x_pred)
        return x_pred

    def forward(self, images):
        batch_size = images.shape[0]
        idx_keep, idx_mask = utils.random_token_mask(
            size=(batch_size, self.sequence_length),
            mask_ratio=self.mask_ratio,
            device=images.device,
        )
        x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
        x_pred = self.forward_decoder(
            x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask
        )

        # get image patches for masked tokens
        patches = utils.patchify(images, self.patch_size)
        # must adjust idx_mask for missing class token
        target = utils.get_at_index(patches, idx_mask - 1)
        return x_pred, target

In [None]:
vit = vit_base_patch32_224()
model = MAE(vit)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

In [None]:
# Count total parameters
total_params = sum(p.numel() for p in model.parameters())

# Count trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

In [None]:
transform = MAETransform()
# we ignore object detection annotations by setting target_transform to return 0

In [None]:
def target_transform(t):
    return 0

In [None]:
# dataset = torchvision.datasets.VOCDetection(
#     "datasets/pascal_voc",
#     download=True,
#     transform=transform,
#     target_transform=target_transform,
# )
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    drop_last=True,
    num_workers=8,
    # transform=transform,
    # target_transform=target_transform,
)

In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4)

In [None]:
# print("Starting Training")
# for epoch in range(10):
#     total_loss = 0
#     for batch in dataloader:
#         views = batch[0]
#         images = views[0].to(device)  # views contains only a single view
#         predictions, targets = model(images)
#         loss = criterion(predictions, targets)
#         total_loss += loss.detach()
#         loss.backward()
#         optimizer.step()
#         optimizer.zero_grad()
#     avg_loss = total_loss / len(dataloader)
#     print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

In [None]:
import torch
from pathlib import Path
import wandb

# # ---------- Drive setup ----------
# try:
#     from google.colab import drive
#     drive.mount('/content/drive')
#     DRIVE_ROOT = Path("/content/drive/MyDrive")
#     IS_COLAB = True
#     print("✓ Running on Colab, Drive mounted.")
# except Exception:
#     DRIVE_ROOT = Path("./saved_models")
#     IS_COLAB = False
#     print("⚠️ Not on Colab, using local folder ./saved_models")

In [None]:
# ---------- Project / save dir ----------
PROJECT_NAME = "mae"  # wandb project AND folder name
DRIVE_ROOT = "outputs"
save_dir = Path(DRIVE_ROOT) / Path(PROJECT_NAME)
save_dir.mkdir(parents=True, exist_ok=True)
print(f"Save directory ready: {save_dir}")

# ---------- wandb init ----------

wandb.init(
    entity="lquan9",
    project=PROJECT_NAME,
    name="mae-run-1",      # change run name if you like
)

# ---------- Training loop ----------
num_epochs = 100
print("Starting Training")
import time
from tqdm import tqdm
global_step = 0
step_start = time.time()
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        views = batch[0]

        # import pdb; pdb.set_trace()
        # images = views[0].to(device)
        images = views.to(device)

        # Forward
        predictions, targets = model(images)
        loss = criterion(predictions, targets)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # ---- wandb STEP LOGGING ----
        wandb.log(
            {
                "loss/step": loss.item(),
                "time/step_sec": time.time() - step_start,
                "step": global_step,
                "epoch": epoch,
            },
            step=global_step,
        )

        global_step += 1

    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

    # ---- wandb logging ----
    wandb.log({
        "loss/train": avg_loss,
        "epoch": epoch,
    })

    # ---- Save checkpoint to Drive/mae/ (always same filename) ----
    ckpt_path = save_dir / f"{PROJECT_NAME}_latest.pt"
    torch.save(
        {
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "avg_loss": avg_loss,
        },
        ckpt_path,
    )
    print(f"✓ Saved checkpoint: {ckpt_path}")