In [1]:
import torch
import argparse
import os
import torch.nn as nn
import uuid
import faiss
import json
from utils.EEGDataset import EEGDataset
import time
from torch.autograd import Variable
from torchvision import transforms, datasets
import torch.nn.functional as F


In [2]:
# Define the LSTM model
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, channels=128, n_layers=2, out_features=384):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.n_layer = n_layers
        self.input_size = input_size
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=n_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, out_features)
    
    def forward(self, x):
        batch_size = x.size(0)
        lstm_init = (torch.zeros(self.n_layer, batch_size, self.hidden_size), torch.zeros(self.n_layer, batch_size, self.hidden_size))
        if x.is_cuda: lstm_init = (lstm_init[0].cuda(), lstm_init[0].cuda())
        lstm_init = (Variable(lstm_init[0], volatile=x.volatile), Variable(lstm_init[1], volatile=x.volatile))

        # Forward LSTM and get final state
        x = self.lstm(x, lstm_init)[0][:,-1,:]
        # x = F.softmax(self.fc(x))
        x = self.fc(x)

        return x
        # h0 = torch.zeros(self.n_layer, x.size(0), self.hidden_size)
        # c0 = torch.zeros(self.n_layer, x.size(0), self.hidden_size)
        # lstm_out, hidden_out = self.lstm(x, (h0, c0))
        # # out = self.fc(lstm_out[:, -1, :])
        # return lstm_out 

In [3]:
SUBJECT = 1
BATCH_SIZE = 8
learning_rate = 0.0001
EPOCHS = 50
SaveModelOnEveryEPOCH = 100
EEG_DATASET_PATH = "./data/eeg/eeg_signals_raw_with_mean_std.pth"
EEG_DATASET_SPLIT = "./data/eeg/block_splits_by_image_all.pth"

LSTM_INPUT_FEATURES = 128 # should be image features output.
LSTM_HIDDEN_SIZE = 460  # should be same as sequence length
selectedDataset = "imagenet40"

transform_image = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(256, antialias=True),       
    transforms.CenterCrop(224),  
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),  
])

dataset = EEGDataset(subset="train",
                         eeg_signals_path=EEG_DATASET_PATH,
                         eeg_splits_path=EEG_DATASET_SPLIT, 
                         subject=SUBJECT,
                         time_low=20,
                         time_high=480,
                         exclude_subjects=[],
                         convert_image_to_tensor=False,
                         apply_channel_wise_norm=True,
                         preprocessin_fn=transform_image)


val_dataset = EEGDataset(subset="val",
                         eeg_signals_path=EEG_DATASET_PATH,
                         eeg_splits_path=EEG_DATASET_SPLIT, 
                         subject=SUBJECT,
                         time_low=20,
                         time_high=480,
                         exclude_subjects=[],
                         convert_image_to_tensor=False,
                         apply_channel_wise_norm=True,
                         preprocessin_fn=transform_image)

