In [1]:
from patcher import extract_patches, separate_patches, add_positional_encoding
from preprocess import preprocess_image, get_image_paths
from predictor import MLP_Predictor
from encoder import load_vit_encoder, extract_patch_embeddings

import os
import torch

In [2]:
folder_path = "jepa_Images\\Dog"  # Example folder

image_paths = get_image_paths(folder_path)

print(f"Found {len(image_paths)} images: {image_paths}")

Found 9753 images: ['jepa_Images\\Dog\\0.jpg', 'jepa_Images\\Dog\\1.jpg', 'jepa_Images\\Dog\\10.jpg', 'jepa_Images\\Dog\\100.jpg', 'jepa_Images\\Dog\\1000.jpg', 'jepa_Images\\Dog\\10001.jpg', 'jepa_Images\\Dog\\10003.jpg', 'jepa_Images\\Dog\\10004.jpg', 'jepa_Images\\Dog\\10005.jpg', 'jepa_Images\\Dog\\10008.jpg', 'jepa_Images\\Dog\\10009.jpg', 'jepa_Images\\Dog\\1001.jpg', 'jepa_Images\\Dog\\10010.jpg', 'jepa_Images\\Dog\\10011.jpg', 'jepa_Images\\Dog\\10012.jpg', 'jepa_Images\\Dog\\10013.jpg', 'jepa_Images\\Dog\\10014.jpg', 'jepa_Images\\Dog\\10015.jpg', 'jepa_Images\\Dog\\10016.jpg', 'jepa_Images\\Dog\\10017.jpg', 'jepa_Images\\Dog\\10018.jpg', 'jepa_Images\\Dog\\10019.jpg', 'jepa_Images\\Dog\\10021.jpg', 'jepa_Images\\Dog\\10022.jpg', 'jepa_Images\\Dog\\10023.jpg', 'jepa_Images\\Dog\\10024.jpg', 'jepa_Images\\Dog\\10026.jpg', 'jepa_Images\\Dog\\10027.jpg', 'jepa_Images\\Dog\\10028.jpg', 'jepa_Images\\Dog\\10030.jpg', 'jepa_Images\\Dog\\10031.jpg', 'jepa_Images\\Dog\\10032.jpg', 'je

In [10]:
# Assuming the rest of your code is the same:
context_encoder = load_vit_encoder(pretrained=True)  # Pretrained ViT encoder for context
target_encoder = load_vit_encoder(pretrained=True)   # Pretrained ViT encoder for target

predictor = MLP_Predictor(input_dim=1000, hidden_dim=512, output_dim=1000)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.AdamW(predictor.parameters(), lr=1e-4)

num_epochs = 50
processed_count = 0

In [11]:
# Loop through images for training
for epoch in range(num_epochs):
    predictor.train()

    total_loss = 0

    # Loop through each image
    for image_path in image_paths:

        # Step 1: Preprocess image
        preprocessed_image = preprocess_image(image_path)
   
        if preprocessed_image is None:
            #print("Image skipped due to size constraints.")
            continue  
        
        else:
            
            # Step 2: Extract patches
            patches = extract_patches(preprocessed_image)

            # Step 3: Separate context and target patches
            context_patches, target_patches, target_positions, context_positions = separate_patches(patches)

            # Step 4: Extract embeddings
            context_embeddings, target_embeddings = extract_patch_embeddings(context_patches, target_patches, context_encoder, target_encoder)

            # Step 5: Add positional encoding
            context_embeddings_with_pos = add_positional_encoding(context_embeddings, context_positions)
            target_embeddings_with_pos = add_positional_encoding(target_embeddings, target_positions)

            # Step 6: Forward pass through the predictor (context to predict target)
            predictions = predictor(context_embeddings_with_pos)  # Shape: [num_context_patches, d_model]

            # Slice predictions to match the target patches
            predictions_for_targets = predictions[:target_embeddings_with_pos.size(0)]  # Shape: [batch_size, output_dim]

            # Step 7: Loss calculation
            loss = criterion(predictions_for_targets, target_embeddings_with_pos)

            # Step 8: Backward pass and optimization
            loss.backward(retain_graph=True)
            optimizer.step()

            total_loss += loss.item()

            processed_count += 1
            if processed_count % 10 == 0:
                print(f"Processed {processed_count} images...")

    # Logging the loss every 5 epochs
    if epoch % 5 == 0:
        print(f"Epoch {epoch}/{num_epochs}, Loss: {total_loss / len(image_paths)}")


Processed 10 images...
Processed 20 images...
Processed 30 images...
Processed 40 images...
Processed 50 images...
Processed 60 images...
Processed 70 images...
Processed 80 images...
Processed 90 images...
Processed 100 images...
Processed 110 images...
Processed 120 images...


KeyboardInterrupt: 

In [None]:
# Save models and embeddings after training
torch.save(context_encoder.state_dict(), "final_context_encoder.pth")
torch.save(predictor.state_dict(), "final_predictor.pth")
torch.save(feature_embeddings, "all_features.pt")

In [None]:
# Function to get image paths from a folder
def get_image_paths(folder_path, file_extension="*.jpg"):
    # Use glob to get all image paths
    return glob.glob(os.path.join(folder_path, file_extension))

# Folder containing the images
folder_path = "C:\\Users\\Manush\\Documents\\PythonCode\\JEPA\\jepa_Images\\Dog"  # Example folder

# Get all image paths
image_paths = get_image_paths(folder_path)

# Check the paths (optional)
print(f"Found {len(image_paths)} images: {image_paths}")

In [None]:
# image_path = "C:\\Users\\Manush\\Documents\\PythonCode\\JEPA\\jepa_Images\\Dog\\77.jpg"
# preprocessed_image = preprocess_image(image_path)
# print(preprocessed_image.shape) 

In [None]:
# Initialize dataset and dataloader
image_folder = "C:\\Users\\Manush\\Documents\\PythonCode\\JEPA\\jepa_Images\\Dog"
dataset = JEPAImageDataset(image_folder, preprocess_image, extract_patches)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [None]:
# Define models
context_encoder = load_vit_encoder(pretrained=True)  # Load pretrained ViT encoder for context
target_encoder = load_vit_encoder(pretrained=True)   # Load pretrained ViT encoder for target
predictor = MLP_Predictor(input_dim=1000, hidden_dim=512, output_dim=1000)

In [None]:
# Loss and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.AdamW(predictor.parameters(), lr=1e-4)

In [None]:
# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    total_loss = 0
    for patches in dataloader:
        context_patches, target_patches, target_positions, context_positions = separate_patches(patches)

        # Extract embeddings
        context_embeddings, target_embeddings = extract_patch_embeddings(context_patches, target_patches, context_encoder, target_encoder)

        # Add positional encoding
        context_embeddings_with_pos = add_positional_encoding(context_embeddings, context_positions)
        target_embeddings_with_pos = add_positional_encoding(target_embeddings, target_positions)

        # Predictor forward pass
        predictions = predictor(context_embeddings_with_pos)  # Shape: [num_context_patches, d_model]
        predictions_for_targets = predictions[:target_embeddings_with_pos.size(0)]

        # Compute loss
        loss = criterion(predictions_for_targets, target_embeddings_with_pos)
        total_loss += loss.item()

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Logging per epoch
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss:.4f}")

