In [2]:
import torch
from torch.utils.data import DataLoader
import open_clip

from sklearn.manifold import TSNE
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import itertools
import seaborn as sns
import pandas as pd

from imagedatasets_v1_1 import ThingsDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
vlmodel, preprocess, _ = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k', precision='fp32', device = device)

In [None]:
train_dataset = ThingsDataset(data_root='/root/workspace/wht/multimodal_brain/datasets/things-eeg-small/Image_set',mode='train')
test_dataset = ThingsDataset(data_root='/root/workspace/wht/multimodal_brain/datasets/things-eeg-small/Image_set',mode='test')

train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [None]:
all_features = []
all_labels = []

for i,batch in tqdm(enumerate(test_dataloader)):
    images, labels = batch
    images = images.to(device)
    with torch.no_grad():
        image_features = vlmodel.encode_image(images)
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    all_features.append(image_features.cpu())
    all_labels.append(labels)
all_features = torch.cat(all_features, dim=0)
all_labels = [list(ele) for ele in all_labels] 
flattened_list = list(itertools.chain.from_iterable(all_labels))
save_dict = {
    'features': all_features,
    'labels': flattened_list
}
torch.save(save_dict, 'test_image_features.pt')

In [3]:
train_data = torch.load('train_image_features.pt')
train_features = train_data['features']
train_labels = [0 for i in range(len(train_features))]

test_data = torch.load('test_image_features.pt')
test_features = test_data['features']
test_labels = [1 for i in range(len(test_features))]

all_features = torch.cat((train_features, test_features), dim=0)
all_labels = train_labels + test_labels

In [4]:
tsne = TSNE(n_components=2, random_state=0)
features_np = all_features.numpy()
features_2d = tsne.fit_transform(features_np)

In [None]:
df = pd.DataFrame(features_2d, columns=['t-SNE Component 1', 't-SNE Component 2'])
df['Label'] = all_labels
plt.figure(figsize=(10, 10),dpi = 400)
sns.scatterplot(x='t-SNE Component 1', y='t-SNE Component 2', hue='Label', data=df, 
                palette='viridis', alpha=0.6, legend='full', s=50)  # s 控制点的大小
plt.title('t-SNE Visualization of Image Features')
plt.xlabel('t-SNE Component 1')
plt.ylabel('t-SNE Component 2')
plt.legend(title='Categories')
plt.savefig('t-SNE_Visualization.png', format='png')