In [357]:
# 模型
import os
import pickle
import sys
import h5py
import torch
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
from torch.autograd import Variable
import cv2
import numpy as np
from tqdm import tnrange, tqdm_notebook
import models
import sklearn.metrics as metrics
import matplotlib.pyplot as plt
import h5py
from PIL import Image

device = 'cuda:0' # device where you put your data and models
data_path = './' # the path of the 'npc_v4_data.h5' file
batch_size = 16 # the batch size of the data loader
insp_layer = 'conv3' # the middle layer extracted from alexnet, available in {'conv1', 'conv2', 'conv3', 'conv4', 'conv5'}

mse_weight = 1.0
l1_weight = 0
spa_weight = 1e-1
ch_weight = 1e-1
lap_weight = 1e-1


#定义损失函数
K = torch.tensor([
    [0,-1,0],
    [-1,4,-1],
    [0,-1,0]],dtype=torch.float)
def mse_loss(prediction, response, weight=None):
    if weight is None:
        mse_loss = torch.mean(torch.mean((prediction - response)**2, dim=1))
    else:
        mse_loss = torch.sum(weight*torch.mean((prediction - response)**2, dim=1))
    return mse_loss

def l2_norm_regularizer(W):
    return torch.mean(torch.sum(W**2))

def l1_norm_regularizer(W):
    return torch.mean(torch.sum(torch.abs(W)))

def smoothness_regularizer_2d(W_s):
    lap = torch.tensor([[0., -1., 0.], [-1., 4.0, -1.], [0., -1., 0.]])
    lap = lap.unsqueeze(0).unsqueeze(0)
    lap = lap.to(device)
    W_s = W_s.to(device)
    out_channels = W_s.size(1)
    W_lap = torch.nn.functional.conv2d(W_s.permute(1, 2, 3, 0),
                                       lap.repeat(out_channels, 1, 1, 1),
                                       stride=1, padding=1)

    penalty = torch.sqrt(torch.sum(W_lap**2, dim=[1, 2])).mean()
    return penalty

def torch_pearson(prediction, response):
    prediction = torch.from_numpy(prediction)
    response = torch.from_numpy(response)
    prediction_mean = torch.mean(prediction, dim=0)
    response_mean = torch.mean(response, dim=0)
    num = torch.sum((prediction - prediction_mean)*(response - response_mean), dim=0)
    den = torch.sqrt(torch.sum((prediction - prediction_mean)**2, dim=0) *
                     torch.sum((response - response_mean)**2, dim=0))
    return torch.mean(num * (1/den))

def explained_variance_score(prediction, response):
    prediction = torch.from_numpy(prediction)
    response = torch.from_numpy(response)
    num = torch.mean((prediction - response)**2, dim=0)
    den = torch.var(response, dim=0)
    ve = 1 - num* (1/den)
    ve_avg = torch.mean(ve)
    return ve_avg

class conv_encoder(nn.Module):

    def __init__(self, neurons, sizes, channels, reg_model_weight = None):
        super(conv_encoder, self).__init__()
        # PUT YOUR CODES HERE
        #print(sizes)
        self.neurons = neurons
        self.channels = channels
        self.px_x_conv = int(sizes[1])
        self.px_y_conv = int(sizes[0])
        self.px_conv = self.px_x_conv * self.px_y_conv

        if reg_model_weight is not None:
            ws_initial_value = torch.from_numpy(reg_model_weight['W_s'][:].reshape(self.neurons, self.px_conv)).transpose(0, 1).float()
            self.W_spatial = torch.nn.Parameter(ws_initial_value)
        else:
            self.W_spatial = torch.nn.Parameter(torch.randn(self.px_conv, neurons) * 0.001)

        if reg_model_weight is not None:
            wf_initial_value = torch.from_numpy(reg_model_weight['W_d'][:]).transpose(0, 1).float()
            self.W_features = torch.nn.Parameter(wf_initial_value)
        else:
            self.W_features = torch.nn.Parameter(torch.randn(channels, neurons) * 0.001)
        if reg_model_weight is not None:
            b_initial_value = torch.from_numpy(reg_model_weight['W_b'][:]).float()
            self.W_b = torch.nn.Parameter(b_initial_value)
        else:
            self.W_b = torch.nn.Parameter(torch.zeros(neurons))

    def forward(self, x):
        # PUT YOUR CODES HERE
        self.fts = x
        conv_flat = torch.reshape(self.fts, (-1, self.px_conv, int(self.channels), 1)) # [batch, 17 * 17, 384]
        W_spatial_flat = torch.reshape(self.W_spatial, [self.neurons, self.px_conv, 1, 1]) # [43, 17 * 17, 1, 1]
        conv_flat = conv_flat.to(device)
        W_spatial_flat = W_spatial_flat.to(device)
        h_spatial = F.conv2d(conv_flat, W_spatial_flat, stride=1, padding=0)
        h_out = torch.sum(torch.mul(h_spatial, self.W_features), dim=[1, 2])

        self.W_sploss = torch.reshape(self.W_spatial, [self.px_y_conv, self.px_x_conv, 1, self.neurons])
        return h_out + self.W_b


