In [39]:
from Uni.model import CLIPModel_simple
import torch
import argparse

model_path = "./save/best128.pt"
print(model_path)


def get_args():
    parser = argparse.ArgumentParser(description='BRAIN CLIP')
    
    parser.add_argument('--save_path', type=str, default='./save', help='')
    parser.add_argument('--embedding_dim', type=int, default=256, help='')
    parser.add_argument('--projection_dim', type=int, default=128, help='')
    parser.add_argument('--dropout', type=float, default=0.1, help='')
    parser.add_argument('--temperature', type=float, default=1.0, help='')
    parser.add_argument('--vote', default=True, type=bool, help='')
    parser.add_argument('--topk', type=int, default=3, help='')
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='')
    
    return parser.parse_args(args=[])

args = get_args()


model = CLIPModel_simple(args).to(args.device)
state_dict = torch.load(model_path)
new_state_dict = {}
for key in state_dict.keys():
    new_key = key.replace('module.', '')  # remove the prefix 'module.'
    new_key = new_key.replace('well', 'spot')  # for compatibility with prior naming
    new_state_dict[new_key] = state_dict[key]

model.load_state_dict(new_state_dict)
model

./Uni/save/best128.pt


CLIPModel_simple(
  (image_encoder): ImageEncoder_linear(
    (model): Linear(in_features=128, out_features=256, bias=True)
  )
  (snabel_encoder): SnabelEncoder_linear(
    (model): Linear(in_features=234, out_features=256, bias=True)
  )
  (image_projection): EmbedHead(
    (projection): Linear(in_features=256, out_features=128, bias=True)
    (gelu): GELU(approximate='none')
    (fc): Linear(in_features=128, out_features=128, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (snabel_projection): EmbedHead(
    (projection): Linear(in_features=256, out_features=128, bias=True)
    (gelu): GELU(approximate='none')
    (fc): Linear(in_features=128, out_features=128, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
)

In [40]:

def get_img_embeddings(img, model):
    model.eval()
    image_features = model.image_encoder(img.cuda())
    image_embeddings = model.image_projection(image_features)

    return image_embeddings


def get_snp_embeddings(snp, model):
    model.eval()
    snp_features = model.snabel_encoder(snp.cuda())
    snp_embeddings = model.snabel_projection(snp_features)

    return snp_embeddings

In [41]:
import torch
# make your SNP or IMG as input
snp = torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0.,
        1., 0., 0., 2., 1., 0., 0., 0., 0., 2., 1., 1., 2., 0., 1., 1., 0., 0.,
        2., 0., 2., 1., 0., 0., 2., 1., 1., 0., 0., 2., 0., 2., 2., 2., 1., 1.,
        0., 1., 2., 2., 2., 0., 2., 2., 0., 0., 2., 2., 1., 0., 1., 0., 0., 1.,
        1., 2., 1., 0., 2., 2., 1., 0., 0., 2., 2., 2., 1., 0., 1., 0., 2., 2.,
        0., 1., 2., 2., 2., 0., 2., 2., 1., 0., 2., 0., 0., 0., 1., 2., 2., 0.,
        2., 2., 2., 0., 0., 2., 1., 2., 2., 0., 0., 2., 0., 0., 1., 1., 1., 0.,
        1., 2., 2., 0., 1., 1., 1., 0., 2., 0., 1., 1., 0., 2., 2., 0., 1., 2.,
        0., 2., 0., 0., 2., 0., 2., 2., 1., 0., 2., 0., 1., 2., 0., 0., 0., 0.,
        1., 1., 0., 2., 1., 0., 1., 2., 1., 2., 2., 1., 1., 0., 2., 0., 0., 0.,
        2., 0., 2., 1., 1., 1., 2., 1., 0., 1., 2., 0., 2., 0., 2., 2., 0., 0.,
        1., 1., 2., 0., 2., 0., 2., 2., 0., 2., 1., 0., 2., 2., 2., 0., 0., 2.,
        2., 2., 1., 0., 0., 2., 0., 1., 0., 2., 1., 0., 0., 0., 0., 0., 2., 0.])

if len(snp)>234:
    snp = snp[:234]
while len(snp)<234:
    snp.append(99)  # 99 as pad
    
snp = snp.unsqueeze(0)

img = torch.tensor([  670.2590,    87.7632,   226.5341,   436.2836,   134.7943,   186.3737,
         -172.9206, -1513.7437,  -193.6193,    28.0608,    11.5967,   449.3401,
          449.2867,   208.9356,   -61.2091,  -120.9889,   267.1925,  -126.9980,
          954.8094,  -609.8949,   786.1028,  -176.6902,   104.8324,    60.3755,
          425.5132,  -261.9039,   282.5353,  -122.8213,   756.0443,  -270.4344,
          -34.8492,  -834.0070,   122.5162,   407.1969,  1160.2214,  -428.9731,
         -631.9886,    97.9889,  -106.9545,   -36.9184,   867.3173,  -318.7118,
         -320.3081,  -320.8904,   845.5027,  1566.8265,   283.6511,   187.8840,
         -336.6803,   -81.8993,  -638.3338,   477.3666,   541.8198,   -50.5883,
         -202.5780,  -486.9493,   142.1280,    18.4567,  -132.8069,   -79.4203,
         -430.0923,   180.3173,  1016.8868,  -297.6619,   284.1148,   -17.6610,
          539.2893,   186.5835,    91.3026,   -36.7723,   178.6299,  -294.6639,
         -175.4978,   413.0392,   148.5926,  -548.8325,   -92.3212,   179.4273,
        -2874.2920,  3197.5642,  -230.1659,  -338.6367,    58.8648,  -307.3961,
         -402.1805,  -325.3324,   364.9382,   288.9287,   516.7051,  -125.3412,
           43.9312,   829.7002,   105.6240,   -15.1332,   896.2142,   845.3306,
         -689.6951,  -328.3130,   361.1234,   444.5173,   -88.8024,   362.8145,
          263.4659,    64.9619,  -232.2993,   -25.8410,   231.1041,   191.6732,
          268.9848,   849.3340,   378.3295,  -388.5266, -1046.9761,  -499.5015,
         -439.0554,  -323.1657,  -202.1560,   -30.4386,  -315.0648,  -613.0076,
          190.9714,  1115.7202,  -494.9961,   387.6808,   629.5889,     3.4351,
          465.5319,  -177.0818])

img = img.unsqueeze(0)
snp.shape
img.shape

torch.Size([1, 128])

In [42]:
img_embeddings = get_img_embeddings(img, model)
snp_embeddings = get_snp_embeddings(snp, model)
img_embeddings.shape
# snp_embeddings.shape

torch.Size([1, 128])

In [47]:
import numpy as np
# query = np.load("./embeddings/img_embeddings_query.npy")
# query = snp_embeddings
query = img_embeddings
refer = np.load("./embeddings/img_embeddings_refer.npy")
label_refer = np.load("./embeddings/labels_refer.npy")
# label_query = np.load("./embeddings/labels_query.npy")
if refer.shape[1]!=128:
    refer = refer.T
    
query.shape
refer.shape

(7671, 128)

In [48]:
import torch.nn.functional as F
def find_matches(refer_embeddings, query_embeddings, topk):
    # find the closest matches
    refer_embeddings = torch.tensor(refer_embeddings).cuda()
    query_embeddings = torch.tensor(query_embeddings).cuda()
    query_embeddings = F.normalize(query_embeddings, p=2, dim=-1)
    refer_embeddings = F.normalize(refer_embeddings, p=2, dim=-1)
    dot_similarity = query_embeddings @ refer_embeddings.T
    print(dot_similarity.shape)
    _, indices = torch.topk(dot_similarity.squeeze(0), k=topk)

    return indices.cpu().numpy()

In [66]:
index = find_matches(refer,query,topk=args.topk)
print(index)
if args.vote==True:
    # The voting mechanism compares the returned index samples
    # and their corresponding labels to make a vote.
    print('voting')
    if np.sum(label_refer[index] == 0)>np.sum(label_refer[index] == 1):
        pred = 0
        print('Not CN (AD or EMCI or LMCI or MCI or SMC or Patient)')
    else:
        pred = 1
        print('CN')

else:
    # Directly select the label of the index sample with the highest probability as the result.
    print('Max')
    pred = label_refer[index[0]]
    if pred == 0.0:
        print('Not CN (AD or EMCI or LMCI or MCI or SMC or Patient)')
    elif pred == 1.0:
        print('CN')

torch.Size([1, 7671])
[4512 3898  914]
voting
Not CN (AD or EMCI or LMCI or MCI or SMC or Patient)


  """
