In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from google.colab import drive
import torch.nn.functional as F

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


CNN architecture used for encoding target/context blocks

In [None]:
# Define a simple CNN encoder for context and target encoders
class CNNEncoder(nn.Module):
    def __init__(self, embed_dim=512):
        super(CNNEncoder, self).__init__()
        # A simple CNN with a few convolutional layers followed by a fully connected layer
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),  # (B, 3, 224, 224) -> (B, 32, 224, 224)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # (B, 32, 224, 224) -> (B, 32, 112, 112)

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # (B, 32, 112, 112) -> (B, 64, 112, 112)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # (B, 64, 112, 112) -> (B, 64, 56, 56)

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # (B, 64, 56, 56) -> (B, 128, 56, 56)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # (B, 128, 56, 56) -> (B, 128, 28, 28)
        )
        self.fc = nn.Linear(128 * 28 * 28, embed_dim)  # Fully connected layer for embedding

    def forward(self, x):
        x = self.conv_layers(x)  # Pass through convolutional layers
        x = x.view(x.size(0), -1)  # Flatten the feature map
        return self.fc(x)  # Output embedding




In [None]:

# The context encoder and target encoder will be the same
class ContextEncoder(CNNEncoder):
    pass

In [None]:


class TargetEncoder(CNNEncoder):
    pass

The predictor used to predict the target blocks based on context blocks as input

In [None]:

