In [1]:
import os
import math
import random
from sklearn.metrics import balanced_accuracy_score

import copy
from functools import wraps, partial

import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

from torch.optim.lr_scheduler import StepLR
import torch.backends.cudnn as cudnn

from einops import repeat, rearrange 
from einops.layers.torch import Rearrange

import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
path = "/data2/patho-vit_5_23_ccsurv/patient/low"
# 此处patho_vit后的_5_23为包的版本号
# 若此jupyternotebook运行中kernel挂掉，重启后仅需运行此一代码块，然后跳到需要运行的代码块即可

In [3]:
seed = 42

def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [4]:
from collections import OrderedDict
from patho_vit.vit_luci import ViT

In [5]:
batch_size = 4
# 取决于gpu数量，原始文献为50，若报错CUDA OUT OF MEMORY,则减少batch_size个数。此外，需关闭kernel重启，只重启机场1代码即可。
num_workers = 0
# 取决于cpu数量，原始文献此处为4，可从0逐渐增加

epochs = 10000
# 原始文献中默认值为50
test_every = 1
# 多少次训练后，进行一次验证。原始文献中默认值为10

In [6]:
vit = ViT(
    image_size = 96,
    patch_size = 2, 
    channels = 768,
    dim = 768 * 2,
    depth = 35,
    heads = 12 * 2,
    mlp_dim = 768 * 2 * 4,
    num_classes = 2
)

In [7]:
#img =torch.randn(1, 768, 96, 96)
#from thop import profile
#flops, params =                  profile(vit, inputs =                               (img, ))  
#params

In [8]:
#for params in vit.parameters():
#    params.requires_grad = False

#vit.mlp_head = nn.Sequential(
#            nn.Linear((48 * 48 + 1) * 768 * 2, 512),
#            nn.GELU(),
#            nn.Linear(512, 512),
#            nn.GELU(),
#            nn.Linear(512, 512),
#            nn.LayerNorm(512),
#            nn.Linear(512, 3)
#        )

In [9]:
# 导入定制vit，加载目标编码器的权重
#weights = torch.load("/data2/patho-vit_5_23_oc/up/13/gpvit_weight_6_24_epoch_1.pt")

weights = torch.load("/data2/patho-vit_5_23_ccsurv/patient/low/57/gpvit_weight_12_23_epoch_2.pt")
#new_dict = OrderedDict()
#for k, v in weights.items():
#    if "module.target_encoder" not in k:
#        new_key = k[7:]
#        new_dict[new_key] = v

#jepa = Jepa()
#jepa.to(device)
#jepa.load_state_dict(new_dict, strict = False)
#jepa = nn.DataParallel(jepa)
new_dict = OrderedDict()
for k, v in weights.items():
    if "module.target_encoder" in k:
        new_key = k[22:]
        new_dict[new_key] = v

vit.to(device)
vit.load_state_dict(new_dict, strict = False)
vit = nn.DataParallel(vit)
##vit.to(device)
#vit = nn.DataParallel(vit)
#weights = torch.load("/data2/patho-vit_5_23_ccim/low/57/checkpoint_2classes_1epoch_8.pth")
#vit.load_state_dict(weights, strict = True)

In [10]:
class layer3dataset(torch.utils.data.Dataset):
    # 为第二层重构的图制作数据集，大小为（96，96，768）
    def __init__(self, libraryfile = "", transform=None, subsample=-1):
        file = pd.read_csv(libraryfile)
        # 原始表格一个病人一行，一共三列，第一列为病理号，
        # 每行的第二列为一个列表，其中是新图片的路径，
        # 第三列为标签值
        self.pathid = file["pathid"]
        self.images = file["images"]
        self.labels = file["labels"]
        
        self.subsample = subsample
        self.transform = transform
    
    def __getitem__(self, index):
        image = torch.load(self.images[index])
        image = np.array(image)
        
        if self.subsample != -1 and self.transform is not None:
            
            image = self.transform(image)
            image = image.unsqueeze(0)
            image = F.interpolate(image, size =(self.subsample, self.subsample))
            image = image.squeeze(0)
            image = image.to(torch.float32)
        
        label = self.labels[index]
        pathid = self.pathid[index]
        return image, label
    
    def __len__(self):
        return len(self.pathid)

In [11]:
normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406), 
                             std=(0.229, 0.224, 0.225))
#train_lib_0 = "/data2/patho-vit_5_23_ccim/low/files_label/store_4_lung_train.db"
train_lib_1 = "/data2/patho-vit_5_23_ccsurv/patient/low/files_label/store_5_96_train.csv"
valid_lib = "/data2/patho-vit_5_23_ccsurv/patient/low/files_label/store_5_96_valid.csv"
valid_lib_1 = "/data2/patho-vit_5_23_ccsurv/patient/low/files_label/store_5_96_exva.csv"