def Loss(y, pred, W_s, W_d):
    return mse_loss(y, pred) * mse_weight + \
          l2_norm_regularizer(W_s) * spa_weight + \
          smoothness_regularizer_2d(W_s) * lap_weight + \
          l2_norm_regularizer(W_d) * ch_weight


In [358]:
# image_data
root_dir = '../data/0_presented_images_800/'
resolution = 300
image_path = os.listdir(root_dir)
path_dict = {}
for j in image_path:
    key = int(j.split('_')[0])  # 刺激呈现的顺序是图像名称下划线前面的数字顺序。
    path_dict[key] = j

stim_arr = np.zeros((len(image_path), resolution, resolution, 3))
for i in range(len(image_path)):
    img_bgr = cv2.imread(os.path.join(root_dir, path_dict[i+1]))
    stim_arr[i] = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
stim_arr = stim_arr.astype('float32')

# random repeat
id = h5py.File('../data/1_L76LM_V1_S18_D155_objects/stimuli/Random_id_80_2021_10_21.mat', 'r')

images_n  = np.zeros(shape=(stim_arr.shape[0], 299, 299, 3))
for i in range(stim_arr.shape[0]):
    images_n[i] = cv2.resize(stim_arr[i], (299, 299))

idx = np.array(id['sampleidlist21']).squeeze().astype('int') - 1
print(idx)
idx, unique_idx = np.unique(idx, return_index=True)
print(idx, unique_idx, images_n.shape)

[  4  21  40  54  55  58  58  80  88  99 100 110 113 126 136 151 171 171
 175 183 200 220 271 277 283 287 309 334 356 358 361 367 389 393 400 403
 407 438 442 445 457 458 460 467 470 477 480 481 481 487 490 537 539 539
 555 589 598 600 600 602 604 609 612 615 622 634 644 650 676 679 687 692
 707 726 746 767 778 786 791 794]
[  4  21  40  54  55  58  80  88  99 100 110 113 126 136 151 171 175 183
 200 220 271 277 283 287 309 334 356 358 361 367 389 393 400 403 407 438
 442 445 457 458 460 467 470 477 480 481 487 490 537 539 555 589 598 600
 602 604 609 612 615 622 634 644 650 676 679 687 692 707 726 746 767 778
 786 791 794] [ 0  1  2  3  4  5  7  8  9 10 11 12 13 14 15 16 18 19 20 21 22 23 24 25
 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 49 50
 51 52 54 55 56 57 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
 77 78 79] (800, 299, 299, 3)


In [359]:
# neurons_data for training
mat_file = h5py.File('../data/1_L76LM_V1_S18_D155_objects/celldataS_43_Objects_11_800_80_30_40_trial_mean_normal.mat', 'r')
#[num_repetitions, num_images, num_neurons]
#print(np.array(mat_file['celldataS']).shape)
neural_n = np.transpose(np.array(mat_file['celldataS']), (2, 1, 0)).astype('float16')
neural_n = neural_n[:880, :, :]
print(neural_n.shape)
#12个trials 880张图片（其中80张是重复），114个细胞

n_images = 800
n_neurons = neural_n.shape[2]
size_imags = images_n.shape[0]
print(n_images, n_neurons, images_n.shape)
#encoder = conv_encoder(neurons, sizes, channels).to(device)

(880, 11, 43)
800 43 (800, 299, 299, 3)


In [360]:
#划分训练集和验证集，找到前800张里不重复的作为训练集，取前800张random的和后80张作为验证集，多个trials取平均值
reps = neural_n.shape[1] # trials
print('reps: ', reps)
rand_ind = np.arange(reps)
np.random.shuffle(rand_ind)