{39: 'Egyptian_cat', 35: 'African_elephant', 0: 'sorrel', 21: 'capuchin', 8: 'giant_panda', 12: 'German_shepherd', 7: 'revolver', 30: 'grand_piano', 36: 'airliner', 10: 'canoe', 20: 'missile', 6: 'mountain_bike', 37: 'electric_locomotive', 24: 'convertible', 25: 'folding_chair', 22: 'pool_table', 32: 'banana', 28: 'electric_guitar', 9: 'daisy', 3: 'anemone_fish', 34: 'digital_watch', 38: 'radio_telescope', 17: 'desktop_computer', 14: "jack-o'-lantern", 11: 'lycaenid', 2: 'iron', 4: 'espresso_maker', 31: 'mountain_tent', 26: 'pajama', 13: 'running_shoe', 16: 'golf_ball', 23: 'mailbag', 18: 'broom', 27: 'mitten', 15: 'cellular_telephone', 1: 'parachute', 19: 'pizza', 29: 'reflex_camera', 33: 'bolete', 5: 'coffee_mug'}
Transforming data to channel wise norm across labels
Transforming data to channel wise norm across labels (done)
{39: 'Egyptian_cat', 35: 'African_elephant', 0: 'sorrel', 21: 'capuchin', 8: 'giant_panda', 12: 'German_shepherd', 7: 'revolver', 30: 'grand_piano', 36: 'airline

In [4]:
def initDinoV2Model(model= "dinov2_vits14"):
    dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", model)
    return dinov2_vits14

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dinov2_model = initDinoV2Model(model="dinov2_vits14").to(device)
dinov2_model = dinov2_model.eval()

Using cache found in C:\Users\ASUS/.cache\torch\hub\facebookresearch_dinov2_main


In [5]:
from utils import utils

class FLAGS:
    num_workers = 4
    dist_url = "env://"
    local_rank = 0

utils.init_distributed_mode(FLAGS)

Will run the code on one GPU.
| distributed init (rank 0): env://


In [None]:
data_loader_train = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    pin_memory=True,
    drop_last=False,
)

data_loader_val = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    pin_memory=True,
    drop_last=False,
)

dataset.extract_features(model=dinov2_model, data_loader=data_loader_train, replace_eeg=False)
val_dataset.extract_features(model=dinov2_model, data_loader=data_loader_val, replace_eeg=False)

# data_loader_train = torch.utils.data.DataLoader(
#     dataset,
#     batch_size=4,
#     pin_memory=True,
#     drop_last=False,
# )

# data_loader_val = torch.utils.data.DataLoader(
#     val_dataset,
#     batch_size=4,
#     pin_memory=True,
#     drop_last=False,
# )


In [None]:
eeg, label,image,i, image_features = next(iter(data_loader_train)) 
outs = dinov2_model(image.to(device))
features_length = outs.size(-1)
print(outs.size())

In [None]:
image_features

In [None]:
model = LSTMModel(input_size=LSTM_INPUT_FEATURES,hidden_size=LSTM_HIDDEN_SIZE,channels=128, out_features=features_length)
model

In [None]:
model.to(device)

In [None]:
output = model(eeg.to(device))
print(output.size())

In [None]:
import torch
import torch.nn as nn

class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, y_lstm, y_resnet):
        # Calculate the Euclidean distance between y_lstm and y_resnet
        loss = torch.mean(torch.square(y_lstm - y_resnet))
        return loss

In [None]:
import numpy as np

In [None]:
class Paramters:
    alpha = 1
    temperature = 0.977


def loss_fn_kd(outputs, labels, teacher_outputs, params):
    """
    Compute the knowledge-distillation (KD) loss given outputs, labels.
    "Hyperparameters": temperature and alpha

    NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
    and student expects the input tensor to be log probabilities! See Issue #2
    """
    alpha = params.alpha
    T = params.temperature
    # KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
    #                          F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
    #           F.cross_entropy(outputs, labels) * (1. - alpha)

    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1), F.softmax(teacher_outputs/T, dim=1)) 

    return KD_loss

In [None]:
# criterion = CustomLoss()
opt = torch.optim.Adam(lr=0.0001, params=model.parameters())
# criterion = 