transform = transforms.ToTensor()

train_dataset_1 = layer3dataset(train_lib_1, transform = transform, subsample = 96)
train_loader = torch.utils.data.DataLoader(
    train_dataset_1,
    batch_size = batch_size, shuffle = True,
    num_workers = num_workers
)

valid_dataset = layer3dataset(valid_lib, transform = transform, subsample = 96)
valid_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size = batch_size, shuffle = False,
    num_workers = num_workers
)

valid_dataset_1 = layer3dataset(valid_lib_1, transform = transform, subsample = 96)
valid_loader_1 = torch.utils.data.DataLoader(
    valid_dataset_1,
    batch_size = batch_size, shuffle = False,
    num_workers = num_workers
)

In [12]:
cudnn.benchmark = True

# 定义丢失函数和优化器
#w = torch.Tensor([0.22, 0.365, 0.415])  
w = torch.Tensor([0.1,0.9]) 
# 设定阴性和阳性的惩罚比例，这里是胡乱取的数
criterion = nn.CrossEntropyLoss(w).to(device)
optimizer = torch.optim.AdamW(vit.parameters(), lr=3e-4, weight_decay = 0.05, eps = 1e-4, betas = [0.9, 0.95])
#scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
# 原始文献lr=1e-4，
# lucidrains为3e-5，且使用了scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

# 生成输出文件
fconv = open(os.path.join("/data2/patho-vit_5_23_ccsurv/patient/low/8/convergence.csv"), "w")
fconv.write("epoch, metric, value\n")
fconv.close()

# 正式训练和验证
# 此处杀bug的重要心得：若出现TypeError: cannot convert the series to <class 'float'>，
# 表示生成的pathids_train_1024.db，或pathids_test_1024.db里有空缺的标签值，可通过
# 同名的xlsx进行查看。此处发现“202017947.A3.bif”切片未生成标签值，遂直接删除.
#best_or = 1
best_ba = 0
#global best_or
global best_ba

