In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------------------------------------------
# Positional Embedding
# ---------------------------------------------
class PositionalEmbedding(nn.Module):
    def __init__(self, num_patches, dim):
        super().__init__()
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, dim))  # learnable

    def forward(self, x):
        return x + self.pos_embed[:, :x.size(1), :]

# ---------------------------------------------
# Patchify and Unpatchify
# ---------------------------------------------
def patchify(images, patch_size=4):
    # images: (B, 3, 32, 32)
    B, C, H, W = images.shape
    patches = F.unfold(images, kernel_size=patch_size, stride=patch_size)  # (B, C*p*p, N)
    patches = patches.transpose(1, 2)  # (B, N, patch_dim)
    return patches

def unpatchify(patches, patch_size=4, img_size=32):
    # patches: (B, N, patch_dim)
    B, N, D = patches.shape
    patches = patches.transpose(1, 2)  # (B, D, N)
    H = W = img_size // patch_size
    return F.fold(patches, output_size=(H * patch_size, W * patch_size), kernel_size=patch_size, stride=patch_size)

# ---------------------------------------------
# Masked Autoencoder
# ---------------------------------------------
class MaskedAutoencoder(nn.Module):
    def __init__(self, image_size=32, patch_size=4, 
                 encoder_dim=64, decoder_dim=64, 
                 encoder_layers=4, decoder_layers=2, 
                 mask_ratio=0.75):

        super().__init__()
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio

        self.patch_dim = patch_size * patch_size * 3
        self.num_patches = (image_size // patch_size) ** 2

        # Linear patch embedding
        self.patch_embed = nn.Linear(self.patch_dim, encoder_dim)
        self.encoder_pos = PositionalEmbedding(self.num_patches, encoder_dim)

        # Encoder
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=encoder_dim, nhead=4),
            num_layers=encoder_layers
        )

        # [MASK] token
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))

        # Decoder
        self.decoder_pos = PositionalEmbedding(self.num_patches, decoder_dim)
        self.decoder_input_proj = nn.Linear(encoder_dim, decoder_dim)
        self.decoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=decoder_dim, nhead=4),
            num_layers=decoder_layers
        )
        self.reconstruction_head = nn.Linear(decoder_dim, self.patch_dim)

    def forward(self, images):
        B = images.size(0)
        patches = patchify(images, self.patch_size)  # (B, N, patch_dim)

        # ----- Step 1: Masking -----
        N = patches.size(1)
        num_keep = int(N * (1 - self.mask_ratio))
        idx = torch.rand(B, N).argsort(dim=1)
        idx_keep = idx[:, :num_keep]
        idx_mask = idx[:, num_keep:]

        # Select visible patches
        visible = torch.stack([patches[i, idx_keep[i]] for i in range(B)])  # (B, num_keep, D)

        # Embed + Positional Encoding
        x = self.patch_embed(visible)                   # (B, num_keep, encoder_dim)
        x = self.encoder_pos(x)
        x = self.encoder(x)                             # (B, num_keep, encoder_dim)

        # ----- Step 2: Decoder Input -----
        x = self.decoder_input_proj(x)

        # Add [MASK] tokens at masked positions
        B, _, C = x.shape
        mask_tokens = self.mask_token.expand(B, self.num_patches - x.shape[1], C)
        full_tokens = torch.zeros(B, self.num_patches, C, device=x.device)

        # scatter encoded visible tokens and mask tokens into full sequence
        for i in range(B):
            full_tokens[i, idx_keep[i]] = x[i]
            full_tokens[i, idx_mask[i]] = mask_tokens[i]

        # Add decoder positional embedding
        full_tokens = self.decoder_pos(full_tokens)

        # Decode
        decoded = self.decoder(full_tokens)  # (B, N, decoder_dim)
        reconstructed_patches = self.reconstruction_head(decoded)  # (B, N, patch_dim)

        return reconstructed_patches, patches, idx_mask

    def loss_fn(self, reconstructed, original, idx_mask):
        # MSE loss on only masked patches
        loss = 0
        for i in range(reconstructed.shape[0]):
            pred = reconstructed[i, idx_mask[i]]  # (num_masked, patch_dim)
            target = original[i, idx_mask[i]]
            loss += F.mse_loss(pred, target)
        return loss / reconstructed.shape[0]