epoch_losses = []
val_epoch_losses = []
for EPOCH in range(40):

    batch_losses = []
    val_batch_losses = []

    model.train()

    for data in data_loader_train:
        eeg, label,image,i, image_features = data

        image_features = torch.from_numpy(np.array(image_features)).to(device)
        # print(image_features.size())

        opt.zero_grad()
        lstm_output = model(eeg.to(device))
        # dinov2_out = dinov2_model(image.to(device))

        # loss = criterion(image_features, lstm_output)
        loss = loss_fn_kd(outputs=lstm_output,labels=None,teacher_outputs=image_features, params=Paramters)
        batch_losses.append(loss.item())

        loss.backward()
        opt.step()

    model.eval()

    for data in data_loader_val:
        eeg, label,image,i, image_features = data

        with torch.no_grad():
            image_features = torch.from_numpy(np.array(image_features)).to(device)
            lstm_output = model(eeg.to(device))
            # loss = criterion(image_features, lstm_output)
            loss = loss_fn_kd(outputs=lstm_output,labels=None,teacher_outputs=image_features, params=Paramters)
            val_batch_losses.append(loss.item())
    
    batch_losses = np.array(batch_losses)
    val_batch_losses = np.array(val_batch_losses)
    val_epoch_loss= val_batch_losses.mean()
    epoch_loss = batch_losses.mean()
    epoch_losses.append(epoch_loss)
    val_epoch_losses.append(val_epoch_loss)

    print(f"EPOCH {EPOCH} train_loss: {epoch_loss} val_loss: {val_epoch_loss}")

In [None]:
torch.save(model, "output/lstm_dinov2_distilled.pt")

In [None]:
lstm_features = []
lstm_features_labels = []


test_dataset = EEGDataset(subset="train",
                         eeg_signals_path=EEG_DATASET_PATH,
                         eeg_splits_path=EEG_DATASET_SPLIT, 
                         subject=SUBJECT,
                         time_low=20,
                         time_high=480,
                         exclude_subjects=[],
                         convert_image_to_tensor=False,
                         preprocessin_fn=transform_image)

model.eval()

for data in test_dataset:
    eeg, label,image,i, image_features = data
    with torch.no_grad():

        # image_features = torch.from_numpy(np.array(image_features)).to(device)
        lstm_output = model(eeg.unsqueeze(0).to(device))
        # loss = criterion(image_features, lstm_output)
        lstm_features.append(lstm_output.cpu().numpy())
        lstm_features_labels.append(label)

In [None]:
lstm_features_labels_int = []
for label in lstm_features_labels:
    lstm_features_labels_int.append(label["ClassId"])

In [None]:
lstm_features = np.array(lstm_features)
lstm_features = lstm_features.reshape(len(lstm_features_labels_int),-1)
lstm_features.shape

In [None]:
import matplotlib.patches as mpatches 
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

In [None]:
X_tsne_sample0_time_RAW = TSNE(n_components=3,perplexity=40, init="pca", learning_rate=0.1, n_iter=300).fit_transform(lstm_features)

handles = []
cmaps = []
gen_colors = []
cmap = plt.cm.get_cmap("hsv",len(list(set(lstm_features_labels_int))))
for eeg_label in list(set(lstm_features_labels_int)):
    _patch = mpatches.Patch(color=cmap(eeg_label), label=f'Class {eeg_label}') 
    cmaps.append(cmap(eeg_label))
    handles.append(_patch)
for i in range(lstm_features.shape[0]):
    colorMap = cmaps[lstm_features_labels_int[i]]
    gen_colors.append(colorMap)

plt.clf()

In [None]:
fig = plt.figure(figsize=(20, 20))
fig.set_size_inches(20,20)
ax = fig.add_subplot(111, projection="3d")
fig.add_axes(ax)

ax.set_title("EEG data")
# ax.view_init(azim=90, elev=1)
ax.view_init(azim=60, elev=30)
_ = ax.text2D(0.8, 0.05, s="n_samples=1500", transform=ax.transAxes)

# sel_channel = 97
ax.scatter(X_tsne_sample0_time_RAW[:,0], X_tsne_sample0_time_RAW[:,1], X_tsne_sample0_time_RAW[:,2], c=gen_colors, s=30, alpha=0.8)
ax.legend(handles=handles, loc="best", fontsize=13, bbox_to_anchor=(1.2, 0.1),fancybox=True,ncol=5)
# fig.savefig(f"./output/AprilTsneAnalysis/Channel_{SelectedChannels[0]}_start{TimeStart}_end{TimeEnd}.png")
plt.show()