In [None]:
import torch
import torch.nn as nn
import random
import numpy as np
import os
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from tqdm import tqdm
from transformers import ViTForImageClassification, ViTConfig


# hyperparameter

In [2]:
SEED = 42
DEVICE = torch.device("cuda:0")
DATA_DIR  = '../train_nucleus_128_with_env_15dis_cell_scale/all/'
BATCH_SIZE = 300
NUM_EPOCHS = 30
PORJECT_NAME = f'Nuspire_mouse_brain_Regression'

In [3]:
def set_seeds(seed_value=42, cuda_deterministic=False):
    """Set seeds for reproducibility."""
    random.seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
    # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
    if cuda_deterministic:  # slower, more reproducible
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:  # faster, less reproducible
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True


In [None]:
set_seeds(SEED)
timestamp = "07"
folder_name = f'./{PORJECT_NAME}_{timestamp}_checkpoint'

if not os.path.exists(folder_name):
    os.makedirs(folder_name)
    # print(f"'{folder_name}'has been created.")
else:
    print(f"'{folder_name}' already exists.")

In [5]:
class ImageDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.file_list = os.listdir(data_dir)
        self.cell_expression = pd.read_csv('../processed_data/cell_expression_filtered_size_allgene.csv', index_col=0)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_name = os.path.join(self.data_dir, self.file_list[idx])
        img_index = img_name.split("/")[-1].replace('image_', '').replace('.png', '')
        image = Image.open(img_name).convert('L')
        if self.transform:
            image = self.transform(image)
        
        if img_index in self.cell_expression.index:
            target = self.cell_expression.loc[img_index].values
        else:
            target = None
        return image, target

In [6]:
transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.21869252622127533], std=[0.1809280514717102])
])

In [7]:
dataset = ImageDataset(DATA_DIR, transform=transform)

In [None]:
total_size = len(dataset)
train_size = int(total_size * 0.8)
remaining_size = total_size - train_size

valid_size = int(remaining_size * 0.5)
test_size = remaining_size - valid_size

indices = list(range(total_size))
np.random.shuffle(indices)

train_indices = indices[:train_size]
remaining_indices = indices[train_size:]
valid_indices = remaining_indices[:valid_size]
test_indices = remaining_indices[valid_size:]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)
test_sampler = SubsetRandomSampler(test_indices)

train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers=4)
valid_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=valid_sampler, num_workers=4)
test_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=test_sampler, num_workers=4)

# print(train_indices)
# print(valid_indices)
# print(test_indices)

# model

In [9]:
config = ViTConfig.from_pretrained("/mnt/Storage/home/huayuwei/container_workspace/spCS/2.result/0.pretrain_model/V5/epoch69")

config.hidden_dropout_prob = 0
config.attention_probs_dropout_prob = 0
config.num_labels = 347

model = ViTForImageClassification.from_pretrained(
    "/mnt/Storage/home/huayuwei/container_workspace/spCS/2.result/0.pretrain_model/V5/epoch69",
    config=config
)

You are using a model of type vit_mae to instantiate a model of type vit. This is not supported for all configurations of models and can yield errors.
Some weights of ViTForImageClassification were not initialized from the model checkpoint at /mnt/Storage/home/huayuwei/container_workspace/spCS/2.result/0.pretrain_model/V5/epoch69 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
model.to(DEVICE)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(1, 768, kernel_size=(8, 8), stride=(8, 8))
      )
      (dropout): Dropout(p=0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_fe

# Training

In [11]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)
criterion = nn.MSELoss()

In [12]:
writer = SummaryWriter(f'./tensorboard/{PORJECT_NAME}_{timestamp}')
step1 = 0
step2 = 0
best_val_loss = 1