data_y_train = np.concatenate((np.delete(neural_n[: 800, :, :], idx, 0), neural_n[880:, :, :]), 0).mean(1)
temp = neural_n
#print('temp shape:', temp.shape, 'idx shape:', idx.shape, 'temp[idx] shape:', temp[idx].shape)
print(temp[:800][idx].shape)
data_y_val = np.concatenate((temp[:800][idx], temp[800:880][unique_idx]), 1)
data_y_val = np.mean(data_y_val, 1)
print(data_y_train.shape)
print(data_y_val.shape)

#
# data_x = images_n[:, np.newaxis].astype(np.float16)
# print('images_n', images_n.shape)
# data_x = data_x / 255 # (640, 1, 299, 299)
# data_x = np.tile(data_x, [1, 3, 1, 1])
# print('data_x', data_x.shape)
# data_x_train = data_x[:576]
# data_x_val = data_x[576:]as indices must be

reps:  11
(75, 11, 43)
(725, 43)
(75, 43)


In [361]:
# image_data for training
print(images_n.shape)
#data_x = images_n[:, np.newaxis].astype(np.float16)
data_x = images_n.astype('float')
imagenet_mean = np.expand_dims(np.array([123.68, 116.779, 103.939]), (0,1,2))
data_x = (data_x - imagenet_mean) / 128.
print(data_x.shape)
#data_x = data_x / 255 # (800, 1, 299, 299)

#data_x = np.tile(data_x, [1, 3, 1, 1])
data_x_train = np.delete(images_n, idx, 0)
data_x_val = images_n[idx]

data_x = np.transpose(data_x, (0, 3, 1, 2))
data_x_train = np.transpose(data_x_train, (0, 3, 1, 2))
data_x_val = np.transpose(data_x_val, (0, 3, 1, 2))
print(data_x.shape, data_x_train.shape, data_x_val.shape)


(800, 299, 299, 3)
(800, 299, 299, 3)
(800, 3, 299, 299) (725, 3, 299, 299) (75, 3, 299, 299)


In [362]:
#设置dataset
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data_x, data_y):
        self.data_x = data_x
        self.data_y = data_y
    def __getitem__(self, index):
        return index, self.data_x[index], self.data_y[index]
    def __len__(self):
        return self.data_x.shape[0]



dataset_train = Dataset(data_x_train, data_y_train)
dataset_val = Dataset(data_x_val, data_y_val)

loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle = True)
loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, shuffle = True)

print(f'val: {data_x_val.shape}, {data_y_val.shape}')
# for i,(x,y) in enumerate(loader_val):
#     print(i, x.shape, y.shape)

val: (75, 3, 299, 299), (75, 43)


In [363]:
alexnet = models.alexnet(pretrained=True)

#
alexnet.to(device)
alexnet.eval()
for param in alexnet.parameters():
    param.requires_grad_(False)

x = torch.from_numpy(data_x[0:1]).float().to(device)
print("x:", x.shape)
fmap = alexnet(x, layer=insp_layer)

neurons = data_y_train.shape[1]
sizes = fmap.shape[2:]
print("fmap: ", fmap.shape)
print("size: ", sizes)
channels = fmap.shape[1]
print(neurons, sizes)
w_s = nn.Parameter(torch.randn(size=(neurons,) + sizes))
print(w_s.shape)


x: torch.Size([1, 3, 299, 299])
fmap:  torch.Size([1, 384, 17, 17])
size:  torch.Size([17, 17])
43 torch.Size([17, 17])
torch.Size([43, 17, 17])


In [364]:
print(sizes)
encoder = conv_encoder(neurons, sizes, channels).to(device)

torch.Size([17, 17])


In [365]:
def train_model(encoder, optimizer):
    losses = []
    encoder.train()
    for i,(z, x,y) in enumerate(loader_train):
        optimizer.zero_grad()
        x = x.float().to(device)
        y = y.float().to(device)
        fmap = alexnet(x,layer = insp_layer)
        out = encoder(fmap) #输出结果
#         print(f'L_e = {l_e} , L_2 = {l_2} , L_l = {l_l}')
        loss = Loss(y, out, encoder.W_sploss, encoder.W_features)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
#         print(f'iteration {i}, train loss: {losses[-1]}')

    return losses