#total_step_0 = len(train_loader_0)
total_step_1 =  len(train_loader)
# 开始训练和验证
for epoch in range(epochs):
    torch.cuda.empty_cache()
    epoch_accuracy_lung = 0
    epoch_loss_lung = 0
    epoch_accuracy = 0
    epoch_loss = 0
    
    #for i, (images, labels) in enumerate(train_loader_0):
        #torch.cuda.empty_cache()
    #    images = images.to(device)
    #    labels = labels.to(device)
    #    outputs, _ = vit(images)
    #    loss_lung = criterion(outputs, labels.long())
        
    #    optimizer.zero_grad()
    #    loss_lung.backward()
    #    optimizer.step()
        
    #    acc_lung = (outputs.argmax(dim=-1) == labels).float().mean()
    #    epoch_accuracy_lung += acc_lung / len(train_loader_0)
    #    epoch_loss_lung += loss_lung / len(train_loader_0)
        
    #    if (i + 1) % 2558 == 0:
        #if (i + 1) % 5230 == 0:
    #        print("Epoch [{}/{}], Step [{}/{}] Loss_lung: {:.10f}"
    #             .format(epoch+1, epochs, i+1, total_step_0, loss_lung))
            
    for i, (images, labels) in enumerate(train_loader):
        #torch.cuda.empty_cache()
        images = images.to(device)
        labels = labels.to(device)
        outputs, _ = vit(images)
        loss = criterion(outputs, labels.long())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        acc = (outputs.argmax(dim=-1) == labels).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)
        
        if (i + 1) % 186 == 0:
        #if (i + 1) % 5230 == 0:
            print("Epoch [{}/{}], Step [{}/{}] Loss: {:.10f}"
                 .format(epoch+1, epochs, i+1, total_step_1, loss))
    
    print("Epoch [{}/{}], Loss_cc: {:.4f}, Acc_cc: {:.4f}".format(epoch+1, epochs, epoch_loss, epoch_accuracy))
    
    fconv = open(os.path.join("/data2/patho-vit_5_23_ccsurv/patient/low/8/convergence.csv"), "a")
    #fconv.write("{}, loss_lung, {:.4f}\n".format(epoch+1, epoch_loss_lung))
    #fconv.write("{}, acc_lung, {:.4f}\n".format(epoch+1, epoch_accuracy_lung))
    fconv.write("{}, loss, {:.4f}\n".format(epoch+1, epoch_loss))
    fconv.write("{}, acc, {:.4f}\n".format(epoch+1, epoch_accuracy))
    fconv.close()
        
    if (epoch+1) % test_every == 0:
        with torch.no_grad():
            epoch_val_accuracy = 0
            epoch_val_loss = 0
            
            pred = []
            real = []
            
            for images, labels in valid_loader:
                #torch.cuda.empty_cache()
                images = images.to(device)
                labels = labels.to(device)
                outputs, _ = vit(images)
                val_loss = criterion(outputs, labels.long())
                
                #outputs = outputs.softmax(dim = -1)
                outputs = outputs.argmax(dim = -1)
                
                acc = (outputs == labels).float().mean()
                epoch_val_accuracy += acc / len(valid_loader)
                epoch_val_loss += val_loss / len(valid_loader)
            
                pred.extend(list(outputs.cpu().numpy()))
                real.extend(list(labels.cpu().numpy()))
                
            #df_pred = pd.DataFrame(pred)
            #df_pred.to_csv("output_downstream/pred_real/pred_.csv")
            
            #df_real = pd.DataFrame(real)
            #df_real.to_csv("output_downstream/real.csv")
            
            pred = np.array(pred)
            real = np.array(real)
            ba = balanced_accuracy_score(real, pred)
            
            # del pred real
            # gc.collect()
            
            c={"prediction": pred, "labels": real}
            df_c = pd.DataFrame(c)
            df_c.to_csv("/data2/patho-vit_5_23_ccsurv/patient/low/8/pred_real_inval/pred_{}.csv".format(epoch+1))
            
            
            #neq = np.not_equal(pred, real)
            #acc = 1 - float(neq.sum()) / pred.shape[0]
            
            eq = np.equal(pred, real)
            sensi = float(np.logical_and(pred==1, eq).sum()) / (real==1).sum()
            speci = float(np.logical_and(pred==0, eq).sum()) / (real==0).sum()
            
            #fpr = float(np.logical_and(pred==1, neq).sum()) / (real==0).sum()
            #fnr = float(np.logical_and(pred==0, neq).sum()) / (real==1).sum()
            
            #odds_ratio = (float(np.logical_and(pred==1, eq).sum()) * float(np.logical_and(pred==0, eq).sum())) \
            #            / (float(np.logical_and(pred==0, neq).sum()) * float(np.logical_and(pred==1, neq).sum()) + 1e-9)
            
            fconv = open(os.path.join("/data2/patho-vit_5_23_ccsurv/patient/low/8/convergence.csv"), "a")
            fconv.write("{}, epoch_val_loss, {:.4f}\n".format(epoch+1, epoch_val_loss))
            fconv.write("{}, val_acc, {:.4f}\n".format(epoch+1, epoch_val_accuracy))
            fconv.write("{}, ba, {:.4f}\n".format(epoch+1, ba))
            fconv.write("{}, sensi, {:.4f}\n".format(epoch+1, sensi))
            fconv.write("{}, speci, {:.4f}\n".format(epoch+1, speci))
            #fconv.write("{}, odds_ratio, {:.4f}\n".format(epoch+1, odds_ratio))
            #fconv.write("{}, fpr, {}\n".format(epoch+1, fpr))
            #fconv.write("{}, fnr, {}\n".format(epoch+1, fnr))
            fconv.close()


            epoch_test_accuracy = 0
            #epoch_test_loss = 0
            
            pred_test = []
            real_test = []
            
            for images, labels in valid_loader_1:
                #torch.cuda.empty_cache()
                images = images.to(device)
                labels = labels.to(device)
                outputs, _ = vit(images)
                #test_loss = criterion(outputs, labels.long())
                
                #outputs = outputs.softmax(dim = -1)
                outputs = outputs.argmax(dim = -1)
                
                test_acc = (outputs == labels).float().mean()
                epoch_test_accuracy += test_acc / len(valid_loader_1)
                #epoch_test_loss += val_loss / len(valid_loader_1)
            
                pred_test.extend(list(outputs.cpu().numpy()))
                real_test.extend(list(labels.cpu().numpy()))
                
            #df_pred = pd.DataFrame(pred)
            #df_pred.to_csv("output_downstream/pred_real/pred_.csv")
            
            #df_real = pd.DataFrame(real)
            #df_real.to_csv("output_downstream/real.csv")
            
            pred_test = np.array(pred_test)
            real_test = np.array(real_test)
            ba_test = balanced_accuracy_score(real_test, pred_test)
            
            # del pred real
            # gc.collect()
            
            c6={"prediction": pred_test, "labels": real_test}
            df_c6 = pd.DataFrame(c6)
            df_c6.to_csv("/data2/patho-vit_5_23_ccsurv/patient/low/8/pred_real_exval/pred_{}.csv".format(epoch+1))
            
            #neq = np.not_equal(pred, real)
            #acc = 1 - float(neq.sum()) / pred.shape[0]
            
            eq_test = np.equal(pred_test, real_test)
            sensi_test = float(np.logical_and(pred_test==1, eq_test).sum()) / (real_test==1).sum()
            speci_test = float(np.logical_and(pred_test==0, eq_test).sum()) / (real_test==0).sum()
            
            #fpr = float(np.logical_and(pred==1, neq).sum()) / (real==0).sum()
            #fnr = float(np.logical_and(pred==0, neq).sum()) / (real==1).sum()
            
            #odds_ratio = (float(np.logical_and(pred==1, eq).sum()) * float(np.logical_and(pred==0, eq).sum())) \
            #            / (float(np.logical_and(pred==0, neq).sum()) * float(np.logical_and(pred==1, neq).sum()) + 1e-9)
            
            fconv = open(os.path.join("/data2/patho-vit_5_23_ccsurv/patient/low/8/convergence.csv"), "a")
            #fconv.write("{}, epoch_val_loss, {:.4f}\n".format(epoch+1, epoch_val_loss))
            fconv.write("{}, test_acc, {:.4f}\n".format(epoch+1, epoch_test_accuracy))
            fconv.write("{}, ba_test, {:.4f}\n".format(epoch+1, ba_test))
            fconv.write("{}, sensi_test, {:.4f}\n".format(epoch+1, sensi_test))
            fconv.write("{}, speci_test, {:.4f}\n".format(epoch+1, speci_test))
            #fconv.write("{}, odds_ratio, {:.4f}\n".format(epoch+1, odds_ratio))
            #fconv.write("{}, fpr, {}\n".format(epoch+1, fpr))
            #fconv.write("{}, fnr, {}\n".format(epoch+1, fnr))
            fconv.close()
            
    #torch.save(vit.state_dict(), os.path.join("/data2/patho-vit_5_23_oc/low/4/7.17.2/checkpoint_3classes_{}.pth".format(epoch+1)))
        
    #if ba >= best_ba:
        # 或用 if ba >= best_ba:
        #    best_or = odds_ratio
    #    best_ba = ba
        
    torch.save(vit.state_dict(), os.path.join("/data2/patho-vit_5_23_ccsurv/patient/low/8/checkpoint_2classes_{}.pth".format(epoch+1)))

    print(
        f"Epoch : {epoch+1} - val_loss: {epoch_val_loss:.4f}, val_acc: {epoch_val_accuracy:.4f} - ba: {ba:.4f}\n" 
    )
    print(
        f"Epoch : {epoch+1} - test_acc: {epoch_test_accuracy:.4f} - ba_test: {ba_test:.4f}\n" 
    )

