In [1]:
import torch
import os

In [2]:
train_dir = os.path.join('dataset', 'part_two_dataset', 'train_data')
eval_dir = os.path.join('dataset', 'part_two_dataset', 'eval_data')
save_dir = os.path.join('part_2_vit_embeds')
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  return torch._C._cuda_getDeviceCount() > 0


In [3]:
import torch
from torchvision import transforms, models
import numpy as np

def get_vit_embeddings(data_dict, batch_size=64):
    # Initialize ViT
    model = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT)
    model.heads = torch.nn.Identity()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    
    # Setup transform
    transform = transforms.Compose([
        # transforms.ToPILImage(),      # Converts numpy array to PIL image (required by torchvision transforms) HAVE TO TRY THIS
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    # transform = transforms.Compose([
    #     transforms.ToPILImage(),      # Converts numpy array to PIL image (required by torchvision transforms) HAVE TO TRY THIS
    #     transforms.Resize((224, 224)),
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=[0.485, 0.456, 0.406], 
    #                        std=[0.229, 0.224, 0.225])
    # ])
    
    # Get data and labels
    images = data_dict['data']  # Assuming shape (N, H, W, C)
    embeddings = []
    
    with torch.no_grad():
        # Process in batches
        for i in range(0, len(images), batch_size):
            batch_images = images[i:i + batch_size]
            
            batch_tensor = torch.stack([
                transform(img.astype(np.uint8)) for img in batch_images
            ])
            
            # Extract features
            batch_tensor = batch_tensor.to(device)
            batch_embeddings = model(batch_tensor)
            embeddings.append(batch_embeddings.cpu().numpy())
            
            # Optional: Print progress
            if (i + batch_size) % 1000 == 0:
                print(f"Processed {i + batch_size}/{len(images)} images")
    
    embeddings = np.vstack(embeddings)
    
    return embeddings

In [4]:
for j in range(10):
    path = os.path.join(save_dir,f'train_embeds_{j+1}.pt')
    
    if os.path.exists(path):
        continue
    
    train_path = os.path.join(train_dir, f'{j+1}_train_data.tar.pth')
    t = torch.load(train_path, weights_only = False)
    embeds = get_vit_embeddings(t)
    
    torch.save(embeds, path)

In [None]:
for j in range(10):
    path = os.path.join(save_dir,f'eval_embeds_{j+1}.pt')
    
    if os.path.exists(path):
        continue
    
    eval_path = os.path.join(eval_dir, f'{j+1}_eval_data.tar.pth')
    t = torch.load(eval_path, weights_only = False)
    embeds = get_vit_embeddings(t)
    
    torch.save(embeds, path)