In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import confusion_matrix, accuracy_score

#### Q1: First Part

In [28]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import random
from sklearn.metrics import accuracy_score, confusion_matrix

# Hyperparameters....
learning_rate = 0.001
batch_size = 32
epochs = 20
num_classes = 3
patch_size = 7
hidden_dim = 128
mask_ratio = 0.5

transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_data = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

def sample_mnist(data, classes, num_samples=100):
    indices = []
    for cls in classes:
        cls_indices = [i for i, (img, label) in enumerate(data) if label == cls]
        sampled_indices = random.sample(cls_indices, num_samples)
        indices.extend(sampled_indices)
    return Subset(data, indices)

classes = [0, 1, 2]
train_subset = sample_mnist(train_data, classes)
test_subset = sample_mnist(test_data, classes)

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

# ViT Block....
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.o = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
        attn_scores = (q @ k.transpose(-2, -1)) / self.head_dim ** 0.5
        attn_probs = attn_scores.softmax(dim=-1)
        attended_values = (attn_probs @ v).transpose(1, 2).reshape(B, N, C)
        return self.o(attended_values)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        x = self.gelu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return self.dropout(x)

class ViTBlock(nn.Module):
    def __init__(self, dim, num_heads, hidden_dim):
        super(ViTBlock, self).__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadSelfAttention(dim, num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.ffn = FeedForward(dim, hidden_dim)

    def forward(self, x):
        attn_output = self.attn(self.norm1(x))
        x = x + attn_output
        ffn_output = self.ffn(self.norm2(x))
        x = x + ffn_output
        return x

# Masking function....
def mask_tokens(images, mask_ratio=0.5):
    B, C, H, W = images.shape
    patch_size = 7
    num_patches = (H // patch_size) * (W // patch_size)
    num_masked = int(num_patches * mask_ratio)

    patches = images.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    patches = patches.contiguous().view(B, C, -1, patch_size, patch_size)

    masked_patches = patches.clone()
    for i in range(B):
        mask_indices = random.sample(range(patches.size(2)), num_masked)
        masked_patches[i, :, mask_indices, :, :] = 0  # Set masked patches to 0...

    masked_images = masked_patches.view(B, C, H // patch_size, patch_size, W // patch_size, patch_size)
    masked_images = masked_images.permute(0, 1, 2, 4, 3, 5).contiguous().view(B, C, H, W)

    return masked_images

# MAE with ViT....
class MAE_ViT(nn.Module):
    def __init__(self, num_classes=3, dim=128, num_heads=8, hidden_dim=256, patch_size=7):
        super(MAE_ViT, self).__init__()
        self.patch_size = patch_size
        self.patch_embed = nn.Conv2d(in_channels=1, out_channels=dim, kernel_size=patch_size, stride=patch_size)
        self.transformer_blocks = nn.Sequential(
            ViTBlock(dim, num_heads, hidden_dim),
            ViTBlock(dim, num_heads, hidden_dim)
        )
        self.reconstruction_layer = nn.Linear(dim, patch_size * patch_size)
        self.classification_layer = nn.Linear(dim, num_classes)

    def forward(self, x, pretrain=True):
        B, C, H, W = x.shape
        patches = self.patch_embed(x).view(B, -1, 128)
        tokens = self.transformer_blocks(patches)
        if pretrain:
            return tokens, patches  # Return tokens and original patches for MAE loss...
        else:
            return tokens.mean(dim=1)  # Use mean pooling for classification...

def mae_pretrain(model, train_loader, num_epochs=20):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for images, _ in train_loader:
            optimizer.zero_grad()
            masked_images = mask_tokens(images)
            tokens, patches = model(masked_images, pretrain=True)
            # Apply L2 loss on masked tokens only....
            loss = F.mse_loss(tokens, patches)  # Modify this to apply loss on masked tokens...
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader)}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MAE_ViT().to(device)
mae_pretrain(model, train_loader)

def finetune(model, train_loader, test_loader, num_epochs=20):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images.to(device), pretrain=False)
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {total_train_loss / len(train_loader)}")

    # Test the model...
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images.to(device), pretrain=False)
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)

    print(f"Test Accuracy: {acc * 100:.2f}%")
    print("Confusion Matrix:\n", cm)

finetune(model, train_loader, test_loader)