Epoch [1/10000], Step [186/359] Loss: 0.0753776804
Epoch [1/10000], Loss_cc: 0.6925, Acc_cc: 0.8210
Epoch : 1 - val_loss: 0.6543, val_acc: 0.8333 - ba: 0.5000

Epoch : 1 - test_acc: 0.9127 - ba_test: 0.5120

Epoch [2/10000], Step [186/359] Loss: 0.5931120515
Epoch [2/10000], Loss_cc: 0.6789, Acc_cc: 0.8043
Epoch : 2 - val_loss: 0.6286, val_acc: 0.8301 - ba: 0.4980

Epoch : 2 - test_acc: 0.9245 - ba_test: 0.5000

Epoch [3/10000], Step [186/359] Loss: 0.3658606112
Epoch [3/10000], Loss_cc: 0.5671, Acc_cc: 0.8691
Epoch : 3 - val_loss: 0.6236, val_acc: 0.7059 - ba: 0.5529

Epoch : 3 - test_acc: 0.7343 - ba_test: 0.5254

Epoch [4/10000], Step [186/359] Loss: 0.0651621968
Epoch [4/10000], Loss_cc: 0.5495, Acc_cc: 0.8712
Epoch : 4 - val_loss: 0.7507, val_acc: 0.8301 - ba: 0.4980

Epoch : 4 - test_acc: 0.9245 - ba_test: 0.5000

Epoch [5/10000], Step [186/359] Loss: 0.1076546088
Epoch [5/10000], Loss_cc: 0.4995, Acc_cc: 0.8802
Epoch : 5 - val_loss: 0.6480, val_acc: 0.7859 - ba: 0.5382

Epoch : 

KeyboardInterrupt: 