# Define the predictor
class Predictor(nn.Module):
    def __init__(self, embed_dim=512):
        super(Predictor, self).__init__()
        self.fc1 = nn.Linear(embed_dim, embed_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        return self.fc2(x)  # Predict the target representation from context

In [None]:

# Self-supervised loss (L2 Loss)
class IJEPALoss(nn.Module):
    def __init__(self):
        super(IJEPALoss, self).__init__()
        self.criterion = nn.MSELoss()

    def forward(self, pred_target_rep, actual_target_rep):
        return self.criterion(pred_target_rep, actual_target_rep)

In [None]:
# Function to create context and target patches and resize them to (224, 224)
def create_context_target_blocks(image, mask_ratio=0.25):
    B, C, H, W = image.shape
    mask_size = int(H * mask_ratio)  # Calculate mask size based on mask ratio

    # Separate context and target blocks
    context = image[:, :, :H - mask_size, :W - mask_size]
    target = image[:, :, mask_size:, mask_size:]

    # Resize context and target back to (224, 224)
    resize_transform = transforms.Resize((224, 224))
    context = resize_transform(context)
    target = resize_transform(target)

    return context, target


I-JEPA Pretraining function

In [None]:

# I-JEPA Pretraining function 
def pretrain_ijepa(context_encoder, target_encoder, predictor, train_loader, start_epoch=56, epochs=200, lr=0.001, patience=3):
    optimizer = optim.Adam(list(context_encoder.parameters()) + list(predictor.parameters()), lr=lr)
    criterion = IJEPALoss()

    context_encoder.train()
    target_encoder.train()
    predictor.train()

    best_loss = 2.754716042266495e-05
    patience_counter = 0
    best_model_path = '/content/drive/MyDrive/best_ijepa_model_CNN_final.pth'

    # If starting from a specific epoch, load the saved state
    if start_epoch > 0:
        checkpoint = torch.load(best_model_path, map_location=torch.device('cpu'))
        context_encoder.load_state_dict(checkpoint['context_encoder'])
        target_encoder.load_state_dict(checkpoint['target_encoder'])
        predictor.load_state_dict(checkpoint['predictor'])
        print(f"Resumed training from epoch {start_epoch}")

    for epoch in range(start_epoch, epochs):
        running_loss = 0.0
        for images, _ in train_loader:
            images = images.to(device)

            # Create context and target blocks
            context, target = create_context_target_blocks(images)
            context, target = context.to(device), target.to(device)

            # Get representations
            context_rep = context_encoder(context)
            actual_target_rep = target_encoder(target)
            pred_target_rep = predictor(context_rep)

            # Compute the loss
            loss = criterion(pred_target_rep, actual_target_rep)

            # Optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss}')

        # Check for early stopping
        if avg_loss < best_loss:
            best_loss = avg_loss
            patience_counter = 0
            torch.save({
                'context_encoder': context_encoder.state_dict(),
                'target_encoder': target_encoder.state_dict(),
                'predictor': predictor.state_dict(),
            }, best_model_path)
            print(f"Best model saved at epoch {epoch + 1} with loss {best_loss}")
        else:
            patience_counter += 1
            print(f"Patience counter: {patience_counter}")

        if patience_counter >= patience:
            print("Early stopping triggered!")
            break

    print("I-JEPA pretraining complete!")

In [None]:
# DataLoader for CIFAR-10
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resizing for CNN
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(trainset, batch_size=64, shuffle=True)
valset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
val_loader = DataLoader(valset, batch_size=64, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:14<00:00, 12073587.78it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [None]:
# Initialize the context encoder, target encoder, and predictor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
context_encoder = ContextEncoder().to(device)
target_encoder = TargetEncoder().to(device)
predictor = Predictor().to(device)

In [None]:
# Pretrain the model using I-JEPA with validation loader
pretrain_ijepa(context_encoder, target_encoder, predictor, train_loader)

  checkpoint = torch.load(best_model_path, map_location=torch.device('cpu'))


Resumed training from epoch 56
Epoch 57/200, Loss: 2.8042292487813527e-05
Patience counter: 1
Epoch 58/200, Loss: 2.6461500049051543e-05
Best model saved at epoch 58 with loss 2.6461500049051543e-05
Epoch 59/200, Loss: 2.6581600236347217e-05
Patience counter: 1
Epoch 60/200, Loss: 2.6845710654434115e-05
Patience counter: 2
Epoch 61/200, Loss: 2.655543753032229e-05
Patience counter: 3
Early stopping triggered!
I-JEPA pretraining complete!


In [None]:
class ContextProcessor(nn.Module):
    def __init__(self):
        super(ContextProcessor, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(128 * 56 * 56, 512)  # Adjust this based on input size after convolutions

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.adaptive_avg_pool2d(x, (56, 56))  # Adjust based on your architecture
        x = x.view(x.size(0), -1)  # Flatten
        return self.fc(x)  # Output shape should be (batch_size, 512)


Linear Probing - for downstream task(Classification)

In [None]:
class LinearProbingClassifier(nn.Module):
    def __init__(self, predictor, num_classes=10):
        super(LinearProbingClassifier, self).__init__()
        self.context_processor = ContextProcessor()  # Initialize the context processor
        self.predictor = predictor
        self.fc = None
        self.num_classes = num_classes

    def forward(self, context):
        processed_context = self.context_processor(context)  # Process the context
        pred_target_rep = self.predictor(processed_context)

        # print(f"Predicted representation shape: {pred_target_rep.shape}")  # Debug output shape

        if self.fc is None:
            self.fc = nn.Linear(pred_target_rep.size(1), self.num_classes).to(context.device)

        return self.fc(pred_target_rep)


# Ensure to initialize the model and train it properly
linear_probing_model = LinearProbingClassifier(predictor).to(device)


def train_linear_probing(model, train_loader, val_loader, epochs=100, lr=0.1, patience=5):
    model.train()
    for images, _ in train_loader:
        images = images.to(device)
        context, _ = create_context_target_blocks(images)
        context = context.to(device)
        _ = model(context)  # Ensure fc layer is initialized during forward pass
        break  # Run this only for the first batch to initialize the fc layer

    optimizer = optim.Adam(model.fc.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    best_loss = float('inf')
    patience_counter = 0
    best_model_path = '/content/drive/MyDrive/best_linear_probing_model_cnn_final.pth'

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Create context block only
            context, _ = create_context_target_blocks(images)
            context = context.to(device)

            # Forward pass using the predictor
            outputs = model(context)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        avg_loss = running_loss / len(train_loader)
        accuracy = 100 * correct / total
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')

        # Validation loss computation
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        model.eval()
        with torch.no_grad():
            for val_images, val_labels in val_loader:
                val_images, val_labels = val_images.to(device), val_labels.to(device)

                context, _ = create_context_target_blocks(val_images)
                context = context.to(device)

                val_outputs = model(context)
                loss = criterion(val_outputs, val_labels)

                val_loss += loss.item()

                _, val_predicted = torch.max(val_outputs.data, 1)
                val_total += val_labels.size(0)
                val_correct += (val_predicted == val_labels).sum().item()

        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = 100 * val_correct / val_total
        print(f'Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')

        # Check for early stopping
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), best_model_path)
            print(f"Best model saved at epoch {epoch + 1} with loss {best_loss:.4f}")
        else:
            patience_counter += 1
            print(f"Patience counter: {patience_counter}")

        if patience_counter >= patience:
            print("Early stopping triggered!")
            break

    print("Linear probing training complete!")

# Load pre-trained weights for predictor
predictor = Predictor().to(device)  # Initialize the predictor

checkpoint_path = '/content/drive/MyDrive/best_ijepa_model_CNN_final.pth'
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'), weights_only=True)
predictor.load_state_dict(checkpoint['predictor'])  # Load predictor weights

# Train the linear probing model with the updated classifier
train_linear_probing(linear_probing_model, train_loader, val_loader)

print("Training process is complete.")


Files already downloaded and verified


  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


Epoch 1/100, Loss: 1.7926, Accuracy: 36.24%
Validation Loss: 1.7442, Validation Accuracy: 38.29%
Best model saved at epoch 1 with validation loss 1.7442
Epoch 2/100, Loss: 1.6764, Accuracy: 40.79%
Validation Loss: 1.6903, Validation Accuracy: 40.61%
Best model saved at epoch 2 with validation loss 1.6903
Epoch 3/100, Loss: 1.6577, Accuracy: 41.52%
Validation Loss: 1.6670, Validation Accuracy: 40.66%
Best model saved at epoch 3 with validation loss 1.6670
Epoch 4/100, Loss: 1.6397, Accuracy: 42.37%
Validation Loss: 1.7051, Validation Accuracy: 40.93%
Epoch 5/100, Loss: 1.6273, Accuracy: 42.57%
Validation Loss: 1.6900, Validation Accuracy: 40.87%
Epoch 6/100, Loss: 1.6246, Accuracy: 42.65%
Validation Loss: 1.6840, Validation Accuracy: 40.50%
Epoch 7/100, Loss: 1.6150, Accuracy: 43.02%
Validation Loss: 1.7064, Validation Accuracy: 41.50%
Epoch 8/100, Loss: 1.6115, Accuracy: 43.31%
Validation Loss: 1.6862, Validation Accuracy: 40.52%
Epoch 9/100, Loss: 1.6119, Accuracy: 43.40%
Validation L