In [23]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

# Filter: airplane (0), automobile (1), ship (8), truck (9)
target_classes = [0, 1, 8, 9]

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip()
])

# Load full CIFAR-10 dataset
full_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# Filter indices
filtered_indices = [i for i, (_, label) in enumerate(full_dataset) if label in target_classes]

# Subset + DataLoader
filtered_dataset = Subset(full_dataset, filtered_indices)
unlabeled_loader = DataLoader(filtered_dataset, batch_size=64, shuffle=True)


In [24]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

model = MaskedAutoencoder().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for epoch in range(10):
    print(f"Epoch {epoch+1} started")
    for i, (images, _) in enumerate(unlabeled_loader):
        print(f"Batch {i+1} loaded")
        images = images.to(device)
        reconstructed, original, idx_mask = model(images)
        loss = model.loss_fn(reconstructed, original, idx_mask)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")


Using: cpu
Epoch 1 started
Batch 1 loaded
Batch 2 loaded
Batch 3 loaded
Batch 4 loaded
Batch 5 loaded
Batch 6 loaded
Batch 7 loaded
Batch 8 loaded
Batch 9 loaded
Batch 10 loaded
Batch 11 loaded
Batch 12 loaded
Batch 13 loaded
Batch 14 loaded
Batch 15 loaded
Batch 16 loaded
Batch 17 loaded
Batch 18 loaded
Batch 19 loaded
Batch 20 loaded
Batch 21 loaded
Batch 22 loaded
Batch 23 loaded
Batch 24 loaded
Batch 25 loaded
Batch 26 loaded
Batch 27 loaded
Batch 28 loaded
Batch 29 loaded
Batch 30 loaded
Batch 31 loaded
Batch 32 loaded
Batch 33 loaded
Batch 34 loaded
Batch 35 loaded
Batch 36 loaded
Batch 37 loaded
Batch 38 loaded
Batch 39 loaded
Batch 40 loaded
Batch 41 loaded
Batch 42 loaded
Batch 43 loaded
Batch 44 loaded
Batch 45 loaded
Batch 46 loaded
Batch 47 loaded
Batch 48 loaded
Batch 49 loaded
Batch 50 loaded
Batch 51 loaded
Batch 52 loaded
Batch 53 loaded
Batch 54 loaded
Batch 55 loaded
Batch 56 loaded
Batch 57 loaded
Batch 58 loaded
Batch 59 loaded
Batch 60 loaded
Batch 61 loaded
Batch 

In [25]:
from collections import Counter
labels = [label for _, label in full_dataset]
filtered_labels = [labels[i] for i in filtered_indices]
print(Counter(filtered_labels))  # Should only contain 0,1,8,9


Counter({9: 5000, 1: 5000, 8: 5000, 0: 5000})


In [26]:
torch.save(model.encoder.state_dict(), "vit_encoder_pretrained.pth")


In [27]:
torch.save(model.state_dict(), "mae_full_model.pth")


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

target_classes = [0, 1, 8, 9]
label_map = {0: 0, 1: 1, 8: 2, 9: 3}  # remap labels to 0–3

Using device: cpu


In [5]:
class FourClassCIFAR(datasets.CIFAR10):
    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        if label in label_map:
            return img, label_map[label]
        else:
            return None

In [6]:
class MaskedAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_dim = 4 * 4 * 3
        self.patch_embed = nn.Linear(self.patch_dim, 64)
        self.encoder_pos = nn.Identity()  # Optional if saved model has this
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=64, nhead=4, batch_first=True),
            num_layers=4
        )

    def forward(self, x):
        return self.encoder(x)

In [17]:
encoder = MaskedAutoencoder()
encoder.load_state_dict(torch.load(r"C:\Users\Laptop\Downloads\Project\Models\mae_full_model.pth", map_location=device))
encoder = encoder.to(device)

# Build classifier
classifier = ViTClassifier(encoder).to(device)