for epoch in range(NUM_EPOCHS):
    print(f"Epoch: {epoch+1}/{NUM_EPOCHS}")
    model.train()
    loss_list = []
    for i, (x,l) in tqdm(enumerate(train_loader), total=len(train_loader)):
        x = x.to(DEVICE)
        l = l.to(DEVICE)
        
        optimizer.zero_grad()
        
        outputs = model(x)
        
        loss = criterion(outputs.logits, l.float())
        
        writer.add_scalar("Step/Train Loss", loss.item(),step1)
        loss_list.append(loss.item())
        
        step1+=1
        loss.backward()
        optimizer.step()
    train_loss = np.mean(loss_list)

    model.eval()
    loss_list = []
    with torch.no_grad():
         for i, (x,l) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
            x = x.to(DEVICE)
            l = l.to(DEVICE)

            optimizer.zero_grad()

            outputs = model(x)

            loss = criterion(outputs.logits, l.float())
            
            writer.add_scalar("Step/Validation Loss", loss.item(),step2)

            loss_list.append(loss.item())
            step2+=1
    val_loss = np.mean(loss_list)
    
    # Save the model if the validation loss is the best we've seen so far.
    if val_loss < best_val_loss:
        torch.save(model.state_dict(), f'{folder_name}/{PORJECT_NAME}_best_model.pt')
        model.save_pretrained(f'{folder_name}/{PORJECT_NAME}_best_model')
        best_epoch=epoch
        best_val_loss = val_loss

    lr = optimizer.param_groups[0]['lr']
    writer.add_scalar("Epoch/Lr", lr, epoch)
    writer.add_scalars("Epoch/Loss",{'Train Loss':train_loss,'Validation Loss':val_loss},epoch)
    print(f"Epoch {epoch}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")


Epoch: 1/30


100%|██████████| 143/143 [04:57<00:00,  2.08s/it]
100%|██████████| 18/18 [00:13<00:00,  1.37it/s]


Epoch 0, Train Loss: 0.1923, Validation Loss: 0.1672
Epoch: 2/30


100%|██████████| 143/143 [04:59<00:00,  2.10s/it]
100%|██████████| 18/18 [00:13<00:00,  1.38it/s]


Epoch 1, Train Loss: 0.1617, Validation Loss: 0.1588
Epoch: 3/30


100%|██████████| 143/143 [04:59<00:00,  2.10s/it]
100%|██████████| 18/18 [00:13<00:00,  1.37it/s]


Epoch 2, Train Loss: 0.1528, Validation Loss: 0.1526
Epoch: 4/30


100%|██████████| 143/143 [05:06<00:00,  2.14s/it]
100%|██████████| 18/18 [00:13<00:00,  1.33it/s]


Epoch 3, Train Loss: 0.1480, Validation Loss: 0.1482
Epoch: 5/30


100%|██████████| 143/143 [05:07<00:00,  2.15s/it]
100%|██████████| 18/18 [00:13<00:00,  1.34it/s]


Epoch 4, Train Loss: 0.1445, Validation Loss: 0.1473
Epoch: 6/30


100%|██████████| 143/143 [05:02<00:00,  2.12s/it]
100%|██████████| 18/18 [00:13<00:00,  1.37it/s]


Epoch 5, Train Loss: 0.1421, Validation Loss: 0.1455
Epoch: 7/30


100%|██████████| 143/143 [07:31<00:00,  3.16s/it]
100%|██████████| 18/18 [00:13<00:00,  1.38it/s]


Epoch 6, Train Loss: 0.1400, Validation Loss: 0.1428
Epoch: 8/30


100%|██████████| 143/143 [05:02<00:00,  2.12s/it]
100%|██████████| 18/18 [00:13<00:00,  1.38it/s]


Epoch 7, Train Loss: 0.1380, Validation Loss: 0.1426
Epoch: 9/30


100%|██████████| 143/143 [04:55<00:00,  2.06s/it]
100%|██████████| 18/18 [00:12<00:00,  1.41it/s]

Epoch 8, Train Loss: 0.1366, Validation Loss: 0.1429
Epoch: 10/30



100%|██████████| 143/143 [04:53<00:00,  2.05s/it]
100%|██████████| 18/18 [00:13<00:00,  1.38it/s]


Epoch 9, Train Loss: 0.1352, Validation Loss: 0.1408
Epoch: 11/30


100%|██████████| 143/143 [04:57<00:00,  2.08s/it]
100%|██████████| 18/18 [00:13<00:00,  1.33it/s]

Epoch 10, Train Loss: 0.1340, Validation Loss: 0.1419
Epoch: 12/30