Epoch 1/20, Loss: 0.06600709445774555
Epoch 2/20, Loss: 0.021137029770761727
Epoch 3/20, Loss: 0.011758691165596247
Epoch 4/20, Loss: 0.007916843285784125
Epoch 5/20, Loss: 0.005966625642031431
Epoch 6/20, Loss: 0.004752375325188041
Epoch 7/20, Loss: 0.003998328046873212
Epoch 8/20, Loss: 0.0035275065572932364
Epoch 9/20, Loss: 0.003173700114712119
Epoch 10/20, Loss: 0.0028679018141701818
Epoch 11/20, Loss: 0.002574163652025163
Epoch 12/20, Loss: 0.0023936241399496795
Epoch 13/20, Loss: 0.002202636655420065
Epoch 14/20, Loss: 0.0020497368648648264
Epoch 15/20, Loss: 0.0019097160897217692
Epoch 16/20, Loss: 0.001777665107510984
Epoch 17/20, Loss: 0.001656846550758928
Epoch 18/20, Loss: 0.0015050343121401966
Epoch 19/20, Loss: 0.001409724703989923
Epoch 20/20, Loss: 0.0013280821382068097
Epoch 1/20, Train Loss: 2.6212555170059204
Epoch 2/20, Train Loss: 0.9580486059188843
Epoch 3/20, Train Loss: 0.5852215141057968
Epoch 4/20, Train Loss: 0.251011623442173
Epoch 5/20, Train Loss: 0.129252

#### Second Part

In [29]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import random
from sklearn.metrics import accuracy_score, confusion_matrix
import torch.nn.functional as F

patch_size = 7
temperature = 0.5

# Vision Transformer (ViT) with InfoNCE loss...
class ViT_InfoNCE(nn.Module):
    def __init__(self, dim=128, num_heads=8, hidden_dim=256, num_classes=3, temperature=temperature):
        super(ViT_InfoNCE, self).__init__()
        self.patch_embed = nn.Conv2d(in_channels=1, out_channels=dim, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))  # CLS token....
        self.pos_embed = nn.Parameter(torch.zeros(1, 17, dim))  # 16 patches + 1 CLS token...
        self.transformer_blocks = nn.Sequential(
            ViTBlock(dim, num_heads, hidden_dim),
            ViTBlock(dim, num_heads, hidden_dim)
        )
        self.classifier = nn.Linear(dim, num_classes)  # For fine-tuning...
        self.temperature = temperature

    def forward(self, x, train_infoNCE=True):
        B, C, H, W = x.shape
        patches = self.patch_embed(x).view(B, -1, 128)  # [B, 16, dim]...
        cls_tokens = self.cls_token.expand(B, -1, -1)  # [B, 1, dim]...
        tokens = torch.cat((cls_tokens, patches), dim=1)  # Append CLS token....

        tokens = tokens + self.pos_embed

        x = self.transformer_blocks(tokens)

        if train_infoNCE:
            # InfoNCE loss applied to CLS token......
            return x[:, 0]
        else:
            # For classification....
            return self.classifier(x[:, 0])

def infoNCE_loss(cls_token_output, temperature=temperature):
    B = cls_token_output.size(0)
    cls_token_output = F.normalize(cls_token_output, dim=-1)  # Normalize CLS token....
    similarity_matrix = cls_token_output @ cls_token_output.T  # Cosine similarity....
    similarity_matrix = similarity_matrix / temperature

    labels = torch.arange(B).to(device)
    loss = F.cross_entropy(similarity_matrix, labels)
    return loss

def pretrain_infoNCE(model, train_loader, num_epochs=20):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for images, _ in train_loader:
            optimizer.zero_grad()
            cls_token_output = model(images.to(device), train_infoNCE=True)
            loss = infoNCE_loss(cls_token_output)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{num_epochs}, InfoNCE Loss: {total_loss/len(train_loader)}")

def finetune(model, train_loader, test_loader, num_epochs=20):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images.to(device), train_infoNCE=False)
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {total_train_loss / len(train_loader)}")

    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images.to(device), train_infoNCE=False)
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)

    print(f"Test Accuracy: {acc * 100:.2f}%")
    print("Confusion Matrix:\n", cm)

model = ViT_InfoNCE().to(device)
pretrain_infoNCE(model, train_loader)

finetune(model, train_loader, test_loader)