def validate_model(encoder):
    encoder.eval()
    y_pred = []
    y_true = []
    losses = []
    for i,(z, x,y) in enumerate(loader_val):
        x = x.float().to(device)
        y = y.float().to(device)
        fmap = alexnet(x,layer = insp_layer)
        out = encoder(fmap)
        y_pred.append(out)
        y_true.append(y)
        loss = Loss(y, out, encoder.W_sploss, encoder.W_features)
        losses.append(loss.item())
    y_pred = torch.cat(y_pred)
    y_true = torch.cat(y_true)
    y_pred = y_pred.detach().cpu().numpy()
    y_true = y_true.detach().cpu().numpy()
    #print(y_pred.shape, y_true.shape)
    ev = explained_variance_score(y_true, y_pred)
    pcc = torch_pearson(y_pred, y_true)
    return pcc, ev,sum(losses)/len(losses)
    #return explained_variance,sum(losses)/len(losses)

"""
    You need to define the conv_encoder() class and train the encoder.
    The code of alexnet has been slightly modified from the torchvision, for convenience
    of extracting the middle layers.

    Example:
        >>> x = x.to(device) # x is a batch of images
        >>> x = transform(x)
        >>> fmap = alexnet(x, layer=insp_layer)
        >>> out= encoder(fmap)
        >>> ...
"""

'\n    You need to define the conv_encoder() class and train the encoder.\n    The code of alexnet has been slightly modified from the torchvision, for convenience\n    of extracting the middle layers.\n\n    Example:\n        >>> x = x.to(device) # x is a batch of images\n        >>> x = transform(x)\n        >>> fmap = alexnet(x, layer=insp_layer)\n        >>> out= encoder(fmap)\n        >>> ...\n'

In [366]:
# losses_train = []
# losses_val = []
# EVs = []

losses_train = []
losses_val = []
EVs = []
pccs = []

In [367]:
lr = 1e-3
optimizer = torch.optim.Adam(encoder.parameters(), lr=lr)
#optimizer = torch.optim.SGD(encoder.parameters(), lr=0.1)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.1)

In [368]:
epoches = 2600
best_loss = 1e100
not_improve = 0
endure = 10
for epoch in tqdm_notebook(range(epoches)):
    losses_train += train_model(encoder,optimizer)
    pcc, ev,loss = validate_model(encoder)
    #ev,loss = validate_model(encoder)
    EVs.append(ev)
    pccs.append(pcc)
    losses_val.append(loss)
    train_loss = sum(losses_train[-10:])/10
    if train_loss < best_loss - 1e-5:
        not_improve = 0
    else:
        not_improve += 1
    if epoch % 1 == 0:
        print(f'epoch {epoch}, EV = {ev}, val loss = {loss} , train loss {sum(losses_train[-10:])/10}, pcc = {pcc}')
        #print(f'epoch {epoch}, EV = {ev}, val loss = {loss} , train loss {sum(losses_train[-10:])/10}')
    if not_improve == endure:
        print("Early stopping!")

    #scheduler.step()



Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for epoch in tqdm_notebook(range(epoches)):


  0%|          | 0/2600 [00:00<?, ?it/s]

epoch 0, EV = -1.6689602136611938, val loss = 1062.84052734375 , train loss 2139.6330322265626, pcc = 0.01875501498579979
epoch 1, EV = -0.29372459650039673, val loss = 92.51400909423828 , train loss 91.10951766967773, pcc = -0.0008818791247904301
epoch 2, EV = -0.07251837849617004, val loss = 39.53854522705078 , train loss 45.471507453918456, pcc = -0.0051062386482954025
epoch 3, EV = -0.034622158855199814, val loss = 22.58813018798828 , train loss 20.077053165435792, pcc = -0.014521627686917782
epoch 4, EV = -0.020709669217467308, val loss = 17.121028327941893 , train loss 11.994826221466065, pcc = -0.021886995062232018
epoch 5, EV = -0.014937559142708778, val loss = 12.553859901428222 , train loss 9.963322401046753, pcc = -0.02844182588160038
epoch 6, EV = -0.017082292586565018, val loss = 9.915342140197755 , train loss 10.018606424331665, pcc = -0.035698335617780685
epoch 7, EV = -0.014861378818750381, val loss = 8.606915950775146 , train loss 6.853611326217651, pcc = -0.0404434241