# Save models and embeddings
torch.save(context_encoder.state_dict(), "context_encoder.pth")
torch.save(predictor.state_dict(), "predictor.pth")


In [None]:
patches = extract_patches(preprocessed_image)
print(patches.shape)  # Should output torch.Size([16, 3, 100, 100])

In [None]:
# Step 2: Separate into context and target patches

context_patches, target_patches, target_positions, context_positions = separate_patches(patches)
print(context_patches.shape) 
print(target_patches.shape) 
print(target_positions) 
print(context_positions) 

In [None]:
context_encoder = load_vit_encoder(pretrained=True)  # Load pretrained ViT encoder for context
target_encoder = load_vit_encoder(pretrained=True)   # Load pretrained ViT encoder for target

context_embeddings, target_embeddings = extract_patch_embeddings(context_patches, target_patches, context_encoder, target_encoder)

In [None]:
# Add positional encoding
print("Context embeddings")
context_embeddings_with_pos = add_positional_encoding(context_embeddings, context_positions)

print("\n\nTarget embeddings")
target_embeddings_with_pos = add_positional_encoding(target_embeddings, target_positions)

print("\n\nSummary")
print("Context embeddings with positional encoding:", context_embeddings_with_pos.shape)
print("Target embeddings with positional encoding:", target_embeddings_with_pos.shape)

In [None]:
predictor = MLP_Predictor(input_dim=1000, hidden_dim=512, output_dim=1000)

criterion = torch.nn.MSELoss()

optimizer = torch.optim.AdamW(predictor.parameters(), lr=1e-4)

In [None]:
# Assume context_encoder, predictor, context_embeddings_with_pos, and target_embeddings_with_pos are defined
num_epochs = 50
for epoch in range(num_epochs):

    predictor.train()

    # Forward pass through context encoder
    predictions = predictor(context_embeddings_with_pos)  # Shape: [num_context_patches, d_model]
    
    #print(target_embeddings_with_pos.size(0))

    # Slice predictions to match the target patches
    predictions_for_targets = predictions[:target_embeddings_with_pos.size(0)]  # Shape: [batch_size, output_dim]

    # Loss calculation
    loss = criterion(predictions_for_targets, target_embeddings_with_pos)

    # Backward pass and optimization
    loss.backward(retain_graph=True)
    optimizer.step()

    # Logging
    if epoch % 5 == 0:
        print(f"Epoch {epoch}/{num_epochs}, Loss: {loss.item()}")


In [None]:
from torchviz import make_dot

# Get the computational graph for the loss
dot = make_dot(loss, params=dict(predictor.named_parameters()))

# Render the graph to a file or display it
dot.render("graph", format="png")  # Saves as graph.png
#dot.view("graph")  # Opens the graph in your default viewer

In [None]:
def print_grad(grad):
    print(grad)
    
# Hook to track gradients at a specific layer (e.g., first layer)
for name, param in predictor.named_parameters():
    param.register_hook(print_grad)

In [None]:
torch.save(context_encoder.state_dict(), "context_encoder.pth")

torch.save(predictor.state_dict(), "predictor.pth")

torch.save(feature_embeddings, "features.pt")

In [None]:
# context_encoder.eval()
# with torch.no_grad():
#     embeddings = context_encoder(image_patches_with_pos) 


# from sklearn.decomposition import PCA
# pca = PCA(n_components=2)
# reduced_embeddings = pca.fit_transform(embeddings.numpy())