Epoch 1/20, InfoNCE Loss: 1.9629568696022033
Epoch 2/20, InfoNCE Loss: 1.7736137509346008
Epoch 3/20, InfoNCE Loss: 1.7321180701255798
Epoch 4/20, InfoNCE Loss: 1.7227694511413574
Epoch 5/20, InfoNCE Loss: 1.6957173466682434
Epoch 6/20, InfoNCE Loss: 1.6895554900169372
Epoch 7/20, InfoNCE Loss: 1.6965280771255493
Epoch 8/20, InfoNCE Loss: 1.6996080219745635
Epoch 9/20, InfoNCE Loss: 1.67894247174263
Epoch 10/20, InfoNCE Loss: 1.661945289373398
Epoch 11/20, InfoNCE Loss: 1.6459826350212097
Epoch 12/20, InfoNCE Loss: 1.6488050758838653
Epoch 13/20, InfoNCE Loss: 1.6480465352535247
Epoch 14/20, InfoNCE Loss: 1.6484094440937043
Epoch 15/20, InfoNCE Loss: 1.6637329041957856
Epoch 16/20, InfoNCE Loss: 1.649157840013504
Epoch 17/20, InfoNCE Loss: 1.655005806684494
Epoch 18/20, InfoNCE Loss: 1.6447826504707337
Epoch 19/20, InfoNCE Loss: 1.6483118951320648
Epoch 20/20, InfoNCE Loss: 1.6523225009441376
Epoch 1/20, Train Loss: 0.3856382817029953
Epoch 2/20, Train Loss: 0.0511058266973123
Epoch 3/

#### Third Part

In [19]:
temperature = 0.5
mae_weight = 0.8
infoNCE_weight = 0.2

# Vision Transformer (ViT) with combined MAE and InfoNCE loss.....
class ViT_MAE_InfoNCE(nn.Module):
    def __init__(self, dim=128, num_heads=8, hidden_dim=256, num_classes=3, temperature=0.5):
        super(ViT_MAE_InfoNCE, self).__init__()
        self.patch_embed = nn.Conv2d(in_channels=1, out_channels=dim, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 17, dim))
        self.transformer_blocks = nn.Sequential(
            ViTBlock(dim, num_heads, hidden_dim),
            ViTBlock(dim, num_heads, hidden_dim)
        )
        self.decoder = nn.Linear(dim, dim)  # For MAE, output same dimension as patch embeddings....
        self.classifier = nn.Linear(dim, num_classes)  # For fine-tuning....
        self.temperature = temperature

    def forward(self, x, mask_ratio=0.5, pretrain=True):
        B, C, H, W = x.shape
        patches = self.patch_embed(x).view(B, -1, 128)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        tokens = torch.cat((cls_tokens, patches), dim=1)

        tokens = tokens + self.pos_embed

        # Pass through transformer....
        x = self.transformer_blocks(tokens)

        if pretrain:
            # Masking for MAE..
            num_patches = tokens.shape[1] - 1
            num_masked = int(mask_ratio * num_patches)
            mask_indices = torch.randperm(num_patches)[:num_masked]

            # Masked patches for MAE...
            masked_patches = patches.clone()
            masked_patches[:, mask_indices, :] = 0  # Set masked patches to 0...

            # MAE Loss: Reconstruction of masked patches...
            decoded_patches = self.decoder(x[:, 1:])  # Decode to patch embeddings....
            mae_loss = F.mse_loss(decoded_patches[:, mask_indices], patches[:, mask_indices])

            # InfoNCE Loss: Applied to CLS token...
            cls_token_output = x[:, 0]  # Extract CLS token...
            infoNCE_loss = self.compute_infoNCE_loss(cls_token_output)

            # Weighted combined loss...
            combined_loss = mae_weight * mae_loss + infoNCE_weight * infoNCE_loss
            return combined_loss
        else:
            # For classification
            return self.classifier(x[:, 0])

    def compute_infoNCE_loss(self, cls_token_output):
        B = cls_token_output.size(0)
        cls_token_output = F.normalize(cls_token_output, dim=-1)
        similarity_matrix = cls_token_output @ cls_token_output.T
        similarity_matrix = similarity_matrix / self.temperature

        labels = torch.arange(B).to(device)
        loss = F.cross_entropy(similarity_matrix, labels)
        return loss

def pretrain_MAE_InfoNCE(model, train_loader, num_epochs=20):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for images, _ in train_loader:
            optimizer.zero_grad()
            loss = model(images.to(device), pretrain=True)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{num_epochs}, Combined Loss (MAE + InfoNCE): {total_loss/len(train_loader)}")