RuntimeError: Error(s) in loading state_dict for MaskedAutoencoder:
	Unexpected key(s) in state_dict: "mask_token", "decoder_pos.pos_embed", "decoder_input_proj.weight", "decoder_input_proj.bias", "decoder.layers.0.self_attn.in_proj_weight", "decoder.layers.0.self_attn.in_proj_bias", "decoder.layers.0.self_attn.out_proj.weight", "decoder.layers.0.self_attn.out_proj.bias", "decoder.layers.0.linear1.weight", "decoder.layers.0.linear1.bias", "decoder.layers.0.linear2.weight", "decoder.layers.0.linear2.bias", "decoder.layers.0.norm1.weight", "decoder.layers.0.norm1.bias", "decoder.layers.0.norm2.weight", "decoder.layers.0.norm2.bias", "decoder.layers.1.self_attn.in_proj_weight", "decoder.layers.1.self_attn.in_proj_bias", "decoder.layers.1.self_attn.out_proj.weight", "decoder.layers.1.self_attn.out_proj.bias", "decoder.layers.1.linear1.weight", "decoder.layers.1.linear1.bias", "decoder.layers.1.linear2.weight", "decoder.layers.1.linear2.bias", "decoder.layers.1.norm1.weight", "decoder.layers.1.norm1.bias", "decoder.layers.1.norm2.weight", "decoder.layers.1.norm2.bias", "reconstruction_head.weight", "reconstruction_head.bias", "encoder_pos.pos_embed". 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from time import time

# ----------------------------
# Device Setup
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

# ----------------------------
# Positional Embedding Module
# ----------------------------
class PositionalEmbedding(nn.Module):
    def __init__(self, num_patches, dim):
        super().__init__()
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, dim))

    def forward(self, x):
        if x.size(1) != self.pos_embed.size(1):
            raise RuntimeError(f"Shape mismatch: input tokens = {x.size(1)}, pos_embed = {self.pos_embed.size(1)}")
        return x + self.pos_embed

# ----------------------------
# MAE Encoder
# ----------------------------
class MaskedAutoencoder(nn.Module):
    def __init__(self, image_size=32, patch_size=4, encoder_dim=64, decoder_dim=64,
                 encoder_layers=4, decoder_layers=2, mask_ratio=0.75):
        super().__init__()
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio
        self.patch_dim = patch_size * patch_size * 3
        self.num_patches = (image_size // patch_size) ** 2

        self.patch_embed = nn.Linear(self.patch_dim, encoder_dim)
        self.encoder_pos = PositionalEmbedding(self.num_patches + 1, encoder_dim)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=encoder_dim, nhead=4, batch_first=True),
            num_layers=encoder_layers
        )

# ----------------------------
# ViT Classifier
# ----------------------------
class ViTClassifier(nn.Module):
    def __init__(self, encoder, num_classes=4):
        super().__init__()
        self.encoder = encoder
        self.cls_token = nn.Parameter(torch.zeros(1, 1, 64))
        self.norm = nn.LayerNorm(64)
        self.head = nn.Linear(64, num_classes)

    def forward(self, x):
        B = x.shape[0]
        x = F.unfold(x, kernel_size=4, stride=4).transpose(1, 2)  # (B, 64, 48)
        x = self.encoder.patch_embed(x)                           # (B, 64, 64)
        cls = self.cls_token.expand(B, -1, -1)                    # (B, 1, 64)
        x = torch.cat((cls, x), dim=1)                            # (B, 65, 64)
        x = self.encoder.encoder_pos(x)
        x = self.encoder.encoder(x)
        x = self.norm(x[:, 0])
        return self.head(x)

# ----------------------------
# Load MAE encoder weights only
# ----------------------------
full_state_dict = torch.load(r"C:\Users\Laptop\Downloads\Project\Models\mae_full_model.pth", map_location=device)

encoder_state_dict = {
    k: v for k, v in full_state_dict.items()
    if k.startswith("patch_embed") or k.startswith("encoder.") or k.startswith("encoder_pos")
}