100%|██████████| 143/143 [04:59<00:00,  2.09s/it]
100%|██████████| 18/18 [00:13<00:00,  1.38it/s]

Epoch 11, Train Loss: 0.1325, Validation Loss: 0.1418
Epoch: 13/30



100%|██████████| 143/143 [04:59<00:00,  2.09s/it]
100%|██████████| 18/18 [00:13<00:00,  1.38it/s]


Epoch 12, Train Loss: 0.1314, Validation Loss: 0.1403
Epoch: 14/30


100%|██████████| 143/143 [05:00<00:00,  2.10s/it]
100%|██████████| 18/18 [00:13<00:00,  1.37it/s]


Epoch 13, Train Loss: 0.1304, Validation Loss: 0.1395
Epoch: 15/30


100%|██████████| 143/143 [05:00<00:00,  2.10s/it]
100%|██████████| 18/18 [00:13<00:00,  1.37it/s]


Epoch 14, Train Loss: 0.1289, Validation Loss: 0.1393
Epoch: 16/30


100%|██████████| 143/143 [05:31<00:00,  2.32s/it]
100%|██████████| 18/18 [00:15<00:00,  1.17it/s]

Epoch 15, Train Loss: 0.1279, Validation Loss: 0.1398
Epoch: 17/30



100%|██████████| 143/143 [06:02<00:00,  2.53s/it]
100%|██████████| 18/18 [00:15<00:00,  1.16it/s]

Epoch 16, Train Loss: 0.1265, Validation Loss: 0.1394
Epoch: 18/30



100%|██████████| 143/143 [06:01<00:00,  2.53s/it]
100%|██████████| 18/18 [00:15<00:00,  1.15it/s]

Epoch 17, Train Loss: 0.1254, Validation Loss: 0.1403
Epoch: 19/30



100%|██████████| 143/143 [06:03<00:00,  2.54s/it]
100%|██████████| 18/18 [00:15<00:00,  1.17it/s]


Epoch 18, Train Loss: 0.1241, Validation Loss: 0.1389
Epoch: 20/30


100%|██████████| 143/143 [06:07<00:00,  2.57s/it]
100%|██████████| 18/18 [00:15<00:00,  1.13it/s]

Epoch 19, Train Loss: 0.1234, Validation Loss: 0.1399
Epoch: 21/30



100%|██████████| 143/143 [06:13<00:00,  2.61s/it]
100%|██████████| 18/18 [00:16<00:00,  1.10it/s]

Epoch 20, Train Loss: 0.1225, Validation Loss: 0.1407
Epoch: 22/30



100%|██████████| 143/143 [06:13<00:00,  2.62s/it]
100%|██████████| 18/18 [00:15<00:00,  1.14it/s]

Epoch 21, Train Loss: 0.1213, Validation Loss: 0.1395
Epoch: 23/30



100%|██████████| 143/143 [06:15<00:00,  2.62s/it]
100%|██████████| 18/18 [00:16<00:00,  1.12it/s]

Epoch 22, Train Loss: 0.1203, Validation Loss: 0.1399
Epoch: 24/30



100%|██████████| 143/143 [06:14<00:00,  2.62s/it]
100%|██████████| 18/18 [00:15<00:00,  1.14it/s]

Epoch 23, Train Loss: 0.1190, Validation Loss: 0.1400
Epoch: 25/30



100%|██████████| 143/143 [06:13<00:00,  2.61s/it]
100%|██████████| 18/18 [00:15<00:00,  1.14it/s]

Epoch 24, Train Loss: 0.1182, Validation Loss: 0.1413
Epoch: 26/30



100%|██████████| 143/143 [05:39<00:00,  2.37s/it]
100%|██████████| 18/18 [00:13<00:00,  1.34it/s]

Epoch 25, Train Loss: 0.1173, Validation Loss: 0.1407
Epoch: 27/30



100%|██████████| 143/143 [05:08<00:00,  2.16s/it]
100%|██████████| 18/18 [00:13<00:00,  1.37it/s]

Epoch 26, Train Loss: 0.1163, Validation Loss: 0.1411
Epoch: 28/30