def finetune(model, train_loader, test_loader, num_epochs=20):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images.to(device), pretrain=False)
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {total_train_loss / len(train_loader)}")

    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images.to(device), pretrain=False)
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)

    print(f"Test Accuracy: {acc * 100:.2f}%")
    print("Confusion Matrix:\n", cm)

model = ViT_MAE_InfoNCE().to(device)
pretrain_MAE_InfoNCE(model, train_loader)
finetune(model, train_loader, test_loader)


Epoch 1/20, Combined Loss (MAE + InfoNCE): 0.44800223410129547
Epoch 2/20, Combined Loss (MAE + InfoNCE): 0.38868720531463624
Epoch 3/20, Combined Loss (MAE + InfoNCE): 0.3696760326623917
Epoch 4/20, Combined Loss (MAE + InfoNCE): 0.3638882577419281
Epoch 5/20, Combined Loss (MAE + InfoNCE): 0.35507997423410415
Epoch 6/20, Combined Loss (MAE + InfoNCE): 0.3482196643948555
Epoch 7/20, Combined Loss (MAE + InfoNCE): 0.34474830478429797
Epoch 8/20, Combined Loss (MAE + InfoNCE): 0.3458032086491585
Epoch 9/20, Combined Loss (MAE + InfoNCE): 0.3455981820821762
Epoch 10/20, Combined Loss (MAE + InfoNCE): 0.3439487859606743
Epoch 11/20, Combined Loss (MAE + InfoNCE): 0.34252172112464907
Epoch 12/20, Combined Loss (MAE + InfoNCE): 0.34050036370754244
Epoch 13/20, Combined Loss (MAE + InfoNCE): 0.3399395927786827
Epoch 14/20, Combined Loss (MAE + InfoNCE): 0.33940733373165133
Epoch 15/20, Combined Loss (MAE + InfoNCE): 0.33965698182582854
Epoch 16/20, Combined Loss (MAE + InfoNCE): 0.3383944734

#### Fourth Part

In [27]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import random
from sklearn.metrics import accuracy_score, confusion_matrix

learning_rate = 0.001
batch_size = 16
epochs = 20
num_classes = 3
patch_size = 4  # Smaller patch size for video...
hidden_dim = 128
mask_ratio = 0.8  # 80% tokens are discarded...
temperature = 0.07

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([transforms.ToTensor()])

train_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_data = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

def sample_mnist(data, classes, num_samples=100):
    indices = []
    for cls in classes:
        cls_indices = [i for i, (img, label) in enumerate(data) if label == cls]
        sampled_indices = random.sample(cls_indices, num_samples)
        indices.extend(sampled_indices)
    return Subset(data, indices)

classes = [0, 1, 2]
train_subset = sample_mnist(train_data, classes)
test_subset = sample_mnist(test_data, classes)

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