for k in list(encoder_state_dict.keys()):
    if "pos_embed" in k:
        print(f"Skipping: {k} due to shape mismatch")
        del encoder_state_dict[k]

encoder = MaskedAutoencoder()
encoder.load_state_dict(encoder_state_dict, strict=False)
encoder = encoder.to(device)

classifier = ViTClassifier(encoder).to(device)

# ----------------------------
# Freeze / Unfreeze helpers
# ----------------------------
def freeze_encoder(model):
    for param in model.encoder.parameters():
        param.requires_grad = False

def unfreeze_encoder(model):
    for param in model.encoder.parameters():
        param.requires_grad = True

# ----------------------------
# Custom 4-Class CIFAR-10 Dataset
# ----------------------------
target_classes = [0, 1, 8, 9]
label_map = {0: 0, 1: 1, 8: 2, 9: 3}

class FourClassCIFAR(datasets.CIFAR10):
    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        if label in label_map:
            return img, label_map[label]
        else:
            return None

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip()
])

full_dataset = FourClassCIFAR(root="./data", train=True, download=True, transform=transform)
filtered_dataset = [x for x in full_dataset if x is not None]
train_loader = DataLoader(filtered_dataset, batch_size=64, shuffle=True)

# ----------------------------
# Train the classifier with progress logging
# ----------------------------
freeze_encoder(classifier)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, classifier.parameters()), lr=1e-4)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    print(f"\n🟢 Epoch {epoch+1}/10")
    if epoch == 5:
        print("🔓 Unfreezing encoder...")
        unfreeze_encoder(classifier)
        optimizer = torch.optim.AdamW(classifier.parameters(), lr=1e-5)

    classifier.train()
    total, correct, running_loss = 0, 0, 0.0
    start = time()

    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        outputs = classifier(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        running_loss += loss.item()

        if (batch_idx + 1) % 20 == 0:
            print(f"  ⏳ Batch {batch_idx+1}/{len(train_loader)} - Loss: {loss.item():.4f}")

    acc = 100 * correct / total
    avg_loss = running_loss / len(train_loader)
    print(f"✅ Epoch {epoch+1} complete | Loss: {avg_loss:.4f} | Accuracy: {acc:.2f}% | Time: {time() - start:.1f}s")


Using: cpu
Skipping: encoder_pos.pos_embed due to shape mismatch

🟢 Epoch 1/10
  ⏳ Batch 20/313 - Loss: 1.4577
  ⏳ Batch 40/313 - Loss: 1.5614
  ⏳ Batch 60/313 - Loss: 1.4244
  ⏳ Batch 80/313 - Loss: 1.5019
  ⏳ Batch 100/313 - Loss: 1.4094
  ⏳ Batch 120/313 - Loss: 1.4101
  ⏳ Batch 140/313 - Loss: 1.4035
  ⏳ Batch 160/313 - Loss: 1.4045
  ⏳ Batch 180/313 - Loss: 1.3640
  ⏳ Batch 200/313 - Loss: 1.3956
  ⏳ Batch 220/313 - Loss: 1.4111
  ⏳ Batch 240/313 - Loss: 1.4024
  ⏳ Batch 260/313 - Loss: 1.4120
  ⏳ Batch 280/313 - Loss: 1.4072
  ⏳ Batch 300/313 - Loss: 1.4092
✅ Epoch 1 complete | Loss: 1.4263 | Accuracy: 24.75% | Time: 207.4s

🟢 Epoch 2/10
  ⏳ Batch 20/313 - Loss: 1.3731
  ⏳ Batch 40/313 - Loss: 1.4234
  ⏳ Batch 60/313 - Loss: 1.4068
  ⏳ Batch 80/313 - Loss: 1.4000
  ⏳ Batch 100/313 - Loss: 1.3949
  ⏳ Batch 120/313 - Loss: 1.3919
  ⏳ Batch 140/313 - Loss: 1.4009
  ⏳ Batch 160/313 - Loss: 1.3833
  ⏳ Batch 180/313 - Loss: 1.3807
  ⏳ Batch 200/313 - Loss: 1.3820
  ⏳ Batch 220/313 - Lo