In [45]:
from torchvision import transforms
import timm
from torch.utils.data import  DataLoader
from torchvision.datasets import ImageFolder
import pandas as pd
import torch
from tqdm import tqdm

In [46]:
#model zoo
BASE_CHECKPOINT = "hf_hub:Snarcy/RadioDino-b16"
SMALL_CHECKPOINT = "hf_hub:Snarcy/RadioDino-s16"
SMALL_CHECKPOINT_P8 = "hf_hub:Snarcy/RadioDino-s16"

DATASET_PATH="C:/Users/lucat/Documents/MEDMINST/data_split/pneumoniamnist_224/train"
OUTPUT_PATH="embeddings.csv"

In [47]:
#Red Dino preprocess
def preprocess(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    return transform(image).unsqueeze(0)

#Load a Red Dino model from our model zoo

model = timm.create_model(SMALL_CHECKPOINT, pretrained=True)

In [48]:
#create a dataset and dataloader from the dataset path
dataset = ImageFolder(DATASET_PATH, transform=preprocess)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [49]:

#Get the model to eval mode
model.eval()
#Get the model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
#Create a list to store the embeddings
embeddings = []
#Create a list to store the labels
labels = []
#Create a list to store the filenames
filenames = []
#Iterate over the dataloader
for  images, label in tqdm(dataloader):
    #Move the images to the GPU if available
    images = images.to(device)
    #if has 5 dimensions, remove the 2 dimension
    if images.dim() == 5:
        images = images.squeeze(1)
    with torch.no_grad():
        embedding = model(images)
    #Move the embeddings to CPU
    embedding = embedding.cpu()
    #Append the embeddings to the list
    embeddings.append(embedding)
    #Append the labels to the list
    labels.append(label)
#Append the filenames to the list
filenames.append(dataloader.dataset.samples)
#Concatenate the embeddings and labels
embeddings = torch.cat(embeddings, dim=0)
labels = torch.cat(labels, dim=0)
#Concatenate the filenames
filenames = [item for sublist in filenames for item in sublist]
#Create a dataframe from the embeddin7gs, labels and filenames
df = pd.DataFrame(embeddings.numpy())

    

100%|██████████| 148/148 [00:12<00:00, 11.56it/s]


In [50]:
print("Embeddings shape: ", embeddings.shape)
print("Labels shape: ", labels.shape)
print("Filenames shape: ", len(filenames))

df['label'] = labels.numpy()
df['filename'] = filenames
#move the label and filename columns to the front
df = df[['label', 'filename'] + [col for col in df.columns if col not in ['label', 'filename']]]
#Save the dataframe to a csv file
df.to_csv(OUTPUT_PATH, index=False)
print(f"Embeddings saved to {OUTPUT_PATH}")

Embeddings shape:  torch.Size([4708, 384])
Labels shape:  torch.Size([4708])
Filenames shape:  4708
Embeddings saved to embeddings.csv