# Video MAE Model
class VideoMAE(nn.Module):
    def __init__(self, dim=128, num_heads=8, hidden_dim=256, num_classes=3, patch_size=4, mask_ratio=0.8):
        super(VideoMAE, self).__init__()
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))  # CLS token...

        # Compute the correct number of patches from the image dimensions...
        self.num_patches_per_frame = (28 // patch_size) * (28 // patch_size)  # MNIST 28x28 images...
        self.num_patches = 3 * self.num_patches_per_frame  # 3 frames
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, dim))  # Positional embeddings for all tokens + CLS token...

        self.patch_embed = nn.Conv2d(in_channels=1, out_channels=dim, kernel_size=patch_size, stride=patch_size)
        self.transformer_blocks = nn.Sequential(
            ViTBlock(dim, num_heads, hidden_dim),
            ViTBlock(dim, num_heads, hidden_dim)
        )
        self.decoder = nn.Linear(dim, dim)  # Ensure the decoder output matches the patch embedding size (128 dimensions)....
        self.classifier = nn.Linear(dim, num_classes)  # Classifier for fine-tuning...

    def forward(self, frames, mask_ratio=0.8, pretrain=True):
        B, T, C, H, W = frames.shape  # B: batch size, T: number of frames, C: channels, H/W: height/width....
        patches = []

        # Tokenize each frame
        for t in range(T):
            patch = self.patch_embed(frames[:, t]).view(B, -1, 128)  # [B, num_patches_per_frame, dim]....
            patches.append(patch)

        patches = torch.cat(patches, dim=1)  # Combine tokens from all 3 frames [B, num_patches, dim]...
        cls_tokens = self.cls_token.expand(B, -1, -1)  # [B, 1, dim]...
        tokens = torch.cat((cls_tokens, patches), dim=1)  # [B, num_patches + 1, dim]....

        # Ensure that positional embeddings are correctly matched to tokens
        tokens = tokens + self.pos_embed[:, :tokens.size(1), :]

        if pretrain:
            # Randomly shuffle and mask tokens...
            num_tokens = tokens.shape[1] - 1  # Exclude CLS token...
            num_masked = int(mask_ratio * num_tokens)
            shuffle_indices = torch.randperm(num_tokens)
            mask_indices = shuffle_indices[:num_masked]
            remaining_indices = shuffle_indices[num_masked:]  # Keep the remaining tokens

            masked_tokens = tokens.clone()
            masked_tokens[:, mask_indices, :] = 0

            # Encoder...
            encoded_tokens = self.transformer_blocks(masked_tokens)

            # Inverse shuffling before decoder
            unmasked_tokens = encoded_tokens[:, 1:].clone()  # Remove CLS token
            unmasked_tokens[:, remaining_indices, :] = encoded_tokens[:, 1:][:, remaining_indices, :]  # Inverse shuffle

            # Decoder: Reconstruct masked tokens....
            decoded_tokens = self.decoder(unmasked_tokens)  # Match original patch embedding size (128)...

            l1_loss = F.l1_loss(decoded_tokens[:, mask_indices], patches[:, mask_indices])
            return l1_loss

        else:
            # For classification during fine-tuning....
            encoded_tokens = self.transformer_blocks(tokens)
            return self.classifier(encoded_tokens[:, 0])


# Pretraining with Video MAE
def pretrain_video_MAE(model, train_loader, num_epochs=20):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for images, _ in train_loader:
            # Simulate 3 frames by duplicating image tensor 3 times...
            frames = images.unsqueeze(1).repeat(1, 3, 1, 1, 1)  # [B, 3, C, H, W]...
            optimizer.zero_grad()
            loss = model(frames.to(device), pretrain=True)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{num_epochs}, Pretraining Loss: {total_loss/len(train_loader)}")


def finetune_video_MAE(model, train_loader, test_loader, num_epochs=20):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0
        for images, labels in train_loader:
            # Simulate 3 frames...
            frames = images.unsqueeze(1).repeat(1, 3, 1, 1, 1)  # [B, 3, C, H, W]....
            optimizer.zero_grad()
            outputs = model(frames.to(device), pretrain=False)
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {total_train_loss / len(train_loader)}")

    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in test_loader:
            frames = images.unsqueeze(1).repeat(1, 3, 1, 1, 1)  # [B, 3, C, H, W]...
            outputs = model(frames.to(device), pretrain=False)
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)

    print(f"Test Accuracy: {acc * 100:.2f}%")
    print("Confusion Matrix:\n", cm)

model = VideoMAE().to(device)
pretrain_video_MAE(model, train_loader)
finetune_video_MAE(model, train_loader, test_loader)



Epoch 1/20, Pretraining Loss: 0.18979522978004656
Epoch 2/20, Pretraining Loss: 0.13230857723637632
Epoch 3/20, Pretraining Loss: 0.10349317914561222
Epoch 4/20, Pretraining Loss: 0.0842678668467622
Epoch 5/20, Pretraining Loss: 0.06977845002946101
Epoch 6/20, Pretraining Loss: 0.058081918249004764
Epoch 7/20, Pretraining Loss: 0.048894419089743964
Epoch 8/20, Pretraining Loss: 0.041238336774863694
Epoch 9/20, Pretraining Loss: 0.035019256566700185
Epoch 10/20, Pretraining Loss: 0.02993740249229105
Epoch 11/20, Pretraining Loss: 0.02602270980806727
Epoch 12/20, Pretraining Loss: 0.02256613754128155
Epoch 13/20, Pretraining Loss: 0.019423122762849455
Epoch 14/20, Pretraining Loss: 0.01706473298959042
Epoch 15/20, Pretraining Loss: 0.015256943132140134
Epoch 16/20, Pretraining Loss: 0.013761960852303003
Epoch 17/20, Pretraining Loss: 0.012073563607899766
Epoch 18/20, Pretraining Loss: 0.010817916867764373
Epoch 19/20, Pretraining Loss: 0.009934318938145512
Epoch 20/20, Pretraining Loss: 