100%|██████████| 143/143 [10:22<00:00,  4.35s/it]
100%|██████████| 18/18 [00:32<00:00,  1.78s/it]

Epoch 27, Train Loss: 0.1156, Validation Loss: 0.1412
Epoch: 29/30



100%|██████████| 143/143 [09:45<00:00,  4.09s/it]
100%|██████████| 18/18 [00:16<00:00,  1.09it/s]

Epoch 28, Train Loss: 0.1144, Validation Loss: 0.1416
Epoch: 30/30



100%|██████████| 143/143 [06:34<00:00,  2.76s/it]
100%|██████████| 18/18 [00:15<00:00,  1.14it/s]

Epoch 29, Train Loss: 0.1134, Validation Loss: 0.1412





# Test

In [13]:
transform = transforms.Compose([
    transforms.Resize((112, 112)),
    # transforms.RandomHorizontalFlip(p=0.5),
    # transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.21869252622127533], std=[0.1809280514717102])
])

In [14]:
dataset = ImageDataset(DATA_DIR, transform=transform)

In [15]:
test_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=test_sampler, num_workers=4)

In [16]:
model_path = f'{folder_name}/{PORJECT_NAME}_best_model.pt'
model.load_state_dict(torch.load(model_path))
model.to(DEVICE) 

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(1, 768, kernel_size=(8, 8), stride=(8, 8))
      )
      (dropout): Dropout(p=0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_fe

In [17]:
model.eval()
true_labels = []
predicted_outputs = []

with torch.no_grad():
    for i, (x, l) in tqdm(enumerate(test_loader), total=len(test_loader)):
        x = x.to(DEVICE)
        l = l.to(DEVICE)

        outputs = model(x)

        # Collect true labels and predicted outputs
        true_labels.append(l.cpu())
        predicted_outputs.append(outputs.logits.cpu())
        
    true_labels = torch.cat(true_labels).numpy()
    predicted_outputs = torch.cat(predicted_outputs).numpy()  

100%|██████████| 18/18 [00:15<00:00,  1.14it/s]


In [18]:
np.save(f'{PORJECT_NAME}_{timestamp}_all_outputs.npy', predicted_outputs)
np.save(f'{PORJECT_NAME}_{timestamp}_all_targets.npy', true_labels)

In [19]:
from scipy.stats import pearsonr
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, mean_absolute_percentage_error, explained_variance_score

n_samples, n_features = true_labels.shape

results = {metric: [] for metric in ['MSE',
                                    #  'RMSE',
                                    #  'MAE', 
                                    #  'MAPE', 
                                    #  'R_squared', 
                                    #  'Explained_Variance',
                                     'Pearson']}

for i in range(n_samples):
    mse = mean_squared_error(true_labels[i, :], predicted_outputs[i, :])
    # rmse = np.sqrt(mse)
    # mae = mean_absolute_error(true_labels[i, :], predicted_outputs[i, :])
    # mape = mean_absolute_percentage_error(true_labels[i, :], predicted_outputs[i, :])
    # r2 = r2_score(true_labels[i, :], predicted_outputs[i, :])
    # explained_var = explained_variance_score(true_labels[i, :], predicted_outputs[i, :])
    pcc, _ = pearsonr(true_labels[i, :], predicted_outputs[i, :])

    results['MSE'].append(mse)
    # results['RMSE'].append(rmse)
    # results['MAE'].append(mae)
    # results['MAPE'].append(mape)
    # results['R_squared'].append(r2)
    # results['Explained_Variance'].append(explained_var)
    results['Pearson'].append(pcc)

for metric in results:
    results[metric] = np.array(results[metric])

for metric in results:
    print(f"{metric}: {results[metric]}")

for metric in results:
    print(f"{metric} - Mean: {np.mean(results[metric]):.4f}, Std: {np.std(results[metric]):.4f}")

MSE: [0.13185952 0.09007648 0.16135108 ... 0.12363786 0.18178999 0.25879715]
Pearson: [0.77104155 0.82377504 0.70091003 ... 0.77836889 0.64078275 0.49941442]
MSE - Mean: 0.1386, Std: 0.0514
Pearson - Mean: 0.7380, Std: 0.1021
